aiocraft/aiocraft/dispatcher.py
alemi d6e5e7a6f6
fix: use definitions from previous protocols
It's not really elegant but it works. If a definition is missing, it's
likely because it's unchanged since last revision
2023-02-16 18:42:36 +01:00

353 lines
11 KiB
Python

import io
import asyncio
import contextlib
import zlib
import logging
from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum
from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
from types import ModuleType
from cryptography.hazmat.primitives.ciphers import CipherContext
from .mc import proto as minecraft_protocol
from .mc.types import VarInt, Context
from .mc.packet import Packet
from .mc.definitions import ConnectionState
from .util import encryption
LOGGER = logging.getLogger(__name__)
class InvalidState(Exception):
pass
class ConnectionError(Exception):
pass
class Dispatcher:
_is_server : bool # True when receiving packets from clients
_down : StreamReader
_reader : Optional[Task]
_decryptor : CipherContext
_up : StreamWriter
_writer : Optional[Task]
_encryptor : CipherContext
_dispatching : bool
_incoming : Queue
_outgoing : Queue
_packet_whitelist : Optional[Set[Type[Packet]]]
_packet_id_whitelist : Optional[Set[int]]
_log_ignored_packets : bool
_host : str
_port : int
_proto : int
_encryption : bool
_compression : Optional[int]
state : ConnectionState # TODO make getter/setter ?
logger : logging.Logger
def __init__(self, server:bool = False):
self._proto = 757
self._is_server = server
self._host = "localhost"
self._port = 25565
self._dispatching = False
self._packet_whitelist = None
self._packet_id_whitelist = None
self._log_ignored_packets = False
@property
def proto(self) -> int:
return self._proto
@property
def host(self) -> str:
return self._host
@property
def port(self) -> int:
return self._port
@property
def encryption(self) -> bool:
return self._encryption
@property
def compression(self) -> Optional[int]:
return self._compression
@property
def is_server(self) -> bool:
return self._is_server
@property
def connected(self) -> bool:
return self._dispatching
async def write(self, packet:Packet, wait:bool=False) -> int:
await self._outgoing.put(packet)
if wait:
await packet.processed.wait()
return self._outgoing.qsize()
async def packets(self, timeout=1) -> AsyncIterator[Packet]:
while self.connected or self._incoming.qsize(): # Finish processing packets on disconnect
try: # TODO replace this timed busy-wait with an event which resolves upon disconnection, and await both
packet = await asyncio.wait_for(self._incoming.get(), timeout=timeout)
try:
yield packet
finally:
self._incoming.task_done()
except asyncio.TimeoutError:
pass # so we recheck self.connected
def encrypt(self, secret:Optional[bytes]=None) -> 'Dispatcher':
if secret is not None:
cipher = encryption.create_AES_cipher(secret)
self._encryptor = cipher.encryptor()
self._decryptor = cipher.decryptor()
self._encryption = True
self.logger.info("Encryption enabled")
else:
self._encryption = False
self.logger.info("Encryption disabled")
return self
def whitelist(self, ids:Optional[List[Type[Packet]]]) -> 'Dispatcher':
self._packet_whitelist = set(ids) if ids is not None else None
if self._packet_whitelist:
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
self._packet_id_whitelist = set((P(self._proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None
return self
def set_host(
self,
host:Optional[str]="",
port:Optional[int]=0,
) -> 'Dispatcher':
self._host = host or self._host
self._port = port or self._port
self.logger = LOGGER.getChild(f"on({self._host}:{self._port})")
return self
def set_proto(self, proto:Optional[int]=757) -> 'Dispatcher':
self._proto = proto or self._proto
if self._packet_whitelist:
self._packet_id_whitelist = set((P(self._proto).id for P in self._packet_whitelist))
return self
def set_compression(self, threshold:Optional[int] = None) -> 'Dispatcher':
self._compression = threshold
return self
def set_state(self, state:Optional[ConnectionState]=ConnectionState.HANDSHAKING) -> 'Dispatcher':
self.state = state or self.state
return self
def log_ignored_packets(self, log:bool) -> 'Dispatcher':
self._log_ignored_packets = log
return self
async def connect(self,
reader : Optional[StreamReader] = None,
writer : Optional[StreamWriter] = None,
queue_size : int = 100,
) -> 'Dispatcher':
if self.connected:
raise InvalidState("Dispatcher already connected")
self._encryption = False
self._compression = None
self._incoming = Queue(queue_size)
self._outgoing = Queue(queue_size)
self._dispatching = False
self._reader = None
self._writer = None
if reader and writer:
self._down, self._up = reader, writer
else: # TODO put a timeout here and throw exception
self.logger.debug("Attempting to connect to %s:%d", self._host, self._port)
self._down, self._up = await asyncio.open_connection(
host=self._host,
port=self._port,
)
self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
self.logger.info("Connected")
return self
async def disconnect(self, block:bool=True) -> 'Dispatcher':
self._dispatching = False
if block and self._writer and self._reader:
await asyncio.gather(self._writer, self._reader)
self.logger.debug("Net workers stopped")
if self._up:
if isinstance(self._up, StreamWriter) and not self._up.is_closing() and self._up.can_write_eof():
try:
self._up.write_eof()
self.logger.debug("Wrote EOF on socket")
except OSError as e:
self.logger.error("Could not write EOF : %s", str(e))
self._up.close()
if block and isinstance(self._up, StreamWriter):
await self._up.wait_closed()
self.logger.debug("Socket closed")
if block:
self.logger.info("Disconnected")
return self
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
m : ModuleType
if self.state == ConnectionState.HANDSHAKING:
m = minecraft_protocol.handshaking
elif self.state == ConnectionState.STATUS:
m = minecraft_protocol.status
elif self.state == ConnectionState.LOGIN:
m = minecraft_protocol.login
elif self.state == ConnectionState.PLAY:
m = minecraft_protocol.play
else:
raise InvalidState("Cannot access registries from invalid state")
if self.is_server:
reg = m.serverbound.REGISTRY
else:
reg = m.clientbound.REGISTRY
if not self._proto:
raise InvalidState("Cannot access registries from invalid protocol")
proto = self._proto
while proto not in reg:
proto -= 1
proto_reg = reg[proto]
return proto_reg[packet_id]
async def _read_varint_from_stream(self) -> int:
if not isinstance(self._down, StreamReader):
raise InvalidState("Requires a TCP connection")
numRead = 0
result = 0
while True:
data = await self._down.readexactly(1)
if self._encryption:
data = self._decryptor.update(data)
buf = int.from_bytes(data, 'little')
result |= (buf & 0b01111111) << (7 * numRead)
numRead +=1
if numRead > 5:
raise ValueError("VarInt is too big")
if buf & 0b10000000 == 0:
break
return result
async def _read_packet(self) -> bytes:
length = await self._read_varint_from_stream()
return await self._down.readexactly(length)
async def _write_packet(self, data:bytes):
self._up.write(data)
await self._up.drain() # TODO maybe no need to call drain?
async def _down_worker(self, timeout:float=30):
while self._dispatching:
try: # Will timeout or raise EOFError if client gets disconnected
data = await asyncio.wait_for(self._read_packet(), timeout=timeout)
if not data:
continue
if self._encryption:
data = self._decryptor.update(data)
buffer = io.BytesIO(data)
if self._compression is not None:
decompressed_size = VarInt.read(buffer, Context(_proto=self._proto))
if decompressed_size > 0:
decompressor = zlib.decompressobj()
decompressed_data = decompressor.decompress(buffer.read())
if len(decompressed_data) != decompressed_size:
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer, Context(_proto=self._proto))
if self.state == ConnectionState.PLAY and self._packet_id_whitelist \
and packet_id not in self._packet_id_whitelist:
if self._log_ignored_packets:
self.logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
continue # ignore this packet, we rarely need them all, should improve performance
cls = self._packet_type_from_registry(packet_id)
packet = cls.deserialize(self._proto, buffer)
self.logger.debug("[<--] Received | %s", repr(packet))
await self._incoming.put(packet)
if self.state != ConnectionState.PLAY:
await self._incoming.join() # During play we can pre-process packets
except (asyncio.TimeoutError, TimeoutError):
self.logger.error("Connection timed out")
await self.disconnect(block=False)
except (ConnectionResetError, BrokenPipeError):
self.logger.error("Connection reset while reading packet")
await self.disconnect(block=False)
except (asyncio.IncompleteReadError, EOFError):
self.logger.error("Received EOF while reading packet")
await self.disconnect(block=False)
except Exception:
self.logger.exception("Exception parsing packet %d", packet_id)
self.logger.debug("%s", buffer.getvalue())
await self.disconnect(block=False)
async def _up_worker(self, timeout=1):
while self._dispatching:
try:
packet : Packet = await asyncio.wait_for(self._outgoing.get(), timeout=timeout)
except asyncio.TimeoutError:
continue # check again self._dispatching
if not self._dispatching: # uglier than 'while self._dispatching' but I need to check it again after unblocking
return
try:
buffer = packet.serialize()
length = len(buffer.getvalue()) # ewww TODO
if self._compression is not None:
if length > self._compression:
new_buffer = io.BytesIO()
VarInt.write(length, new_buffer, Context(_proto=self._proto))
new_buffer.write(zlib.compress(buffer.read()))
buffer = new_buffer
else:
new_buffer = io.BytesIO()
VarInt.write(0, new_buffer, Context(_proto=self._proto))
new_buffer.write(buffer.read())
buffer = new_buffer
length = len(buffer.getvalue())
data = VarInt.serialize(length) + buffer.getvalue()
if self._encryption:
data = self._encryptor.update(data)
await self._write_packet(data)
self.logger.debug("[-->] Sent | %s", repr(packet))
except Exception:
self.logger.exception("Exception dispatching packet %s", str(packet))
packet.processed.set() # Notify that packet has been processed