tweaks to work for server too, catch and disconnect if ConnectionResetError
This commit is contained in:
parent
741ec33f1b
commit
a5b82cb8a6
1 changed files with 48 additions and 17 deletions
|
@ -5,11 +5,11 @@ import zlib
|
|||
import logging
|
||||
from asyncio import StreamReader, StreamWriter, Queue, Task
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Optional, AsyncIterator
|
||||
from typing import List, Dict, Optional, AsyncIterator, Type
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers import CipherContext
|
||||
|
||||
from .mc import proto
|
||||
from .mc import proto as minecraft_protocol
|
||||
from .mc.types import VarInt
|
||||
from .mc.packet import Packet
|
||||
from .mc.definitions import ConnectionState
|
||||
|
@ -24,14 +24,9 @@ class InvalidState(Exception):
|
|||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
_STATE_REGS = {
|
||||
ConnectionState.HANDSHAKING : proto.handshaking.clientbound.REGISTRY,
|
||||
ConnectionState.STATUS : proto.status.clientbound.REGISTRY,
|
||||
ConnectionState.LOGIN : proto.login.clientbound.REGISTRY,
|
||||
ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
|
||||
}
|
||||
|
||||
class Dispatcher:
|
||||
_is_server : bool # True when receiving packets from clients
|
||||
|
||||
_down : StreamReader
|
||||
_reader : Optional[Task]
|
||||
_decryptor : CipherContext
|
||||
|
@ -55,9 +50,16 @@ class Dispatcher:
|
|||
|
||||
_logger : logging.Logger
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, server:bool = False):
|
||||
self._is_server = server
|
||||
self._host = "localhost"
|
||||
self._port = 25565
|
||||
self._prepare()
|
||||
|
||||
@property
|
||||
def is_server(self) -> bool:
|
||||
return self._is_server
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
return self._dispatching
|
||||
|
@ -146,6 +148,32 @@ class Dispatcher:
|
|||
self._logger.debug("Socket closed")
|
||||
self._logger.info("Disconnected")
|
||||
|
||||
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
|
||||
# TODO de-jank this, language server gets kinda mad
|
||||
reg = None
|
||||
if self.state == ConnectionState.HANDSHAKING:
|
||||
reg = minecraft_protocol.handshaking
|
||||
elif self.state == ConnectionState.STATUS:
|
||||
reg = minecraft_protocol.status
|
||||
elif self.state == ConnectionState.LOGIN:
|
||||
reg = minecraft_protocol.login
|
||||
elif self.state == ConnectionState.PLAY:
|
||||
reg = minecraft_protocol.play
|
||||
else:
|
||||
raise InvalidState("Cannot access registries from invalid state")
|
||||
|
||||
if self.is_server:
|
||||
reg = reg.serverbound.REGISTRY
|
||||
else:
|
||||
reg = reg.clientbound.REGISTRY
|
||||
|
||||
if not self.proto:
|
||||
raise InvalidState("Cannot access registries from invalid protocol")
|
||||
|
||||
reg = reg[self.proto]
|
||||
|
||||
return reg[packet_id]
|
||||
|
||||
async def _read_varint(self) -> int:
|
||||
numRead = 0
|
||||
result = 0
|
||||
|
@ -165,6 +193,7 @@ class Dispatcher:
|
|||
async def _down_worker(self):
|
||||
while self._dispatching:
|
||||
try: # these 2 will timeout or raise EOFError if client gets disconnected
|
||||
self._logger.debug("Reading packet")
|
||||
length = await self._read_varint()
|
||||
data = await self._down.readexactly(length)
|
||||
|
||||
|
@ -183,16 +212,18 @@ class Dispatcher:
|
|||
buffer = io.BytesIO(decompressed_data)
|
||||
|
||||
packet_id = VarInt.read(buffer)
|
||||
cls = _STATE_REGS[self.state][self.proto][packet_id]
|
||||
cls = self._packet_type_from_registry(packet_id)
|
||||
self._logger.debug("Deserializing packet %s | %s", str(cls), cls._state)
|
||||
packet = cls.deserialize(self.proto, buffer)
|
||||
self._logger.debug("[<--] Received | %s", str(packet))
|
||||
await self._incoming.put(packet)
|
||||
if self.state == ConnectionState.LOGIN:
|
||||
await self._incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression
|
||||
except AttributeError:
|
||||
self._logger.debug("Unimplemented packet %d", packet_id)
|
||||
except asyncio.IncompleteReadError:
|
||||
self._logger.error("EOF from server")
|
||||
if self.state != ConnectionState.PLAY:
|
||||
await self._incoming.join() # During play we can pre-process packets
|
||||
except ConnectionResetError:
|
||||
self._logger.error("Connection reset while reading packet")
|
||||
await self.disconnect(block=False)
|
||||
except (asyncio.IncompleteReadError, EOFError):
|
||||
self._logger.error("Received EOF while reading packet")
|
||||
await self.disconnect(block=False)
|
||||
except Exception:
|
||||
self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue())
|
||||
|
|
Loading…
Reference in a new issue