diff --git a/src/treepuncher/traits/callbacks.py b/src/treepuncher/traits/callbacks.py index 98221f6..f5d644c 100644 --- a/src/treepuncher/traits/callbacks.py +++ b/src/treepuncher/traits/callbacks.py @@ -10,7 +10,7 @@ from ..events.base import BaseEvent class CallbacksHolder: _callbacks : Dict[Any, List[Callable]] - _tasks : Dict[uuid.UUID, asyncio.Event] + _tasks : Dict[uuid.UUID, asyncio.Task] def __init__(self): super().__init__() @@ -34,23 +34,20 @@ class CallbacksHolder: def _wrap(self, cb:Callable, uid:uuid.UUID) -> Callable: async def wrapper(*args): try: - ret = await cb(*args) + return await cb(*args) except Exception: logging.exception("Exception processing callback") - ret = None - self._tasks[uid].set() - self._tasks.pop(uid) - return ret + return None + finally: + self._tasks.pop(uid) return wrapper def run_callbacks(self, key:Any, *args) -> None: for cb in self.trigger(key): task_id = uuid.uuid4() - self._tasks[task_id] = asyncio.Event() - - asyncio.get_event_loop().create_task(self._wrap(cb, task_id)(*args)) + self._tasks[task_id] = asyncio.get_event_loop().create_task(self._wrap(cb, task_id)(*args)) async def join_callbacks(self): - await asyncio.gather(*list(t.wait() for t in self._tasks.values())) + await asyncio.gather(*list(self._tasks.values())) self._tasks.clear() diff --git a/src/treepuncher/treepuncher.py b/src/treepuncher/treepuncher.py index 6a9a029..aa1d25d 100644 --- a/src/treepuncher/treepuncher.py +++ b/src/treepuncher/treepuncher.py @@ -221,9 +221,9 @@ class Treepuncher( await self.dispatcher.disconnect(block=not force) if not force: await self._worker + for m in self.modules: + await m.cleanup() await self.join_callbacks() - for m in self.modules: - await m.cleanup() await super().stop() self.logger.info("Treepuncher stopped")