diff options
Diffstat (limited to 'Lib/test/test_multiprocessing.py')
-rw-r--r-- | Lib/test/test_multiprocessing.py | 619 |
1 files changed, 572 insertions, 47 deletions
diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index b812e4840b8..65e7b0be6eb 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -8,6 +8,7 @@ import unittest import queue as pyqueue import time import io +import itertools import sys import os import gc @@ -83,6 +84,13 @@ HAVE_GETVALUE = not getattr(_multiprocessing, WIN32 = (sys.platform == "win32") +from multiprocessing.connection import wait + +def wait_for_handle(handle, timeout): + if timeout is not None and timeout < 0.0: + timeout = None + return wait([handle], timeout) + try: MAXFD = os.sysconf("SC_OPEN_MAX") except: @@ -196,6 +204,18 @@ class _TestProcess(BaseTestCase): self.assertEqual(current.ident, os.getpid()) self.assertEqual(current.exitcode, None) + def test_daemon_argument(self): + if self.TYPE == "threads": + return + + # By default uses the current process's daemon flag. + proc0 = self.Process(target=self._test) + self.assertEqual(proc0.daemon, self.current_process().daemon) + proc1 = self.Process(target=self._test, daemon=True) + self.assertTrue(proc1.daemon) + proc2 = self.Process(target=self._test, daemon=False) + self.assertFalse(proc2.daemon) + @classmethod def _test(cls, q, *args, **kwds): current = cls.current_process() @@ -261,9 +281,18 @@ class _TestProcess(BaseTestCase): self.assertIn(p, self.active_children()) self.assertEqual(p.exitcode, None) + join = TimingWrapper(p.join) + + self.assertEqual(join(0), None) + self.assertTimingAlmostEqual(join.elapsed, 0.0) + self.assertEqual(p.is_alive(), True) + + self.assertEqual(join(-1), None) + self.assertTimingAlmostEqual(join.elapsed, 0.0) + self.assertEqual(p.is_alive(), True) + p.terminate() - join = TimingWrapper(p.join) self.assertEqual(join(), None) self.assertTimingAlmostEqual(join.elapsed, 0.0) @@ -328,6 +357,26 @@ class _TestProcess(BaseTestCase): ] self.assertEqual(result, expected) + @classmethod + def _test_sentinel(cls, event): + event.wait(10.0) + + def test_sentinel(self): + if self.TYPE == "threads": + return + event = self.Event() + p = self.Process(target=self._test_sentinel, args=(event,)) + with self.assertRaises(ValueError): + p.sentinel + p.start() + self.addCleanup(p.join) + sentinel = p.sentinel + self.assertIsInstance(sentinel, int) + self.assertFalse(wait_for_handle(sentinel, timeout=0.0)) + event.set() + p.join() + self.assertTrue(wait_for_handle(sentinel, timeout=DELTA)) + # # # @@ -867,6 +916,104 @@ class _TestCondition(BaseTestCase): self.assertEqual(res, False) self.assertTimingAlmostEqual(wait.elapsed, TIMEOUT1) + @classmethod + def _test_waitfor_f(cls, cond, state): + with cond: + state.value = 0 + cond.notify() + result = cond.wait_for(lambda : state.value==4) + if not result or state.value != 4: + sys.exit(1) + + @unittest.skipUnless(HAS_SHAREDCTYPES, 'needs sharedctypes') + def test_waitfor(self): + # based on test in test/lock_tests.py + cond = self.Condition() + state = self.Value('i', -1) + + p = self.Process(target=self._test_waitfor_f, args=(cond, state)) + p.daemon = True + p.start() + + with cond: + result = cond.wait_for(lambda : state.value==0) + self.assertTrue(result) + self.assertEqual(state.value, 0) + + for i in range(4): + time.sleep(0.01) + with cond: + state.value += 1 + cond.notify() + + p.join(5) + self.assertFalse(p.is_alive()) + self.assertEqual(p.exitcode, 0) + + @classmethod + def _test_waitfor_timeout_f(cls, cond, state, success, sem): + sem.release() + with cond: + expected = 0.1 + dt = time.time() + result = cond.wait_for(lambda : state.value==4, timeout=expected) + dt = time.time() - dt + # borrow logic in assertTimeout() from test/lock_tests.py + if not result and expected * 0.6 < dt < expected * 10.0: + success.value = True + + @unittest.skipUnless(HAS_SHAREDCTYPES, 'needs sharedctypes') + def test_waitfor_timeout(self): + # based on test in test/lock_tests.py + cond = self.Condition() + state = self.Value('i', 0) + success = self.Value('i', False) + sem = self.Semaphore(0) + + p = self.Process(target=self._test_waitfor_timeout_f, + args=(cond, state, success, sem)) + p.daemon = True + p.start() + self.assertTrue(sem.acquire(timeout=10)) + + # Only increment 3 times, so state == 4 is never reached. + for i in range(3): + time.sleep(0.01) + with cond: + state.value += 1 + cond.notify() + + p.join(5) + self.assertTrue(success.value) + + @classmethod + def _test_wait_result(cls, c, pid): + with c: + c.notify() + time.sleep(1) + if pid is not None: + os.kill(pid, signal.SIGINT) + + def test_wait_result(self): + if isinstance(self, ProcessesMixin) and sys.platform != 'win32': + pid = os.getpid() + else: + pid = None + + c = self.Condition() + with c: + self.assertFalse(c.wait(0)) + self.assertFalse(c.wait(0.1)) + + p = self.Process(target=self._test_wait_result, args=(c, pid)) + p.start() + + self.assertTrue(c.wait(10)) + if pid is not None: + self.assertRaises(KeyboardInterrupt, c.wait, 10) + + p.join() + class _TestEvent(BaseTestCase): @@ -1129,6 +1276,9 @@ def sqr(x, wait=0.0): time.sleep(wait) return x*x +def mul(x, y): + return x*y + class _TestPool(BaseTestCase): def test_apply(self): @@ -1142,6 +1292,20 @@ class _TestPool(BaseTestCase): self.assertEqual(pmap(sqr, list(range(100)), chunksize=20), list(map(sqr, list(range(100))))) + def test_starmap(self): + psmap = self.pool.starmap + tuples = list(zip(range(10), range(9,-1, -1))) + self.assertEqual(psmap(mul, tuples), + list(itertools.starmap(mul, tuples))) + tuples = list(zip(range(100), range(99,-1, -1))) + self.assertEqual(psmap(mul, tuples, chunksize=20), + list(itertools.starmap(mul, tuples))) + + def test_starmap_async(self): + tuples = list(zip(range(100), range(99,-1, -1))) + self.assertEqual(self.pool.starmap_async(mul, tuples).get(), + list(itertools.starmap(mul, tuples))) + def test_map_chunksize(self): try: self.pool.map_async(sqr, [], chunksize=1).get(timeout=TIMEOUT1) @@ -1569,6 +1733,9 @@ class _TestConnection(BaseTestCase): self.assertEqual(poll(), False) self.assertTimingAlmostEqual(poll.elapsed, 0) + self.assertEqual(poll(-1), False) + self.assertTimingAlmostEqual(poll.elapsed, 0) + self.assertEqual(poll(TIMEOUT1), False) self.assertTimingAlmostEqual(poll.elapsed, TIMEOUT1) @@ -1754,6 +1921,17 @@ class _TestConnection(BaseTestCase): self.assertRaises(RuntimeError, reduction.recv_handle, conn) p.join() +class _TestListener(BaseTestCase): + + ALLOWED_TYPES = ('processes') + + def test_multiple_bind(self): + for family in self.connection.families: + l = self.connection.Listener(family=family) + self.addCleanup(l.close) + self.assertRaises(OSError, self.connection.Listener, + l.address, family) + class _TestListenerClient(BaseTestCase): ALLOWED_TYPES = ('processes', 'threads') @@ -1791,52 +1969,137 @@ class _TestListenerClient(BaseTestCase): p.join() l.close() +class _TestPoll(unittest.TestCase): + + ALLOWED_TYPES = ('processes', 'threads') + + def test_empty_string(self): + a, b = self.Pipe() + self.assertEqual(a.poll(), False) + b.send_bytes(b'') + self.assertEqual(a.poll(), True) + self.assertEqual(a.poll(), True) + + @classmethod + def _child_strings(cls, conn, strings): + for s in strings: + time.sleep(0.1) + conn.send_bytes(s) + conn.close() + + def test_strings(self): + strings = (b'hello', b'', b'a', b'b', b'', b'bye', b'', b'lop') + a, b = self.Pipe() + p = self.Process(target=self._child_strings, args=(b, strings)) + p.start() + + for s in strings: + for i in range(200): + if a.poll(0.01): + break + x = a.recv_bytes() + self.assertEqual(s, x) + + p.join() + + @classmethod + def _child_boundaries(cls, r): + # Polling may "pull" a message in to the child process, but we + # don't want it to pull only part of a message, as that would + # corrupt the pipe for any other processes which might later + # read from it. + r.poll(5) + + def test_boundaries(self): + r, w = self.Pipe(False) + p = self.Process(target=self._child_boundaries, args=(r,)) + p.start() + time.sleep(2) + L = [b"first", b"second"] + for obj in L: + w.send_bytes(obj) + w.close() + p.join() + self.assertIn(r.recv_bytes(), L) + + @classmethod + def _child_dont_merge(cls, b): + b.send_bytes(b'a') + b.send_bytes(b'b') + b.send_bytes(b'cd') + + def test_dont_merge(self): + a, b = self.Pipe() + self.assertEqual(a.poll(0.0), False) + self.assertEqual(a.poll(0.1), False) + + p = self.Process(target=self._child_dont_merge, args=(b,)) + p.start() + + self.assertEqual(a.recv_bytes(), b'a') + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.recv_bytes(), b'b') + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(0.0), True) + self.assertEqual(a.recv_bytes(), b'cd') + + p.join() + # # Test of sending connection and socket objects between processes # -""" + +# Intermittent fails on Mac OS X -- see Issue14669 and Issue12958 +@unittest.skipIf(sys.platform == "darwin", "fd passing unreliable on Mac OS X") +@unittest.skipUnless(HAS_REDUCTION, "test needs multiprocessing.reduction") class _TestPicklingConnections(BaseTestCase): ALLOWED_TYPES = ('processes',) - def _listener(self, conn, families): + @classmethod + def tearDownClass(cls): + from multiprocessing.reduction import resource_sharer + resource_sharer.stop(timeout=5) + + @classmethod + def _listener(cls, conn, families): for fam in families: - l = self.connection.Listener(family=fam) + l = cls.connection.Listener(family=fam) conn.send(l.address) new_conn = l.accept() conn.send(new_conn) + new_conn.close() + l.close() - if self.TYPE == 'processes': - l = socket.socket() - l.bind(('localhost', 0)) - conn.send(l.getsockname()) - l.listen(1) - new_conn, addr = l.accept() - conn.send(new_conn) + l = socket.socket() + l.bind(('localhost', 0)) + l.listen(1) + conn.send(l.getsockname()) + new_conn, addr = l.accept() + conn.send(new_conn) + new_conn.close() + l.close() conn.recv() - def _remote(self, conn): + @classmethod + def _remote(cls, conn): for (address, msg) in iter(conn.recv, None): - client = self.connection.Client(address) + client = cls.connection.Client(address) client.send(msg.upper()) client.close() - if self.TYPE == 'processes': - address, msg = conn.recv() - client = socket.socket() - client.connect(address) - client.sendall(msg.upper()) - client.close() + address, msg = conn.recv() + client = socket.socket() + client.connect(address) + client.sendall(msg.upper()) + client.close() conn.close() def test_pickling(self): - try: - multiprocessing.allow_connection_pickling() - except ImportError: - return - families = self.connection.families lconn, lconn0 = self.Pipe() @@ -1860,16 +2123,19 @@ class _TestPicklingConnections(BaseTestCase): rconn.send(None) - if self.TYPE == 'processes': - msg = latin('This connection uses a normal socket') - address = lconn.recv() - rconn.send((address, msg)) - if hasattr(socket, 'fromfd'): - new_conn = lconn.recv() - self.assertEqual(new_conn.recv(100), msg.upper()) - else: - # XXX On Windows with Py2.6 need to backport fromfd() - discard = lconn.recv_bytes() + msg = latin('This connection uses a normal socket') + address = lconn.recv() + rconn.send((address, msg)) + new_conn = lconn.recv() + buf = [] + while True: + s = new_conn.recv(100) + if not s: + break + buf.append(s) + buf = b''.join(buf) + self.assertEqual(buf, msg.upper()) + new_conn.close() lconn.send(None) @@ -1878,7 +2144,46 @@ class _TestPicklingConnections(BaseTestCase): lp.join() rp.join() -""" + + @classmethod + def child_access(cls, conn): + w = conn.recv() + w.send('all is well') + w.close() + + r = conn.recv() + msg = r.recv() + conn.send(msg*2) + + conn.close() + + def test_access(self): + # On Windows, if we do not specify a destination pid when + # using DupHandle then we need to be careful to use the + # correct access flags for DuplicateHandle(), or else + # DupHandle.detach() will raise PermissionError. For example, + # for a read only pipe handle we should use + # access=FILE_GENERIC_READ. (Unfortunately + # DUPLICATE_SAME_ACCESS does not work.) + conn, child_conn = self.Pipe() + p = self.Process(target=self.child_access, args=(child_conn,)) + p.daemon = True + p.start() + child_conn.close() + + r, w = self.Pipe(duplex=False) + conn.send(w) + w.close() + self.assertEqual(r.recv(), 'all is well') + r.close() + + r, w = self.Pipe(duplex=False) + conn.send(r) + r.close() + w.send('foobar') + w.close() + self.assertEqual(conn.recv(), 'foobar'*2) + # # # @@ -2173,9 +2478,15 @@ class TestInvalidHandle(unittest.TestCase): @unittest.skipIf(WIN32, "skipped on Windows") def test_invalid_handles(self): - conn = _multiprocessing.Connection(44977608) - self.assertRaises(IOError, conn.poll) - self.assertRaises(IOError, _multiprocessing.Connection, -1) + conn = multiprocessing.connection.Connection(44977608) + try: + self.assertRaises((ValueError, IOError), conn.poll) + finally: + # Hack private attribute _handle to avoid printing an error + # in conn.__del__ + conn._handle = None + self.assertRaises((ValueError, IOError), + multiprocessing.connection.Connection, -1) # # Functions used to create test cases from the base ones in this module @@ -2296,6 +2607,7 @@ class TestInitializers(unittest.TestCase): def tearDown(self): self.mgr.shutdown() + self.mgr.join() def test_manager_initializer(self): m = multiprocessing.managers.SyncManager() @@ -2303,6 +2615,7 @@ class TestInitializers(unittest.TestCase): m.start(initializer, (self.ns,)) self.assertEqual(self.ns.test, 1) m.shutdown() + m.join() def test_pool_initializer(self): self.assertRaises(TypeError, multiprocessing.Pool, initializer=1) @@ -2335,6 +2648,8 @@ def _afunc(x): def pool_in_process(): pool = multiprocessing.Pool(processes=4) x = pool.map(_afunc, [1, 2, 3, 4, 5, 6, 7]) + pool.close() + pool.join() class _file_like(object): def __init__(self, delegate): @@ -2379,6 +2694,180 @@ class TestStdinBadfiledescriptor(unittest.TestCase): assert sio.getvalue() == 'foo' +class TestWait(unittest.TestCase): + + @classmethod + def _child_test_wait(cls, w, slow): + for i in range(10): + if slow: + time.sleep(random.random()*0.1) + w.send((i, os.getpid())) + w.close() + + def test_wait(self, slow=False): + from multiprocessing.connection import wait + readers = [] + procs = [] + messages = [] + + for i in range(4): + r, w = multiprocessing.Pipe(duplex=False) + p = multiprocessing.Process(target=self._child_test_wait, args=(w, slow)) + p.daemon = True + p.start() + w.close() + readers.append(r) + procs.append(p) + self.addCleanup(p.join) + + while readers: + for r in wait(readers): + try: + msg = r.recv() + except EOFError: + readers.remove(r) + r.close() + else: + messages.append(msg) + + messages.sort() + expected = sorted((i, p.pid) for i in range(10) for p in procs) + self.assertEqual(messages, expected) + + @classmethod + def _child_test_wait_socket(cls, address, slow): + s = socket.socket() + s.connect(address) + for i in range(10): + if slow: + time.sleep(random.random()*0.1) + s.sendall(('%s\n' % i).encode('ascii')) + s.close() + + def test_wait_socket(self, slow=False): + from multiprocessing.connection import wait + l = socket.socket() + l.bind(('', 0)) + l.listen(4) + addr = ('localhost', l.getsockname()[1]) + readers = [] + procs = [] + dic = {} + + for i in range(4): + p = multiprocessing.Process(target=self._child_test_wait_socket, + args=(addr, slow)) + p.daemon = True + p.start() + procs.append(p) + self.addCleanup(p.join) + + for i in range(4): + r, _ = l.accept() + readers.append(r) + dic[r] = [] + l.close() + + while readers: + for r in wait(readers): + msg = r.recv(32) + if not msg: + readers.remove(r) + r.close() + else: + dic[r].append(msg) + + expected = ''.join('%s\n' % i for i in range(10)).encode('ascii') + for v in dic.values(): + self.assertEqual(b''.join(v), expected) + + def test_wait_slow(self): + self.test_wait(True) + + def test_wait_socket_slow(self): + self.test_wait_socket(True) + + def test_wait_timeout(self): + from multiprocessing.connection import wait + + expected = 5 + a, b = multiprocessing.Pipe() + + start = time.time() + res = wait([a, b], expected) + delta = time.time() - start + + self.assertEqual(res, []) + self.assertLess(delta, expected * 2) + self.assertGreater(delta, expected * 0.5) + + b.send(None) + + start = time.time() + res = wait([a, b], 20) + delta = time.time() - start + + self.assertEqual(res, [a]) + self.assertLess(delta, 0.4) + + @classmethod + def signal_and_sleep(cls, sem, period): + sem.release() + time.sleep(period) + + def test_wait_integer(self): + from multiprocessing.connection import wait + + expected = 3 + sem = multiprocessing.Semaphore(0) + a, b = multiprocessing.Pipe() + p = multiprocessing.Process(target=self.signal_and_sleep, + args=(sem, expected)) + + p.start() + self.assertIsInstance(p.sentinel, int) + self.assertTrue(sem.acquire(timeout=20)) + + start = time.time() + res = wait([a, p.sentinel, b], expected + 20) + delta = time.time() - start + + self.assertEqual(res, [p.sentinel]) + self.assertLess(delta, expected + 2) + self.assertGreater(delta, expected - 2) + + a.send(None) + + start = time.time() + res = wait([a, p.sentinel, b], 20) + delta = time.time() - start + + self.assertEqual(res, [p.sentinel, b]) + self.assertLess(delta, 0.4) + + b.send(None) + + start = time.time() + res = wait([a, p.sentinel, b], 20) + delta = time.time() - start + + self.assertEqual(res, [a, p.sentinel, b]) + self.assertLess(delta, 0.4) + + p.terminate() + p.join() + + def test_neg_timeout(self): + from multiprocessing.connection import wait + a, b = multiprocessing.Pipe() + t = time.time() + res = wait([a], timeout=-1) + t = time.time() - t + self.assertEqual(res, []) + self.assertLess(t, 1) + a.close() + b.close() + # # Issue 14151: Test invalid family on invalid environment # @@ -2395,9 +2884,41 @@ class TestInvalidFamily(unittest.TestCase): with self.assertRaises(ValueError): multiprocessing.connection.Listener('/var/test.pipe') +# +# Issue 12098: check sys.flags of child matches that for parent +# + +class TestFlags(unittest.TestCase): + @classmethod + def run_in_grandchild(cls, conn): + conn.send(tuple(sys.flags)) + + @classmethod + def run_in_child(cls): + import json + r, w = multiprocessing.Pipe(duplex=False) + p = multiprocessing.Process(target=cls.run_in_grandchild, args=(w,)) + p.start() + grandchild_flags = r.recv() + p.join() + r.close() + w.close() + flags = (tuple(sys.flags), grandchild_flags) + print(json.dumps(flags)) + + def test_flags(self): + import json, subprocess + # start child process using unusual flags + prog = ('from test.test_multiprocessing import TestFlags; ' + + 'TestFlags.run_in_child()') + data = subprocess.check_output( + [sys.executable, '-E', '-S', '-O', '-c', prog]) + child_flags, grandchild_flags = json.loads(data.decode('ascii')) + self.assertEqual(child_flags, grandchild_flags) testcases_other = [OtherTest, TestInvalidHandle, TestInitializers, - TestStdinBadfiledescriptor, TestInvalidFamily] + TestStdinBadfiledescriptor, TestWait, TestInvalidFamily, + TestFlags] # # @@ -2434,14 +2955,18 @@ def test_main(run=None): loadTestsFromTestCase = unittest.defaultTestLoader.loadTestsFromTestCase suite = unittest.TestSuite(loadTestsFromTestCase(tc) for tc in testcases) - run(suite) - - ThreadsMixin.pool.terminate() - ProcessesMixin.pool.terminate() - ManagerMixin.pool.terminate() - ManagerMixin.manager.shutdown() - - del ProcessesMixin.pool, ThreadsMixin.pool, ManagerMixin.pool + try: + run(suite) + finally: + ThreadsMixin.pool.terminate() + ProcessesMixin.pool.terminate() + ManagerMixin.pool.terminate() + ManagerMixin.pool.join() + ManagerMixin.manager.shutdown() + ManagerMixin.manager.join() + ThreadsMixin.pool.join() + ProcessesMixin.pool.join() + del ProcessesMixin.pool, ThreadsMixin.pool, ManagerMixin.pool def main(): test_main(unittest.TextTestRunner(verbosity=2).run) |