tweaks to work for server too, catch and disconnect if ConnectionResetError

This commit is contained in:
əlemi 2021-11-21 23:55:59 +01:00
parent 741ec33f1b
commit a5b82cb8a6

View file

@ -5,11 +5,11 @@ import zlib
import logging import logging
from asyncio import StreamReader, StreamWriter, Queue, Task from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum from enum import Enum
from typing import List, Dict, Optional, AsyncIterator from typing import List, Dict, Optional, AsyncIterator, Type
from cryptography.hazmat.primitives.ciphers import CipherContext from cryptography.hazmat.primitives.ciphers import CipherContext
from .mc import proto from .mc import proto as minecraft_protocol
from .mc.types import VarInt from .mc.types import VarInt
from .mc.packet import Packet from .mc.packet import Packet
from .mc.definitions import ConnectionState from .mc.definitions import ConnectionState
@ -24,14 +24,9 @@ class InvalidState(Exception):
class ConnectionError(Exception): class ConnectionError(Exception):
pass pass
_STATE_REGS = {
ConnectionState.HANDSHAKING : proto.handshaking.clientbound.REGISTRY,
ConnectionState.STATUS : proto.status.clientbound.REGISTRY,
ConnectionState.LOGIN : proto.login.clientbound.REGISTRY,
ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
}
class Dispatcher: class Dispatcher:
_is_server : bool # True when receiving packets from clients
_down : StreamReader _down : StreamReader
_reader : Optional[Task] _reader : Optional[Task]
_decryptor : CipherContext _decryptor : CipherContext
@ -55,9 +50,16 @@ class Dispatcher:
_logger : logging.Logger _logger : logging.Logger
def __init__(self): def __init__(self, server:bool = False):
self._is_server = server
self._host = "localhost"
self._port = 25565
self._prepare() self._prepare()
@property
def is_server(self) -> bool:
return self._is_server
@property @property
def connected(self) -> bool: def connected(self) -> bool:
return self._dispatching return self._dispatching
@ -146,6 +148,32 @@ class Dispatcher:
self._logger.debug("Socket closed") self._logger.debug("Socket closed")
self._logger.info("Disconnected") self._logger.info("Disconnected")
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
# TODO de-jank this, language server gets kinda mad
reg = None
if self.state == ConnectionState.HANDSHAKING:
reg = minecraft_protocol.handshaking
elif self.state == ConnectionState.STATUS:
reg = minecraft_protocol.status
elif self.state == ConnectionState.LOGIN:
reg = minecraft_protocol.login
elif self.state == ConnectionState.PLAY:
reg = minecraft_protocol.play
else:
raise InvalidState("Cannot access registries from invalid state")
if self.is_server:
reg = reg.serverbound.REGISTRY
else:
reg = reg.clientbound.REGISTRY
if not self.proto:
raise InvalidState("Cannot access registries from invalid protocol")
reg = reg[self.proto]
return reg[packet_id]
async def _read_varint(self) -> int: async def _read_varint(self) -> int:
numRead = 0 numRead = 0
result = 0 result = 0
@ -165,6 +193,7 @@ class Dispatcher:
async def _down_worker(self): async def _down_worker(self):
while self._dispatching: while self._dispatching:
try: # these 2 will timeout or raise EOFError if client gets disconnected try: # these 2 will timeout or raise EOFError if client gets disconnected
self._logger.debug("Reading packet")
length = await self._read_varint() length = await self._read_varint()
data = await self._down.readexactly(length) data = await self._down.readexactly(length)
@ -183,16 +212,18 @@ class Dispatcher:
buffer = io.BytesIO(decompressed_data) buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer) packet_id = VarInt.read(buffer)
cls = _STATE_REGS[self.state][self.proto][packet_id] cls = self._packet_type_from_registry(packet_id)
self._logger.debug("Deserializing packet %s | %s", str(cls), cls._state)
packet = cls.deserialize(self.proto, buffer) packet = cls.deserialize(self.proto, buffer)
self._logger.debug("[<--] Received | %s", str(packet)) self._logger.debug("[<--] Received | %s", str(packet))
await self._incoming.put(packet) await self._incoming.put(packet)
if self.state == ConnectionState.LOGIN: if self.state != ConnectionState.PLAY:
await self._incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression await self._incoming.join() # During play we can pre-process packets
except AttributeError: except ConnectionResetError:
self._logger.debug("Unimplemented packet %d", packet_id) self._logger.error("Connection reset while reading packet")
except asyncio.IncompleteReadError: await self.disconnect(block=False)
self._logger.error("EOF from server") except (asyncio.IncompleteReadError, EOFError):
self._logger.error("Received EOF while reading packet")
await self.disconnect(block=False) await self.disconnect(block=False)
except Exception: except Exception:
self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue()) self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue())