tweaks to work for server too, catch and disconnect if ConnectionResetError

This commit is contained in:
əlemi 2021-11-21 23:55:59 +01:00
parent 741ec33f1b
commit a5b82cb8a6

View file

@ -5,11 +5,11 @@ import zlib
import logging
from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum
from typing import List, Dict, Optional, AsyncIterator
from typing import List, Dict, Optional, AsyncIterator, Type
from cryptography.hazmat.primitives.ciphers import CipherContext
from .mc import proto
from .mc import proto as minecraft_protocol
from .mc.types import VarInt
from .mc.packet import Packet
from .mc.definitions import ConnectionState
@ -24,14 +24,9 @@ class InvalidState(Exception):
class ConnectionError(Exception):
pass
_STATE_REGS = {
ConnectionState.HANDSHAKING : proto.handshaking.clientbound.REGISTRY,
ConnectionState.STATUS : proto.status.clientbound.REGISTRY,
ConnectionState.LOGIN : proto.login.clientbound.REGISTRY,
ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
}
class Dispatcher:
_is_server : bool # True when receiving packets from clients
_down : StreamReader
_reader : Optional[Task]
_decryptor : CipherContext
@ -55,9 +50,16 @@ class Dispatcher:
_logger : logging.Logger
def __init__(self):
def __init__(self, server:bool = False):
self._is_server = server
self._host = "localhost"
self._port = 25565
self._prepare()
@property
def is_server(self) -> bool:
return self._is_server
@property
def connected(self) -> bool:
return self._dispatching
@ -146,6 +148,32 @@ class Dispatcher:
self._logger.debug("Socket closed")
self._logger.info("Disconnected")
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
# TODO de-jank this, language server gets kinda mad
reg = None
if self.state == ConnectionState.HANDSHAKING:
reg = minecraft_protocol.handshaking
elif self.state == ConnectionState.STATUS:
reg = minecraft_protocol.status
elif self.state == ConnectionState.LOGIN:
reg = minecraft_protocol.login
elif self.state == ConnectionState.PLAY:
reg = minecraft_protocol.play
else:
raise InvalidState("Cannot access registries from invalid state")
if self.is_server:
reg = reg.serverbound.REGISTRY
else:
reg = reg.clientbound.REGISTRY
if not self.proto:
raise InvalidState("Cannot access registries from invalid protocol")
reg = reg[self.proto]
return reg[packet_id]
async def _read_varint(self) -> int:
numRead = 0
result = 0
@ -165,6 +193,7 @@ class Dispatcher:
async def _down_worker(self):
while self._dispatching:
try: # these 2 will timeout or raise EOFError if client gets disconnected
self._logger.debug("Reading packet")
length = await self._read_varint()
data = await self._down.readexactly(length)
@ -183,16 +212,18 @@ class Dispatcher:
buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer)
cls = _STATE_REGS[self.state][self.proto][packet_id]
cls = self._packet_type_from_registry(packet_id)
self._logger.debug("Deserializing packet %s | %s", str(cls), cls._state)
packet = cls.deserialize(self.proto, buffer)
self._logger.debug("[<--] Received | %s", str(packet))
await self._incoming.put(packet)
if self.state == ConnectionState.LOGIN:
await self._incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression
except AttributeError:
self._logger.debug("Unimplemented packet %d", packet_id)
except asyncio.IncompleteReadError:
self._logger.error("EOF from server")
if self.state != ConnectionState.PLAY:
await self._incoming.join() # During play we can pre-process packets
except ConnectionResetError:
self._logger.error("Connection reset while reading packet")
await self.disconnect(block=False)
except (asyncio.IncompleteReadError, EOFError):
self._logger.error("Received EOF while reading packet")
await self.disconnect(block=False)
except Exception:
self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue())