From 540dbd5d90a787dd41545359119f0083ae8fcc26 Mon Sep 17 00:00:00 2001 From: alemidev Date: Mon, 22 Nov 2021 02:27:42 +0100 Subject: [PATCH] moved common code into traits, added callbacks, made it cleaner? --- aiocraft/__main__.py | 14 ++--- aiocraft/client.py | 114 +++++++++++++++++------------------ aiocraft/server.py | 104 +++++++++++++++----------------- aiocraft/traits/__init__.py | 2 + aiocraft/traits/callbacks.py | 38 ++++++++++++ aiocraft/traits/runnable.py | 35 +++++++++++ 6 files changed, 185 insertions(+), 122 deletions(-) create mode 100644 aiocraft/traits/__init__.py create mode 100644 aiocraft/traits/callbacks.py create mode 100644 aiocraft/traits/runnable.py diff --git a/aiocraft/__main__.py b/aiocraft/__main__.py index cc0b512..184eb77 100644 --- a/aiocraft/__main__.py +++ b/aiocraft/__main__.py @@ -5,19 +5,19 @@ import logging from .mc.proto.play.clientbound import PacketChat from .mc.token import Token from .dispatcher import ConnectionState -from .client import Client +from .client import MinecraftClient from .server import MinecraftServer from .util.helpers import parse_chat -async def idle(): - while True: - await asyncio.sleep(1) - if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) if sys.argv[1] == "--server": - serv = MinecraftServer("0.0.0.0", 25565) + host = sys.argv[2] if len(sys.argv) > 2 else "localhost" + port = sys.argv[3] if len(sys.argv) > 3 else 25565 + + serv = MinecraftServer(host, port) + serv.run() # will block and start asyncio event loop else: username = sys.argv[1] @@ -33,7 +33,7 @@ if __name__ == "__main__": host = server.strip() port = 25565 - client = Client(host, port, username=username, password=pwd) + client = MinecraftClient(host, port, username=username, password=pwd) @client.on_packet(PacketChat, ConnectionState.PLAY) async def print_chat(packet: PacketChat): diff --git a/aiocraft/client.py b/aiocraft/client.py index 1034b30..dd2b76f 100644 --- a/aiocraft/client.py +++ b/aiocraft/client.py @@ -2,12 +2,14 @@ import asyncio import logging import uuid +from dataclasses import dataclass from asyncio import Task from enum import Enum from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator from .dispatcher import Dispatcher +from .traits import CallbacksHolder, Runnable from .mc.packet import Packet from .mc.token import Token, AuthException from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState @@ -22,10 +24,21 @@ from .util import encryption LOGGER = logging.getLogger(__name__) -class Client: +@dataclass +class ClientOptions: + reconnect : bool + reconnect_delay : float + keep_alive : bool + poll_interval : float + +class ClientEvent(Enum): + CONNECTED = 0 + DISCONNECTED = 1 + +class MinecraftClient(CallbacksHolder, Runnable): host:str port:int - options:dict + options:ClientOptions username:Optional[str] password:Optional[str] @@ -37,27 +50,30 @@ class Client: _worker : Task _callbacks = Dict[str, Task] - _packet_callbacks : Dict[Type[Packet], List[Callable]] _logger : logging.Logger def __init__( self, host:str, port:int = 25565, - options:dict = None, username:Optional[str] = None, password:Optional[str] = None, token:Optional[Token] = None, + reconnect:bool = True, + reconnect_delay:float = 10.0, + keep_alive:bool = True, + poll_interval:float = 1.0, + ): self.host = host self.port = port - self.options = options or { - "reconnect" : True, - "rctime" : 5.0, - "keep-alive" : True, - "poll-timeout" : 1, - } + self.options = ClientOptions( + reconnect=reconnect, + reconnect_delay=reconnect_delay, + keep_alive=keep_alive, + poll_interval=poll_interval + ) self.token = token self.username = username @@ -67,9 +83,6 @@ class Client: self._processing = False self._authenticated = False - self._packet_callbacks = {} - self._callbacks = {} - self._logger = LOGGER.getChild(f"{self.host}:{self.port}") @property @@ -80,23 +93,21 @@ class Client: def connected(self) -> bool: return self.started and self.dispatcher.connected - def _run_async(self, func, pkt:Packet): - key = str(uuid.uuid4()) # ugly! - - async def wrapper(packet:Packet): - try: - await func(packet) - except Exception as e: - self._logger.error("Exception in callback %s for packet %s | %s", func.__name__, packet, str(e)) - self._callbacks.pop(key, None) - - self._callbacks[key] = asyncio.get_event_loop().create_task(wrapper(pkt)) - - def on_packet(self, packet:Type[Packet], *args) -> Callable: # receive *args for retro compatibility + def on_connected(self) -> Callable: def wrapper(fun): - if packet not in self._packet_callbacks: - self._packet_callbacks[packet] = [] - self._packet_callbacks[packet].append(fun) + self.register(ClientEvent.CONNECTED, fun) + return fun + return wrapper + + def on_disconnected(self) -> Callable: + def wrapper(fun): + self.register(ClientEvent.DISCONNECTED, fun) + return fun + return wrapper + + def on_packet(self, packet:Type[Packet]) -> Callable: + def wrapper(fun): + self.register(packet, fun) return fun return wrapper @@ -136,39 +147,22 @@ class Client: if restart: await self.start() - def run(self): - loop = asyncio.get_event_loop() - - loop.run_until_complete(self.start()) - - async def idle(): - while True: # TODO don't busywait even if it doesn't matter much - await asyncio.sleep(self.options["poll-timeout"]) - - try: - loop.run_until_complete(idle()) - except KeyboardInterrupt: - self._logger.info("Received SIGINT, stopping...") - try: - loop.run_until_complete(self.stop()) - except KeyboardInterrupt: - self._logger.info("Received SIGINT, stopping for real") - loop.run_until_complete(self.stop(wait_tasks=False)) - async def start(self): + if self.started: + return self._processing = True self._worker = asyncio.get_event_loop().create_task(self._client_worker()) self._logger.info("Minecraft client started") - async def stop(self, block=True, wait_tasks=True): + async def stop(self, force:bool=False): self._processing = False if self.dispatcher.connected: - await self.dispatcher.disconnect(block=block) - if block: + await self.dispatcher.disconnect(block=not force) + if not force: await self._worker self._logger.info("Minecraft client stopped") - if block and wait_tasks: - await asyncio.gather(*list(self._callbacks.values())) + if not force: + await self.join_callbacks() async def _client_worker(self): while self._processing: @@ -188,9 +182,9 @@ class Client: self._logger.exception("Exception in Client connection") if self.dispatcher.connected: await self.dispatcher.disconnect() - if not self.options["reconnect"]: + if not self.options.reconnect: break - await asyncio.sleep(self.options["rctime"]) + await asyncio.sleep(self.options.reconnect_delay) await self.stop(block=False) async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first @@ -255,20 +249,20 @@ class Client: async def _play(self) -> bool: self.dispatcher.state = ConnectionState.PLAY + self.run_callbacks(ClientEvent.CONNECTED) async for packet in self.dispatcher.packets(): self._logger.debug("[ * ] Processing | %s", str(packet)) if isinstance(packet, PacketSetCompression): self._logger.info("Compression updated") self.dispatcher.compression = packet.threshold elif isinstance(packet, PacketKeepAlive): - if self.options["keep-alive"]: + if self.options.keep_alive: keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId) await self.dispatcher.write(keep_alive_packet) elif isinstance(packet, PacketKickDisconnect): self._logger.error("Kicked while in game") break - for packet_type in (Packet, type(packet)): # check both callbacks for base class and instance class - if packet_type in self._packet_callbacks: - for cb in self._packet_callbacks[packet_type]: - self._run_async(cb, packet) + self.run_callbacks(Packet, packet) + self.run_callbacks(type(packet), packet) + self.run_callbacks(ClientEvent.DISCONNECTED) return False diff --git a/aiocraft/server.py b/aiocraft/server.py index c75beb1..bbc7312 100644 --- a/aiocraft/server.py +++ b/aiocraft/server.py @@ -2,6 +2,7 @@ import asyncio import logging import uuid +from dataclasses import dataclass from asyncio import Task, StreamReader, StreamWriter from asyncio.base_events import Server # just for typing from enum import Enum @@ -9,6 +10,7 @@ from enum import Enum from typing import Dict, List, Callable, Coroutine, Type, Optional, Tuple, AsyncIterator from .dispatcher import Dispatcher +from .traits import CallbacksHolder, Runnable from .mc.packet import Packet from .mc.token import Token, AuthException from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState @@ -23,89 +25,80 @@ from .util import encryption LOGGER = logging.getLogger(__name__) -class MinecraftServer: +@dataclass +class ServerOptions: + online_mode : bool + spawn_player : bool + poll_interval : float + +class ServerEvent(Enum): + CLIENT_CONNECTED = 0 + CLIENT_DISCONNECTED = 1 + +class MinecraftServer(CallbacksHolder, Runnable): host:str port:int - options:dict + options:ServerOptions _dispatcher_pool : List[Dispatcher] _processing : bool _server : Server _worker : Task - _callbacks = Dict[str, Task] _logger : logging.Logger - _disconnect_handlers : List[Callable] - _connect_handlers : List[Callable] - _packet_handlers : List[Callable] - def __init__( self, host:str, port:int = 25565, - options:dict = None, + online_mode:bool = False, + spawn_player:bool = True, + poll_interval:float = 1.0, ): + super().__init__() self.host = host self.port = port - self.options = options or { - "poll-timeout" : 1, - "online-mode" : False, - "spawn-player" : True, - } + self.options = ServerOptions( + online_mode=online_mode, + spawn_player=spawn_player, + poll_interval=poll_interval, + ) self._dispatcher_pool = [] self._processing = False self._logger = LOGGER.getChild(f"@({self.host}:{self.port})") - self._disconnect_handlers = [] - self._connect_handlers = [] - self._packet_handlers = [] - @property def started(self) -> bool: return self._processing - def on_client_connect(self, *args): + @property + def connected(self) -> int: + return len(self._dispatcher_pool) + + def on_connect(self): def wrapper(fun): - self._connect_handlers.append(fun) + self.register(ServerEvent.CLIENT_CONNECTED, fun) return fun return wrapper - def on_client_disconnect(self, *args): + def on_disconnect(self): def wrapper(fun): - self._disconnect_handlers.append(fun) + self.register(ServerEvent.CLIENT_DISCONNECTED, fun) return fun return wrapper - def on_client_packet(self, *args): + def on_packet(self, packet:Type[Packet]): def wrapper(fun): - self._packet_handlers.append(fun) + self.register(packet, fun) return fun return wrapper - def run(self): - loop = asyncio.get_event_loop() - - loop.run_until_complete(self.start()) - - async def idle(): - while True: # TODO don't busywait even if it doesn't matter much - await asyncio.sleep(self.options["poll-timeout"]) - - try: - loop.run_until_complete(idle()) - except KeyboardInterrupt: - self._logger.info("Received SIGINT, stopping...") - try: - loop.run_until_complete(self.stop()) - except KeyboardInterrupt: - self._logger.info("Received SIGINT, stopping for real") - loop.run_until_complete(self.stop(wait_tasks=False)) - async def start(self): + if self.started: + return self._server = await asyncio.start_server( self._server_worker, self.host, self.port ) @@ -114,14 +107,15 @@ class MinecraftServer: await self._server.start_serving() self._logger.info("Minecraft server started") - async def stop(self, block=True, wait_tasks=True): + async def stop(self, force:bool = False): self._processing = False self._server.close() - await asyncio.gather(*[d.disconnect(block=block) for d in self._dispatcher_pool]) - if block: - await self._server.wait_closed() - # if block and wait_tasks: # TODO wait for client workers - # await asyncio.gather(*list(self._callbacks.values())) + await asyncio.gather(*[d.disconnect(block=not force) for d in self._dispatcher_pool]) + if not force: + await asyncio.gather( + self._server.wait_closed(), + self.join_callbacks(), + ) async def _disconnect_client(self, dispatcher): if dispatcher.state == ConnectionState.LOGIN: @@ -178,7 +172,7 @@ class MinecraftServer: self._logger.info("Logging in player") async for packet in dispatcher.packets(): if isinstance(packet, PacketLoginStart): - if self.options["online-mode"]: + if self.options.online_mode: # await dispatcher.write( # PacketEncryptionBegin( # dispatcher.proto, @@ -207,7 +201,7 @@ class MinecraftServer: async def _play(self, dispatcher:Dispatcher) -> bool: self._logger.info("Player connected") - if self.options["spawn-player"]: + if self.options.spawn_player: await dispatcher.write( PacketLogin( dispatcher.proto, @@ -245,14 +239,14 @@ class MinecraftServer: ) ) - await asyncio.gather(*[cb(dispatcher) for cb in self._connect_handlers]) + self.run_callbacks(ServerEvent.CLIENT_CONNECTED) async for packet in dispatcher.packets(): - for cb in self._packet_handlers: - asyncio.get_event_loop().create_task(cb(dispatcher, packet)) - pass # TODO handle play + # TODO handle play + self.run_callbacks(Packet, packet) + self.run_callbacks(type(packet), packet) - await asyncio.gather(*[cb(dispatcher) for cb in self._disconnect_handlers]) + self.run_callbacks(ServerEvent.CLIENT_DISCONNECTED) return False diff --git a/aiocraft/traits/__init__.py b/aiocraft/traits/__init__.py new file mode 100644 index 0000000..950994c --- /dev/null +++ b/aiocraft/traits/__init__.py @@ -0,0 +1,2 @@ +from .callbacks import CallbacksHolder +from .runnable import Runnable diff --git a/aiocraft/traits/callbacks.py b/aiocraft/traits/callbacks.py new file mode 100644 index 0000000..5bbf606 --- /dev/null +++ b/aiocraft/traits/callbacks.py @@ -0,0 +1,38 @@ +import asyncio +from uuid import uuid4 + +from typing import Dict, List, Any, Callable + +class CallbacksHolder: + + _callbacks : Dict[Any, List[Callable]] + _tasks : Dict[str, asyncio.Task] + + def __init__(self): + self._callbacks = {} + + def register(self, key:Any, callback:Callable): + if key not in self._callbacks: + self._callbacks[key] = [] + self._callbacks[key].append(callback) + + def trigger(self, key:Any) -> List[Callable]: + if key not in self._callbacks: + return [] + return self._callbacks[key] + + def run_callbacks(self, key:Any, *args) -> None: + for cb in self.trigger(key): + task_id = str(uuid4()) + + async def wrapper(*args): + await cb(*args) + self._tasks.pop(task_id) + + loop = asyncio.get_event_loop() + self._tasks[task_id] = loop.create_task(wrapper(*args)) + + async def join_callbacks(self): + await asyncio.gather(*list(self._tasks.values())) + self._tasks.clear() + diff --git a/aiocraft/traits/runnable.py b/aiocraft/traits/runnable.py new file mode 100644 index 0000000..91e6239 --- /dev/null +++ b/aiocraft/traits/runnable.py @@ -0,0 +1,35 @@ +import asyncio +import logging + +class Runnable: + + async def start(self): + raise NotImplementedError + + async def stop(self, force:bool=False): + raise NotImplementedError + + def run(self): + loop = asyncio.get_event_loop() + + logging.info("Starting") + + loop.run_until_complete(self.start()) + + async def idle(): + never = asyncio.Event() + logging.info("Idling") + await never.wait() + + try: + loop.run_until_complete(idle()) + except KeyboardInterrupt: + logging.info("Received SIGINT, stopping...") + try: + loop.run_until_complete(self.stop()) + except KeyboardInterrupt: + logging.info("Received SIGINT, stopping for real") + loop.run_until_complete(self.stop(force=True)) + + logging.info("Done") +