diff options
Diffstat (limited to 'Lib/test/test_socket.py')
-rw-r--r-- | Lib/test/test_socket.py | 1085 |
1 files changed, 779 insertions, 306 deletions
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index fec62efe038..e40b21e83cb 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1,21 +1,27 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 import unittest -from test import test_support +from test import support import errno +import io import socket import select import time import traceback -import Queue +import queue import sys import os import array +import platform import contextlib from weakref import proxy import signal import math +try: + import fcntl +except ImportError: + fcntl = False def try_address(host, port=0, family=socket.AF_INET): """Try to bind a socket on the given host:port and return True @@ -29,25 +35,30 @@ def try_address(host, port=0, family=socket.AF_INET): sock.close() return True -HOST = test_support.HOST -MSG = b'Michael Gilfix was here\n' +def linux_version(): + try: + # platform.release() is something like '2.6.33.7-desktop-2mnb' + version_string = platform.release().split('-')[0] + return tuple(map(int, version_string.split('.'))) + except ValueError: + return 0, 0, 0 + +HOST = support.HOST +MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf8') ## test unicode string and carriage return SUPPORTS_IPV6 = socket.has_ipv6 and try_address('::1', family=socket.AF_INET6) try: - import thread + import _thread as thread import threading except ImportError: thread = None threading = None -HOST = test_support.HOST -MSG = 'Michael Gilfix was here\n' - class SocketTCPTest(unittest.TestCase): def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = test_support.bind_port(self.serv) + self.port = support.bind_port(self.serv) self.serv.listen(1) def tearDown(self): @@ -58,7 +69,7 @@ class SocketUDPTest(unittest.TestCase): def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self.port = test_support.bind_port(self.serv) + self.port = support.bind_port(self.serv) def tearDown(self): self.serv.close() @@ -120,7 +131,7 @@ class ThreadableTest: self.server_ready = threading.Event() self.client_ready = threading.Event() self.done = threading.Event() - self.queue = Queue.Queue(1) + self.queue = queue.Queue(1) # Do some munging to start the client test. methodname = self.id() @@ -139,21 +150,22 @@ class ThreadableTest: self.__tearDown() self.done.wait() - if not self.queue.empty(): - msg = self.queue.get() - self.fail(msg) + if self.queue.qsize(): + exc = self.queue.get() + raise exc def clientRun(self, test_func): self.server_ready.wait() self.clientSetUp() self.client_ready.set() - if not callable(test_func): - raise TypeError("test_func must be a callable function.") + if not hasattr(test_func, '__call__'): + raise TypeError("test_func must be a callable function") try: test_func() - except Exception, strerror: - self.queue.put(strerror) - self.clientTearDown() + except BaseException as e: + self.queue.put(e) + finally: + self.clientTearDown() def clientSetUp(self): raise NotImplementedError("clientSetUp must be implemented.") @@ -191,6 +203,11 @@ class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest): ThreadableTest.clientTearDown(self) class SocketConnectedTest(ThreadedTCPSocketTest): + """Socket tests for client-server connection. + + self.cli_conn is a client socket connected to the server. The + setUp() method guarantees that it is connected to the server. + """ def __init__(self, methodName='runTest'): ThreadedTCPSocketTest.__init__(self, methodName=methodName) @@ -245,6 +262,11 @@ class SocketPairTest(unittest.TestCase, ThreadableTest): class GeneralModuleTests(unittest.TestCase): + def test_repr(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.addCleanup(s.close) + self.assertTrue(repr(s).startswith("<socket.socket object")) + def test_weakref(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) p = proxy(s) @@ -280,38 +302,43 @@ class GeneralModuleTests(unittest.TestCase): s.bind(('', 0)) sockname = s.getsockname() # 2 args - with self.assertRaises(UnicodeEncodeError): - s.sendto(u'\u2620', sockname) + with self.assertRaises(TypeError) as cm: + s.sendto('\u2620', sockname) + self.assertEqual(str(cm.exception), + "'str' does not support the buffer interface") with self.assertRaises(TypeError) as cm: s.sendto(5j, sockname) - self.assertIn('not complex', str(cm.exception)) + self.assertEqual(str(cm.exception), + "'complex' does not support the buffer interface") with self.assertRaises(TypeError) as cm: - s.sendto('foo', None) - self.assertIn('not NoneType', str(cm.exception)) + s.sendto(b'foo', None) + self.assertIn('not NoneType',str(cm.exception)) # 3 args - with self.assertRaises(UnicodeEncodeError): - s.sendto(u'\u2620', 0, sockname) + with self.assertRaises(TypeError) as cm: + s.sendto('\u2620', 0, sockname) + self.assertEqual(str(cm.exception), + "'str' does not support the buffer interface") with self.assertRaises(TypeError) as cm: s.sendto(5j, 0, sockname) - self.assertIn('not complex', str(cm.exception)) + self.assertEqual(str(cm.exception), + "'complex' does not support the buffer interface") with self.assertRaises(TypeError) as cm: - s.sendto('foo', 0, None) + s.sendto(b'foo', 0, None) self.assertIn('not NoneType', str(cm.exception)) with self.assertRaises(TypeError) as cm: - s.sendto('foo', 'bar', sockname) + s.sendto(b'foo', 'bar', sockname) self.assertIn('an integer is required', str(cm.exception)) with self.assertRaises(TypeError) as cm: - s.sendto('foo', None, None) + s.sendto(b'foo', None, None) self.assertIn('an integer is required', str(cm.exception)) # wrong number of args with self.assertRaises(TypeError) as cm: - s.sendto('foo') + s.sendto(b'foo') self.assertIn('(1 given)', str(cm.exception)) with self.assertRaises(TypeError) as cm: - s.sendto('foo', 0, sockname, 4) + s.sendto(b'foo', 0, sockname, 4) self.assertIn('(4 given)', str(cm.exception)) - def testCrucialConstants(self): # Testing for mission critical constants socket.AF_INET @@ -350,8 +377,8 @@ class GeneralModuleTests(unittest.TestCase): orig = sys.getrefcount(__name__) socket.getnameinfo(__name__,0) except TypeError: - self.assertEqual(sys.getrefcount(__name__), orig, - "socket.getnameinfo loses a reference") + if sys.getrefcount(__name__) != orig: + self.fail("socket.getnameinfo loses a reference") def testInterpreterCrash(self): # Making sure getnameinfo doesn't crash the interpreter @@ -367,17 +394,17 @@ class GeneralModuleTests(unittest.TestCase): sizes = {socket.htonl: 32, socket.ntohl: 32, socket.htons: 16, socket.ntohs: 16} for func, size in sizes.items(): - mask = (1L<<size) - 1 + mask = (1<<size) - 1 for i in (0, 1, 0xffff, ~0xffff, 2, 0x01234567, 0x76543210): self.assertEqual(i & mask, func(func(i&mask)) & mask) swapped = func(mask) self.assertEqual(swapped & mask, mask) - self.assertRaises(OverflowError, func, 1L<<34) + self.assertRaises(OverflowError, func, 1<<34) def testNtoHErrors(self): - good_values = [ 1, 2, 3, 1L, 2L, 3L ] - bad_values = [ -1, -2, -3, -1L, -2L, -3L ] + good_values = [ 1, 2, 3, 1, 2, 3 ] + bad_values = [ -1, -2, -3, -1, -2, -3 ] for k in good_values: socket.ntohl(k) socket.ntohs(k) @@ -463,8 +490,8 @@ class GeneralModuleTests(unittest.TestCase): return # No inet_aton, nothing to check # Test that issue1008086 and issue767150 are fixed. # It must return 4 bytes. - self.assertEqual('\x00'*4, socket.inet_aton('0.0.0.0')) - self.assertEqual('\xff'*4, socket.inet_aton('255.255.255.255')) + self.assertEqual(b'\x00'*4, socket.inet_aton('0.0.0.0')) + self.assertEqual(b'\xff'*4, socket.inet_aton('255.255.255.255')) def testIPv4toString(self): if not hasattr(socket, 'inet_pton'): @@ -472,16 +499,30 @@ class GeneralModuleTests(unittest.TestCase): from socket import inet_aton as f, inet_pton, AF_INET g = lambda a: inet_pton(AF_INET, a) - self.assertEqual('\x00\x00\x00\x00', f('0.0.0.0')) - self.assertEqual('\xff\x00\xff\x00', f('255.0.255.0')) - self.assertEqual('\xaa\xaa\xaa\xaa', f('170.170.170.170')) - self.assertEqual('\x01\x02\x03\x04', f('1.2.3.4')) - self.assertEqual('\xff\xff\xff\xff', f('255.255.255.255')) + assertInvalid = lambda func,a: self.assertRaises( + (socket.error, ValueError), func, a + ) - self.assertEqual('\x00\x00\x00\x00', g('0.0.0.0')) - self.assertEqual('\xff\x00\xff\x00', g('255.0.255.0')) - self.assertEqual('\xaa\xaa\xaa\xaa', g('170.170.170.170')) - self.assertEqual('\xff\xff\xff\xff', g('255.255.255.255')) + self.assertEqual(b'\x00\x00\x00\x00', f('0.0.0.0')) + self.assertEqual(b'\xff\x00\xff\x00', f('255.0.255.0')) + self.assertEqual(b'\xaa\xaa\xaa\xaa', f('170.170.170.170')) + self.assertEqual(b'\x01\x02\x03\x04', f('1.2.3.4')) + self.assertEqual(b'\xff\xff\xff\xff', f('255.255.255.255')) + assertInvalid(f, '0.0.0.') + assertInvalid(f, '300.0.0.0') + assertInvalid(f, 'a.0.0.0') + assertInvalid(f, '1.2.3.4.5') + assertInvalid(f, '::1') + + self.assertEqual(b'\x00\x00\x00\x00', g('0.0.0.0')) + self.assertEqual(b'\xff\x00\xff\x00', g('255.0.255.0')) + self.assertEqual(b'\xaa\xaa\xaa\xaa', g('170.170.170.170')) + self.assertEqual(b'\xff\xff\xff\xff', g('255.255.255.255')) + assertInvalid(g, '0.0.0.') + assertInvalid(g, '300.0.0.0') + assertInvalid(g, 'a.0.0.0') + assertInvalid(g, '1.2.3.4.5') + assertInvalid(g, '::1') def testIPv6toString(self): if not hasattr(socket, 'inet_pton'): @@ -493,29 +534,73 @@ class GeneralModuleTests(unittest.TestCase): except ImportError: return f = lambda a: inet_pton(AF_INET6, a) + assertInvalid = lambda a: self.assertRaises( + (socket.error, ValueError), f, a + ) - self.assertEqual('\x00' * 16, f('::')) - self.assertEqual('\x00' * 16, f('0::0')) - self.assertEqual('\x00\x01' + '\x00' * 14, f('1::')) + self.assertEqual(b'\x00' * 16, f('::')) + self.assertEqual(b'\x00' * 16, f('0::0')) + self.assertEqual(b'\x00\x01' + b'\x00' * 14, f('1::')) self.assertEqual( - '\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae', + b'\x45\xef\x76\xcb\x00\x1a\x56\xef\xaf\xeb\x0b\xac\x19\x24\xae\xae', f('45ef:76cb:1a:56ef:afeb:bac:1924:aeae') ) + self.assertEqual( + b'\xad\x42\x0a\xbc' + b'\x00' * 4 + b'\x01\x27\x00\x00\x02\x54\x00\x02', + f('ad42:abc::127:0:254:2') + ) + self.assertEqual(b'\x00\x12\x00\x0a' + b'\x00' * 12, f('12:a::')) + assertInvalid('0x20::') + assertInvalid(':::') + assertInvalid('::0::') + assertInvalid('1::abc::') + assertInvalid('1::abc::def') + assertInvalid('1:2:3:4:5:6:') + assertInvalid('1:2:3:4:5:6') + assertInvalid('1:2:3:4:5:6:7:8:') + assertInvalid('1:2:3:4:5:6:7:8:0') + + self.assertEqual(b'\x00' * 12 + b'\xfe\x2a\x17\x40', + f('::254.42.23.64') + ) + self.assertEqual( + b'\x00\x42' + b'\x00' * 8 + b'\xa2\x9b\xfe\x2a\x17\x40', + f('42::a29b:254.42.23.64') + ) + self.assertEqual( + b'\x00\x42\xa8\xb9\x00\x00\x00\x02\xff\xff\xa2\x9b\xfe\x2a\x17\x40', + f('42:a8b9:0:2:ffff:a29b:254.42.23.64') + ) + assertInvalid('255.254.253.252') + assertInvalid('1::260.2.3.0') + assertInvalid('1::0.be.e.0') + assertInvalid('1:2:3:4:5:6:7:1.2.3.4') + assertInvalid('::1.2.3.4:0') + assertInvalid('0.100.200.0:3:4:5:6:7:8') def testStringToIPv4(self): if not hasattr(socket, 'inet_ntop'): return # No inet_ntop() on this platform from socket import inet_ntoa as f, inet_ntop, AF_INET g = lambda a: inet_ntop(AF_INET, a) + assertInvalid = lambda func,a: self.assertRaises( + (socket.error, ValueError), func, a + ) - self.assertEqual('1.0.1.0', f('\x01\x00\x01\x00')) - self.assertEqual('170.85.170.85', f('\xaa\x55\xaa\x55')) - self.assertEqual('255.255.255.255', f('\xff\xff\xff\xff')) - self.assertEqual('1.2.3.4', f('\x01\x02\x03\x04')) - - self.assertEqual('1.0.1.0', g('\x01\x00\x01\x00')) - self.assertEqual('170.85.170.85', g('\xaa\x55\xaa\x55')) - self.assertEqual('255.255.255.255', g('\xff\xff\xff\xff')) + self.assertEqual('1.0.1.0', f(b'\x01\x00\x01\x00')) + self.assertEqual('170.85.170.85', f(b'\xaa\x55\xaa\x55')) + self.assertEqual('255.255.255.255', f(b'\xff\xff\xff\xff')) + self.assertEqual('1.2.3.4', f(b'\x01\x02\x03\x04')) + assertInvalid(f, b'\x00' * 3) + assertInvalid(f, b'\x00' * 5) + assertInvalid(f, b'\x00' * 16) + + self.assertEqual('1.0.1.0', g(b'\x01\x00\x01\x00')) + self.assertEqual('170.85.170.85', g(b'\xaa\x55\xaa\x55')) + self.assertEqual('255.255.255.255', g(b'\xff\xff\xff\xff')) + assertInvalid(g, b'\x00' * 3) + assertInvalid(g, b'\x00' * 5) + assertInvalid(g, b'\x00' * 16) def testStringToIPv6(self): if not hasattr(socket, 'inet_ntop'): @@ -527,33 +612,26 @@ class GeneralModuleTests(unittest.TestCase): except ImportError: return f = lambda a: inet_ntop(AF_INET6, a) + assertInvalid = lambda a: self.assertRaises( + (socket.error, ValueError), f, a + ) - self.assertEqual('::', f('\x00' * 16)) - self.assertEqual('::1', f('\x00' * 15 + '\x01')) + self.assertEqual('::', f(b'\x00' * 16)) + self.assertEqual('::1', f(b'\x00' * 15 + b'\x01')) self.assertEqual( 'aef:b01:506:1001:ffff:9997:55:170', - f('\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70') + f(b'\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70') ) - # XXX The following don't test module-level functionality... - - def _get_unused_port(self, bind_address='0.0.0.0'): - """Use a temporary socket to elicit an unused ephemeral port. + assertInvalid(b'\x12' * 15) + assertInvalid(b'\x12' * 17) + assertInvalid(b'\x12' * 4) - Args: - bind_address: Hostname or IP address to search for a port on. - - Returns: A most likely to be unused port. - """ - tempsock = socket.socket() - tempsock.bind((bind_address, 0)) - host, port = tempsock.getsockname() - tempsock.close() - return port + # XXX The following don't test module-level functionality... def testSockName(self): # Testing getsockname() - port = self._get_unused_port() + port = support.find_unused_port() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.addCleanup(sock.close) sock.bind(("0.0.0.0", port)) @@ -590,7 +668,7 @@ class GeneralModuleTests(unittest.TestCase): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(1) sock.close() - self.assertRaises(socket.error, sock.send, "spam") + self.assertRaises(socket.error, sock.send, b"spam") def testNewAttributes(self): # testing .family, .type and .protocol @@ -602,7 +680,7 @@ class GeneralModuleTests(unittest.TestCase): def test_getsockaddrarg(self): host = '0.0.0.0' - port = self._get_unused_port(bind_address=host) + port = support.find_unused_port() big_port = port + 65536 neg_port = port - 65536 sock = socket.socket() @@ -644,10 +722,9 @@ class GeneralModuleTests(unittest.TestCase): if SUPPORTS_IPV6: socket.getaddrinfo('::1', 80) # port can be a string service name such as "http", a numeric - # port number (int or long), or None + # port number or None socket.getaddrinfo(HOST, "http") socket.getaddrinfo(HOST, 80) - socket.getaddrinfo(HOST, 80L) socket.getaddrinfo(HOST, None) # test family and socktype filters infos = socket.getaddrinfo(HOST, None, socket.AF_INET) @@ -663,7 +740,46 @@ class GeneralModuleTests(unittest.TestCase): # usually do this socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE) - + # test keyword arguments + a = socket.getaddrinfo(HOST, None) + b = socket.getaddrinfo(host=HOST, port=None) + self.assertEqual(a, b) + a = socket.getaddrinfo(HOST, None, socket.AF_INET) + b = socket.getaddrinfo(HOST, None, family=socket.AF_INET) + self.assertEqual(a, b) + a = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM) + b = socket.getaddrinfo(HOST, None, type=socket.SOCK_STREAM) + self.assertEqual(a, b) + a = socket.getaddrinfo(HOST, None, 0, 0, socket.SOL_TCP) + b = socket.getaddrinfo(HOST, None, proto=socket.SOL_TCP) + self.assertEqual(a, b) + a = socket.getaddrinfo(HOST, None, 0, 0, 0, socket.AI_PASSIVE) + b = socket.getaddrinfo(HOST, None, flags=socket.AI_PASSIVE) + self.assertEqual(a, b) + a = socket.getaddrinfo(None, 0, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, + socket.AI_PASSIVE) + b = socket.getaddrinfo(host=None, port=0, family=socket.AF_UNSPEC, + type=socket.SOCK_STREAM, proto=0, + flags=socket.AI_PASSIVE) + self.assertEqual(a, b) + # Issue #6697. + self.assertRaises(UnicodeEncodeError, socket.getaddrinfo, 'localhost', '\uD800') + + def test_getnameinfo(self): + # only IP addresses are allowed + self.assertRaises(socket.error, socket.getnameinfo, ('mail.python.org',0), 0) + + @unittest.skipUnless(support.is_resource_enabled('network'), + 'network is not enabled') + def test_idna(self): + support.requires('network') + # these should all be successful + socket.gethostbyname('испытание.python.org') + socket.gethostbyname_ex('испытание.python.org') + socket.getaddrinfo('испытание.python.org',0,socket.AF_UNSPEC,socket.SOCK_STREAM) + # this may not work if the forward lookup choses the IPv6 address, as that doesn't + # have a reverse entry yet + # socket.gethostbyaddr('испытание.python.org') def check_sendall_interrupted(self, with_timeout): # socketpair() is not stricly required, but it makes things easier. @@ -700,6 +816,40 @@ class GeneralModuleTests(unittest.TestCase): def test_sendall_interrupted_with_timeout(self): self.check_sendall_interrupted(True) + def test_dealloc_warn(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + r = repr(sock) + with self.assertWarns(ResourceWarning) as cm: + sock = None + support.gc_collect() + self.assertIn(r, str(cm.warning.args[0])) + # An open socket file object gets dereferenced after the socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + f = sock.makefile('rb') + r = repr(sock) + sock = None + support.gc_collect() + with self.assertWarns(ResourceWarning): + f = None + support.gc_collect() + + def test_name_closed_socketio(self): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + fp = sock.makefile("rb") + fp.close() + self.assertEqual(repr(fp), "<_io.BufferedReader name=-1>") + + def test_unusable_closed_socketio(self): + with socket.socket() as sock: + fp = sock.makefile("rb", buffering=0) + self.assertTrue(fp.readable()) + self.assertFalse(fp.writable()) + self.assertFalse(fp.seekable()) + fp.close() + self.assertRaises(ValueError, fp.readable) + self.assertRaises(ValueError, fp.writable) + self.assertRaises(ValueError, fp.seekable) + def testListenBacklog0(self): srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) srv.bind((HOST, 0)) @@ -711,11 +861,8 @@ class GeneralModuleTests(unittest.TestCase): def test_flowinfo(self): self.assertRaises(OverflowError, socket.getnameinfo, ('::1',0, 0xffffffff), 0) - s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - try: + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: self.assertRaises(OverflowError, s.bind, ('::1', 0, -10)) - finally: - s.close() @unittest.skipUnless(thread, 'Threading required for this test.') @@ -762,25 +909,24 @@ class BasicTCPTest(SocketConnectedTest): def testSendAll(self): # Testing sendall() with a 2048 byte string over TCP - msg = '' + msg = b'' while 1: read = self.cli_conn.recv(1024) if not read: break msg += read - self.assertEqual(msg, 'f' * 2048) + self.assertEqual(msg, b'f' * 2048) def _testSendAll(self): - big_chunk = 'f' * 2048 + big_chunk = b'f' * 2048 self.serv_conn.sendall(big_chunk) def testFromFd(self): # Testing fromfd() - if not hasattr(socket, "fromfd"): - return # On Windows, this doesn't exist fd = self.cli_conn.fileno() sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) self.addCleanup(sock.close) + self.assertIsInstance(sock, socket.socket) msg = sock.recv(1024) self.assertEqual(msg, MSG) @@ -810,6 +956,25 @@ class BasicTCPTest(SocketConnectedTest): self.serv_conn.send(MSG) self.serv_conn.shutdown(2) + def testDetach(self): + # Testing detach() + fileno = self.cli_conn.fileno() + f = self.cli_conn.detach() + self.assertEqual(f, fileno) + # cli_conn cannot be used anymore... + self.assertTrue(self.cli_conn._closed) + self.assertRaises(socket.error, self.cli_conn.recv, 1024) + self.cli_conn.close() + # ...but we can create another socket using the (still open) + # file descriptor + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=f) + self.addCleanup(sock.close) + msg = sock.recv(1024) + self.assertEqual(msg, MSG) + + def _testDetach(self): + self.serv_conn.send(MSG) + @unittest.skipUnless(thread, 'Threading required for this test.') class BasicUDPTest(ThreadedUDPSocketTest): @@ -849,7 +1014,11 @@ class TCPCloserTest(ThreadedTCPSocketTest): sd = self.cli read, write, err = select.select([sd], [], [], 1.0) self.assertEqual(read, [sd]) - self.assertEqual(sd.recv(1), '') + self.assertEqual(sd.recv(1), b'') + + # Calling close() many times should be safe. + conn.close() + conn.close() def _testClose(self): self.cli.connect((HOST, self.port)) @@ -861,6 +1030,21 @@ class BasicSocketPairTest(SocketPairTest): def __init__(self, methodName='runTest'): SocketPairTest.__init__(self, methodName=methodName) + def _check_defaults(self, sock): + self.assertIsInstance(sock, socket.socket) + if hasattr(socket, 'AF_UNIX'): + self.assertEqual(sock.family, socket.AF_UNIX) + else: + self.assertEqual(sock.family, socket.AF_INET) + self.assertEqual(sock.type, socket.SOCK_STREAM) + self.assertEqual(sock.proto, 0) + + def _testDefaults(self): + self._check_defaults(self.cli) + + def testDefaults(self): + self._check_defaults(self.serv) + def testRecv(self): msg = self.serv.recv(1024) self.assertEqual(msg, MSG) @@ -895,6 +1079,47 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): def _testSetBlocking(self): pass + if hasattr(socket, "SOCK_NONBLOCK"): + def testInitNonBlocking(self): + v = linux_version() + if v < (2, 6, 28): + self.skipTest("Linux kernel 2.6.28 or higher required, not %s" + % ".".join(map(str, v))) + # reinit server socket + self.serv.close() + self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM | + socket.SOCK_NONBLOCK) + self.port = support.bind_port(self.serv) + self.serv.listen(1) + # actual testing + start = time.time() + try: + self.serv.accept() + except socket.error: + pass + end = time.time() + self.assertTrue((end - start) < 1.0, "Error creating with non-blocking mode.") + + def _testInitNonBlocking(self): + pass + + def testInheritFlags(self): + # Issue #7995: when calling accept() on a listening socket with a + # timeout, the resulting socket should not be non-blocking. + self.serv.settimeout(10) + try: + conn, addr = self.serv.accept() + message = conn.recv(len(MSG)) + finally: + conn.close() + self.serv.settimeout(None) + + def _testInheritFlags(self): + time.sleep(0.1) + self.cli.connect((HOST, self.port)) + time.sleep(0.5) + self.cli.send(MSG) + def testAccept(self): # Testing non-blocking accept self.serv.setblocking(0) @@ -949,107 +1174,166 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): @unittest.skipUnless(thread, 'Threading required for this test.') class FileObjectClassTestCase(SocketConnectedTest): + """Unit tests for the object returned by socket.makefile() + + self.read_file is the io object returned by makefile() on + the client connection. You can read from this file to + get output from the server. + + self.write_file is the io object returned by makefile() on the + server connection. You can write to this file to send output + to the client. + """ bufsize = -1 # Use default buffer size + encoding = 'utf8' + errors = 'strict' + newline = None + + read_mode = 'rb' + read_msg = MSG + write_mode = 'wb' + write_msg = MSG def __init__(self, methodName='runTest'): SocketConnectedTest.__init__(self, methodName=methodName) def setUp(self): + self.evt1, self.evt2, self.serv_finished, self.cli_finished = [ + threading.Event() for i in range(4)] SocketConnectedTest.setUp(self) - self.serv_file = self.cli_conn.makefile('rb', self.bufsize) + self.read_file = self.cli_conn.makefile( + self.read_mode, self.bufsize, + encoding = self.encoding, + errors = self.errors, + newline = self.newline) def tearDown(self): - self.serv_file.close() - self.assertTrue(self.serv_file.closed) + self.serv_finished.set() + self.read_file.close() + self.assertTrue(self.read_file.closed) + self.read_file = None SocketConnectedTest.tearDown(self) - self.serv_file = None def clientSetUp(self): SocketConnectedTest.clientSetUp(self) - self.cli_file = self.serv_conn.makefile('wb') + self.write_file = self.serv_conn.makefile( + self.write_mode, self.bufsize, + encoding = self.encoding, + errors = self.errors, + newline = self.newline) def clientTearDown(self): - self.cli_file.close() - self.assertTrue(self.cli_file.closed) - self.cli_file = None + self.cli_finished.set() + self.write_file.close() + self.assertTrue(self.write_file.closed) + self.write_file = None SocketConnectedTest.clientTearDown(self) + def testReadAfterTimeout(self): + # Issue #7322: A file object must disallow further reads + # after a timeout has occurred. + self.cli_conn.settimeout(1) + self.read_file.read(3) + # First read raises a timeout + self.assertRaises(socket.timeout, self.read_file.read, 1) + # Second read is disallowed + with self.assertRaises(IOError) as ctx: + self.read_file.read(1) + self.assertIn("cannot read from timed out object", str(ctx.exception)) + + def _testReadAfterTimeout(self): + self.write_file.write(self.write_msg[0:3]) + self.write_file.flush() + self.serv_finished.wait() + def testSmallRead(self): # Performing small file read test - first_seg = self.serv_file.read(len(MSG)-3) - second_seg = self.serv_file.read(3) + first_seg = self.read_file.read(len(self.read_msg)-3) + second_seg = self.read_file.read(3) msg = first_seg + second_seg - self.assertEqual(msg, MSG) + self.assertEqual(msg, self.read_msg) def _testSmallRead(self): - self.cli_file.write(MSG) - self.cli_file.flush() + self.write_file.write(self.write_msg) + self.write_file.flush() def testFullRead(self): # read until EOF - msg = self.serv_file.read() - self.assertEqual(msg, MSG) + msg = self.read_file.read() + self.assertEqual(msg, self.read_msg) def _testFullRead(self): - self.cli_file.write(MSG) - self.cli_file.close() + self.write_file.write(self.write_msg) + self.write_file.close() def testUnbufferedRead(self): # Performing unbuffered file read test - buf = '' + buf = type(self.read_msg)() while 1: - char = self.serv_file.read(1) + char = self.read_file.read(1) if not char: break buf += char - self.assertEqual(buf, MSG) + self.assertEqual(buf, self.read_msg) def _testUnbufferedRead(self): - self.cli_file.write(MSG) - self.cli_file.flush() + self.write_file.write(self.write_msg) + self.write_file.flush() def testReadline(self): # Performing file readline test - line = self.serv_file.readline() - self.assertEqual(line, MSG) + line = self.read_file.readline() + self.assertEqual(line, self.read_msg) def _testReadline(self): - self.cli_file.write(MSG) - self.cli_file.flush() - - def testReadlineAfterRead(self): - a_baloo_is = self.serv_file.read(len("A baloo is")) - self.assertEqual("A baloo is", a_baloo_is) - _a_bear = self.serv_file.read(len(" a bear")) - self.assertEqual(" a bear", _a_bear) - line = self.serv_file.readline() - self.assertEqual("\n", line) - line = self.serv_file.readline() - self.assertEqual("A BALOO IS A BEAR.\n", line) - line = self.serv_file.readline() - self.assertEqual(MSG, line) - - def _testReadlineAfterRead(self): - self.cli_file.write("A baloo is a bear\n") - self.cli_file.write("A BALOO IS A BEAR.\n") - self.cli_file.write(MSG) - self.cli_file.flush() - - def testReadlineAfterReadNoNewline(self): - end_of_ = self.serv_file.read(len("End Of ")) - self.assertEqual("End Of ", end_of_) - line = self.serv_file.readline() - self.assertEqual("Line", line) - - def _testReadlineAfterReadNoNewline(self): - self.cli_file.write("End Of Line") + self.write_file.write(self.write_msg) + self.write_file.flush() + + def testCloseAfterMakefile(self): + # The file returned by makefile should keep the socket open. + self.cli_conn.close() + # read until EOF + msg = self.read_file.read() + self.assertEqual(msg, self.read_msg) + + def _testCloseAfterMakefile(self): + self.write_file.write(self.write_msg) + self.write_file.flush() + + def testMakefileAfterMakefileClose(self): + self.read_file.close() + msg = self.cli_conn.recv(len(MSG)) + if isinstance(self.read_msg, str): + msg = msg.decode() + self.assertEqual(msg, self.read_msg) + + def _testMakefileAfterMakefileClose(self): + self.write_file.write(self.write_msg) + self.write_file.flush() def testClosedAttr(self): - self.assertTrue(not self.serv_file.closed) + self.assertTrue(not self.read_file.closed) def _testClosedAttr(self): - self.assertTrue(not self.cli_file.closed) + self.assertTrue(not self.write_file.closed) + + def testAttributes(self): + self.assertEqual(self.read_file.mode, self.read_mode) + self.assertEqual(self.read_file.name, self.cli_conn.fileno()) + + def _testAttributes(self): + self.assertEqual(self.write_file.mode, self.write_mode) + self.assertEqual(self.write_file.name, self.serv_conn.fileno()) + + def testRealClose(self): + self.read_file.close() + self.assertRaises(ValueError, self.read_file.fileno) + self.cli_conn.close() + self.assertRaises(socket.error, self.cli_conn.getsockname) + + def _testRealClose(self): + pass class FileObjectInterruptedTestCase(unittest.TestCase): @@ -1061,34 +1345,75 @@ class FileObjectInterruptedTestCase(unittest.TestCase): # call to recv(). self._recv_step = iter(recv_funcs) - def recv(self, size): - return self._recv_step.next()() + def recv_into(self, buffer): + data = next(self._recv_step)() + assert len(buffer) >= len(data) + buffer[:len(data)] = data + return len(data) + + def _decref_socketios(self): + pass + + def _textiowrap_for_test(self, buffering=-1): + raw = socket.SocketIO(self, "r") + if buffering < 0: + buffering = io.DEFAULT_BUFFER_SIZE + if buffering == 0: + return raw + buffer = io.BufferedReader(raw, buffering) + text = io.TextIOWrapper(buffer, None, None) + text.mode = "rb" + return text @staticmethod def _raise_eintr(): raise socket.error(errno.EINTR) - def _test_readline(self, size=-1, **kwargs): + def _textiowrap_mock_socket(self, mock, buffering=-1): + raw = socket.SocketIO(mock, "r") + if buffering < 0: + buffering = io.DEFAULT_BUFFER_SIZE + if buffering == 0: + return raw + buffer = io.BufferedReader(raw, buffering) + text = io.TextIOWrapper(buffer, None, None) + text.mode = "rb" + return text + + def _test_readline(self, size=-1, buffering=-1): mock_sock = self.MockSocket(recv_funcs=[ - lambda : "This is the first line\nAnd the sec", + lambda : b"This is the first line\nAnd the sec", self._raise_eintr, - lambda : "ond line is here\n", - lambda : "", + lambda : b"ond line is here\n", + lambda : b"", + lambda : b"", # XXX(gps): io library does an extra EOF read ]) - fo = socket._fileobject(mock_sock, **kwargs) + fo = mock_sock._textiowrap_for_test(buffering=buffering) self.assertEqual(fo.readline(size), "This is the first line\n") self.assertEqual(fo.readline(size), "And the second line is here\n") - def _test_read(self, size=-1, **kwargs): + def _test_read(self, size=-1, buffering=-1): mock_sock = self.MockSocket(recv_funcs=[ - lambda : "This is the first line\nAnd the sec", + lambda : b"This is the first line\nAnd the sec", self._raise_eintr, - lambda : "ond line is here\n", - lambda : "", + lambda : b"ond line is here\n", + lambda : b"", + lambda : b"", # XXX(gps): io library does an extra EOF read ]) - fo = socket._fileobject(mock_sock, **kwargs) - self.assertEqual(fo.read(size), "This is the first line\n" - "And the second line is here\n") + expecting = (b"This is the first line\n" + b"And the second line is here\n") + fo = mock_sock._textiowrap_for_test(buffering=buffering) + if buffering == 0: + data = b'' + else: + data = '' + expecting = expecting.decode('utf8') + while len(data) != len(expecting): + part = fo.read(size) + if not part: + break + data += part + self.assertEqual(data, expecting) def test_default(self): self._test_readline() @@ -1097,29 +1422,29 @@ class FileObjectInterruptedTestCase(unittest.TestCase): self._test_read(size=100) def test_with_1k_buffer(self): - self._test_readline(bufsize=1024) - self._test_readline(size=100, bufsize=1024) - self._test_read(bufsize=1024) - self._test_read(size=100, bufsize=1024) + self._test_readline(buffering=1024) + self._test_readline(size=100, buffering=1024) + self._test_read(buffering=1024) + self._test_read(size=100, buffering=1024) def _test_readline_no_buffer(self, size=-1): mock_sock = self.MockSocket(recv_funcs=[ - lambda : "aa", - lambda : "\n", - lambda : "BB", + lambda : b"a", + lambda : b"\n", + lambda : b"B", self._raise_eintr, - lambda : "bb", - lambda : "", + lambda : b"b", + lambda : b"", ]) - fo = socket._fileobject(mock_sock, bufsize=0) - self.assertEqual(fo.readline(size), "aa\n") - self.assertEqual(fo.readline(size), "BBbb") + fo = mock_sock._textiowrap_for_test(buffering=0) + self.assertEqual(fo.readline(size), b"a\n") + self.assertEqual(fo.readline(size), b"Bb") def test_no_buffer(self): self._test_readline_no_buffer() self._test_readline_no_buffer(size=4) - self._test_read(bufsize=0) - self._test_read(size=100, bufsize=0) + self._test_read(buffering=0) + self._test_read(size=100, buffering=0) class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): @@ -1129,94 +1454,150 @@ class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): In this case (and in this case only), it should be possible to create a file object, read a line from it, create another file object, read another line from it, without loss of data in the - first file object's buffer. Note that httplib relies on this + first file object's buffer. Note that http.client relies on this when reading multiple requests from the same socket.""" bufsize = 0 # Use unbuffered mode def testUnbufferedReadline(self): # Read a line, create a new file object, read another line with it - line = self.serv_file.readline() # first line - self.assertEqual(line, "A. " + MSG) # first line - self.serv_file = self.cli_conn.makefile('rb', 0) - line = self.serv_file.readline() # second line - self.assertEqual(line, "B. " + MSG) # second line + line = self.read_file.readline() # first line + self.assertEqual(line, b"A. " + self.write_msg) # first line + self.read_file = self.cli_conn.makefile('rb', 0) + line = self.read_file.readline() # second line + self.assertEqual(line, b"B. " + self.write_msg) # second line def _testUnbufferedReadline(self): - self.cli_file.write("A. " + MSG) - self.cli_file.write("B. " + MSG) - self.cli_file.flush() + self.write_file.write(b"A. " + self.write_msg) + self.write_file.write(b"B. " + self.write_msg) + self.write_file.flush() + + def testMakefileClose(self): + # The file returned by makefile should keep the socket open... + self.cli_conn.close() + msg = self.cli_conn.recv(1024) + self.assertEqual(msg, self.read_msg) + # ...until the file is itself closed + self.read_file.close() + self.assertRaises(socket.error, self.cli_conn.recv, 1024) + + def _testMakefileClose(self): + self.write_file.write(self.write_msg) + self.write_file.flush() + + def testMakefileCloseSocketDestroy(self): + refcount_before = sys.getrefcount(self.cli_conn) + self.read_file.close() + refcount_after = sys.getrefcount(self.cli_conn) + self.assertEqual(refcount_before - 1, refcount_after) + + def _testMakefileCloseSocketDestroy(self): + pass + + # Non-blocking ops + # NOTE: to set `read_file` as non-blocking, we must call + # `cli_conn.setblocking` and vice-versa (see setUp / clientSetUp). + + def testSmallReadNonBlocking(self): + self.cli_conn.setblocking(False) + self.assertEqual(self.read_file.readinto(bytearray(10)), None) + self.assertEqual(self.read_file.read(len(self.read_msg) - 3), None) + self.evt1.set() + self.evt2.wait(1.0) + first_seg = self.read_file.read(len(self.read_msg) - 3) + if first_seg is None: + # Data not arrived (can happen under Windows), wait a bit + time.sleep(0.5) + first_seg = self.read_file.read(len(self.read_msg) - 3) + buf = bytearray(10) + n = self.read_file.readinto(buf) + self.assertEqual(n, 3) + msg = first_seg + buf[:n] + self.assertEqual(msg, self.read_msg) + self.assertEqual(self.read_file.readinto(bytearray(16)), None) + self.assertEqual(self.read_file.read(1), None) + + def _testSmallReadNonBlocking(self): + self.evt1.wait(1.0) + self.write_file.write(self.write_msg) + self.write_file.flush() + self.evt2.set() + # Avoid cloding the socket before the server test has finished, + # otherwise system recv() will return 0 instead of EWOULDBLOCK. + self.serv_finished.wait(5.0) + + def testWriteNonBlocking(self): + self.cli_finished.wait(5.0) + # The client thread can't skip directly - the SkipTest exception + # would appear as a failure. + if self.serv_skipped: + self.skipTest(self.serv_skipped) + + def _testWriteNonBlocking(self): + self.serv_skipped = None + self.serv_conn.setblocking(False) + # Try to saturate the socket buffer pipe with repeated large writes. + BIG = b"x" * (1024 ** 2) + LIMIT = 10 + # The first write() succeeds since a chunk of data can be buffered + n = self.write_file.write(BIG) + self.assertGreater(n, 0) + for i in range(LIMIT): + n = self.write_file.write(BIG) + if n is None: + # Succeeded + break + self.assertGreater(n, 0) + else: + # Let us know that this test didn't manage to establish + # the expected conditions. This is not a failure in itself but, + # if it happens repeatedly, the test should be fixed. + self.serv_skipped = "failed to saturate the socket buffer" + class LineBufferedFileObjectClassTestCase(FileObjectClassTestCase): bufsize = 1 # Default-buffered for reading; line-buffered for writing - class SocketMemo(object): - """A wrapper to keep track of sent data, needed to examine write behaviour""" - def __init__(self, sock): - self._sock = sock - self.sent = [] - def send(self, data, flags=0): - n = self._sock.send(data, flags) - self.sent.append(data[:n]) - return n +class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase): - def sendall(self, data, flags=0): - self._sock.sendall(data, flags) - self.sent.append(data) + bufsize = 2 # Exercise the buffering code - def __getattr__(self, attr): - return getattr(self._sock, attr) - def getsent(self): - return [e.tobytes() if isinstance(e, memoryview) else e for e in self.sent] +class UnicodeReadFileObjectClassTestCase(FileObjectClassTestCase): + """Tests for socket.makefile() in text mode (rather than binary)""" - def setUp(self): - FileObjectClassTestCase.setUp(self) - self.serv_file._sock = self.SocketMemo(self.serv_file._sock) - - def testLinebufferedWrite(self): - # Write two lines, in small chunks - msg = MSG.strip() - print >> self.serv_file, msg, - print >> self.serv_file, msg - - # second line: - print >> self.serv_file, msg, - print >> self.serv_file, msg, - print >> self.serv_file, msg - - # third line - print >> self.serv_file, '' - - self.serv_file.flush() - - msg1 = "%s %s\n"%(msg, msg) - msg2 = "%s %s %s\n"%(msg, msg, msg) - msg3 = "\n" - self.assertEqual(self.serv_file._sock.getsent(), [msg1, msg2, msg3]) - - def _testLinebufferedWrite(self): - msg = MSG.strip() - msg1 = "%s %s\n"%(msg, msg) - msg2 = "%s %s %s\n"%(msg, msg, msg) - msg3 = "\n" - l1 = self.cli_file.readline() - self.assertEqual(l1, msg1) - l2 = self.cli_file.readline() - self.assertEqual(l2, msg2) - l3 = self.cli_file.readline() - self.assertEqual(l3, msg3) + read_mode = 'r' + read_msg = MSG.decode('utf8') + write_mode = 'wb' + write_msg = MSG + newline = '' -class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase): +class UnicodeWriteFileObjectClassTestCase(FileObjectClassTestCase): + """Tests for socket.makefile() in text mode (rather than binary)""" + + read_mode = 'rb' + read_msg = MSG + write_mode = 'w' + write_msg = MSG.decode('utf8') + newline = '' - bufsize = 2 # Exercise the buffering code + +class UnicodeReadWriteFileObjectClassTestCase(FileObjectClassTestCase): + """Tests for socket.makefile() in text mode (rather than binary)""" + + read_mode = 'r' + read_msg = MSG.decode('utf8') + write_mode = 'w' + write_msg = MSG.decode('utf8') + newline = '' class NetworkConnectionTest(object): """Prove network connection.""" + def clientSetUp(self): # We're inherited below by BasicTCPTest2, which also inherits # BasicTCPTest, which defines self.port referenced below. @@ -1228,6 +1609,7 @@ class BasicTCPTest2(NetworkConnectionTest, BasicTCPTest): """ class NetworkConnectionNoServer(unittest.TestCase): + class MockSocket(socket.socket): def connect(self, *args): raise socket.timeout('timed out') @@ -1243,7 +1625,7 @@ class NetworkConnectionNoServer(unittest.TestCase): socket.socket = old_socket def test_connect(self): - port = test_support.find_unused_port() + port = support.find_unused_port() cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.addCleanup(cli.close) with self.assertRaises(socket.error) as cm: @@ -1253,7 +1635,7 @@ class NetworkConnectionNoServer(unittest.TestCase): def test_create_connection(self): # Issue #9792: errors raised by create_connection() should have # a proper errno attribute. - port = test_support.find_unused_port() + port = support.find_unused_port() with self.assertRaises(socket.error) as cm: socket.create_connection((HOST, port)) @@ -1293,7 +1675,7 @@ class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): ThreadableTest.__init__(self) def clientSetUp(self): - self.source_port = test_support.find_unused_port() + self.source_port = support.find_unused_port() def clientTearDown(self): self.cli.close() @@ -1373,43 +1755,19 @@ class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest): conn, addr = self.serv.accept() self.addCleanup(conn.close) time.sleep(3) - conn.send("done!") + conn.send(b"done!") testOutsideTimeout = testInsideTimeout def _testInsideTimeout(self): self.cli = sock = socket.create_connection((HOST, self.port)) data = sock.recv(5) - self.assertEqual(data, "done!") + self.assertEqual(data, b"done!") def _testOutsideTimeout(self): self.cli = sock = socket.create_connection((HOST, self.port), timeout=1) self.assertRaises(socket.timeout, lambda: sock.recv(5)) -class Urllib2FileobjectTest(unittest.TestCase): - - # urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that - # it close the socket if the close c'tor argument is true - - def testClose(self): - class MockSocket: - closed = False - def flush(self): pass - def close(self): self.closed = True - - # must not close unless we request it: the original use of _fileobject - # by module socket requires that the underlying socket not be closed until - # the _socketobject that created the _fileobject is closed - s = MockSocket() - f = socket._fileobject(s) - f.close() - self.assertTrue(not s.closed) - - s = MockSocket() - f = socket._fileobject(s, close=True) - f.close() - self.assertTrue(s.closed) - class TCPTimeoutTest(SocketTCPTest): def testTCPTimeout(self): @@ -1503,26 +1861,26 @@ class TestLinuxAbstractNamespace(unittest.TestCase): UNIX_PATH_MAX = 108 def testLinuxAbstractNamespace(self): - address = "\x00python-test-hello\x00\xff" - s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - s1.bind(address) - s1.listen(1) - s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - s2.connect(s1.getsockname()) - s1.accept() - self.assertEqual(s1.getsockname(), address) - self.assertEqual(s2.getpeername(), address) + address = b"\x00python-test-hello\x00\xff" + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s1: + s1.bind(address) + s1.listen(1) + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s2: + s2.connect(s1.getsockname()) + with s1.accept()[0] as s3: + self.assertEqual(s1.getsockname(), address) + self.assertEqual(s2.getpeername(), address) def testMaxName(self): - address = "\x00" + "h" * (self.UNIX_PATH_MAX - 1) - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - s.bind(address) - self.assertEqual(s.getsockname(), address) + address = b"\x00" + b"h" * (self.UNIX_PATH_MAX - 1) + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.bind(address) + self.assertEqual(s.getsockname(), address) def testNameOverflow(self): address = "\x00" + "h" * self.UNIX_PATH_MAX - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.assertRaises(socket.error, s.bind, address) + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + self.assertRaises(socket.error, s.bind, address) @unittest.skipUnless(thread, 'Threading required for this test.') @@ -1534,15 +1892,14 @@ class BufferIOTest(SocketConnectedTest): SocketConnectedTest.__init__(self, methodName=methodName) def testRecvIntoArray(self): - buf = array.array('c', ' '*1024) + buf = bytearray(1024) nbytes = self.cli_conn.recv_into(buf) self.assertEqual(nbytes, len(MSG)) - msg = buf.tostring()[:len(MSG)] + msg = buf[:len(MSG)] self.assertEqual(msg, MSG) def _testRecvIntoArray(self): - with test_support.check_py3k_warnings(): - buf = buffer(MSG) + buf = bytes(MSG) self.serv_conn.send(buf) def testRecvIntoBytearray(self): @@ -1564,15 +1921,14 @@ class BufferIOTest(SocketConnectedTest): _testRecvIntoMemoryview = _testRecvIntoArray def testRecvFromIntoArray(self): - buf = array.array('c', ' '*1024) + buf = bytearray(1024) nbytes, addr = self.cli_conn.recvfrom_into(buf) self.assertEqual(nbytes, len(MSG)) - msg = buf.tostring()[:len(MSG)] + msg = buf[:len(MSG)] self.assertEqual(msg, MSG) def _testRecvFromIntoArray(self): - with test_support.check_py3k_warnings(): - buf = buffer(MSG) + buf = bytes(MSG) self.serv_conn.send(buf) def testRecvFromIntoBytearray(self): @@ -1612,14 +1968,16 @@ def isTipcAvailable(): for line in f: if line.startswith("tipc "): return True - if test_support.verbose: - print "TIPC module is not loaded, please 'sudo modprobe tipc'" + if support.verbose: + print("TIPC module is not loaded, please 'sudo modprobe tipc'") return False -class TIPCTest (unittest.TestCase): +class TIPCTest(unittest.TestCase): def testRDM(self): srv = socket.socket(socket.AF_TIPC, socket.SOCK_RDM) cli = socket.socket(socket.AF_TIPC, socket.SOCK_RDM) + self.addCleanup(srv.close) + self.addCleanup(cli.close) srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE, @@ -1627,7 +1985,7 @@ class TIPCTest (unittest.TestCase): srv.bind(srvaddr) sendaddr = (socket.TIPC_ADDR_NAME, TIPC_STYPE, - TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0) + TIPC_LOWER + int((TIPC_UPPER - TIPC_LOWER) / 2), 0) cli.sendto(MSG, sendaddr) msg, recvaddr = srv.recvfrom(1024) @@ -1636,13 +1994,14 @@ class TIPCTest (unittest.TestCase): self.assertEqual(msg, MSG) -class TIPCThreadableTest (unittest.TestCase, ThreadableTest): +class TIPCThreadableTest(unittest.TestCase, ThreadableTest): def __init__(self, methodName = 'runTest'): unittest.TestCase.__init__(self, methodName = methodName) ThreadableTest.__init__(self) def setUp(self): self.srv = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM) + self.addCleanup(self.srv.close) self.srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE, TIPC_LOWER, TIPC_UPPER) @@ -1650,6 +2009,7 @@ class TIPCThreadableTest (unittest.TestCase, ThreadableTest): self.srv.listen(5) self.serverExplicitReady() self.conn, self.connaddr = self.srv.accept() + self.addCleanup(self.conn.close) def clientSetUp(self): # The is a hittable race between serverExplicitReady() and the @@ -1657,8 +2017,9 @@ class TIPCThreadableTest (unittest.TestCase, ThreadableTest): # we could get an exception time.sleep(0.1) self.cli = socket.socket(socket.AF_TIPC, socket.SOCK_STREAM) + self.addCleanup(self.cli.close) addr = (socket.TIPC_ADDR_NAME, TIPC_STYPE, - TIPC_LOWER + (TIPC_UPPER - TIPC_LOWER) / 2, 0) + TIPC_LOWER + int((TIPC_UPPER - TIPC_LOWER) / 2), 0) self.cli.connect(addr) self.cliaddr = self.cli.getsockname() @@ -1672,10 +2033,117 @@ class TIPCThreadableTest (unittest.TestCase, ThreadableTest): self.cli.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class ContextManagersTest(ThreadedTCPSocketTest): + + def _testSocketClass(self): + # base test + with socket.socket() as sock: + self.assertFalse(sock._closed) + self.assertTrue(sock._closed) + # close inside with block + with socket.socket() as sock: + sock.close() + self.assertTrue(sock._closed) + # exception inside with block + with socket.socket() as sock: + self.assertRaises(socket.error, sock.sendall, b'foo') + self.assertTrue(sock._closed) + + def testCreateConnectionBase(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + data = conn.recv(1024) + conn.sendall(data) + + def _testCreateConnectionBase(self): + address = self.serv.getsockname() + with socket.create_connection(address) as sock: + self.assertFalse(sock._closed) + sock.sendall(b'foo') + self.assertEqual(sock.recv(1024), b'foo') + self.assertTrue(sock._closed) + + def testCreateConnectionClose(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + data = conn.recv(1024) + conn.sendall(data) + + def _testCreateConnectionClose(self): + address = self.serv.getsockname() + with socket.create_connection(address) as sock: + sock.close() + self.assertTrue(sock._closed) + self.assertRaises(socket.error, sock.sendall, b'foo') + + +@unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"), + "SOCK_CLOEXEC not defined") +@unittest.skipUnless(fcntl, "module fcntl not available") +class CloexecConstantTest(unittest.TestCase): + def test_SOCK_CLOEXEC(self): + v = linux_version() + if v < (2, 6, 28): + self.skipTest("Linux kernel 2.6.28 or higher required, not %s" + % ".".join(map(str, v))) + with socket.socket(socket.AF_INET, + socket.SOCK_STREAM | socket.SOCK_CLOEXEC) as s: + self.assertTrue(s.type & socket.SOCK_CLOEXEC) + self.assertTrue(fcntl.fcntl(s, fcntl.F_GETFD) & fcntl.FD_CLOEXEC) + + +@unittest.skipUnless(hasattr(socket, "SOCK_NONBLOCK"), + "SOCK_NONBLOCK not defined") +class NonblockConstantTest(unittest.TestCase): + def checkNonblock(self, s, nonblock=True, timeout=0.0): + if nonblock: + self.assertTrue(s.type & socket.SOCK_NONBLOCK) + self.assertEqual(s.gettimeout(), timeout) + else: + self.assertFalse(s.type & socket.SOCK_NONBLOCK) + self.assertEqual(s.gettimeout(), None) + + def test_SOCK_NONBLOCK(self): + v = linux_version() + if v < (2, 6, 28): + self.skipTest("Linux kernel 2.6.28 or higher required, not %s" + % ".".join(map(str, v))) + # a lot of it seems silly and redundant, but I wanted to test that + # changing back and forth worked ok + with socket.socket(socket.AF_INET, + socket.SOCK_STREAM | socket.SOCK_NONBLOCK) as s: + self.checkNonblock(s) + s.setblocking(1) + self.checkNonblock(s, False) + s.setblocking(0) + self.checkNonblock(s) + s.settimeout(None) + self.checkNonblock(s, False) + s.settimeout(2.0) + self.checkNonblock(s, timeout=2.0) + s.setblocking(1) + self.checkNonblock(s, False) + # defaulttimeout + t = socket.getdefaulttimeout() + socket.setdefaulttimeout(0.0) + with socket.socket() as s: + self.checkNonblock(s) + socket.setdefaulttimeout(None) + with socket.socket() as s: + self.checkNonblock(s, False) + socket.setdefaulttimeout(2.0) + with socket.socket() as s: + self.checkNonblock(s, timeout=2.0) + socket.setdefaulttimeout(None) + with socket.socket() as s: + self.checkNonblock(s, False) + socket.setdefaulttimeout(t) + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, - TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, - UDPTimeoutTest ] + TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] tests.extend([ NonBlockingTCPTests, @@ -1684,10 +2152,15 @@ def test_main(): UnbufferedFileObjectClassTestCase, LineBufferedFileObjectClassTestCase, SmallBufferedFileObjectClassTestCase, - Urllib2FileobjectTest, + UnicodeReadFileObjectClassTestCase, + UnicodeWriteFileObjectClassTestCase, + UnicodeReadWriteFileObjectClassTestCase, NetworkConnectionNoServer, NetworkConnectionAttributesTest, NetworkConnectionBehaviourTest, + ContextManagersTest, + CloexecConstantTest, + NonblockConstantTest ]) if hasattr(socket, "socketpair"): tests.append(BasicSocketPairTest) @@ -1697,9 +2170,9 @@ def test_main(): tests.append(TIPCTest) tests.append(TIPCThreadableTest) - thread_info = test_support.threading_setup() - test_support.run_unittest(*tests) - test_support.threading_cleanup(*thread_info) + thread_info = support.threading_setup() + support.run_unittest(*tests) + support.threading_cleanup(*thread_info) if __name__ == "__main__": test_main() |