aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/Lib/asyncio/locks.py
diff options
context:
space:
mode:
authorDuprat <yduprat@gmail.com>2022-03-25 23:01:21 +0100
committerGitHub <noreply@github.com>2022-03-26 00:01:21 +0200
commitd03acd7270d66ddb8e987f9743405147ecc15087 (patch)
treecffe25f0c26d55aef28c910dcf825747da99a6d4 /Lib/asyncio/locks.py
parent20e6e5636a06fe5e1472062918d0a302d82a71c3 (diff)
downloadcpython-d03acd7270d66ddb8e987f9743405147ecc15087.tar.gz
cpython-d03acd7270d66ddb8e987f9743405147ecc15087.zip
bpo-43352: Add a Barrier object in asyncio lib (GH-24903)
Co-authored-by: Yury Selivanov <yury@edgedb.com> Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
Diffstat (limited to 'Lib/asyncio/locks.py')
-rw-r--r--Lib/asyncio/locks.py157
1 files changed, 155 insertions, 2 deletions
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py
index 9b4612197de..e71130274dd 100644
--- a/Lib/asyncio/locks.py
+++ b/Lib/asyncio/locks.py
@@ -1,14 +1,15 @@
"""Synchronization primitives."""
-__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')
+__all__ = ('Lock', 'Event', 'Condition', 'Semaphore',
+ 'BoundedSemaphore', 'Barrier')
import collections
+import enum
from . import exceptions
from . import mixins
from . import tasks
-
class _ContextManagerMixin:
async def __aenter__(self):
await self.acquire()
@@ -416,3 +417,155 @@ class BoundedSemaphore(Semaphore):
if self._value >= self._bound_value:
raise ValueError('BoundedSemaphore released too many times')
super().release()
+
+
+
+class _BarrierState(enum.Enum):
+ FILLING = 'filling'
+ DRAINING = 'draining'
+ RESETTING = 'resetting'
+ BROKEN = 'broken'
+
+
+class Barrier(mixins._LoopBoundMixin):
+ """Asyncio equivalent to threading.Barrier
+
+ Implements a Barrier primitive.
+ Useful for synchronizing a fixed number of tasks at known synchronization
+ points. Tasks block on 'wait()' and are simultaneously awoken once they
+ have all made their call.
+ """
+
+ def __init__(self, parties):
+ """Create a barrier, initialised to 'parties' tasks."""
+ if parties < 1:
+ raise ValueError('parties must be > 0')
+
+ self._cond = Condition() # notify all tasks when state changes
+
+ self._parties = parties
+ self._state = _BarrierState.FILLING
+ self._count = 0 # count tasks in Barrier
+
+ def __repr__(self):
+ res = super().__repr__()
+ extra = f'{self._state.value}'
+ if not self.broken:
+ extra += f', waiters:{self.n_waiting}/{self.parties}'
+ return f'<{res[1:-1]} [{extra}]>'
+
+ async def __aenter__(self):
+ # wait for the barrier reaches the parties number
+ # when start draining release and return index of waited task
+ return await self.wait()
+
+ async def __aexit__(self, *args):
+ pass
+
+ async def wait(self):
+ """Wait for the barrier.
+
+ When the specified number of tasks have started waiting, they are all
+ simultaneously awoken.
+ Returns an unique and individual index number from 0 to 'parties-1'.
+ """
+ async with self._cond:
+ await self._block() # Block while the barrier drains or resets.
+ try:
+ index = self._count
+ self._count += 1
+ if index + 1 == self._parties:
+ # We release the barrier
+ await self._release()
+ else:
+ await self._wait()
+ return index
+ finally:
+ self._count -= 1
+ # Wake up any tasks waiting for barrier to drain.
+ self._exit()
+
+ async def _block(self):
+ # Block until the barrier is ready for us,
+ # or raise an exception if it is broken.
+ #
+ # It is draining or resetting, wait until done
+ # unless a CancelledError occurs
+ await self._cond.wait_for(
+ lambda: self._state not in (
+ _BarrierState.DRAINING, _BarrierState.RESETTING
+ )
+ )
+
+ # see if the barrier is in a broken state
+ if self._state is _BarrierState.BROKEN:
+ raise exceptions.BrokenBarrierError("Barrier aborted")
+
+ async def _release(self):
+ # Release the tasks waiting in the barrier.
+
+ # Enter draining state.
+ # Next waiting tasks will be blocked until the end of draining.
+ self._state = _BarrierState.DRAINING
+ self._cond.notify_all()
+
+ async def _wait(self):
+ # Wait in the barrier until we are released. Raise an exception
+ # if the barrier is reset or broken.
+
+ # wait for end of filling
+ # unless a CancelledError occurs
+ await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
+
+ if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
+ raise exceptions.BrokenBarrierError("Abort or reset of barrier")
+
+ def _exit(self):
+ # If we are the last tasks to exit the barrier, signal any tasks
+ # waiting for the barrier to drain.
+ if self._count == 0:
+ if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
+ self._state = _BarrierState.FILLING
+ self._cond.notify_all()
+
+ async def reset(self):
+ """Reset the barrier to the initial state.
+
+ Any tasks currently waiting will get the BrokenBarrier exception
+ raised.
+ """
+ async with self._cond:
+ if self._count > 0:
+ if self._state is not _BarrierState.RESETTING:
+ #reset the barrier, waking up tasks
+ self._state = _BarrierState.RESETTING
+ else:
+ self._state = _BarrierState.FILLING
+ self._cond.notify_all()
+
+ async def abort(self):
+ """Place the barrier into a 'broken' state.
+
+ Useful in case of error. Any currently waiting tasks and tasks
+ attempting to 'wait()' will have BrokenBarrierError raised.
+ """
+ async with self._cond:
+ self._state = _BarrierState.BROKEN
+ self._cond.notify_all()
+
+ @property
+ def parties(self):
+ """Return the number of tasks required to trip the barrier."""
+ return self._parties
+
+ @property
+ def n_waiting(self):
+ """Return the number of tasks currently waiting at the barrier."""
+ if self._state is _BarrierState.FILLING:
+ return self._count
+ return 0
+
+ @property
+ def broken(self):
+ """Return True if the barrier is in a broken state."""
+ return self._state is _BarrierState.BROKEN