diff --git a/aiocraft/dispatcher.py b/aiocraft/dispatcher.py index 99ddd2d..b7fad60 100644 --- a/aiocraft/dispatcher.py +++ b/aiocraft/dispatcher.py @@ -1,149 +1,200 @@ +import io import asyncio import zlib +import logging from asyncio import StreamReader, StreamWriter, Queue, Task from enum import Enum +from typing import Dict, Optional + +from cryptography.hazmat.primitives.ciphers import CipherContext from .mc import proto from .mc.mctypes import VarInt +from .mc.packet import Packet +from .mc import encryption + +logger = logging.getLogger(__name__) class ConnectionState(Enum): + NONE = -1 HANDSHAKING = 0 STATUS = 1 LOGIN = 2 PLAY = 3 -def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Packet]]: - if state == ConnectionState.HANDSHAKING: - return proto.handshaking.clientbound.REGISTRY - if state == ConnectionState.STATUS: - return proto.status.clientbound.REGISTRY - if state == ConnectionState.LOGIN: - return proto.login.clientbound.REGISTRY - if state == ConnectionState.PLAY: - return proto.play.clientbound.REGISTRY - class InvalidState(Exception): pass -async def read_varint(stream: asyncio.StreamReader) -> int: - """Utility method to read a VarInt off the socket, because len comes as a VarInt...""" - buf = 0 - off = 0 - while True: - byte = int.from_bytes(await stream.read(1), 'little') - buf |= (byte & 0b01111111) >> (7*off) - if not byte & 0b10000000: - break - off += 1 - return buf +class ConnectionError(Exception): + pass _STATE_REGS = { - ConnectionState.HANDSHAKING : proto.handshaking, - ConnectionState.STATUS : proto.status, - ConnectionState.LOGIN : proto.login, - ConnectionState.PLAY : proto.play, + ConnectionState.HANDSHAKING : proto.handshaking.clientbound.REGISTRY, + ConnectionState.STATUS : proto.status.clientbound.REGISTRY, + ConnectionState.LOGIN : proto.login.clientbound.REGISTRY, + ConnectionState.PLAY : proto.play.clientbound.REGISTRY, } class Dispatcher: _down : StreamReader - _up : StreamWriter _reader : Task + _decryptor : CipherContext + + _up : StreamWriter _writer : Task + _encryptor : CipherContext + _dispatching : bool - incoming : Queue - outgoing : Queue + incoming : Queue[Packet] + outgoing : Queue[Packet] + host : str + port : int + + proto : int connected : bool state : ConnectionState encryption : bool compression : Optional[int] - host : str - port : int - def __init__(self, host:str, port:int): self.host = host self.port = port + + self.proto = 340 self.connected = False self._dispatching = False + self.compression = None + self.encryption = False + + self.incoming = Queue() + self.outgoing = Queue() + + async def write(self, packet:Packet, wait:bool=False) -> int: + await self.outgoing.put(packet) + if wait: + await packet.sent.wait() + return self.outgoing.qsize() + + async def start(self): + if self.connected: + raise InvalidState("Dispatcher already connected") + await self.connect() + + async def stop(self, block:bool=True): + self._dispatching = False + if block: + await asyncio.gather(self._writer, self._reader) async def connect(self): self._down, self._up = await asyncio.open_connection( host=self.host, port=self.port, ) - self.state = ConnectionState.HANDSHAKING self.encryption = False self.compression = None self.connected = True - - packet_handshake = proto.handshaking.serverbound.PacketSetProtocol( - self.proto, - protocolVersion=self.proto, - serverHost=self.host, - serverPost=self.port, - nextState=3, # play - ) - packet_login = proto.login.serverbound.PacketLoginStart(340, username=self.username) - - await self.outgoing.put(packet_handshake) - await self.outgoing.put(packet_login) + self.state = ConnectionState.HANDSHAKING self._dispatching = True self._reader = asyncio.get_event_loop().create_task(self._down_worker()) self._writer = asyncio.get_event_loop().create_task(self._up_worker()) + async def encrypt(self, secret:bytes): + cipher = encryption.create_AES_cipher(secret) + self._encryptor = cipher.encryptor() + self._decryptor = cipher.decryptor() + self.encryption = True + + + async def _read_varint(self) -> int: + numRead = 0 + result = 0 + while True: + data = await self._down.readexactly(1) + if len(data) < 1: + raise ConnectionError("Could not read data off socket") + if self.encryption: + data = self._decryptor.update(data) + buf = int.from_bytes(data, 'little') + result |= (buf & 0b01111111) << (7 * numRead) + numRead +=1 + if numRead > 5: + raise ValueError("VarInt is too big") + if buf & 0b10000000 == 0: + break + return result + async def _down_worker(self): while self._dispatching: - length = await read_varint(self._down) - buffer = io.BytesIO(await self._down.read(length)) + if self.state != ConnectionState.PLAY: + await self.incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression + try: # these 2 will timeout or raise EOFError if client gets disconnected + length = await self._read_varint() + data = await self._down.readexactly(length) - # TODO encryption + if self.encryption: + data = self._decryptor.update(data) - if self.compression is not None: - decompressed_size = VarInt.read(buffer) - if decompressed_size > 0: - decompressor = zlib.decompressobj() - decompressed_data = decompressor.decompress(buffer.read()) - if len(decompressed_data) != decompressed_size: - raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}") - buffer = io.BytesIO(decompressed_data) + buffer = io.BytesIO(data) - packet_id = VarInt.read(buffer) - cls = _registry_from_state(self.state)[self.proto][packet_id] - await self.incoming.put(cls.deserialize(self.proto, buffer)) + if self.compression is not None: + decompressed_size = VarInt.read(buffer) + # logger.info("Decompressing packet to %d | %s", decompressed_size, buffer.getvalue()) + if decompressed_size > 0: + decompressor = zlib.decompressobj() + decompressed_data = decompressor.decompress(buffer.read()) + # logger.info("Obtained %s", decompressed_data) + if len(decompressed_data) != decompressed_size: + raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}") + buffer = io.BytesIO(decompressed_data) + + packet_id = VarInt.read(buffer) + # logger.info("Parsing packet '%d' [%s] | %s", packet_id, str(self.state), buffer.getvalue()) + cls = _STATE_REGS[self.state][self.proto][packet_id] + packet = cls.deserialize(self.proto, buffer) + await self.incoming.put(packet) + except AttributeError: + logger.debug("Received unimplemented packet %d", packet_id) + except (ConnectionError, asyncio.IncompleteReadError): + logger.exception("Connection error") + await self.stop(block=False) + except Exception: + logger.exception("Error while processing packet %d | %s", packet_id, buffer.getvalue()) async def _up_worker(self): while self._dispatching: - packet = await self.outgoing.get() - buffer = packet.serialize() - length = len(buffer) + # logger.info("Up packet") + try: + packet = await asyncio.wait_for(self.outgoing.get(), timeout=5) + buffer = packet.serialize() + # logger.info("Sending packet %s [%s]", str(packet), buffer.getvalue()) + length = len(buffer.getvalue()) # ewww TODO - if self.compression is not None: - if length > self.compression: - new_buffer = io.BytesIO() - VarInt.write(length, new_buffer) - new_buffer.write(zlib.compress(buffer.read())) - buffer = new_buffer - else: - new_buffer = io.BytesIO() - VarInt.write(0, new_buffer) - new_buffer.write(buffer.read) - buffer = new_buffer - length = len(buffer) + if self.compression is not None: + if length > self.compression: + new_buffer = io.BytesIO() + VarInt.write(length, new_buffer) + new_buffer.write(zlib.compress(buffer.read())) + buffer = new_buffer + else: + new_buffer = io.BytesIO() + VarInt.write(0, new_buffer) + new_buffer.write(buffer.read()) + buffer = new_buffer + length = len(buffer.getvalue()) - # TODO encryption + data = VarInt.serialize(length) + buffer.getvalue() + if self.encryption: + data = self._encryptor.update(data) - await self._up.write(VarInt.serialize(length) + buffer) - - async def run(self): - if self.connected: - raise InvalidState("Dispatcher already connected") - await self.connect() - - async def stop(self): - self._dispatching = False - await asyncio.gather(self._writer, self._reader) + self._up.write(data) + await self._up.drain() + packet.sent.set() # Notify + except asyncio.TimeoutError: + pass # need this to recheck self._dispatching periodically + except Exception: + logger.exception("Error while sending packet %s", str(packet)) diff --git a/aiocraft/mc/encryption.py b/aiocraft/mc/encryption.py new file mode 100644 index 0000000..421d632 --- /dev/null +++ b/aiocraft/mc/encryption.py @@ -0,0 +1,39 @@ +"""Minecraft encryption utilities. These are mostly pasted and edited from https://github.com/ammaraskar/pyCraft""" +# TODO read more about this, improve implementation if possible +import os +from hashlib import sha1 +from asyncio import StreamWriter, StreamReader +from dataclasses import dataclass + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 +from cryptography.hazmat.primitives.serialization import load_der_public_key +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + + +def generate_shared_secret() -> bytes: + return os.urandom(16) + + +def create_AES_cipher(shared_secret:bytes) -> Cipher: + return Cipher(algorithms.AES(shared_secret), modes.CFB8(shared_secret), + backend=default_backend()) + +def encrypt_token_and_secret(pubkey, verification_token, shared_secret:bytes): + pubkey = load_der_public_key(pubkey, default_backend()) + + encrypted_token = pubkey.encrypt(verification_token, PKCS1v15()) + encrypted_secret = pubkey.encrypt(shared_secret, PKCS1v15()) + return encrypted_token, encrypted_secret + +def generate_verification_hash(server_id, shared_secret, public_key) -> str: + verification_hash = sha1() + + verification_hash.update(server_id.encode('utf-8')) + verification_hash.update(shared_secret) + verification_hash.update(public_key) + + return minecraft_sha1_hash_digest(verification_hash) + +def minecraft_sha1_hash_digest(sha1_hash) -> str: + return format(int.from_bytes(sha1_hash.digest(), byteorder='big', signed=True), 'x')