implemented some basic proto logic

This commit is contained in:
əlemi 2021-10-28 20:21:07 +02:00 committed by alemidev
parent aa92bfe3cd
commit a485b25150
3 changed files with 64 additions and 16 deletions

View file

@ -4,16 +4,11 @@ from enum import Enum
from typing import Dict from typing import Dict
from .dispatcher import Dispatcher from .dispatcher import Dispatcher, ConnectionState
from .mc.mctypes import VarInt from .mc.mctypes import VarInt
from .mc.packet import Packet from .mc.packet import Packet
from .mc import proto from .mc import proto
class ConnectionState(Enum):
HANDSHAKING = 0
STATUS = 1
LOGIN = 2
PLAY = 3
def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Packet]]: def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Packet]]:
if state == ConnectionState.HANDSHAKING: if state == ConnectionState.HANDSHAKING:

View file

@ -1,8 +1,26 @@
import asyncio import asyncio
import zlib
from asyncio import StreamReader, StreamWriter, Queue, Task from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum from enum import Enum
from .mc import proto from .mc import proto
from .mc.mctypes import VarInt
class ConnectionState(Enum):
HANDSHAKING = 0
STATUS = 1
LOGIN = 2
PLAY = 3
def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Packet]]:
if state == ConnectionState.HANDSHAKING:
return proto.handshaking.clientbound.REGISTRY
if state == ConnectionState.STATUS:
return proto.status.clientbound.REGISTRY
if state == ConnectionState.LOGIN:
return proto.login.clientbound.REGISTRY
if state == ConnectionState.PLAY:
return proto.play.clientbound.REGISTRY
class InvalidState(Exception): class InvalidState(Exception):
pass pass
@ -20,10 +38,10 @@ async def read_varint(stream: asyncio.StreamReader) -> int:
return buf return buf
_STATE_REGS = { _STATE_REGS = {
ConnectionStatus.HANDSHAKING : proto.handshaking, ConnectionState.HANDSHAKING : proto.handshaking,
ConnectionStatus.STATUS : proto.status, ConnectionState.STATUS : proto.status,
ConnectionStatus.LOGIN : proto.login, ConnectionState.LOGIN : proto.login,
ConnectionStatus.PLAY : proto.play, ConnectionState.PLAY : proto.play,
} }
class Dispatcher: class Dispatcher:
@ -32,9 +50,15 @@ class Dispatcher:
_reader : Task _reader : Task
_writer : Task _writer : Task
_dispatching : bool _dispatching : bool
incoming : Queue incoming : Queue
outgoing : Queue outgoing : Queue
connected : bool connected : bool
state : ConnectionState
encryption : bool
compression : Optional[int]
host : str host : str
port : int port : int
@ -49,6 +73,9 @@ class Dispatcher:
host=self.host, host=self.host,
port=self.port, port=self.port,
) )
self.state = ConnectionState.HANDSHAKING
self.encryption = False
self.compression = None
self.connected = True self.connected = True
packet_handshake = proto.handshaking.serverbound.PacketSetProtocol( packet_handshake = proto.handshaking.serverbound.PacketSetProtocol(
@ -70,18 +97,44 @@ class Dispatcher:
async def _down_worker(self): async def _down_worker(self):
while self._dispatching: while self._dispatching:
length = await read_varint(self._down) length = await read_varint(self._down)
buffer = await self._down.read(length) buffer = io.BytesIO(await self._down.read(length))
# TODO encryption # TODO encryption
# TODO compression
await self.incoming.put(buffer) if self.compression is not None:
decompressed_size = VarInt.read(buffer)
if decompressed_size > 0:
decompressor = zlib.decompressobj()
decompressed_data = decompressor.decompress(buffer.read())
if len(decompressed_data) != decompressed_size:
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer)
cls = _registry_from_state(self.state)[self.proto][packet_id]
await self.incoming.put(cls.deserialize(self.proto, buffer))
async def _up_worker(self): async def _up_worker(self):
while self._dispatching: while self._dispatching:
packet = await self.outgoing.get() packet = await self.outgoing.get()
buffer = packet.serialize() buffer = packet.serialize()
length = len(buffer) length = len(buffer)
# TODO compression
if self.compression is not None:
if length > self.compression:
new_buffer = io.BytesIO()
VarInt.write(length, new_buffer)
new_buffer.write(zlib.compress(buffer.read()))
buffer = new_buffer
else:
new_buffer = io.BytesIO()
VarInt.write(0, new_buffer)
new_buffer.write(buffer.read)
buffer = new_buffer
length = len(buffer)
# TODO encryption # TODO encryption
await self._up.write(VarInt.serialize(length) + buffer) await self._up.write(VarInt.serialize(length) + buffer)
async def run(self): async def run(self):

View file

@ -6,8 +6,8 @@ compile()
setup( setup(
name='aiocraft', name='aiocraft',
version='0.0.1', version='0.0.2',
description='asyncio-powered headless minecraft client', description='asyncio-powered headless minecraft client library',
url='https://github.com/alemidev/aiocraft', url='https://github.com/alemidev/aiocraft',
author='alemi', author='alemi',
author_email='me@alemi.dev', author_email='me@alemi.dev',