diff options
Diffstat (limited to 'Lib/test/test_socket.py')
-rw-r--r-- | Lib/test/test_socket.py | 426 |
1 files changed, 292 insertions, 134 deletions
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index b412386c439..cf45b7346ca 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -20,6 +20,8 @@ import signal import math import pickle import struct +import random +import string try: import multiprocessing except ImportError: @@ -76,7 +78,7 @@ class SocketTCPTest(unittest.TestCase): def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.port = support.bind_port(self.serv) - self.serv.listen(1) + self.serv.listen() def tearDown(self): self.serv.close() @@ -445,7 +447,7 @@ class SocketListeningTestMixin(SocketTestBase): def setUp(self): super().setUp() - self.serv.listen(1) + self.serv.listen() class ThreadedSocketTestMixin(ThreadSafeCleanupTestCase, SocketTestBase, @@ -716,11 +718,11 @@ class GeneralModuleTests(unittest.TestCase): with self.assertRaises(TypeError) as cm: s.sendto('\u2620', sockname) self.assertEqual(str(cm.exception), - "'str' does not support the buffer interface") + "a bytes-like object is required, not 'str'") with self.assertRaises(TypeError) as cm: s.sendto(5j, sockname) self.assertEqual(str(cm.exception), - "'complex' does not support the buffer interface") + "a bytes-like object is required, not 'complex'") with self.assertRaises(TypeError) as cm: s.sendto(b'foo', None) self.assertIn('not NoneType',str(cm.exception)) @@ -728,11 +730,11 @@ class GeneralModuleTests(unittest.TestCase): with self.assertRaises(TypeError) as cm: s.sendto('\u2620', 0, sockname) self.assertEqual(str(cm.exception), - "'str' does not support the buffer interface") + "a bytes-like object is required, not 'str'") with self.assertRaises(TypeError) as cm: s.sendto(5j, 0, sockname) self.assertEqual(str(cm.exception), - "'complex' does not support the buffer interface") + "a bytes-like object is required, not 'complex'") with self.assertRaises(TypeError) as cm: s.sendto(b'foo', 0, None) self.assertIn('not NoneType', str(cm.exception)) @@ -1378,10 +1380,13 @@ class GeneralModuleTests(unittest.TestCase): def test_listen_backlog(self): for backlog in 0, -1: - srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: + srv.bind((HOST, 0)) + srv.listen(backlog) + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: srv.bind((HOST, 0)) - srv.listen(backlog) - srv.close() + srv.listen() @support.cpython_only def test_listen_backlog_overflow(self): @@ -3585,7 +3590,7 @@ class InterruptedTimeoutBase(unittest.TestCase): def setUp(self): super().setUp() orig_alrm_handler = signal.signal(signal.SIGALRM, - lambda signum, frame: None) + lambda signum, frame: 1 / 0) self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler) self.addCleanup(self.setAlarm, 0) @@ -3622,13 +3627,11 @@ class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase): self.serv.settimeout(self.timeout) def checkInterruptedRecv(self, func, *args, **kwargs): - # Check that func(*args, **kwargs) raises OSError with an + # Check that func(*args, **kwargs) raises # errno of EINTR when interrupted by a signal. self.setAlarm(self.alarm_time) - with self.assertRaises(OSError) as cm: + with self.assertRaises(ZeroDivisionError) as cm: func(*args, **kwargs) - self.assertNotIsInstance(cm.exception, socket.timeout) - self.assertEqual(cm.exception.errno, errno.EINTR) def testInterruptedRecvTimeout(self): self.checkInterruptedRecv(self.serv.recv, 1024) @@ -3684,12 +3687,10 @@ class InterruptedSendTimeoutTest(InterruptedTimeoutBase, # Check that func(*args, **kwargs), run in a loop, raises # OSError with an errno of EINTR when interrupted by a # signal. - with self.assertRaises(OSError) as cm: + with self.assertRaises(ZeroDivisionError) as cm: while True: self.setAlarm(self.alarm_time) func(*args, **kwargs) - self.assertNotIsInstance(cm.exception, socket.timeout) - self.assertEqual(cm.exception.errno, errno.EINTR) # Issue #12958: The following tests have problems on OS X prior to 10.7 @support.requires_mac_ver(10, 7) @@ -3731,8 +3732,6 @@ class TCPCloserTest(ThreadedTCPSocketTest): self.cli.connect((HOST, self.port)) time.sleep(1.0) -@unittest.skipUnless(hasattr(socket, 'socketpair'), - 'test needs socket.socketpair()') @unittest.skipUnless(thread, 'Threading required for this test.') class BasicSocketPairTest(SocketPairTest): @@ -3813,7 +3812,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK) self.port = support.bind_port(self.serv) - self.serv.listen(1) + self.serv.listen() # actual testing start = time.time() try: @@ -4059,117 +4058,6 @@ class FileObjectClassTestCase(SocketConnectedTest): pass -class FileObjectInterruptedTestCase(unittest.TestCase): - """Test that the file object correctly handles EINTR internally.""" - - class MockSocket(object): - def __init__(self, recv_funcs=()): - # A generator that returns callables that we'll call for each - # call to recv(). - self._recv_step = iter(recv_funcs) - - def recv_into(self, buffer): - data = next(self._recv_step)() - assert len(buffer) >= len(data) - buffer[:len(data)] = data - return len(data) - - def _decref_socketios(self): - pass - - def _textiowrap_for_test(self, buffering=-1): - raw = socket.SocketIO(self, "r") - if buffering < 0: - buffering = io.DEFAULT_BUFFER_SIZE - if buffering == 0: - return raw - buffer = io.BufferedReader(raw, buffering) - text = io.TextIOWrapper(buffer, None, None) - text.mode = "rb" - return text - - @staticmethod - def _raise_eintr(): - raise OSError(errno.EINTR, "interrupted") - - def _textiowrap_mock_socket(self, mock, buffering=-1): - raw = socket.SocketIO(mock, "r") - if buffering < 0: - buffering = io.DEFAULT_BUFFER_SIZE - if buffering == 0: - return raw - buffer = io.BufferedReader(raw, buffering) - text = io.TextIOWrapper(buffer, None, None) - text.mode = "rb" - return text - - def _test_readline(self, size=-1, buffering=-1): - mock_sock = self.MockSocket(recv_funcs=[ - lambda : b"This is the first line\nAnd the sec", - self._raise_eintr, - lambda : b"ond line is here\n", - lambda : b"", - lambda : b"", # XXX(gps): io library does an extra EOF read - ]) - fo = mock_sock._textiowrap_for_test(buffering=buffering) - self.assertEqual(fo.readline(size), "This is the first line\n") - self.assertEqual(fo.readline(size), "And the second line is here\n") - - def _test_read(self, size=-1, buffering=-1): - mock_sock = self.MockSocket(recv_funcs=[ - lambda : b"This is the first line\nAnd the sec", - self._raise_eintr, - lambda : b"ond line is here\n", - lambda : b"", - lambda : b"", # XXX(gps): io library does an extra EOF read - ]) - expecting = (b"This is the first line\n" - b"And the second line is here\n") - fo = mock_sock._textiowrap_for_test(buffering=buffering) - if buffering == 0: - data = b'' - else: - data = '' - expecting = expecting.decode('utf-8') - while len(data) != len(expecting): - part = fo.read(size) - if not part: - break - data += part - self.assertEqual(data, expecting) - - def test_default(self): - self._test_readline() - self._test_readline(size=100) - self._test_read() - self._test_read(size=100) - - def test_with_1k_buffer(self): - self._test_readline(buffering=1024) - self._test_readline(size=100, buffering=1024) - self._test_read(buffering=1024) - self._test_read(size=100, buffering=1024) - - def _test_readline_no_buffer(self, size=-1): - mock_sock = self.MockSocket(recv_funcs=[ - lambda : b"a", - lambda : b"\n", - lambda : b"B", - self._raise_eintr, - lambda : b"b", - lambda : b"", - ]) - fo = mock_sock._textiowrap_for_test(buffering=0) - self.assertEqual(fo.readline(size), b"a\n") - self.assertEqual(fo.readline(size), b"Bb") - - def test_no_buffer(self): - self._test_readline_no_buffer() - self._test_readline_no_buffer(size=4) - self._test_read(buffering=0) - self._test_read(size=100, buffering=0) - - class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): """Repeat the tests from FileObjectClassTestCase with bufsize==0. @@ -4588,7 +4476,7 @@ class TestLinuxAbstractNamespace(unittest.TestCase): address = b"\x00python-test-hello\x00\xff" with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s1: s1.bind(address) - s1.listen(1) + s1.listen() with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s2: s2.connect(s1.getsockname()) with s1.accept()[0] as s3: @@ -4820,7 +4708,7 @@ class TIPCThreadableTest(unittest.TestCase, ThreadableTest): srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE, TIPC_LOWER, TIPC_UPPER) self.srv.bind(srvaddr) - self.srv.listen(5) + self.srv.listen() self.serverExplicitReady() self.conn, self.connaddr = self.srv.accept() self.addCleanup(self.conn.close) @@ -5109,6 +4997,275 @@ class TestSocketSharing(SocketTCPTest): source.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendfileUsingSendTest(ThreadedTCPSocketTest): + """ + Test the send() implementation of socket.sendfile(). + """ + + FILESIZE = (10 * 1024 * 1024) # 10MB + BUFSIZE = 8192 + FILEDATA = b"" + TIMEOUT = 2 + + @classmethod + def setUpClass(cls): + def chunks(total, step): + assert total >= step + while total > step: + yield step + total -= step + if total: + yield total + + chunk = b"".join([random.choice(string.ascii_letters).encode() + for i in range(cls.BUFSIZE)]) + with open(support.TESTFN, 'wb') as f: + for csize in chunks(cls.FILESIZE, cls.BUFSIZE): + f.write(chunk) + with open(support.TESTFN, 'rb') as f: + cls.FILEDATA = f.read() + assert len(cls.FILEDATA) == cls.FILESIZE + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + + def accept_conn(self): + self.serv.settimeout(self.TIMEOUT) + conn, addr = self.serv.accept() + conn.settimeout(self.TIMEOUT) + self.addCleanup(conn.close) + return conn + + def recv_data(self, conn): + received = [] + while True: + chunk = conn.recv(self.BUFSIZE) + if not chunk: + break + received.append(chunk) + return b''.join(received) + + def meth_from_sock(self, sock): + # Depending on the mixin class being run return either send() + # or sendfile() method implementation. + return getattr(sock, "_sendfile_use_send") + + # regular file + + def _testRegularFile(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + + def testRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # non regular file + + def _testNonRegularFile(self): + address = self.serv.getsockname() + file = io.BytesIO(self.FILEDATA) + with socket.create_connection(address) as sock, file as file: + sent = sock.sendfile(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + self.assertRaises(socket._GiveupOnSendfile, + sock._sendfile_use_sendfile, file) + + def testNonRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # empty file + + def _testEmptyFileSend(self): + address = self.serv.getsockname() + filename = support.TESTFN + "2" + with open(filename, 'wb'): + self.addCleanup(support.unlink, filename) + file = open(filename, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, 0) + self.assertEqual(file.tell(), 0) + + def testEmptyFileSend(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(data, b"") + + # offset + + def _testOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file, offset=5000) + self.assertEqual(sent, self.FILESIZE - 5000) + self.assertEqual(file.tell(), self.FILESIZE) + + def testOffset(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE - 5000) + self.assertEqual(data, self.FILEDATA[5000:]) + + # count + + def _testCount(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 5000007 + meth = self.meth_from_sock(sock) + sent = meth(file, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count) + + def testCount(self): + count = 5000007 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[:count]) + + # count small + + def _testCountSmall(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 1 + meth = self.meth_from_sock(sock) + sent = meth(file, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count) + + def testCountSmall(self): + count = 1 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[:count]) + + # count + offset + + def _testCountWithOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 100007 + meth = self.meth_from_sock(sock) + sent = meth(file, offset=2007, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count + 2007) + + def testCountWithOffset(self): + count = 100007 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[2007:count+2007]) + + # non blocking sockets are not supposed to work + + def _testNonBlocking(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + sock.setblocking(False) + meth = self.meth_from_sock(sock) + self.assertRaises(ValueError, meth, file) + self.assertRaises(ValueError, sock.sendfile, file) + + def testNonBlocking(self): + conn = self.accept_conn() + if conn.recv(8192): + self.fail('was not supposed to receive any data') + + # timeout (non-triggered) + + def _testWithTimeout(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, self.FILESIZE) + + def testWithTimeout(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # timeout (triggered) + + def _testWithTimeoutTriggeredSend(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=0.01) as sock, \ + file as file: + meth = self.meth_from_sock(sock) + self.assertRaises(socket.timeout, meth, file) + + def testWithTimeoutTriggeredSend(self): + conn = self.accept_conn() + conn.recv(88192) + + # errors + + def _test_errors(self): + pass + + def test_errors(self): + with open(support.TESTFN, 'rb') as file: + with socket.socket(type=socket.SOCK_DGRAM) as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex( + ValueError, "SOCK_STREAM", meth, file) + with open(support.TESTFN, 'rt') as file: + with socket.socket() as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex( + ValueError, "binary mode", meth, file) + with open(support.TESTFN, 'rb') as file: + with socket.socket() as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex(TypeError, "positive integer", + meth, file, count='2') + self.assertRaisesRegex(TypeError, "positive integer", + meth, file, count=0.1) + self.assertRaisesRegex(ValueError, "positive integer", + meth, file, count=0) + self.assertRaisesRegex(ValueError, "positive integer", + meth, file, count=-1) + + +@unittest.skipUnless(thread, 'Threading required for this test.') +@unittest.skipUnless(hasattr(os, "sendfile"), + 'os.sendfile() required for this test.') +class SendfileUsingSendfileTest(SendfileUsingSendTest): + """ + Test the sendfile() implementation of socket.sendfile(). + """ + def meth_from_sock(self, sock): + return getattr(sock, "_sendfile_use_sendfile") + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] @@ -5116,7 +5273,6 @@ def test_main(): tests.extend([ NonBlockingTCPTests, FileObjectClassTestCase, - FileObjectInterruptedTestCase, UnbufferedFileObjectClassTestCase, LineBufferedFileObjectClassTestCase, SmallBufferedFileObjectClassTestCase, @@ -5161,6 +5317,8 @@ def test_main(): InterruptedRecvTimeoutTest, InterruptedSendTimeoutTest, TestSocketSharing, + SendfileUsingSendTest, + SendfileUsingSendfileTest, ]) thread_info = support.threading_setup() |