From a485b251508e2b99f4caa72d98b55de53a0bf269 Mon Sep 17 00:00:00 2001 From: alemi Date: Thu, 28 Oct 2021 20:21:07 +0200 Subject: [PATCH] implemented some basic proto logic --- aiocraft/client.py | 7 +---- aiocraft/dispatcher.py | 69 +++++++++++++++++++++++++++++++++++++----- setup.py | 4 +-- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/aiocraft/client.py b/aiocraft/client.py index f39fb36..c361944 100644 --- a/aiocraft/client.py +++ b/aiocraft/client.py @@ -4,16 +4,11 @@ from enum import Enum from typing import Dict -from .dispatcher import Dispatcher +from .dispatcher import Dispatcher, ConnectionState from .mc.mctypes import VarInt from .mc.packet import Packet from .mc import proto -class ConnectionState(Enum): - HANDSHAKING = 0 - STATUS = 1 - LOGIN = 2 - PLAY = 3 def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Packet]]: if state == ConnectionState.HANDSHAKING: diff --git a/aiocraft/dispatcher.py b/aiocraft/dispatcher.py index f1883e5..99ddd2d 100644 --- a/aiocraft/dispatcher.py +++ b/aiocraft/dispatcher.py @@ -1,8 +1,26 @@ import asyncio +import zlib from asyncio import StreamReader, StreamWriter, Queue, Task from enum import Enum from .mc import proto +from .mc.mctypes import VarInt + +class ConnectionState(Enum): + HANDSHAKING = 0 + STATUS = 1 + LOGIN = 2 + PLAY = 3 + +def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Packet]]: + if state == ConnectionState.HANDSHAKING: + return proto.handshaking.clientbound.REGISTRY + if state == ConnectionState.STATUS: + return proto.status.clientbound.REGISTRY + if state == ConnectionState.LOGIN: + return proto.login.clientbound.REGISTRY + if state == ConnectionState.PLAY: + return proto.play.clientbound.REGISTRY class InvalidState(Exception): pass @@ -20,10 +38,10 @@ async def read_varint(stream: asyncio.StreamReader) -> int: return buf _STATE_REGS = { - ConnectionStatus.HANDSHAKING : proto.handshaking, - ConnectionStatus.STATUS : proto.status, - ConnectionStatus.LOGIN : proto.login, - ConnectionStatus.PLAY : proto.play, + ConnectionState.HANDSHAKING : proto.handshaking, + ConnectionState.STATUS : proto.status, + ConnectionState.LOGIN : proto.login, + ConnectionState.PLAY : proto.play, } class Dispatcher: @@ -32,9 +50,15 @@ class Dispatcher: _reader : Task _writer : Task _dispatching : bool + incoming : Queue outgoing : Queue + connected : bool + state : ConnectionState + encryption : bool + compression : Optional[int] + host : str port : int @@ -49,6 +73,9 @@ class Dispatcher: host=self.host, port=self.port, ) + self.state = ConnectionState.HANDSHAKING + self.encryption = False + self.compression = None self.connected = True packet_handshake = proto.handshaking.serverbound.PacketSetProtocol( @@ -70,18 +97,44 @@ class Dispatcher: async def _down_worker(self): while self._dispatching: length = await read_varint(self._down) - buffer = await self._down.read(length) + buffer = io.BytesIO(await self._down.read(length)) + # TODO encryption - # TODO compression - await self.incoming.put(buffer) + + if self.compression is not None: + decompressed_size = VarInt.read(buffer) + if decompressed_size > 0: + decompressor = zlib.decompressobj() + decompressed_data = decompressor.decompress(buffer.read()) + if len(decompressed_data) != decompressed_size: + raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}") + buffer = io.BytesIO(decompressed_data) + + packet_id = VarInt.read(buffer) + cls = _registry_from_state(self.state)[self.proto][packet_id] + await self.incoming.put(cls.deserialize(self.proto, buffer)) async def _up_worker(self): while self._dispatching: packet = await self.outgoing.get() buffer = packet.serialize() length = len(buffer) - # TODO compression + + if self.compression is not None: + if length > self.compression: + new_buffer = io.BytesIO() + VarInt.write(length, new_buffer) + new_buffer.write(zlib.compress(buffer.read())) + buffer = new_buffer + else: + new_buffer = io.BytesIO() + VarInt.write(0, new_buffer) + new_buffer.write(buffer.read) + buffer = new_buffer + length = len(buffer) + # TODO encryption + await self._up.write(VarInt.serialize(length) + buffer) async def run(self): diff --git a/setup.py b/setup.py index 99641c7..9514732 100644 --- a/setup.py +++ b/setup.py @@ -6,8 +6,8 @@ compile() setup( name='aiocraft', - version='0.0.1', - description='asyncio-powered headless minecraft client', + version='0.0.2', + description='asyncio-powered headless minecraft client library', url='https://github.com/alemidev/aiocraft', author='alemi', author_email='me@alemi.dev',