diff options
Diffstat (limited to 'Lib/multiprocessing/forkserver.py')
-rw-r--r-- | Lib/multiprocessing/forkserver.py | 72 |
1 files changed, 64 insertions, 8 deletions
diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py index bff7fb91d97..df9b9be9d18 100644 --- a/Lib/multiprocessing/forkserver.py +++ b/Lib/multiprocessing/forkserver.py @@ -9,6 +9,7 @@ import sys import threading import warnings +from . import AuthenticationError from . import connection from . import process from .context import reduction @@ -25,6 +26,7 @@ __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process', MAXFDS_TO_SEND = 256 SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t +_AUTHKEY_LEN = 32 # <= PIPEBUF so it fits a single write to an empty pipe. # # Forkserver class @@ -33,6 +35,7 @@ SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t class ForkServer(object): def __init__(self): + self._forkserver_authkey = None self._forkserver_address = None self._forkserver_alive_fd = None self._forkserver_pid = None @@ -59,6 +62,7 @@ class ForkServer(object): if not util.is_abstract_socket_namespace(self._forkserver_address): os.unlink(self._forkserver_address) self._forkserver_address = None + self._forkserver_authkey = None def set_forkserver_preload(self, modules_names): '''Set list of module names to try to load in forkserver process.''' @@ -83,6 +87,7 @@ class ForkServer(object): process data. ''' self.ensure_running() + assert self._forkserver_authkey if len(fds) + 4 >= MAXFDS_TO_SEND: raise ValueError('too many fds') with socket.socket(socket.AF_UNIX) as client: @@ -93,6 +98,18 @@ class ForkServer(object): resource_tracker.getfd()] allfds += fds try: + client.setblocking(True) + wrapped_client = connection.Connection(client.fileno()) + # The other side of this exchange happens in the child as + # implemented in main(). + try: + connection.answer_challenge( + wrapped_client, self._forkserver_authkey) + connection.deliver_challenge( + wrapped_client, self._forkserver_authkey) + finally: + wrapped_client._detach() + del wrapped_client reduction.sendfds(client, allfds) return parent_r, parent_w except: @@ -120,6 +137,7 @@ class ForkServer(object): return # dead, launch it again os.close(self._forkserver_alive_fd) + self._forkserver_authkey = None self._forkserver_address = None self._forkserver_alive_fd = None self._forkserver_pid = None @@ -130,9 +148,9 @@ class ForkServer(object): if self._preload_modules: desired_keys = {'main_path', 'sys_path'} data = spawn.get_preparation_data('ignore') - data = {x: y for x, y in data.items() if x in desired_keys} + main_kws = {x: y for x, y in data.items() if x in desired_keys} else: - data = {} + main_kws = {} with socket.socket(socket.AF_UNIX) as listener: address = connection.arbitrary_address('AF_UNIX') @@ -144,19 +162,31 @@ class ForkServer(object): # all client processes own the write end of the "alive" pipe; # when they all terminate the read end becomes ready. alive_r, alive_w = os.pipe() + # A short lived pipe to initialize the forkserver authkey. + authkey_r, authkey_w = os.pipe() try: - fds_to_pass = [listener.fileno(), alive_r] + fds_to_pass = [listener.fileno(), alive_r, authkey_r] + main_kws['authkey_r'] = authkey_r cmd %= (listener.fileno(), alive_r, self._preload_modules, - data) + main_kws) exe = spawn.get_executable() args = [exe] + util._args_from_interpreter_flags() args += ['-c', cmd] pid = util.spawnv_passfds(exe, args, fds_to_pass) except: os.close(alive_w) + os.close(authkey_w) raise finally: os.close(alive_r) + os.close(authkey_r) + # Authenticate our control socket to prevent access from + # processes we have not shared this key with. + try: + self._forkserver_authkey = os.urandom(_AUTHKEY_LEN) + os.write(authkey_w, self._forkserver_authkey) + finally: + os.close(authkey_w) self._forkserver_address = address self._forkserver_alive_fd = alive_w self._forkserver_pid = pid @@ -165,8 +195,18 @@ class ForkServer(object): # # -def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): - '''Run forkserver.''' +def main(listener_fd, alive_r, preload, main_path=None, sys_path=None, + *, authkey_r=None): + """Run forkserver.""" + if authkey_r is not None: + try: + authkey = os.read(authkey_r, _AUTHKEY_LEN) + assert len(authkey) == _AUTHKEY_LEN, f'{len(authkey)} < {_AUTHKEY_LEN}' + finally: + os.close(authkey_r) + else: + authkey = b'' + if preload: if sys_path is not None: sys.path[:] = sys_path @@ -257,8 +297,24 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None): if listener in rfds: # Incoming fork request with listener.accept()[0] as s: - # Receive fds from client - fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) + try: + if authkey: + wrapped_s = connection.Connection(s.fileno()) + # The other side of this exchange happens in + # in connect_to_new_process(). + try: + connection.deliver_challenge( + wrapped_s, authkey) + connection.answer_challenge( + wrapped_s, authkey) + finally: + wrapped_s._detach() + del wrapped_s + # Receive fds from client + fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1) + except (EOFError, BrokenPipeError, AuthenticationError): + s.close() + continue if len(fds) > MAXFDS_TO_SEND: raise RuntimeError( "Too many ({0:n}) fds to send".format( |