From d3587f65aed6fb2b4c49d64b83176ce2ae62ccbe Mon Sep 17 00:00:00 2001 From: alemidev Date: Tue, 8 Mar 2022 01:39:03 +0100 Subject: [PATCH] moved features out of aiocraft, cleaned up client states routines --- src/aiocraft/client.py | 190 +++++++------------------------ src/aiocraft/dispatcher.py | 35 +++--- src/aiocraft/traits/__init__.py | 2 - src/aiocraft/traits/callbacks.py | 56 --------- src/aiocraft/traits/runnable.py | 45 -------- 5 files changed, 59 insertions(+), 269 deletions(-) delete mode 100644 src/aiocraft/traits/__init__.py delete mode 100644 src/aiocraft/traits/callbacks.py delete mode 100644 src/aiocraft/traits/runnable.py diff --git a/src/aiocraft/client.py b/src/aiocraft/client.py index 425b2d2..65042af 100644 --- a/src/aiocraft/client.py +++ b/src/aiocraft/client.py @@ -8,10 +8,9 @@ from asyncio import Task from enum import Enum from time import time -from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any +from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any, Set from .dispatcher import Dispatcher -from .traits import CallbacksHolder, Runnable from .mc.packet import Packet from .mc.auth import AuthInterface, AuthException, MojangToken, MicrosoftAuthenticator from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState @@ -36,11 +35,7 @@ class ClientOptions: poll_interval : float = 1.0 use_packet_whitelist : bool = True -class ClientEvent(Enum): - CONNECTED = 0 - DISCONNECTED = 1 - -class MinecraftClient(CallbacksHolder, Runnable): +class MinecraftClient: host:str port:int options:ClientOptions @@ -53,7 +48,6 @@ class MinecraftClient(CallbacksHolder, Runnable): _processing : bool _authenticated : bool _worker : Task - _callbacks = Dict[str, Task] _logger : logging.Logger @@ -90,35 +84,13 @@ class MinecraftClient(CallbacksHolder, Runnable): self._logger = LOGGER.getChild(f"on({self.host}:{self.port})") - @property - def started(self) -> bool: - return self._processing - @property def connected(self) -> bool: - return self.started and self.dispatcher.connected + return self.dispatcher.connected async def write(self, packet:Packet, wait:bool=False): await self.dispatcher.write(packet, wait) - def on_connected(self) -> Callable: - def wrapper(fun): - self.register(ClientEvent.CONNECTED, fun) - return fun - return wrapper - - def on_disconnected(self) -> Callable: - def wrapper(fun): - self.register(ClientEvent.DISCONNECTED, fun) - return fun - return wrapper - - def on_packet(self, packet:Type[Packet]) -> Callable: - def wrapper(fun): - self.register(packet, fun) - return fun - return wrapper - async def authenticate(self): if self._authenticated: return # Don't spam Auth endpoint! @@ -136,60 +108,49 @@ class MinecraftClient(CallbacksHolder, Runnable): raise ValueError("No refreshable auth or code to login") self._authenticated = True - async def change_server(self, server:str): - restart = self.started - if restart: - await self.stop() - - if ":" in server: - _host, _port = server.split(":", 1) - self.host = _host.strip() - self.port = int(_port) - else: - self.host = server.strip() - self.port = 25565 - self._logger = LOGGER.getChild(f"{self.host}:{self.port}") - - if restart: - await self.start() - - async def start(self): - await super().start() - if self.started: - return - self._processing = True - self._worker = asyncio.get_event_loop().create_task(self._client_worker()) - self._logger.info("Minecraft client started") - - async def stop(self, force:bool=False): - self._processing = False - if self.dispatcher.connected: - await self.dispatcher.disconnect(block=not force) - if not force: - await self._worker - self._logger.info("Minecraft client stopped") - if not force: - await self.join_callbacks() - await super().stop(force) - - async def info(self, host:str="", port:int=0, ping:bool=False) -> Dict[str, Any]: + async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]: """Make a mini connection to asses server status and version""" - await self.dispatcher.connect( - host or self.host, - port or self.port, - ) - #Handshake + self.host = host or self.host + self.port = port or self.port + try: + await self.dispatcher.connect(self.host, self.port) + await self._handshake(ConnectionState.STATUS) + return await self._status(ping) + finally: + await self.dispatcher.disconnect() + + async def join(self, host:str="", port:int=0, proto:int=0, packet_whitelist:Optional[Set[Type[Packet]]]=None): # jank packet_whitelist argument! TODO + self.host = host or self.host + self.port = port or self.port + if self.online_mode: + await self.authenticate() + try: + await self.dispatcher.connect( + host=self.host, + port=self.port, + proto=proto, + queue_timeout=self.options.poll_interval, + packet_whitelist=packet_whitelist + ) + await self._handshake(ConnectionState.LOGIN) + if await self._login(): + await self._play() + finally: + await self.dispatcher.disconnect() + + async def _handshake(self, state:ConnectionState): await self.dispatcher.write( PacketSetProtocol( self.dispatcher.proto, protocolVersion=self.dispatcher.proto, - serverHost=host or self.host, - serverPort=port or self.port, - nextState=ConnectionState.STATUS.value, + serverHost=self.host, + serverPort=self.port, + nextState=state.value ) ) + + async def _status(self, ping:bool=False) -> Dict[str, Any]: self.dispatcher.state = ConnectionState.STATUS - #Request await self.dispatcher.write( PacketPingStart(self.dispatcher.proto) #empty packet ) @@ -215,80 +176,16 @@ class MinecraftClient(CallbacksHolder, Runnable): if packet.time == ping_id: data['ping'] = int(time() - ping_time) break - await self.dispatcher.disconnect() return data - async def _client_worker(self): - try: - self._logger.info("Pinging server") - server_data = await self.info() - self._logger.info( - "Connecting to: %s (%d/%d)", - server_data['version']['name'], - server_data['players']['online'], - server_data['players']['max'] - ) - except Exception: - self._logger.exception("Exception checking server stats") - return - while self._processing: - if self.online_mode: - try: - await self.authenticate() - except AuthException as e: - self._logger.error(str(e)) - break - except Exception as e: - self._logger.exception("Unexpected error while authenticating") - break - try: - packet_whitelist = self.callback_keys(filter=Packet) if self.options.use_packet_whitelist else set() - await self.dispatcher.connect( - host=self.host, - port=self.port, - proto=server_data['version']['protocol'], - queue_timeout=self.options.poll_interval, - packet_whitelist=packet_whitelist - ) - self.dispatcher.proto = server_data['version']['protocol'] # TODO maybe check if it's supported? - await self._handshake() - if await self._login(): - await self._play() - except ConnectionRefusedError: - self._logger.error("Server rejected connection") - except OSError as e: - self._logger.error("Connection error : %s", str(e)) - except Exception: - self._logger.exception("Exception in Client connection") - if self.dispatcher.connected: - await self.dispatcher.disconnect() - if not self.options.reconnect: - break - if self._processing: # if client was stopped exit immediately - await asyncio.sleep(self.options.reconnect_delay) - if self._processing: - await self.stop(force=True) - - async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first - await self.dispatcher.write( - PacketSetProtocol( - self.dispatcher.proto, - protocolVersion=self.dispatcher.proto, - serverHost=self.host, - serverPort=self.port, - nextState=2, # play - ) - ) + async def _login(self) -> bool: + self.dispatcher.state = ConnectionState.LOGIN await self.dispatcher.write( PacketLoginStart( self.dispatcher.proto, username=self._authenticator.selectedProfile.name if self.online_mode else self._username ) ) - return True - - async def _login(self) -> bool: - self.dispatcher.state = ConnectionState.LOGIN async for packet in self.dispatcher.packets(): if isinstance(packet, PacketEncryptionBegin): if not self.online_mode: @@ -338,9 +235,8 @@ class MinecraftClient(CallbacksHolder, Runnable): return False return False - async def _play(self) -> bool: + async def _play(self): self.dispatcher.state = ConnectionState.PLAY - self.run_callbacks(ClientEvent.CONNECTED) async for packet in self.dispatcher.packets(): self._logger.debug("[ * ] Processing %s", packet.__class__.__name__) if isinstance(packet, PacketSetCompression): @@ -353,7 +249,3 @@ class MinecraftClient(CallbacksHolder, Runnable): elif isinstance(packet, PacketKickDisconnect): self._logger.error("Kicked while in game : %s", helpers.parse_chat(packet.reason)) break - self.run_callbacks(Packet, packet) - self.run_callbacks(type(packet), packet) - self.run_callbacks(ClientEvent.DISCONNECTED) - return False diff --git a/src/aiocraft/dispatcher.py b/src/aiocraft/dispatcher.py index cde8ed4..b864296 100644 --- a/src/aiocraft/dispatcher.py +++ b/src/aiocraft/dispatcher.py @@ -5,7 +5,7 @@ import zlib import logging from asyncio import StreamReader, StreamWriter, Queue, Task from enum import Enum -from typing import List, Dict, Set, Optional, AsyncIterator, Type +from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union from cryptography.hazmat.primitives.ciphers import CipherContext @@ -39,7 +39,7 @@ class Dispatcher: _incoming : Queue _outgoing : Queue - _packet_whitelist : Set[Packet] + _packet_whitelist : Set[Type[Packet]] _packet_id_whitelist : Set[int] host : str @@ -95,15 +95,15 @@ class Dispatcher: host:Optional[str] = None, port:Optional[int] = None, proto:Optional[int] = None, - queue_timeout:int = 1, + queue_timeout:float = 1, queue_size:int = 100, - packet_whitelist : List[Packet] = None + packet_whitelist : Set[Type[Packet]] = None ): self.proto = proto or self.proto or 757 # TODO not hardcode this? self.host = host or self.host or "localhost" self.port = port or self.port or 25565 self._logger = LOGGER.getChild(f"on({self.host}:{self.port})") - self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() + self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() # just in case make new set if self._packet_whitelist: self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive) self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect) @@ -128,9 +128,9 @@ class Dispatcher: proto : Optional[int] = None, reader : Optional[StreamReader] = None, writer : Optional[StreamWriter] = None, - queue_timeout : int = 1, + queue_timeout : float = 1, queue_size : int = 100, - packet_whitelist : Set[Packet] = None, + packet_whitelist : Set[Type[Packet]] = None, ): if self.connected: raise InvalidState("Dispatcher already connected") @@ -165,33 +165,34 @@ class Dispatcher: if block: await self._up.wait_closed() self._logger.debug("Socket closed") - self._logger.info("Disconnected") + if block: + self._logger.info("Disconnected") def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]: # TODO de-jank this, language server gets kinda mad - reg = None + # m : Module if self.state == ConnectionState.HANDSHAKING: - reg = minecraft_protocol.handshaking + m = minecraft_protocol.handshaking elif self.state == ConnectionState.STATUS: - reg = minecraft_protocol.status + m = minecraft_protocol.status elif self.state == ConnectionState.LOGIN: - reg = minecraft_protocol.login + m = minecraft_protocol.login elif self.state == ConnectionState.PLAY: - reg = minecraft_protocol.play + m = minecraft_protocol.play else: raise InvalidState("Cannot access registries from invalid state") if self.is_server: - reg = reg.serverbound.REGISTRY + reg = m.serverbound.REGISTRY else: - reg = reg.clientbound.REGISTRY + reg = m.clientbound.REGISTRY if not self.proto: raise InvalidState("Cannot access registries from invalid protocol") - reg = reg[self.proto] + proto_reg = reg[self.proto] - return reg[packet_id] + return proto_reg[packet_id] async def _read_varint(self) -> int: numRead = 0 diff --git a/src/aiocraft/traits/__init__.py b/src/aiocraft/traits/__init__.py deleted file mode 100644 index 950994c..0000000 --- a/src/aiocraft/traits/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .callbacks import CallbacksHolder -from .runnable import Runnable diff --git a/src/aiocraft/traits/callbacks.py b/src/aiocraft/traits/callbacks.py deleted file mode 100644 index a3b4ed5..0000000 --- a/src/aiocraft/traits/callbacks.py +++ /dev/null @@ -1,56 +0,0 @@ -import asyncio -import uuid -import logging - -from inspect import isclass -from typing import Dict, List, Set, Any, Callable, Type - -class CallbacksHolder: - - _callbacks : Dict[Any, List[Callable]] - _tasks : Dict[uuid.UUID, asyncio.Event] - - _logger : logging.Logger - - def __init__(self): - super().__init__() - self._callbacks = {} - self._tasks = {} - - def callback_keys(self, filter:Type = None) -> Set[Any]: - return set(x for x in self._callbacks.keys() if not filter or (isclass(x) and issubclass(x, filter))) - - def register(self, key:Any, callback:Callable): - if key not in self._callbacks: - self._callbacks[key] = [] - self._callbacks[key].append(callback) - return callback - - def trigger(self, key:Any) -> List[Callable]: - if key not in self._callbacks: - return [] - return self._callbacks[key] - - def _wrap(self, cb:Callable, uid:uuid.UUID) -> Callable: - async def wrapper(*args): - try: - ret = await cb(*args) - except Exception: - logging.exception("Exception processing callback") - ret = None - self._tasks[uid].set() - self._tasks.pop(uid) - return ret - return wrapper - - def run_callbacks(self, key:Any, *args) -> None: - for cb in self.trigger(key): - task_id = uuid.uuid4() - self._tasks[task_id] = asyncio.Event() - - asyncio.get_event_loop().create_task(self._wrap(cb, task_id)(*args)) - - async def join_callbacks(self): - await asyncio.gather(*list(t.wait() for t in self._tasks.values())) - self._tasks.clear() - diff --git a/src/aiocraft/traits/runnable.py b/src/aiocraft/traits/runnable.py deleted file mode 100644 index eee0d29..0000000 --- a/src/aiocraft/traits/runnable.py +++ /dev/null @@ -1,45 +0,0 @@ -import asyncio -import logging - -from typing import Optional -from signal import signal, SIGINT, SIGTERM, SIGABRT - -class Runnable: - _is_running : bool - _stop_task : Optional[asyncio.Task] - - def __init__(self): - self._is_running = False - self._stop_task = None - - async def start(self): - self._is_running = True - - async def stop(self, force:bool=False): - self._is_running = False - - def run(self): - logging.info("Starting process") - - def signal_handler(signum, __): - if signum == SIGINT: - if self._stop_task: - self._stop_task.cancel() - logging.info("Received SIGINT, terminating") - else: - logging.info("Received SIGINT, stopping gracefully...") - self._stop_task = asyncio.get_event_loop().create_task(self.stop(force=self._stop_task is not None)) - - signal(SIGINT, signal_handler) - - loop = asyncio.get_event_loop() - - async def main(): - await self.start() - while self._is_running: - await asyncio.sleep(1) - - loop.run_until_complete(main()) - - logging.info("Process finished") -