diff --git a/setup.cfg b/setup.cfg index 4dcf6b2..2dd182d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = aiocraft -version = 0.0.8 +version = 0.0.9 author = alemi author_email = me@alemi.dev description = asyncio-powered headless minecraft client library @@ -20,6 +20,7 @@ install_requires = cryptography aiohttp termcolor + asyncio-dgram package_dir = = src packages = find: diff --git a/src/aiocraft/client.py b/src/aiocraft/client.py index d067e92..77e5f45 100644 --- a/src/aiocraft/client.py +++ b/src/aiocraft/client.py @@ -10,7 +10,7 @@ from time import time from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any, Set -from .dispatcher import Dispatcher +from .dispatcher import Dispatcher, Transport from .mc.packet import Packet from .mc.auth import AuthInterface, AuthException, MojangAuthenticator, MicrosoftAuthenticator from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState @@ -40,6 +40,7 @@ class MinecraftClient: self, server:str, authenticator:AuthInterface, + use_udp:bool=False, online_mode:bool = True, ): super().__init__() @@ -55,7 +56,8 @@ class MinecraftClient: self.authenticator = authenticator self._authenticated = False - self.dispatcher = Dispatcher().set_host(host, port) + _transp = Transport.UDP if use_udp else Transport.TCP + self.dispatcher = Dispatcher().set_host(host, port, transport=_transp) self._processing = False self.logger = LOGGER.getChild(f"on({server})") @@ -106,10 +108,10 @@ class MinecraftClient: async def _handshake(self, state:ConnectionState): await self.dispatcher.write( PacketSetProtocol( - self.dispatcher.proto, - protocolVersion=self.dispatcher.proto, - serverHost=self.dispatcher.host, - serverPort=self.dispatcher.port, + self.dispatcher._proto, + protocolVersion=self.dispatcher._proto, + serverHost=self.dispatcher._host, + serverPort=self.dispatcher._port, nextState=state.value ) ) @@ -117,7 +119,7 @@ class MinecraftClient: async def _status(self, ping:bool=False) -> Dict[str, Any]: self.dispatcher.state = ConnectionState.STATUS await self.dispatcher.write( - PacketPingStart(self.dispatcher.proto) #empty packet + PacketPingStart(self.dispatcher._proto) #empty packet ) #Response data : Dict[str, Any] = {} @@ -133,7 +135,7 @@ class MinecraftClient: ping_time = time() await self.dispatcher.write( PacketPing( - self.dispatcher.proto, + self.dispatcher._proto, time=ping_id, ) ) @@ -147,7 +149,7 @@ class MinecraftClient: self.dispatcher.state = ConnectionState.LOGIN await self.dispatcher.write( PacketLoginStart( - self.dispatcher.proto, + self.dispatcher._proto, username=self.authenticator.selectedProfile.name ) ) @@ -178,7 +180,7 @@ class MinecraftClient: else: self.logger.warning("Server gave an offline-mode serverId but still requested Encryption") encryption_response = PacketEncryptionResponse( - self.dispatcher.proto, + self.dispatcher._proto, sharedSecret=encrypted_secret, verifyToken=token ) @@ -186,7 +188,7 @@ class MinecraftClient: self.dispatcher.encrypt(secret) elif isinstance(packet, PacketCompress): self.logger.info("Compression enabled") - self.dispatcher.compression = packet.threshold + self.dispatcher._compression = packet.threshold elif isinstance(packet, PacketLoginPluginRequest): self.logger.info("Ignoring plugin request") # TODO ? elif isinstance(packet, PacketSuccess): @@ -203,7 +205,7 @@ class MinecraftClient: self.logger.debug("[ * ] Processing %s", packet.__class__.__name__) if isinstance(packet, PacketSetCompression): self.logger.info("Compression updated") - self.dispatcher.compression = packet.threshold + self.dispatcher._compression = packet.threshold elif isinstance(packet, PacketKeepAlive): keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId) await self.dispatcher.write(keep_alive_packet) diff --git a/src/aiocraft/dispatcher.py b/src/aiocraft/dispatcher.py index 5569e30..d3d9d11 100644 --- a/src/aiocraft/dispatcher.py +++ b/src/aiocraft/dispatcher.py @@ -9,6 +9,7 @@ from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union from types import ModuleType from cryptography.hazmat.primitives.ciphers import CipherContext +import asyncio_dgram from .mc import proto as minecraft_protocol from .mc.types import VarInt, Context @@ -18,6 +19,10 @@ from .util import encryption LOGGER = logging.getLogger(__name__) +class Transport(Enum): + TCP = 0 + UDP = 1 + class InvalidState(Exception): pass @@ -27,11 +32,11 @@ class ConnectionError(Exception): class Dispatcher: _is_server : bool # True when receiving packets from clients - _down : StreamReader + _down : Union[StreamReader, asyncio_dgram.DatagramServer] _reader : Optional[Task] _decryptor : CipherContext - _up : StreamWriter + _up : Union[StreamWriter, asyncio_dgram.DatagramClient] _writer : Optional[Task] _encryptor : CipherContext @@ -45,26 +50,53 @@ class Dispatcher: _log_ignored_packets : bool - host : str - port : int + _host : str + _port : int + _transport : Transport - proto : int - state : ConnectionState - encryption : bool - compression : Optional[int] + _proto : int + _encryption : bool + _compression : Optional[int] + + state : ConnectionState # TODO make getter/setter ? logger : logging.Logger def __init__(self, server:bool = False): - self.proto = 757 + self._proto = 757 self._is_server = server - self.host = "localhost" - self.port = 25565 + self._host = "localhost" + self._port = 25565 + self._transport = Transport.TCP self._dispatching = False self._packet_whitelist = None self._packet_id_whitelist = None self._log_ignored_packets = False + @property + def proto(self) -> int: + return self._proto + + @property + def host(self) -> str: + return self._host + + @property + def port(self) -> int: + return self._port + + @property + def transport(self) -> Transport: + return self._transport + + @property + def encryption(self) -> bool: + return self._encryption + + @property + def compression(self) -> Optional[int]: + return self._compression + @property def is_server(self) -> bool: return self._is_server @@ -95,10 +127,10 @@ class Dispatcher: cipher = encryption.create_AES_cipher(secret) self._encryptor = cipher.encryptor() self._decryptor = cipher.decryptor() - self.encryption = True + self._encryption = True self.logger.info("Encryption enabled") else: - self.encryption = False + self._encryption = False self.logger.info("Encryption disabled") return self @@ -107,19 +139,25 @@ class Dispatcher: if self._packet_whitelist: self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive) self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect) - self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None + self._packet_id_whitelist = set((P(self._proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None return self - def set_host(self, host:Optional[str]="localhost", port:Optional[int]=25565) -> 'Dispatcher': - self.host = host or self.host - self.port = port or self.port - self.logger = LOGGER.getChild(f"on({self.host}:{self.port})") + def set_host( + self, + host:Optional[str]="localhost", + port:Optional[int]=25565, + transport:Transport=Transport.TCP + ) -> 'Dispatcher': + self._host = host or self._host + self._port = port or self._port + self._transport = transport + self.logger = LOGGER.getChild(f"on({self._host}:{self._port})") return self def set_proto(self, proto:Optional[int]=757) -> 'Dispatcher': - self.proto = proto or self.proto + self._proto = proto or self._proto if self._packet_whitelist: - self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) + self._packet_id_whitelist = set((P(self._proto).id for P in self._packet_whitelist)) return self def set_state(self, state:Optional[ConnectionState]=ConnectionState.HANDSHAKING) -> 'Dispatcher': @@ -138,8 +176,8 @@ class Dispatcher: if self.connected: raise InvalidState("Dispatcher already connected") - self.encryption = False - self.compression = None + self._encryption = False + self._compression = None self._incoming = Queue(queue_size) self._outgoing = Queue(queue_size) self._dispatching = False @@ -149,11 +187,15 @@ class Dispatcher: if reader and writer: self._down, self._up = reader, writer else: # TODO put a timeout here and throw exception - self.logger.debug("Attempting to connect to %s:%d", self.host, self.port) - self._down, self._up = await asyncio.open_connection( - host=self.host, - port=self.port, - ) + self.logger.debug("Attempting to connect to %s:%d", self._host, self._port) + if self.transport == Transport.TCP: + self._down, self._up = await asyncio.open_connection( + host=self._host, + port=self._port, + ) + else: + self._up = await asyncio_dgram.connect((self.host, self.port)) + self._down = await asyncio_dgram.bind(("0.0.0.0", self.port)) self._dispatching = True self._reader = asyncio.get_event_loop().create_task(self._down_worker()) @@ -167,14 +209,14 @@ class Dispatcher: await asyncio.gather(self._writer, self._reader) self.logger.debug("Net workers stopped") if self._up: - if not self._up.is_closing() and self._up.can_write_eof(): + if isinstance(self._up, StreamWriter) and not self._up.is_closing() and self._up.can_write_eof(): try: self._up.write_eof() self.logger.debug("Wrote EOF on socket") except OSError as e: self.logger.error("Could not write EOF : %s", str(e)) self._up.close() - if block: + if block and isinstance(self._up, StreamWriter): await self._up.wait_closed() self.logger.debug("Socket closed") if block: @@ -200,19 +242,21 @@ class Dispatcher: else: reg = m.clientbound.REGISTRY - if not self.proto: + if not self._proto: raise InvalidState("Cannot access registries from invalid protocol") - proto_reg = reg[self.proto] + proto_reg = reg[self._proto] return proto_reg[packet_id] - async def _read_varint(self) -> int: + async def _read_varint_from_stream(self) -> int: + if not isinstance(self._down, StreamReader): + raise InvalidState("Requires a TCP connection") numRead = 0 result = 0 while True: data = await self._down.readexactly(1) - if self.encryption: + if self._encryption: data = self._decryptor.update(data) buf = int.from_bytes(data, 'little') result |= (buf & 0b01111111) << (7 * numRead) @@ -224,22 +268,43 @@ class Dispatcher: return result async def _read_packet(self) -> bytes: - length = await self._read_varint() - return await self._down.readexactly(length) + if isinstance(self._down, StreamReader): + length = await self._read_varint_from_stream() + return await self._down.readexactly(length) + elif isinstance(self._down, asyncio_dgram.DatagramServer): + data, source = await self._down.recv() + if source != self.host: + self.logger.warning("Host %s sent buffer '%s'", source, str(data)) + return b'' + # TODO do I need to discard size or maybe check it and merge with next packet? + return data + else: + self.logger.error("Unknown protocol, could not read from stream") + return b'' + async def _write_packet(self, data:bytes): + if isinstance(self._up, StreamWriter): + self._up.write(data) + await self._up.drain() # TODO maybe no need to call drain? + elif isinstance(self._up, asyncio_dgram.DatagramClient): + await self._up.send(data) + else: + self.logger.error("Unknown protocol, could not send packet") async def _down_worker(self, timeout:float=30): while self._dispatching: try: # Will timeout or raise EOFError if client gets disconnected data = await asyncio.wait_for(self._read_packet(), timeout=timeout) + if not data: + continue - if self.encryption: + if self._encryption: data = self._decryptor.update(data) buffer = io.BytesIO(data) - if self.compression is not None: - decompressed_size = VarInt.read(buffer, Context(_proto=self.proto)) + if self._compression is not None: + decompressed_size = VarInt.read(buffer, Context(_proto=self._proto)) if decompressed_size > 0: decompressor = zlib.decompressobj() decompressed_data = decompressor.decompress(buffer.read()) @@ -247,14 +312,14 @@ class Dispatcher: raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}") buffer = io.BytesIO(decompressed_data) - packet_id = VarInt.read(buffer, Context(_proto=self.proto)) + packet_id = VarInt.read(buffer, Context(_proto=self._proto)) if self.state == ConnectionState.PLAY and self._packet_id_whitelist \ and packet_id not in self._packet_id_whitelist: if self._log_ignored_packets: self.logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id) continue # ignore this packet, we rarely need them all, should improve performance cls = self._packet_type_from_registry(packet_id) - packet = cls.deserialize(self.proto, buffer) + packet = cls.deserialize(self._proto, buffer) self.logger.debug("[<--] Received | %s", repr(packet)) await self._incoming.put(packet) if self.state != ConnectionState.PLAY: @@ -288,26 +353,23 @@ class Dispatcher: buffer = packet.serialize() length = len(buffer.getvalue()) # ewww TODO - if self.compression is not None: - if length > self.compression: + if self._compression is not None: + if length > self._compression: new_buffer = io.BytesIO() - VarInt.write(length, new_buffer, Context(_proto=self.proto)) + VarInt.write(length, new_buffer, Context(_proto=self._proto)) new_buffer.write(zlib.compress(buffer.read())) buffer = new_buffer else: new_buffer = io.BytesIO() - VarInt.write(0, new_buffer, Context(_proto=self.proto)) + VarInt.write(0, new_buffer, Context(_proto=self._proto)) new_buffer.write(buffer.read()) buffer = new_buffer length = len(buffer.getvalue()) data = VarInt.serialize(length) + buffer.getvalue() - if self.encryption: + if self._encryption: data = self._encryptor.update(data) - - self._up.write(data) - await self._up.drain() # TODO maybe no need to call drain? - + await self._write_packet(data) self.logger.debug("[-->] Sent | %s", repr(packet)) except Exception: self.logger.exception("Exception dispatching packet %s", str(packet))