diff --git a/src/aiocraft/client.py b/src/aiocraft/client.py index acddf52..fcf1a67 100644 --- a/src/aiocraft/client.py +++ b/src/aiocraft/client.py @@ -28,11 +28,8 @@ from .util import encryption, helpers LOGGER = logging.getLogger(__name__) class MinecraftClient: - host:str - port:int - username:str online_mode:bool - authenticator:Optional[AuthInterface] + authenticator:AuthInterface dispatcher : Dispatcher logger : logging.Logger _authenticated : bool @@ -42,28 +39,26 @@ class MinecraftClient: def __init__( self, server:str, + authenticator:AuthInterface, online_mode:bool = True, - authenticator:AuthInterface=None, - username:str = "", ): super().__init__() if ":" in server: _host, _port = server.split(":", 1) - self.host = _host.strip() - self.port = int(_port) + host = _host.strip() + port = int(_port) else: - self.host = server.strip() - self.port = 25565 + host = server.strip() + port = 25565 - self.username = username self.online_mode = online_mode self.authenticator = authenticator self._authenticated = False - self.dispatcher = Dispatcher() + self.dispatcher = Dispatcher().set_host(host, port) self._processing = False - self.logger = LOGGER.getChild(f"on({self.host}:{self.port})") + self.logger = LOGGER.getChild(f"on({server})") @property def connected(self) -> bool: @@ -88,27 +83,18 @@ class MinecraftClient: async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]: """Make a mini connection to asses server status and version""" - self.host = host or self.host - self.port = port or self.port try: - await self.dispatcher.connect(self.host, self.port) + await self.dispatcher.set_host(host, port).connect() await self._handshake(ConnectionState.STATUS) return await self._status(ping) finally: await self.dispatcher.disconnect() - async def join(self, host:str="", port:int=0, proto:int=0, packet_whitelist:Optional[Set[Type[Packet]]]=None): # jank packet_whitelist argument! TODO - self.host = host or self.host - self.port = port or self.port + async def join(self, host:str="", port:int=0, proto:int=0): if self.online_mode: await self.authenticate() try: - await self.dispatcher.connect( - host=self.host, - port=self.port, - proto=proto, - packet_whitelist=packet_whitelist - ) + await self.dispatcher.set_host(host, port).set_proto(proto).connect() await self._handshake(ConnectionState.LOGIN) if await self._login(): await self._play() @@ -120,8 +106,8 @@ class MinecraftClient: PacketSetProtocol( self.dispatcher.proto, protocolVersion=self.dispatcher.proto, - serverHost=self.host, - serverPort=self.port, + serverHost=self.dispatcher.host, + serverPort=self.dispatcher.port, nextState=state.value ) ) @@ -160,17 +146,14 @@ class MinecraftClient: await self.dispatcher.write( PacketLoginStart( self.dispatcher.proto, - username=self.authenticator.selectedProfile.name if self.online_mode and self.authenticator else self.username + username=self.authenticator.selectedProfile.name ) ) async for packet in self.dispatcher.packets(): if isinstance(packet, PacketEncryptionBegin): - if not self.online_mode: + if not self.online_mode or not self.authenticator or not self.authenticator.accessToken: # overkill to check authenticator and accessToken but whatever self.logger.error("Cannot answer Encryption Request in offline mode") return False - if not self.authenticator: - self.logger.error("No available token to enable encryption") - return False secret = encryption.generate_shared_secret() token, encrypted_secret = encryption.encrypt_token_and_secret( packet.publicKey, diff --git a/src/aiocraft/dispatcher.py b/src/aiocraft/dispatcher.py index 1700793..baac8c3 100644 --- a/src/aiocraft/dispatcher.py +++ b/src/aiocraft/dispatcher.py @@ -6,6 +6,7 @@ import logging from asyncio import StreamReader, StreamWriter, Queue, Task from enum import Enum from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union +from types import ModuleType from cryptography.hazmat.primitives.ciphers import CipherContext @@ -39,8 +40,8 @@ class Dispatcher: _incoming : Queue _outgoing : Queue - _packet_whitelist : Set[Type[Packet]] - _packet_id_whitelist : Set[int] + _packet_whitelist : Optional[Set[Type[Packet]]] + _packet_id_whitelist : Optional[Set[int]] host : str port : int @@ -50,14 +51,13 @@ class Dispatcher: encryption : bool compression : Optional[int] - _logger : logging.Logger + logger : logging.Logger def __init__(self, server:bool = False): self.proto = 757 self._is_server = server self.host = "localhost" self.port = 25565 - self._prepare() @property def is_server(self) -> bool: @@ -84,57 +84,58 @@ class Dispatcher: except asyncio.TimeoutError: pass # so we recheck self.connected - def encrypt(self, secret:bytes): - cipher = encryption.create_AES_cipher(secret) - self._encryptor = cipher.encryptor() - self._decryptor = cipher.decryptor() - self.encryption = True - self._logger.info("Encryption enabled") + def encrypt(self, secret:Optional[bytes]=None) -> 'Dispatcher': + if secret is not None: + cipher = encryption.create_AES_cipher(secret) + self._encryptor = cipher.encryptor() + self._decryptor = cipher.decryptor() + self.encryption = True + self.logger.info("Encryption enabled") + else: + self.encryption = False + self.logger.info("Encryption disabled") + return self - def _prepare(self, - host:Optional[str] = None, - port:Optional[int] = None, - proto:Optional[int] = None, - queue_size:int = 100, - packet_whitelist : Set[Type[Packet]] = None - ): - self.proto = proto or self.proto or 757 # TODO not hardcode this? - self.host = host or self.host or "localhost" - self.port = port or self.port or 25565 - self._logger = LOGGER.getChild(f"on({self.host}:{self.port})") - self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() # just in case make new set + def whitelist(self, ids:Optional[List[Type[Packet]]]) -> 'Dispatcher': + self._packet_whitelist = set(ids) if ids is not None else None if self._packet_whitelist: self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive) self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect) + self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None + return self + + def set_host(self, host:Optional[str]="localhost", port:Optional[int]=25565) -> 'Dispatcher': + self.host = host or self.host + self.port = port or self.port + self.logger = LOGGER.getChild(f"on({self.host}:{self.port})") + return self + + def set_proto(self, proto:Optional[int]=757) -> 'Dispatcher': + self.proto = proto or self.proto + if self._packet_id_whitelist: + self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else set() + return self + + def set_state(self, state:Optional[ConnectionState]=ConnectionState.HANDSHAKING) -> 'Dispatcher': + self.state = state or self.state + return self + + async def connect(self, + reader : Optional[StreamReader] = None, + writer : Optional[StreamWriter] = None, + queue_size : int = 100, + ) -> 'Dispatcher': + if self.connected: + raise InvalidState("Dispatcher already connected") self.encryption = False self.compression = None - self.state = ConnectionState.HANDSHAKING - - # This can only happen after we know the connection protocol - self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else set() - - # Make new queues, do set a max size to sorta propagate back pressure self._incoming = Queue(queue_size) self._outgoing = Queue(queue_size) self._dispatching = False self._reader = None self._writer = None - async def connect(self, - host : Optional[str] = None, - port : Optional[int] = None, - proto : Optional[int] = None, - reader : Optional[StreamReader] = None, - writer : Optional[StreamWriter] = None, - queue_size : int = 100, - packet_whitelist : Set[Type[Packet]] = None, - ): - if self.connected: - raise InvalidState("Dispatcher already connected") - - self._prepare(host, port, proto, queue_size, packet_whitelist) - if reader and writer: self._down, self._up = reader, writer else: @@ -146,29 +147,31 @@ class Dispatcher: self._dispatching = True self._reader = asyncio.get_event_loop().create_task(self._down_worker()) self._writer = asyncio.get_event_loop().create_task(self._up_worker()) - self._logger.info("Connected") + self.logger.info("Connected") + return self - async def disconnect(self, block:bool=True): + async def disconnect(self, block:bool=True) -> 'Dispatcher': self._dispatching = False if block and self._writer and self._reader: await asyncio.gather(self._writer, self._reader) - self._logger.debug("Net workers stopped") + self.logger.debug("Net workers stopped") if self._up: if not self._up.is_closing() and self._up.can_write_eof(): try: self._up.write_eof() except OSError as e: - self._logger.error("Could not write EOF : %s", str(e)) + self.logger.error("Could not write EOF : %s", str(e)) self._up.close() if block: await self._up.wait_closed() - self._logger.debug("Socket closed") + self.logger.debug("Socket closed") if block: - self._logger.info("Disconnected") + self.logger.info("Disconnected") + return self def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]: - # TODO de-jank this, language server gets kinda mad - # m : Module + m : ModuleType + if self.state == ConnectionState.HANDSHAKING: m = minecraft_protocol.handshaking elif self.state == ConnectionState.STATUS: @@ -235,26 +238,26 @@ class Dispatcher: packet_id = VarInt.read(buffer, Context(_proto=self.proto)) if self.state == ConnectionState.PLAY and self._packet_id_whitelist \ and packet_id not in self._packet_id_whitelist: - self._logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id) + self.logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id) continue # ignore this packet, we rarely need them all, should improve performance cls = self._packet_type_from_registry(packet_id) packet = cls.deserialize(self.proto, buffer) - self._logger.debug("[<--] Received | %s", repr(packet)) + self.logger.debug("[<--] Received | %s", repr(packet)) await self._incoming.put(packet) if self.state != ConnectionState.PLAY: await self._incoming.join() # During play we can pre-process packets except (asyncio.TimeoutError, TimeoutError): - self._logger.error("Connection timed out") + self.logger.error("Connection timed out") await self.disconnect(block=False) except (ConnectionResetError, BrokenPipeError): - self._logger.error("Connection reset while reading packet") + self.logger.error("Connection reset while reading packet") await self.disconnect(block=False) except (asyncio.IncompleteReadError, EOFError): - self._logger.error("Received EOF while reading packet") + self.logger.error("Received EOF while reading packet") await self.disconnect(block=False) except Exception: - self._logger.exception("Exception parsing packet %d", packet_id) - self._logger.debug("%s", buffer.getvalue()) + self.logger.exception("Exception parsing packet %d", packet_id) + self.logger.debug("%s", buffer.getvalue()) await self.disconnect(block=False) async def _up_worker(self, timeout=1): @@ -288,9 +291,9 @@ class Dispatcher: self._up.write(data) await self._up.drain() - self._logger.debug("[-->] Sent | %s", repr(packet)) + self.logger.debug("[-->] Sent | %s", repr(packet)) except Exception: - self._logger.exception("Exception dispatching packet %s", str(packet)) + self.logger.exception("Exception dispatching packet %s", str(packet)) packet.processed.set() # Notify that packet has been processed