diff --git a/aiocraft/dispatcher.py b/aiocraft/dispatcher.py index ac66402..4b135ab 100644 --- a/aiocraft/dispatcher.py +++ b/aiocraft/dispatcher.py @@ -5,11 +5,11 @@ import zlib import logging from asyncio import StreamReader, StreamWriter, Queue, Task from enum import Enum -from typing import List, Dict, Optional, AsyncIterator +from typing import List, Dict, Optional, AsyncIterator, Type from cryptography.hazmat.primitives.ciphers import CipherContext -from .mc import proto +from .mc import proto as minecraft_protocol from .mc.types import VarInt from .mc.packet import Packet from .mc.definitions import ConnectionState @@ -24,14 +24,9 @@ class InvalidState(Exception): class ConnectionError(Exception): pass -_STATE_REGS = { - ConnectionState.HANDSHAKING : proto.handshaking.clientbound.REGISTRY, - ConnectionState.STATUS : proto.status.clientbound.REGISTRY, - ConnectionState.LOGIN : proto.login.clientbound.REGISTRY, - ConnectionState.PLAY : proto.play.clientbound.REGISTRY, -} - class Dispatcher: + _is_server : bool # True when receiving packets from clients + _down : StreamReader _reader : Optional[Task] _decryptor : CipherContext @@ -55,9 +50,16 @@ class Dispatcher: _logger : logging.Logger - def __init__(self): + def __init__(self, server:bool = False): + self._is_server = server + self._host = "localhost" + self._port = 25565 self._prepare() + @property + def is_server(self) -> bool: + return self._is_server + @property def connected(self) -> bool: return self._dispatching @@ -146,6 +148,32 @@ class Dispatcher: self._logger.debug("Socket closed") self._logger.info("Disconnected") + def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]: + # TODO de-jank this, language server gets kinda mad + reg = None + if self.state == ConnectionState.HANDSHAKING: + reg = minecraft_protocol.handshaking + elif self.state == ConnectionState.STATUS: + reg = minecraft_protocol.status + elif self.state == ConnectionState.LOGIN: + reg = minecraft_protocol.login + elif self.state == ConnectionState.PLAY: + reg = minecraft_protocol.play + else: + raise InvalidState("Cannot access registries from invalid state") + + if self.is_server: + reg = reg.serverbound.REGISTRY + else: + reg = reg.clientbound.REGISTRY + + if not self.proto: + raise InvalidState("Cannot access registries from invalid protocol") + + reg = reg[self.proto] + + return reg[packet_id] + async def _read_varint(self) -> int: numRead = 0 result = 0 @@ -165,6 +193,7 @@ class Dispatcher: async def _down_worker(self): while self._dispatching: try: # these 2 will timeout or raise EOFError if client gets disconnected + self._logger.debug("Reading packet") length = await self._read_varint() data = await self._down.readexactly(length) @@ -183,16 +212,18 @@ class Dispatcher: buffer = io.BytesIO(decompressed_data) packet_id = VarInt.read(buffer) - cls = _STATE_REGS[self.state][self.proto][packet_id] + cls = self._packet_type_from_registry(packet_id) + self._logger.debug("Deserializing packet %s | %s", str(cls), cls._state) packet = cls.deserialize(self.proto, buffer) self._logger.debug("[<--] Received | %s", str(packet)) await self._incoming.put(packet) - if self.state == ConnectionState.LOGIN: - await self._incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression - except AttributeError: - self._logger.debug("Unimplemented packet %d", packet_id) - except asyncio.IncompleteReadError: - self._logger.error("EOF from server") + if self.state != ConnectionState.PLAY: + await self._incoming.join() # During play we can pre-process packets + except ConnectionResetError: + self._logger.error("Connection reset while reading packet") + await self.disconnect(block=False) + except (asyncio.IncompleteReadError, EOFError): + self._logger.error("Received EOF while reading packet") await self.disconnect(block=False) except Exception: self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue())