aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/multiprocessing/forkserver.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/multiprocessing/forkserver.py')
-rw-r--r--Lib/multiprocessing/forkserver.py72
1 files changed, 64 insertions, 8 deletions
diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py
index bff7fb91d97..df9b9be9d18 100644
--- a/Lib/multiprocessing/forkserver.py
+++ b/Lib/multiprocessing/forkserver.py
@@ -9,6 +9,7 @@ import sys
import threading
import warnings
+from . import AuthenticationError
from . import connection
from . import process
from .context import reduction
@@ -25,6 +26,7 @@ __all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
MAXFDS_TO_SEND = 256
SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
+_AUTHKEY_LEN = 32 # <= PIPEBUF so it fits a single write to an empty pipe.
#
# Forkserver class
@@ -33,6 +35,7 @@ SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
class ForkServer(object):
def __init__(self):
+ self._forkserver_authkey = None
self._forkserver_address = None
self._forkserver_alive_fd = None
self._forkserver_pid = None
@@ -59,6 +62,7 @@ class ForkServer(object):
if not util.is_abstract_socket_namespace(self._forkserver_address):
os.unlink(self._forkserver_address)
self._forkserver_address = None
+ self._forkserver_authkey = None
def set_forkserver_preload(self, modules_names):
'''Set list of module names to try to load in forkserver process.'''
@@ -83,6 +87,7 @@ class ForkServer(object):
process data.
'''
self.ensure_running()
+ assert self._forkserver_authkey
if len(fds) + 4 >= MAXFDS_TO_SEND:
raise ValueError('too many fds')
with socket.socket(socket.AF_UNIX) as client:
@@ -93,6 +98,18 @@ class ForkServer(object):
resource_tracker.getfd()]
allfds += fds
try:
+ client.setblocking(True)
+ wrapped_client = connection.Connection(client.fileno())
+ # The other side of this exchange happens in the child as
+ # implemented in main().
+ try:
+ connection.answer_challenge(
+ wrapped_client, self._forkserver_authkey)
+ connection.deliver_challenge(
+ wrapped_client, self._forkserver_authkey)
+ finally:
+ wrapped_client._detach()
+ del wrapped_client
reduction.sendfds(client, allfds)
return parent_r, parent_w
except:
@@ -120,6 +137,7 @@ class ForkServer(object):
return
# dead, launch it again
os.close(self._forkserver_alive_fd)
+ self._forkserver_authkey = None
self._forkserver_address = None
self._forkserver_alive_fd = None
self._forkserver_pid = None
@@ -130,9 +148,9 @@ class ForkServer(object):
if self._preload_modules:
desired_keys = {'main_path', 'sys_path'}
data = spawn.get_preparation_data('ignore')
- data = {x: y for x, y in data.items() if x in desired_keys}
+ main_kws = {x: y for x, y in data.items() if x in desired_keys}
else:
- data = {}
+ main_kws = {}
with socket.socket(socket.AF_UNIX) as listener:
address = connection.arbitrary_address('AF_UNIX')
@@ -144,19 +162,31 @@ class ForkServer(object):
# all client processes own the write end of the "alive" pipe;
# when they all terminate the read end becomes ready.
alive_r, alive_w = os.pipe()
+ # A short lived pipe to initialize the forkserver authkey.
+ authkey_r, authkey_w = os.pipe()
try:
- fds_to_pass = [listener.fileno(), alive_r]
+ fds_to_pass = [listener.fileno(), alive_r, authkey_r]
+ main_kws['authkey_r'] = authkey_r
cmd %= (listener.fileno(), alive_r, self._preload_modules,
- data)
+ main_kws)
exe = spawn.get_executable()
args = [exe] + util._args_from_interpreter_flags()
args += ['-c', cmd]
pid = util.spawnv_passfds(exe, args, fds_to_pass)
except:
os.close(alive_w)
+ os.close(authkey_w)
raise
finally:
os.close(alive_r)
+ os.close(authkey_r)
+ # Authenticate our control socket to prevent access from
+ # processes we have not shared this key with.
+ try:
+ self._forkserver_authkey = os.urandom(_AUTHKEY_LEN)
+ os.write(authkey_w, self._forkserver_authkey)
+ finally:
+ os.close(authkey_w)
self._forkserver_address = address
self._forkserver_alive_fd = alive_w
self._forkserver_pid = pid
@@ -165,8 +195,18 @@ class ForkServer(object):
#
#
-def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
- '''Run forkserver.'''
+def main(listener_fd, alive_r, preload, main_path=None, sys_path=None,
+ *, authkey_r=None):
+ """Run forkserver."""
+ if authkey_r is not None:
+ try:
+ authkey = os.read(authkey_r, _AUTHKEY_LEN)
+ assert len(authkey) == _AUTHKEY_LEN, f'{len(authkey)} < {_AUTHKEY_LEN}'
+ finally:
+ os.close(authkey_r)
+ else:
+ authkey = b''
+
if preload:
if sys_path is not None:
sys.path[:] = sys_path
@@ -257,8 +297,24 @@ def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
if listener in rfds:
# Incoming fork request
with listener.accept()[0] as s:
- # Receive fds from client
- fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
+ try:
+ if authkey:
+ wrapped_s = connection.Connection(s.fileno())
+ # The other side of this exchange happens in
+ # in connect_to_new_process().
+ try:
+ connection.deliver_challenge(
+ wrapped_s, authkey)
+ connection.answer_challenge(
+ wrapped_s, authkey)
+ finally:
+ wrapped_s._detach()
+ del wrapped_s
+ # Receive fds from client
+ fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
+ except (EOFError, BrokenPipeError, AuthenticationError):
+ s.close()
+ continue
if len(fds) > MAXFDS_TO_SEND:
raise RuntimeError(
"Too many ({0:n}) fds to send".format(