moved common code into traits, added callbacks, made it cleaner?
This commit is contained in:
parent
c2e8e80806
commit
540dbd5d90
6 changed files with 185 additions and 122 deletions
|
@ -5,19 +5,19 @@ import logging
|
||||||
from .mc.proto.play.clientbound import PacketChat
|
from .mc.proto.play.clientbound import PacketChat
|
||||||
from .mc.token import Token
|
from .mc.token import Token
|
||||||
from .dispatcher import ConnectionState
|
from .dispatcher import ConnectionState
|
||||||
from .client import Client
|
from .client import MinecraftClient
|
||||||
from .server import MinecraftServer
|
from .server import MinecraftServer
|
||||||
from .util.helpers import parse_chat
|
from .util.helpers import parse_chat
|
||||||
|
|
||||||
async def idle():
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
if sys.argv[1] == "--server":
|
if sys.argv[1] == "--server":
|
||||||
serv = MinecraftServer("0.0.0.0", 25565)
|
host = sys.argv[2] if len(sys.argv) > 2 else "localhost"
|
||||||
|
port = sys.argv[3] if len(sys.argv) > 3 else 25565
|
||||||
|
|
||||||
|
serv = MinecraftServer(host, port)
|
||||||
|
|
||||||
serv.run() # will block and start asyncio event loop
|
serv.run() # will block and start asyncio event loop
|
||||||
else:
|
else:
|
||||||
username = sys.argv[1]
|
username = sys.argv[1]
|
||||||
|
@ -33,7 +33,7 @@ if __name__ == "__main__":
|
||||||
host = server.strip()
|
host = server.strip()
|
||||||
port = 25565
|
port = 25565
|
||||||
|
|
||||||
client = Client(host, port, username=username, password=pwd)
|
client = MinecraftClient(host, port, username=username, password=pwd)
|
||||||
|
|
||||||
@client.on_packet(PacketChat, ConnectionState.PLAY)
|
@client.on_packet(PacketChat, ConnectionState.PLAY)
|
||||||
async def print_chat(packet: PacketChat):
|
async def print_chat(packet: PacketChat):
|
||||||
|
|
|
@ -2,12 +2,14 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from asyncio import Task
|
from asyncio import Task
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator
|
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator
|
||||||
|
|
||||||
from .dispatcher import Dispatcher
|
from .dispatcher import Dispatcher
|
||||||
|
from .traits import CallbacksHolder, Runnable
|
||||||
from .mc.packet import Packet
|
from .mc.packet import Packet
|
||||||
from .mc.token import Token, AuthException
|
from .mc.token import Token, AuthException
|
||||||
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
|
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
|
||||||
|
@ -22,10 +24,21 @@ from .util import encryption
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Client:
|
@dataclass
|
||||||
|
class ClientOptions:
|
||||||
|
reconnect : bool
|
||||||
|
reconnect_delay : float
|
||||||
|
keep_alive : bool
|
||||||
|
poll_interval : float
|
||||||
|
|
||||||
|
class ClientEvent(Enum):
|
||||||
|
CONNECTED = 0
|
||||||
|
DISCONNECTED = 1
|
||||||
|
|
||||||
|
class MinecraftClient(CallbacksHolder, Runnable):
|
||||||
host:str
|
host:str
|
||||||
port:int
|
port:int
|
||||||
options:dict
|
options:ClientOptions
|
||||||
|
|
||||||
username:Optional[str]
|
username:Optional[str]
|
||||||
password:Optional[str]
|
password:Optional[str]
|
||||||
|
@ -37,27 +50,30 @@ class Client:
|
||||||
_worker : Task
|
_worker : Task
|
||||||
_callbacks = Dict[str, Task]
|
_callbacks = Dict[str, Task]
|
||||||
|
|
||||||
_packet_callbacks : Dict[Type[Packet], List[Callable]]
|
|
||||||
_logger : logging.Logger
|
_logger : logging.Logger
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host:str,
|
host:str,
|
||||||
port:int = 25565,
|
port:int = 25565,
|
||||||
options:dict = None,
|
|
||||||
username:Optional[str] = None,
|
username:Optional[str] = None,
|
||||||
password:Optional[str] = None,
|
password:Optional[str] = None,
|
||||||
token:Optional[Token] = None,
|
token:Optional[Token] = None,
|
||||||
|
reconnect:bool = True,
|
||||||
|
reconnect_delay:float = 10.0,
|
||||||
|
keep_alive:bool = True,
|
||||||
|
poll_interval:float = 1.0,
|
||||||
|
|
||||||
):
|
):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
|
||||||
self.options = options or {
|
self.options = ClientOptions(
|
||||||
"reconnect" : True,
|
reconnect=reconnect,
|
||||||
"rctime" : 5.0,
|
reconnect_delay=reconnect_delay,
|
||||||
"keep-alive" : True,
|
keep_alive=keep_alive,
|
||||||
"poll-timeout" : 1,
|
poll_interval=poll_interval
|
||||||
}
|
)
|
||||||
|
|
||||||
self.token = token
|
self.token = token
|
||||||
self.username = username
|
self.username = username
|
||||||
|
@ -67,9 +83,6 @@ class Client:
|
||||||
self._processing = False
|
self._processing = False
|
||||||
self._authenticated = False
|
self._authenticated = False
|
||||||
|
|
||||||
self._packet_callbacks = {}
|
|
||||||
self._callbacks = {}
|
|
||||||
|
|
||||||
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
|
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -80,23 +93,21 @@ class Client:
|
||||||
def connected(self) -> bool:
|
def connected(self) -> bool:
|
||||||
return self.started and self.dispatcher.connected
|
return self.started and self.dispatcher.connected
|
||||||
|
|
||||||
def _run_async(self, func, pkt:Packet):
|
def on_connected(self) -> Callable:
|
||||||
key = str(uuid.uuid4()) # ugly!
|
|
||||||
|
|
||||||
async def wrapper(packet:Packet):
|
|
||||||
try:
|
|
||||||
await func(packet)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error("Exception in callback %s for packet %s | %s", func.__name__, packet, str(e))
|
|
||||||
self._callbacks.pop(key, None)
|
|
||||||
|
|
||||||
self._callbacks[key] = asyncio.get_event_loop().create_task(wrapper(pkt))
|
|
||||||
|
|
||||||
def on_packet(self, packet:Type[Packet], *args) -> Callable: # receive *args for retro compatibility
|
|
||||||
def wrapper(fun):
|
def wrapper(fun):
|
||||||
if packet not in self._packet_callbacks:
|
self.register(ClientEvent.CONNECTED, fun)
|
||||||
self._packet_callbacks[packet] = []
|
return fun
|
||||||
self._packet_callbacks[packet].append(fun)
|
return wrapper
|
||||||
|
|
||||||
|
def on_disconnected(self) -> Callable:
|
||||||
|
def wrapper(fun):
|
||||||
|
self.register(ClientEvent.DISCONNECTED, fun)
|
||||||
|
return fun
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def on_packet(self, packet:Type[Packet]) -> Callable:
|
||||||
|
def wrapper(fun):
|
||||||
|
self.register(packet, fun)
|
||||||
return fun
|
return fun
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
@ -136,39 +147,22 @@ class Client:
|
||||||
if restart:
|
if restart:
|
||||||
await self.start()
|
await self.start()
|
||||||
|
|
||||||
def run(self):
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
loop.run_until_complete(self.start())
|
|
||||||
|
|
||||||
async def idle():
|
|
||||||
while True: # TODO don't busywait even if it doesn't matter much
|
|
||||||
await asyncio.sleep(self.options["poll-timeout"])
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(idle())
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self._logger.info("Received SIGINT, stopping...")
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(self.stop())
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self._logger.info("Received SIGINT, stopping for real")
|
|
||||||
loop.run_until_complete(self.stop(wait_tasks=False))
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
if self.started:
|
||||||
|
return
|
||||||
self._processing = True
|
self._processing = True
|
||||||
self._worker = asyncio.get_event_loop().create_task(self._client_worker())
|
self._worker = asyncio.get_event_loop().create_task(self._client_worker())
|
||||||
self._logger.info("Minecraft client started")
|
self._logger.info("Minecraft client started")
|
||||||
|
|
||||||
async def stop(self, block=True, wait_tasks=True):
|
async def stop(self, force:bool=False):
|
||||||
self._processing = False
|
self._processing = False
|
||||||
if self.dispatcher.connected:
|
if self.dispatcher.connected:
|
||||||
await self.dispatcher.disconnect(block=block)
|
await self.dispatcher.disconnect(block=not force)
|
||||||
if block:
|
if not force:
|
||||||
await self._worker
|
await self._worker
|
||||||
self._logger.info("Minecraft client stopped")
|
self._logger.info("Minecraft client stopped")
|
||||||
if block and wait_tasks:
|
if not force:
|
||||||
await asyncio.gather(*list(self._callbacks.values()))
|
await self.join_callbacks()
|
||||||
|
|
||||||
async def _client_worker(self):
|
async def _client_worker(self):
|
||||||
while self._processing:
|
while self._processing:
|
||||||
|
@ -188,9 +182,9 @@ class Client:
|
||||||
self._logger.exception("Exception in Client connection")
|
self._logger.exception("Exception in Client connection")
|
||||||
if self.dispatcher.connected:
|
if self.dispatcher.connected:
|
||||||
await self.dispatcher.disconnect()
|
await self.dispatcher.disconnect()
|
||||||
if not self.options["reconnect"]:
|
if not self.options.reconnect:
|
||||||
break
|
break
|
||||||
await asyncio.sleep(self.options["rctime"])
|
await asyncio.sleep(self.options.reconnect_delay)
|
||||||
await self.stop(block=False)
|
await self.stop(block=False)
|
||||||
|
|
||||||
async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first
|
async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first
|
||||||
|
@ -255,20 +249,20 @@ class Client:
|
||||||
|
|
||||||
async def _play(self) -> bool:
|
async def _play(self) -> bool:
|
||||||
self.dispatcher.state = ConnectionState.PLAY
|
self.dispatcher.state = ConnectionState.PLAY
|
||||||
|
self.run_callbacks(ClientEvent.CONNECTED)
|
||||||
async for packet in self.dispatcher.packets():
|
async for packet in self.dispatcher.packets():
|
||||||
self._logger.debug("[ * ] Processing | %s", str(packet))
|
self._logger.debug("[ * ] Processing | %s", str(packet))
|
||||||
if isinstance(packet, PacketSetCompression):
|
if isinstance(packet, PacketSetCompression):
|
||||||
self._logger.info("Compression updated")
|
self._logger.info("Compression updated")
|
||||||
self.dispatcher.compression = packet.threshold
|
self.dispatcher.compression = packet.threshold
|
||||||
elif isinstance(packet, PacketKeepAlive):
|
elif isinstance(packet, PacketKeepAlive):
|
||||||
if self.options["keep-alive"]:
|
if self.options.keep_alive:
|
||||||
keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId)
|
keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId)
|
||||||
await self.dispatcher.write(keep_alive_packet)
|
await self.dispatcher.write(keep_alive_packet)
|
||||||
elif isinstance(packet, PacketKickDisconnect):
|
elif isinstance(packet, PacketKickDisconnect):
|
||||||
self._logger.error("Kicked while in game")
|
self._logger.error("Kicked while in game")
|
||||||
break
|
break
|
||||||
for packet_type in (Packet, type(packet)): # check both callbacks for base class and instance class
|
self.run_callbacks(Packet, packet)
|
||||||
if packet_type in self._packet_callbacks:
|
self.run_callbacks(type(packet), packet)
|
||||||
for cb in self._packet_callbacks[packet_type]:
|
self.run_callbacks(ClientEvent.DISCONNECTED)
|
||||||
self._run_async(cb, packet)
|
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -2,6 +2,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
from asyncio import Task, StreamReader, StreamWriter
|
from asyncio import Task, StreamReader, StreamWriter
|
||||||
from asyncio.base_events import Server # just for typing
|
from asyncio.base_events import Server # just for typing
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -9,6 +10,7 @@ from enum import Enum
|
||||||
from typing import Dict, List, Callable, Coroutine, Type, Optional, Tuple, AsyncIterator
|
from typing import Dict, List, Callable, Coroutine, Type, Optional, Tuple, AsyncIterator
|
||||||
|
|
||||||
from .dispatcher import Dispatcher
|
from .dispatcher import Dispatcher
|
||||||
|
from .traits import CallbacksHolder, Runnable
|
||||||
from .mc.packet import Packet
|
from .mc.packet import Packet
|
||||||
from .mc.token import Token, AuthException
|
from .mc.token import Token, AuthException
|
||||||
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
|
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
|
||||||
|
@ -23,89 +25,80 @@ from .util import encryption
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
class MinecraftServer:
|
@dataclass
|
||||||
|
class ServerOptions:
|
||||||
|
online_mode : bool
|
||||||
|
spawn_player : bool
|
||||||
|
poll_interval : float
|
||||||
|
|
||||||
|
class ServerEvent(Enum):
|
||||||
|
CLIENT_CONNECTED = 0
|
||||||
|
CLIENT_DISCONNECTED = 1
|
||||||
|
|
||||||
|
class MinecraftServer(CallbacksHolder, Runnable):
|
||||||
host:str
|
host:str
|
||||||
port:int
|
port:int
|
||||||
options:dict
|
options:ServerOptions
|
||||||
|
|
||||||
_dispatcher_pool : List[Dispatcher]
|
_dispatcher_pool : List[Dispatcher]
|
||||||
_processing : bool
|
_processing : bool
|
||||||
_server : Server
|
_server : Server
|
||||||
_worker : Task
|
_worker : Task
|
||||||
_callbacks = Dict[str, Task]
|
|
||||||
|
|
||||||
_logger : logging.Logger
|
_logger : logging.Logger
|
||||||
|
|
||||||
_disconnect_handlers : List[Callable]
|
|
||||||
_connect_handlers : List[Callable]
|
|
||||||
_packet_handlers : List[Callable]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host:str,
|
host:str,
|
||||||
port:int = 25565,
|
port:int = 25565,
|
||||||
options:dict = None,
|
online_mode:bool = False,
|
||||||
|
spawn_player:bool = True,
|
||||||
|
poll_interval:float = 1.0,
|
||||||
):
|
):
|
||||||
|
super().__init__()
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
|
||||||
self.options = options or {
|
self.options = ServerOptions(
|
||||||
"poll-timeout" : 1,
|
online_mode=online_mode,
|
||||||
"online-mode" : False,
|
spawn_player=spawn_player,
|
||||||
"spawn-player" : True,
|
poll_interval=poll_interval,
|
||||||
}
|
)
|
||||||
|
|
||||||
self._dispatcher_pool = []
|
self._dispatcher_pool = []
|
||||||
self._processing = False
|
self._processing = False
|
||||||
|
|
||||||
self._logger = LOGGER.getChild(f"@({self.host}:{self.port})")
|
self._logger = LOGGER.getChild(f"@({self.host}:{self.port})")
|
||||||
|
|
||||||
self._disconnect_handlers = []
|
|
||||||
self._connect_handlers = []
|
|
||||||
self._packet_handlers = []
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def started(self) -> bool:
|
def started(self) -> bool:
|
||||||
return self._processing
|
return self._processing
|
||||||
|
|
||||||
def on_client_connect(self, *args):
|
@property
|
||||||
|
def connected(self) -> int:
|
||||||
|
return len(self._dispatcher_pool)
|
||||||
|
|
||||||
|
def on_connect(self):
|
||||||
def wrapper(fun):
|
def wrapper(fun):
|
||||||
self._connect_handlers.append(fun)
|
self.register(ServerEvent.CLIENT_CONNECTED, fun)
|
||||||
return fun
|
return fun
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def on_client_disconnect(self, *args):
|
def on_disconnect(self):
|
||||||
def wrapper(fun):
|
def wrapper(fun):
|
||||||
self._disconnect_handlers.append(fun)
|
self.register(ServerEvent.CLIENT_DISCONNECTED, fun)
|
||||||
return fun
|
return fun
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def on_client_packet(self, *args):
|
def on_packet(self, packet:Type[Packet]):
|
||||||
def wrapper(fun):
|
def wrapper(fun):
|
||||||
self._packet_handlers.append(fun)
|
self.register(packet, fun)
|
||||||
return fun
|
return fun
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def run(self):
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
loop.run_until_complete(self.start())
|
|
||||||
|
|
||||||
async def idle():
|
|
||||||
while True: # TODO don't busywait even if it doesn't matter much
|
|
||||||
await asyncio.sleep(self.options["poll-timeout"])
|
|
||||||
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(idle())
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self._logger.info("Received SIGINT, stopping...")
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(self.stop())
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self._logger.info("Received SIGINT, stopping for real")
|
|
||||||
loop.run_until_complete(self.stop(wait_tasks=False))
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
if self.started:
|
||||||
|
return
|
||||||
self._server = await asyncio.start_server(
|
self._server = await asyncio.start_server(
|
||||||
self._server_worker, self.host, self.port
|
self._server_worker, self.host, self.port
|
||||||
)
|
)
|
||||||
|
@ -114,14 +107,15 @@ class MinecraftServer:
|
||||||
await self._server.start_serving()
|
await self._server.start_serving()
|
||||||
self._logger.info("Minecraft server started")
|
self._logger.info("Minecraft server started")
|
||||||
|
|
||||||
async def stop(self, block=True, wait_tasks=True):
|
async def stop(self, force:bool = False):
|
||||||
self._processing = False
|
self._processing = False
|
||||||
self._server.close()
|
self._server.close()
|
||||||
await asyncio.gather(*[d.disconnect(block=block) for d in self._dispatcher_pool])
|
await asyncio.gather(*[d.disconnect(block=not force) for d in self._dispatcher_pool])
|
||||||
if block:
|
if not force:
|
||||||
await self._server.wait_closed()
|
await asyncio.gather(
|
||||||
# if block and wait_tasks: # TODO wait for client workers
|
self._server.wait_closed(),
|
||||||
# await asyncio.gather(*list(self._callbacks.values()))
|
self.join_callbacks(),
|
||||||
|
)
|
||||||
|
|
||||||
async def _disconnect_client(self, dispatcher):
|
async def _disconnect_client(self, dispatcher):
|
||||||
if dispatcher.state == ConnectionState.LOGIN:
|
if dispatcher.state == ConnectionState.LOGIN:
|
||||||
|
@ -178,7 +172,7 @@ class MinecraftServer:
|
||||||
self._logger.info("Logging in player")
|
self._logger.info("Logging in player")
|
||||||
async for packet in dispatcher.packets():
|
async for packet in dispatcher.packets():
|
||||||
if isinstance(packet, PacketLoginStart):
|
if isinstance(packet, PacketLoginStart):
|
||||||
if self.options["online-mode"]:
|
if self.options.online_mode:
|
||||||
# await dispatcher.write(
|
# await dispatcher.write(
|
||||||
# PacketEncryptionBegin(
|
# PacketEncryptionBegin(
|
||||||
# dispatcher.proto,
|
# dispatcher.proto,
|
||||||
|
@ -207,7 +201,7 @@ class MinecraftServer:
|
||||||
async def _play(self, dispatcher:Dispatcher) -> bool:
|
async def _play(self, dispatcher:Dispatcher) -> bool:
|
||||||
self._logger.info("Player connected")
|
self._logger.info("Player connected")
|
||||||
|
|
||||||
if self.options["spawn-player"]:
|
if self.options.spawn_player:
|
||||||
await dispatcher.write(
|
await dispatcher.write(
|
||||||
PacketLogin(
|
PacketLogin(
|
||||||
dispatcher.proto,
|
dispatcher.proto,
|
||||||
|
@ -245,14 +239,14 @@ class MinecraftServer:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.gather(*[cb(dispatcher) for cb in self._connect_handlers])
|
self.run_callbacks(ServerEvent.CLIENT_CONNECTED)
|
||||||
|
|
||||||
async for packet in dispatcher.packets():
|
async for packet in dispatcher.packets():
|
||||||
for cb in self._packet_handlers:
|
# TODO handle play
|
||||||
asyncio.get_event_loop().create_task(cb(dispatcher, packet))
|
self.run_callbacks(Packet, packet)
|
||||||
pass # TODO handle play
|
self.run_callbacks(type(packet), packet)
|
||||||
|
|
||||||
await asyncio.gather(*[cb(dispatcher) for cb in self._disconnect_handlers])
|
self.run_callbacks(ServerEvent.CLIENT_DISCONNECTED)
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
2
aiocraft/traits/__init__.py
Normal file
2
aiocraft/traits/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .callbacks import CallbacksHolder
|
||||||
|
from .runnable import Runnable
|
38
aiocraft/traits/callbacks.py
Normal file
38
aiocraft/traits/callbacks.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import asyncio
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from typing import Dict, List, Any, Callable
|
||||||
|
|
||||||
|
class CallbacksHolder:
|
||||||
|
|
||||||
|
_callbacks : Dict[Any, List[Callable]]
|
||||||
|
_tasks : Dict[str, asyncio.Task]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._callbacks = {}
|
||||||
|
|
||||||
|
def register(self, key:Any, callback:Callable):
|
||||||
|
if key not in self._callbacks:
|
||||||
|
self._callbacks[key] = []
|
||||||
|
self._callbacks[key].append(callback)
|
||||||
|
|
||||||
|
def trigger(self, key:Any) -> List[Callable]:
|
||||||
|
if key not in self._callbacks:
|
||||||
|
return []
|
||||||
|
return self._callbacks[key]
|
||||||
|
|
||||||
|
def run_callbacks(self, key:Any, *args) -> None:
|
||||||
|
for cb in self.trigger(key):
|
||||||
|
task_id = str(uuid4())
|
||||||
|
|
||||||
|
async def wrapper(*args):
|
||||||
|
await cb(*args)
|
||||||
|
self._tasks.pop(task_id)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
self._tasks[task_id] = loop.create_task(wrapper(*args))
|
||||||
|
|
||||||
|
async def join_callbacks(self):
|
||||||
|
await asyncio.gather(*list(self._tasks.values()))
|
||||||
|
self._tasks.clear()
|
||||||
|
|
35
aiocraft/traits/runnable.py
Normal file
35
aiocraft/traits/runnable.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
class Runnable:
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def stop(self, force:bool=False):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
logging.info("Starting")
|
||||||
|
|
||||||
|
loop.run_until_complete(self.start())
|
||||||
|
|
||||||
|
async def idle():
|
||||||
|
never = asyncio.Event()
|
||||||
|
logging.info("Idling")
|
||||||
|
await never.wait()
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(idle())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logging.info("Received SIGINT, stopping...")
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(self.stop())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logging.info("Received SIGINT, stopping for real")
|
||||||
|
loop.run_until_complete(self.stop(force=True))
|
||||||
|
|
||||||
|
logging.info("Done")
|
||||||
|
|
Loading…
Reference in a new issue