of course it was a race condition

This commit is contained in:
əlemi 2021-11-22 03:05:15 +01:00
parent d82ba7a975
commit b4d93b3b2f

View file

@ -6,7 +6,7 @@ from typing import Dict, List, Any, Callable
class CallbacksHolder: class CallbacksHolder:
_callbacks : Dict[Any, List[Callable]] _callbacks : Dict[Any, List[Callable]]
_tasks : Dict[str, asyncio.Task] _tasks : Dict[str, asyncio.Event]
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -26,15 +26,16 @@ class CallbacksHolder:
def run_callbacks(self, key:Any, *args) -> None: def run_callbacks(self, key:Any, *args) -> None:
for cb in self.trigger(key): for cb in self.trigger(key):
task_id = str(uuid4()) task_id = str(uuid4())
self._tasks[task_id] = asyncio.Event()
async def wrapper(*args): async def wrapper(*args):
await cb(*args) await cb(*args)
self._tasks[task_id].set()
self._tasks.pop(task_id) self._tasks.pop(task_id)
loop = asyncio.get_event_loop() asyncio.get_event_loop().create_task(wrapper(*args))
self._tasks[task_id] = loop.create_task(wrapper(*args))
async def join_callbacks(self): async def join_callbacks(self):
await asyncio.gather(*list(self._tasks.values())) await asyncio.gather(*list(t.wait() for t in self._tasks.values()))
self._tasks.clear() self._tasks.clear()