diff options
Diffstat (limited to 'Lib/asyncio')
-rw-r--r-- | Lib/asyncio/base_events.py | 119 | ||||
-rw-r--r-- | Lib/asyncio/proactor_events.py | 3 | ||||
-rw-r--r-- | Lib/asyncio/selector_events.py | 3 | ||||
-rw-r--r-- | Lib/asyncio/test_utils.py | 9 |
4 files changed, 89 insertions, 45 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index c5ffad40807..4505732f9ac 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -16,8 +16,10 @@ to modify the meaning of the API call itself. import collections import concurrent.futures +import functools import heapq import inspect +import ipaddress import itertools import logging import os @@ -70,49 +72,83 @@ def _format_pipe(fd): return repr(fd) +# Linux's sock.type is a bitmask that can include extra info about socket. +_SOCKET_TYPE_MASK = 0 +if hasattr(socket, 'SOCK_NONBLOCK'): + _SOCKET_TYPE_MASK |= socket.SOCK_NONBLOCK +if hasattr(socket, 'SOCK_CLOEXEC'): + _SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC + + +@functools.lru_cache(maxsize=1024) +def _ipaddr_info(host, port, family, type, proto): + # Try to skip getaddrinfo if "host" is already an IP. Since getaddrinfo + # blocks on an exclusive lock on some platforms, users might handle name + # resolution in their own code and pass in resolved IPs. + if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or host is None: + return None + + type &= ~_SOCKET_TYPE_MASK + if type == socket.SOCK_STREAM: + proto = socket.IPPROTO_TCP + elif type == socket.SOCK_DGRAM: + proto = socket.IPPROTO_UDP + else: + return None + + if hasattr(socket, 'inet_pton'): + if family == socket.AF_UNSPEC: + afs = [socket.AF_INET, socket.AF_INET6] + else: + afs = [family] + + for af in afs: + # Linux's inet_pton doesn't accept an IPv6 zone index after host, + # like '::1%lo0', so strip it. If we happen to make an invalid + # address look valid, we fail later in sock.connect or sock.bind. + try: + if af == socket.AF_INET6: + socket.inet_pton(af, host.partition('%')[0]) + else: + socket.inet_pton(af, host) + return af, type, proto, '', (host, port) + except OSError: + pass + + # "host" is not an IP address. + return None + + # No inet_pton. (On Windows it's only available since Python 3.4.) + # Even though getaddrinfo with AI_NUMERICHOST would be non-blocking, it + # still requires a lock on some platforms, and waiting for that lock could + # block the event loop. Use ipaddress instead, it's just text parsing. + try: + addr = ipaddress.IPv4Address(host) + except ValueError: + try: + addr = ipaddress.IPv6Address(host.partition('%')[0]) + except ValueError: + return None + + af = socket.AF_INET if addr.version == 4 else socket.AF_INET6 + if family not in (socket.AF_UNSPEC, af): + # "host" is wrong IP version for "family". + return None + + return af, type, proto, '', (host, port) + + def _check_resolved_address(sock, address): # Ensure that the address is already resolved to avoid the trap of hanging # the entire event loop when the address requires doing a DNS lookup. - # - # getaddrinfo() is slow (around 10 us per call): this function should only - # be called in debug mode - family = sock.family - - if family == socket.AF_INET: - host, port = address - elif family == socket.AF_INET6: - host, port = address[:2] - else: + + if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX: return - # On Windows, socket.inet_pton() is only available since Python 3.4 - if hasattr(socket, 'inet_pton'): - # getaddrinfo() is slow and has known issue: prefer inet_pton() - # if available - try: - socket.inet_pton(family, host) - except OSError as exc: - raise ValueError("address must be resolved (IP address), " - "got host %r: %s" - % (host, exc)) - else: - # Use getaddrinfo(flags=AI_NUMERICHOST) to ensure that the address is - # already resolved. - type_mask = 0 - if hasattr(socket, 'SOCK_NONBLOCK'): - type_mask |= socket.SOCK_NONBLOCK - if hasattr(socket, 'SOCK_CLOEXEC'): - type_mask |= socket.SOCK_CLOEXEC - try: - socket.getaddrinfo(host, port, - family=family, - type=(sock.type & ~type_mask), - proto=sock.proto, - flags=socket.AI_NUMERICHOST) - except socket.gaierror as err: - raise ValueError("address must be resolved (IP address), " - "got host %r: %s" - % (host, err)) + host, port = address[:2] + if _ipaddr_info(host, port, sock.family, sock.type, sock.proto) is None: + raise ValueError("address must be resolved (IP address)," + " got host %r" % host) def _run_until_complete_cb(fut): @@ -535,7 +571,12 @@ class BaseEventLoop(events.AbstractEventLoop): def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): - if self._debug: + info = _ipaddr_info(host, port, family, type, proto) + if info is not None: + fut = futures.Future(loop=self) + fut.set_result([info]) + return fut + elif self._debug: return self.run_in_executor(None, self._getaddrinfo_debug, host, port, family, type, proto, flags) else: diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 7eac41eec02..14c0659ddee 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -441,8 +441,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): def sock_connect(self, sock, address): try: - if self._debug: - base_events._check_resolved_address(sock, address) + base_events._check_resolved_address(sock, address) except ValueError as err: fut = futures.Future(loop=self) fut.set_exception(err) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index a05f81cd9de..5b26631d80d 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -397,8 +397,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) try: - if self._debug: - base_events._check_resolved_address(sock, address) + base_events._check_resolved_address(sock, address) except ValueError as err: fut.set_exception(err) else: diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py index 8170533188c..396e6aed567 100644 --- a/Lib/asyncio/test_utils.py +++ b/Lib/asyncio/test_utils.py @@ -446,9 +446,14 @@ def disable_logger(): finally: logger.setLevel(old_level) -def mock_nonblocking_socket(): + +def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, + family=socket.AF_INET): """Create a mock of a non-blocking socket.""" - sock = mock.Mock(socket.socket) + sock = mock.MagicMock(socket.socket) + sock.proto = proto + sock.type = type + sock.family = family sock.gettimeout.return_value = 0.0 return sock |