diff --git a/aiocraft/__main__.py b/aiocraft/__main__.py index e8a7ebe..be30fe5 100644 --- a/aiocraft/__main__.py +++ b/aiocraft/__main__.py @@ -4,10 +4,10 @@ import asyncio import logging from .mc.proto.play.clientbound import PacketChat -from .mc.identity import Token +from .mc.token import Token from .dispatcher import ConnectionState from .client import Client -from .helpers import parse_chat +from .util.helpers import parse_chat async def idle(): while True: diff --git a/aiocraft/client.py b/aiocraft/client.py index e3af021..87f4ad8 100644 --- a/aiocraft/client.py +++ b/aiocraft/client.py @@ -3,35 +3,23 @@ import logging from asyncio import Task from enum import Enum -from typing import Dict, List, Callable, Type, Optional, Tuple +from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator -from .dispatcher import Dispatcher, ConnectionState -from .mc.mctypes import VarInt +from .dispatcher import Dispatcher from .mc.packet import Packet -from .mc.identity import Token, AuthException -from .mc.definitions import Dimension, Difficulty, Gamemode -from .mc import proto, encryption +from .mc.token import Token, AuthException +from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState +from .mc.proto.handshaking.serverbound import PacketSetProtocol +from .mc.proto.play.serverbound import PacketKeepAlive as PacketKeepAliveResponse +from .mc.proto.play.clientbound import PacketKeepAlive, PacketSetCompression, PacketKickDisconnect +from .mc.proto.login.serverbound import PacketLoginStart, PacketEncryptionBegin as PacketEncryptionResponse +from .mc.proto.login.clientbound import ( + PacketCompress, PacketDisconnect, PacketEncryptionBegin, PacketLoginPluginRequest, PacketSuccess +) +from .util import encryption LOGGER = logging.getLogger(__name__) -def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Type[Packet]]]: - if state == ConnectionState.HANDSHAKING: - return proto.handshaking.clientbound.REGISTRY - if state == ConnectionState.STATUS: - return proto.status.clientbound.REGISTRY - if state == ConnectionState.LOGIN: - return proto.login.clientbound.REGISTRY - if state == ConnectionState.PLAY: - return proto.play.clientbound.REGISTRY - return {} - -_STATE_REGS = { - ConnectionState.HANDSHAKING : proto.handshaking, - ConnectionState.STATUS : proto.status, - ConnectionState.LOGIN : proto.login, - ConnectionState.PLAY : proto.play, -} - class Client: host:str port:int @@ -46,7 +34,7 @@ class Client: _authenticated : bool _worker : Task - _packet_callbacks : Dict[ConnectionState, Dict[Packet, List[Callable]]] + _packet_callbacks : Dict[Type[Packet], List[Callable]] _logger : logging.Logger def __init__( @@ -64,6 +52,8 @@ class Client: self.options = options or { "reconnect" : True, "rctime" : 5.0, + "keep-alive" : True, + "poll-timeout" : 1, } self.token = token @@ -142,7 +132,7 @@ class Client: try: while self._processing: # TODO don't busywait even if it doesn't matter much - await asyncio.sleep(5) + await asyncio.sleep(self.options["poll-timeout"]) except KeyboardInterrupt: self._logger.info("Received SIGINT, stopping...") else: @@ -157,171 +147,114 @@ class Client: async def stop(self, block=True): self._processing = False - await self.dispatcher.disconnect() + if self.dispatcher.connected: + await self.dispatcher.disconnect(block=block) if block: await self._worker - self._logger.info("Minecraft client stopped") + self._logger.info("Minecraft client stopped") async def _client_worker(self): while self._processing: try: await self.authenticate() - except AuthException: - self._logger.error("Token not refreshable or credentials invalid") - await self.stop(block=False) + except AuthException as e: + self._logger.error(str(e)) + break try: await self.dispatcher.connect(self.host, self.port) - for packet in self._handshake(): - await self.dispatcher.write(packet) - self.dispatcher.state = ConnectionState.LOGIN - await self._process_packets() + await self._handshake() + if await self._login(): + await self._play() except ConnectionRefusedError: self._logger.error("Server rejected connection") except Exception: self._logger.exception("Exception in Client connection") + if self.dispatcher.connected: + await self.dispatcher.disconnect() if not self.options["reconnect"]: - await self.stop(block=False) break await asyncio.sleep(self.options["rctime"]) + await self.stop(block=False) - def _handshake(self, force:bool=False) -> Tuple[Packet, Packet]: # TODO make this fancier! poll for version and status first - return ( proto.handshaking.serverbound.PacketSetProtocol( + async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first + await self.dispatcher.write( + PacketSetProtocol( 340, # TODO!!!! protocolVersion=340, serverHost=self.host, serverPort=self.port, nextState=2, # play - ), - proto.login.serverbound.PacketLoginStart( + ) + ) + await self.dispatcher.write( + PacketLoginStart( 340, username=self.token.profile.name if self.token else self.username ) ) + return True - async def _process_packets(self): - while self.dispatcher.connected: - try: - packet = await asyncio.wait_for(self.dispatcher.incoming.get(), timeout=5) - self._logger.debug("[ * ] Processing | %s", str(packet)) - - if self.dispatcher.state == ConnectionState.LOGIN: - await self.login_logic(packet) - elif self.dispatcher.state == ConnectionState.PLAY: - await self.play_logic(packet) - - if self.dispatcher.state in self._packet_callbacks: - if Packet in self._packet_callbacks[self.dispatcher.state]: # callback for any packet - for cb in self._packet_callbacks[self.dispatcher.state][Packet]: - try: - await cb(packet) - except Exception as e: - self._logger.exception("Exception while handling callback") - if packet.__class__ in self._packet_callbacks[self.dispatcher.state]: # callback for this packet - for cb in self._packet_callbacks[self.dispatcher.state][packet.__class__]: - try: - await cb(packet) - except Exception as e: - self._logger.exception("Exception while handling callback") - - self.dispatcher.incoming.task_done() - except asyncio.TimeoutError: - pass # need this to recheck self._processing periodically - except AuthException: - self._authenticated = False - self._logger.error("Authentication exception") - await self.dispatcher.disconnect(block=False) - except Exception: - self._logger.exception("Exception while processing packet %s", packet) - - # TODO move these in separate module - - async def login_logic(self, packet:Packet): - if isinstance(packet, proto.login.clientbound.PacketEncryptionBegin): - secret = encryption.generate_shared_secret() - - token, encrypted_secret = encryption.encrypt_token_and_secret( - packet.publicKey, - packet.verifyToken, - secret - ) - - if packet.serverId != '-' and self.token: - await self.token.join( - encryption.generate_verification_hash( - packet.serverId, - secret, - packet.publicKey - ) + async def _login(self) -> bool: + self.dispatcher.state = ConnectionState.LOGIN + async for packet in self.dispatcher.packets(): + if isinstance(packet, PacketEncryptionBegin): + secret = encryption.generate_shared_secret() + token, encrypted_secret = encryption.encrypt_token_and_secret( + packet.publicKey, + packet.verifyToken, + secret ) - - encryption_response = proto.login.serverbound.PacketEncryptionBegin( - 340, # TODO!!!! - sharedSecret=encrypted_secret, - verifyToken=token - ) - - await self.dispatcher.write(encryption_response, wait=True) - - self.dispatcher.encrypt(secret) - - elif isinstance(packet, proto.login.clientbound.PacketDisconnect): - self._logger.error("Kicked while logging in") - await self.dispatcher.disconnect(block=False) - # raise Exception("Disconnected while logging in") # TODO make a more specific one, do some shit - - elif isinstance(packet, proto.login.clientbound.PacketCompress): - self._logger.info("Compression enabled") - self.dispatcher.compression = packet.threshold - - elif isinstance(packet, proto.login.clientbound.PacketSuccess): - self._logger.info("Login success, joining world...") - self.dispatcher.state = ConnectionState.PLAY - - elif isinstance(packet, proto.login.clientbound.PacketLoginPluginRequest): - pass # TODO ? - - async def play_logic(self, packet:Packet): - if isinstance(packet, proto.play.clientbound.PacketSetCompression): - self._logger.info("Compression updated") - self.dispatcher.compression = packet.threshold - - elif isinstance(packet, proto.play.clientbound.PacketKeepAlive): - keep_alive_packet = proto.play.serverbound.packet_keep_alive.PacketKeepAlive(340, keepAliveId=packet.keepAliveId) - await self.dispatcher.write(keep_alive_packet) - - elif isinstance(packet, proto.play.clientbound.PacketRespawn): - self._logger.info( - "Reloading world: %s (%s) in %s", - Dimension(packet.dimension).name, - Difficulty(packet.difficulty).name, - Gamemode(packet.gamemode).name - ) - - elif isinstance(packet, proto.play.clientbound.PacketLogin): - self._logger.info( - "Joined world: %s (%s) in %s", - Dimension(packet.dimension).name, - Difficulty(packet.difficulty).name, - Gamemode(packet.gameMode).name - ) - - elif isinstance(packet, proto.play.clientbound.PacketPosition): - self._logger.info("Position synchronized") - await self.dispatcher.write( - proto.play.serverbound.PacketTeleportConfirm( - 340, - teleportId=packet.teleportId + if packet.serverId != '-' and self.token: + try: + await self.token.join( + encryption.generate_verification_hash( + packet.serverId, + secret, + packet.publicKey + ) + ) + except AuthException: + self._logger.error("Could not authenticate with Mojang") + break + encryption_response = PacketEncryptionResponse( + 340, # TODO!!!! + sharedSecret=encrypted_secret, + verifyToken=token ) - ) - - elif isinstance(packet, proto.play.clientbound.PacketUpdateHealth): - if packet.health <= 0: - self._logger.info("Dead, respawning...") - await self.dispatcher.write( - proto.play.serverbound.PacketClientCommand(self.dispatcher.proto, actionId=0) # respawn - ) - - elif isinstance(packet, proto.play.clientbound.PacketKickDisconnect): - self._logger.error("Kicked while in game") - await self.dispatcher.disconnect(block=False) + await self.dispatcher.write(encryption_response, wait=True) + self.dispatcher.encrypt(secret) + elif isinstance(packet, PacketCompress): + self._logger.info("Compression enabled") + self.dispatcher.compression = packet.threshold + elif isinstance(packet, PacketLoginPluginRequest): + self._logger.info("Ignoring plugin request") # TODO ? + elif isinstance(packet, PacketSuccess): + self._logger.info("Login success, joining world...") + return True + elif isinstance(packet, PacketDisconnect): + self._logger.error("Kicked while logging in") + break + return False + async def _play(self) -> bool: + self.dispatcher.state = ConnectionState.PLAY + async for packet in self.dispatcher.packets(): + self._logger.debug("[ * ] Processing | %s", str(packet)) + if isinstance(packet, PacketSetCompression): + self._logger.info("Compression updated") + self.dispatcher.compression = packet.threshold + elif isinstance(packet, PacketKeepAlive): + if self.options["keep-alive"]: + keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId) + await self.dispatcher.write(keep_alive_packet) + elif isinstance(packet, PacketKickDisconnect): + self._logger.error("Kicked while in game") + break + for packet_type in (Packet, packet.__class__): # check both callbacks for base class and instance class + if packet_type in self._packet_callbacks: + for cb in self._packet_callbacks[packet_type]: + try: # TODO run in executor to not block + await cb(packet) + except Exception as e: + self._logger.exception("Exception while handling callback") + return False diff --git a/aiocraft/dispatcher.py b/aiocraft/dispatcher.py index 5cdeb37..66dc6bc 100644 --- a/aiocraft/dispatcher.py +++ b/aiocraft/dispatcher.py @@ -1,26 +1,22 @@ import io import asyncio +import contextlib import zlib import logging from asyncio import StreamReader, StreamWriter, Queue, Task from enum import Enum -from typing import Dict, Optional +from typing import Dict, Optional, AsyncIterator from cryptography.hazmat.primitives.ciphers import CipherContext from .mc import proto -from .mc.mctypes import VarInt +from .mc.types import VarInt from .mc.packet import Packet -from .mc import encryption +from .mc.definitions import ConnectionState +from .util import encryption LOGGER = logging.getLogger(__name__) -class ConnectionState(Enum): - NONE = -1 - HANDSHAKING = 0 - STATUS = 1 - LOGIN = 2 - PLAY = 3 class InvalidState(Exception): pass @@ -35,6 +31,19 @@ _STATE_REGS = { ConnectionState.PLAY : proto.play.clientbound.REGISTRY, } +class PacketFrame: + _packet : Packet + _queue : Queue + + def __init__(self, packet:Packet, queue:Queue): + self._packet = packet + self._queue = queue + + def __enter__(self): + return self._packet + def __exit__(self): + self.queue.task_done() + class Dispatcher: _down : StreamReader _reader : Optional[Task] @@ -46,8 +55,8 @@ class Dispatcher: _dispatching : bool - incoming : Queue - outgoing : Queue + _incoming : Queue + _outgoing : Queue _host : str _port : int @@ -64,8 +73,8 @@ class Dispatcher: self._dispatching = False self.compression = None self.encryption = False - self.incoming = Queue() - self.outgoing = Queue() + self._incoming = Queue() + self._outgoing = Queue() self._reader = None self._writer = None self._host = "localhost" @@ -78,24 +87,37 @@ class Dispatcher: return self._dispatching async def write(self, packet:Packet, wait:bool=False) -> int: - await self.outgoing.put(packet) + await self._outgoing.put(packet) if wait: - await packet.sent.wait() - return self.outgoing.qsize() + await packet.processed.wait() + return self._outgoing.qsize() + + async def packets(self, timeout=1) -> AsyncIterator[Packet]: + while self.connected: + try: # TODO replace this timed busy-wait with an event which resolves upon disconnection, and await both + packet = await asyncio.wait_for(self._incoming.get(), timeout=timeout) + try: + yield packet + finally: + self._incoming.task_done() + except asyncio.TimeoutError: + pass # so we recheck self.connected async def disconnect(self, block:bool=True): self._dispatching = False if block and self._writer and self._reader: await asyncio.gather(self._writer, self._reader) + self._logger.debug("Net workers stopped") if self._up: if self._up.can_write_eof(): self._up.write_eof() self._up.close() if block: await self._up.wait_closed() + self._logger.debug("Socket closed") self._logger.info("Disconnected") - async def connect(self, host:Optional[str] = None, port:Optional[int] = None): + async def connect(self, host:Optional[str] = None, port:Optional[int] = None, queue_timeout:int = 1, queue_size:int = 100): if self.connected: raise InvalidState("Dispatcher already connected") @@ -110,9 +132,9 @@ class Dispatcher: self.state = ConnectionState.HANDSHAKING # self.proto = 340 # TODO - # Make new queues - self.incoming = Queue() - self.outgoing = Queue() + # Make new queues, do set a max size to sorta propagate back pressure + self._incoming = Queue(queue_size) + self._outgoing = Queue(queue_size) self._down, self._up = await asyncio.open_connection( host=self._host, @@ -121,7 +143,7 @@ class Dispatcher: self._dispatching = True self._reader = asyncio.get_event_loop().create_task(self._down_worker()) - self._writer = asyncio.get_event_loop().create_task(self._up_worker()) + self._writer = asyncio.get_event_loop().create_task(self._up_worker(timeout=queue_timeout)) self._logger.info("Connected") def encrypt(self, secret:bytes): @@ -149,8 +171,6 @@ class Dispatcher: async def _down_worker(self): while self._dispatching: - if self.state != ConnectionState.PLAY: - await self.incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression try: # these 2 will timeout or raise EOFError if client gets disconnected length = await self._read_varint() data = await self._down.readexactly(length) @@ -173,7 +193,9 @@ class Dispatcher: cls = _STATE_REGS[self.state][self.proto][packet_id] packet = cls.deserialize(self.proto, buffer) self._logger.debug("[<--] Received | %s", str(packet)) - await self.incoming.put(packet) + await self._incoming.put(packet) + if self.state == ConnectionState.LOGIN: + await self._incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression except AttributeError: self._logger.debug("Unimplemented packet %d", packet_id) except asyncio.IncompleteReadError: @@ -182,10 +204,14 @@ class Dispatcher: except Exception: self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue()) - async def _up_worker(self): + async def _up_worker(self, timeout=1): while self._dispatching: try: - packet = await asyncio.wait_for(self.outgoing.get(), timeout=5) + packet = await asyncio.wait_for(self._outgoing.get(), timeout=timeout) + except asyncio.TimeoutError: + continue # check again self._dispatching + + try: buffer = packet.serialize() length = len(buffer.getvalue()) # ewww TODO @@ -209,10 +235,9 @@ class Dispatcher: self._up.write(data) await self._up.drain() - packet.sent.set() # Notify self._logger.debug("[-->] Sent | %s", str(packet)) - except asyncio.TimeoutError: - pass # need this to recheck self._dispatching periodically except Exception: self._logger.exception("Exception dispatching packet %s", str(packet)) + packet.processed.set() # Notify that packet has been processed + diff --git a/aiocraft/mc/__init__.py b/aiocraft/mc/__init__.py index 8c30523..e69de29 100644 --- a/aiocraft/mc/__init__.py +++ b/aiocraft/mc/__init__.py @@ -1,4 +0,0 @@ -"""Minecraft definitions""" -from .packet import Packet -from .mctypes import * -from .proto import * diff --git a/aiocraft/mc/definitions.py b/aiocraft/mc/definitions.py index 054139f..730da3d 100644 --- a/aiocraft/mc/definitions.py +++ b/aiocraft/mc/definitions.py @@ -16,3 +16,10 @@ class Gamemode(Enum): CREATIVE = 1 ADVENTURE = 2 SPECTATOR = 3 + +class ConnectionState(Enum): + NONE = -1 + HANDSHAKING = 0 + STATUS = 1 + LOGIN = 2 + PLAY = 3 diff --git a/aiocraft/mc/packet.py b/aiocraft/mc/packet.py index feb110c..0badf0c 100644 --- a/aiocraft/mc/packet.py +++ b/aiocraft/mc/packet.py @@ -3,14 +3,14 @@ import json from asyncio import Event from typing import Tuple, Dict, Any -from .mctypes import Type, VarInt +from .types import Type, VarInt class Packet: - __slots__ = 'id', 'definition', 'sent', '_protocol', '_state' + __slots__ = 'id', 'definition', '_processed', '_protocol', '_state' id : int definition : Tuple[Tuple[str, Type]] - sent : Event + _processed : Event _protocol : int _state : int @@ -19,12 +19,17 @@ class Packet: def __init__(self, proto:int, **kwargs): self._protocol = proto + self._processed = Event() self.definition = self._definitions[proto] - self.sent = Event() self.id = self._ids[proto] for name, t in self.definition: setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None) + @property + def processed(self) -> Event: + """Returns an event which will be set only after the packet has been processed (either sent or raised exc)""" + return self._processed + @classmethod def deserialize(cls, proto:int, buffer:io.BytesIO): return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._definitions[proto] }) diff --git a/aiocraft/mc/identity.py b/aiocraft/mc/token.py similarity index 89% rename from aiocraft/mc/identity.py rename to aiocraft/mc/token.py index 59ee6b1..beafe20 100644 --- a/aiocraft/mc/identity.py +++ b/aiocraft/mc/token.py @@ -101,8 +101,10 @@ class Token: async with aiohttp.ClientSession() as sess: async with sess.post(endpoint, headers=cls.HEADERS, data=json.dumps(data).encode('utf-8')) as res: data = await res.json(content_type=None) - logging.info(f"Auth request | {res.status} | {data}") if res.status >= 400: - raise AuthException(f"Action '{endpoint.rsplit('/',1)[1]}' did not succeed") + err_type = data["error"] if data and "error" in data else "Unknown Error" + err_msg = data["errorMessage"] if data and "errorMessage" in data else "Credentials invalid or token not refreshable anymore" + action = endpoint.rsplit('/',1)[1] + raise AuthException(f"[{action}] {err_type} : {err_msg}") return data diff --git a/aiocraft/mc/mctypes.py b/aiocraft/mc/types.py similarity index 100% rename from aiocraft/mc/mctypes.py rename to aiocraft/mc/types.py diff --git a/aiocraft/util/__init__.py b/aiocraft/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aiocraft/mc/encryption.py b/aiocraft/util/encryption.py similarity index 100% rename from aiocraft/mc/encryption.py rename to aiocraft/util/encryption.py diff --git a/aiocraft/helpers.py b/aiocraft/util/helpers.py similarity index 100% rename from aiocraft/helpers.py rename to aiocraft/util/helpers.py diff --git a/compiler/proto.py b/compiler/proto.py index 59ca132..a6bab24 100644 --- a/compiler/proto.py +++ b/compiler/proto.py @@ -7,13 +7,13 @@ import logging from pathlib import Path from typing import List, Dict, Union -from aiocraft.mc.mctypes import * +from aiocraft.mc.types import * DIR_MAP = {"toClient": "clientbound", "toServer": "serverbound"} PREFACE = """\"\"\"[!] This file is autogenerated\"\"\"\n\n""" IMPORTS = """from typing import Tuple, List, Dict from ....packet import Packet -from ....mctypes import *\n""" +from ....types import *\n""" IMPORT_ALL = """__all__ = [\n\t{all}\n]\n""" REGISTRY_ENTRY = """ REGISTRY = {entries}\n"""