diff options
author | Duprat <yduprat@gmail.com> | 2022-03-25 23:01:21 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-26 00:01:21 +0200 |
commit | d03acd7270d66ddb8e987f9743405147ecc15087 (patch) | |
tree | cffe25f0c26d55aef28c910dcf825747da99a6d4 /Lib/asyncio/locks.py | |
parent | 20e6e5636a06fe5e1472062918d0a302d82a71c3 (diff) | |
download | cpython-d03acd7270d66ddb8e987f9743405147ecc15087.tar.gz cpython-d03acd7270d66ddb8e987f9743405147ecc15087.zip |
bpo-43352: Add a Barrier object in asyncio lib (GH-24903)
Co-authored-by: Yury Selivanov <yury@edgedb.com>
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
Diffstat (limited to 'Lib/asyncio/locks.py')
-rw-r--r-- | Lib/asyncio/locks.py | 157 |
1 files changed, 155 insertions, 2 deletions
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py index 9b4612197de..e71130274dd 100644 --- a/Lib/asyncio/locks.py +++ b/Lib/asyncio/locks.py @@ -1,14 +1,15 @@ """Synchronization primitives.""" -__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore') +__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', + 'BoundedSemaphore', 'Barrier') import collections +import enum from . import exceptions from . import mixins from . import tasks - class _ContextManagerMixin: async def __aenter__(self): await self.acquire() @@ -416,3 +417,155 @@ class BoundedSemaphore(Semaphore): if self._value >= self._bound_value: raise ValueError('BoundedSemaphore released too many times') super().release() + + + +class _BarrierState(enum.Enum): + FILLING = 'filling' + DRAINING = 'draining' + RESETTING = 'resetting' + BROKEN = 'broken' + + +class Barrier(mixins._LoopBoundMixin): + """Asyncio equivalent to threading.Barrier + + Implements a Barrier primitive. + Useful for synchronizing a fixed number of tasks at known synchronization + points. Tasks block on 'wait()' and are simultaneously awoken once they + have all made their call. + """ + + def __init__(self, parties): + """Create a barrier, initialised to 'parties' tasks.""" + if parties < 1: + raise ValueError('parties must be > 0') + + self._cond = Condition() # notify all tasks when state changes + + self._parties = parties + self._state = _BarrierState.FILLING + self._count = 0 # count tasks in Barrier + + def __repr__(self): + res = super().__repr__() + extra = f'{self._state.value}' + if not self.broken: + extra += f', waiters:{self.n_waiting}/{self.parties}' + return f'<{res[1:-1]} [{extra}]>' + + async def __aenter__(self): + # wait for the barrier reaches the parties number + # when start draining release and return index of waited task + return await self.wait() + + async def __aexit__(self, *args): + pass + + async def wait(self): + """Wait for the barrier. + + When the specified number of tasks have started waiting, they are all + simultaneously awoken. + Returns an unique and individual index number from 0 to 'parties-1'. + """ + async with self._cond: + await self._block() # Block while the barrier drains or resets. + try: + index = self._count + self._count += 1 + if index + 1 == self._parties: + # We release the barrier + await self._release() + else: + await self._wait() + return index + finally: + self._count -= 1 + # Wake up any tasks waiting for barrier to drain. + self._exit() + + async def _block(self): + # Block until the barrier is ready for us, + # or raise an exception if it is broken. + # + # It is draining or resetting, wait until done + # unless a CancelledError occurs + await self._cond.wait_for( + lambda: self._state not in ( + _BarrierState.DRAINING, _BarrierState.RESETTING + ) + ) + + # see if the barrier is in a broken state + if self._state is _BarrierState.BROKEN: + raise exceptions.BrokenBarrierError("Barrier aborted") + + async def _release(self): + # Release the tasks waiting in the barrier. + + # Enter draining state. + # Next waiting tasks will be blocked until the end of draining. + self._state = _BarrierState.DRAINING + self._cond.notify_all() + + async def _wait(self): + # Wait in the barrier until we are released. Raise an exception + # if the barrier is reset or broken. + + # wait for end of filling + # unless a CancelledError occurs + await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING) + + if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING): + raise exceptions.BrokenBarrierError("Abort or reset of barrier") + + def _exit(self): + # If we are the last tasks to exit the barrier, signal any tasks + # waiting for the barrier to drain. + if self._count == 0: + if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING): + self._state = _BarrierState.FILLING + self._cond.notify_all() + + async def reset(self): + """Reset the barrier to the initial state. + + Any tasks currently waiting will get the BrokenBarrier exception + raised. + """ + async with self._cond: + if self._count > 0: + if self._state is not _BarrierState.RESETTING: + #reset the barrier, waking up tasks + self._state = _BarrierState.RESETTING + else: + self._state = _BarrierState.FILLING + self._cond.notify_all() + + async def abort(self): + """Place the barrier into a 'broken' state. + + Useful in case of error. Any currently waiting tasks and tasks + attempting to 'wait()' will have BrokenBarrierError raised. + """ + async with self._cond: + self._state = _BarrierState.BROKEN + self._cond.notify_all() + + @property + def parties(self): + """Return the number of tasks required to trip the barrier.""" + return self._parties + + @property + def n_waiting(self): + """Return the number of tasks currently waiting at the barrier.""" + if self._state is _BarrierState.FILLING: + return self._count + return 0 + + @property + def broken(self): + """Return True if the barrier is in a broken state.""" + return self._state is _BarrierState.BROKEN |