tweaks to work for server too, catch and disconnect if ConnectionResetError
This commit is contained in:
parent
741ec33f1b
commit
a5b82cb8a6
1 changed files with 48 additions and 17 deletions
|
@ -5,11 +5,11 @@ import zlib
|
||||||
import logging
|
import logging
|
||||||
from asyncio import StreamReader, StreamWriter, Queue, Task
|
from asyncio import StreamReader, StreamWriter, Queue, Task
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Dict, Optional, AsyncIterator
|
from typing import List, Dict, Optional, AsyncIterator, Type
|
||||||
|
|
||||||
from cryptography.hazmat.primitives.ciphers import CipherContext
|
from cryptography.hazmat.primitives.ciphers import CipherContext
|
||||||
|
|
||||||
from .mc import proto
|
from .mc import proto as minecraft_protocol
|
||||||
from .mc.types import VarInt
|
from .mc.types import VarInt
|
||||||
from .mc.packet import Packet
|
from .mc.packet import Packet
|
||||||
from .mc.definitions import ConnectionState
|
from .mc.definitions import ConnectionState
|
||||||
|
@ -24,14 +24,9 @@ class InvalidState(Exception):
|
||||||
class ConnectionError(Exception):
|
class ConnectionError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
_STATE_REGS = {
|
|
||||||
ConnectionState.HANDSHAKING : proto.handshaking.clientbound.REGISTRY,
|
|
||||||
ConnectionState.STATUS : proto.status.clientbound.REGISTRY,
|
|
||||||
ConnectionState.LOGIN : proto.login.clientbound.REGISTRY,
|
|
||||||
ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
|
|
||||||
}
|
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher:
|
||||||
|
_is_server : bool # True when receiving packets from clients
|
||||||
|
|
||||||
_down : StreamReader
|
_down : StreamReader
|
||||||
_reader : Optional[Task]
|
_reader : Optional[Task]
|
||||||
_decryptor : CipherContext
|
_decryptor : CipherContext
|
||||||
|
@ -55,9 +50,16 @@ class Dispatcher:
|
||||||
|
|
||||||
_logger : logging.Logger
|
_logger : logging.Logger
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, server:bool = False):
|
||||||
|
self._is_server = server
|
||||||
|
self._host = "localhost"
|
||||||
|
self._port = 25565
|
||||||
self._prepare()
|
self._prepare()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_server(self) -> bool:
|
||||||
|
return self._is_server
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connected(self) -> bool:
|
def connected(self) -> bool:
|
||||||
return self._dispatching
|
return self._dispatching
|
||||||
|
@ -146,6 +148,32 @@ class Dispatcher:
|
||||||
self._logger.debug("Socket closed")
|
self._logger.debug("Socket closed")
|
||||||
self._logger.info("Disconnected")
|
self._logger.info("Disconnected")
|
||||||
|
|
||||||
|
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
|
||||||
|
# TODO de-jank this, language server gets kinda mad
|
||||||
|
reg = None
|
||||||
|
if self.state == ConnectionState.HANDSHAKING:
|
||||||
|
reg = minecraft_protocol.handshaking
|
||||||
|
elif self.state == ConnectionState.STATUS:
|
||||||
|
reg = minecraft_protocol.status
|
||||||
|
elif self.state == ConnectionState.LOGIN:
|
||||||
|
reg = minecraft_protocol.login
|
||||||
|
elif self.state == ConnectionState.PLAY:
|
||||||
|
reg = minecraft_protocol.play
|
||||||
|
else:
|
||||||
|
raise InvalidState("Cannot access registries from invalid state")
|
||||||
|
|
||||||
|
if self.is_server:
|
||||||
|
reg = reg.serverbound.REGISTRY
|
||||||
|
else:
|
||||||
|
reg = reg.clientbound.REGISTRY
|
||||||
|
|
||||||
|
if not self.proto:
|
||||||
|
raise InvalidState("Cannot access registries from invalid protocol")
|
||||||
|
|
||||||
|
reg = reg[self.proto]
|
||||||
|
|
||||||
|
return reg[packet_id]
|
||||||
|
|
||||||
async def _read_varint(self) -> int:
|
async def _read_varint(self) -> int:
|
||||||
numRead = 0
|
numRead = 0
|
||||||
result = 0
|
result = 0
|
||||||
|
@ -165,6 +193,7 @@ class Dispatcher:
|
||||||
async def _down_worker(self):
|
async def _down_worker(self):
|
||||||
while self._dispatching:
|
while self._dispatching:
|
||||||
try: # these 2 will timeout or raise EOFError if client gets disconnected
|
try: # these 2 will timeout or raise EOFError if client gets disconnected
|
||||||
|
self._logger.debug("Reading packet")
|
||||||
length = await self._read_varint()
|
length = await self._read_varint()
|
||||||
data = await self._down.readexactly(length)
|
data = await self._down.readexactly(length)
|
||||||
|
|
||||||
|
@ -183,16 +212,18 @@ class Dispatcher:
|
||||||
buffer = io.BytesIO(decompressed_data)
|
buffer = io.BytesIO(decompressed_data)
|
||||||
|
|
||||||
packet_id = VarInt.read(buffer)
|
packet_id = VarInt.read(buffer)
|
||||||
cls = _STATE_REGS[self.state][self.proto][packet_id]
|
cls = self._packet_type_from_registry(packet_id)
|
||||||
|
self._logger.debug("Deserializing packet %s | %s", str(cls), cls._state)
|
||||||
packet = cls.deserialize(self.proto, buffer)
|
packet = cls.deserialize(self.proto, buffer)
|
||||||
self._logger.debug("[<--] Received | %s", str(packet))
|
self._logger.debug("[<--] Received | %s", str(packet))
|
||||||
await self._incoming.put(packet)
|
await self._incoming.put(packet)
|
||||||
if self.state == ConnectionState.LOGIN:
|
if self.state != ConnectionState.PLAY:
|
||||||
await self._incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression
|
await self._incoming.join() # During play we can pre-process packets
|
||||||
except AttributeError:
|
except ConnectionResetError:
|
||||||
self._logger.debug("Unimplemented packet %d", packet_id)
|
self._logger.error("Connection reset while reading packet")
|
||||||
except asyncio.IncompleteReadError:
|
await self.disconnect(block=False)
|
||||||
self._logger.error("EOF from server")
|
except (asyncio.IncompleteReadError, EOFError):
|
||||||
|
self._logger.error("Received EOF while reading packet")
|
||||||
await self.disconnect(block=False)
|
await self.disconnect(block=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue())
|
self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue())
|
||||||
|
|
Loading…
Reference in a new issue