made properties private, added UDP

UDP is added with asyncio-dgram library for now, maybe will implement it
directly as asyncio.Protocol later?
This commit is contained in:
əlemi 2022-05-09 23:50:54 +02:00
parent aa5aed5fbc
commit 4e2f0d87c1
No known key found for this signature in database
GPG key ID: BBCBFE5D7244634E
3 changed files with 127 additions and 62 deletions

View file

@ -1,6 +1,6 @@
[metadata] [metadata]
name = aiocraft name = aiocraft
version = 0.0.8 version = 0.0.9
author = alemi author = alemi
author_email = me@alemi.dev author_email = me@alemi.dev
description = asyncio-powered headless minecraft client library description = asyncio-powered headless minecraft client library
@ -20,6 +20,7 @@ install_requires =
cryptography cryptography
aiohttp aiohttp
termcolor termcolor
asyncio-dgram
package_dir = package_dir =
= src = src
packages = find: packages = find:

View file

@ -10,7 +10,7 @@ from time import time
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any, Set from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any, Set
from .dispatcher import Dispatcher from .dispatcher import Dispatcher, Transport
from .mc.packet import Packet from .mc.packet import Packet
from .mc.auth import AuthInterface, AuthException, MojangAuthenticator, MicrosoftAuthenticator from .mc.auth import AuthInterface, AuthException, MojangAuthenticator, MicrosoftAuthenticator
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
@ -40,6 +40,7 @@ class MinecraftClient:
self, self,
server:str, server:str,
authenticator:AuthInterface, authenticator:AuthInterface,
use_udp:bool=False,
online_mode:bool = True, online_mode:bool = True,
): ):
super().__init__() super().__init__()
@ -55,7 +56,8 @@ class MinecraftClient:
self.authenticator = authenticator self.authenticator = authenticator
self._authenticated = False self._authenticated = False
self.dispatcher = Dispatcher().set_host(host, port) _transp = Transport.UDP if use_udp else Transport.TCP
self.dispatcher = Dispatcher().set_host(host, port, transport=_transp)
self._processing = False self._processing = False
self.logger = LOGGER.getChild(f"on({server})") self.logger = LOGGER.getChild(f"on({server})")
@ -106,10 +108,10 @@ class MinecraftClient:
async def _handshake(self, state:ConnectionState): async def _handshake(self, state:ConnectionState):
await self.dispatcher.write( await self.dispatcher.write(
PacketSetProtocol( PacketSetProtocol(
self.dispatcher.proto, self.dispatcher._proto,
protocolVersion=self.dispatcher.proto, protocolVersion=self.dispatcher._proto,
serverHost=self.dispatcher.host, serverHost=self.dispatcher._host,
serverPort=self.dispatcher.port, serverPort=self.dispatcher._port,
nextState=state.value nextState=state.value
) )
) )
@ -117,7 +119,7 @@ class MinecraftClient:
async def _status(self, ping:bool=False) -> Dict[str, Any]: async def _status(self, ping:bool=False) -> Dict[str, Any]:
self.dispatcher.state = ConnectionState.STATUS self.dispatcher.state = ConnectionState.STATUS
await self.dispatcher.write( await self.dispatcher.write(
PacketPingStart(self.dispatcher.proto) #empty packet PacketPingStart(self.dispatcher._proto) #empty packet
) )
#Response #Response
data : Dict[str, Any] = {} data : Dict[str, Any] = {}
@ -133,7 +135,7 @@ class MinecraftClient:
ping_time = time() ping_time = time()
await self.dispatcher.write( await self.dispatcher.write(
PacketPing( PacketPing(
self.dispatcher.proto, self.dispatcher._proto,
time=ping_id, time=ping_id,
) )
) )
@ -147,7 +149,7 @@ class MinecraftClient:
self.dispatcher.state = ConnectionState.LOGIN self.dispatcher.state = ConnectionState.LOGIN
await self.dispatcher.write( await self.dispatcher.write(
PacketLoginStart( PacketLoginStart(
self.dispatcher.proto, self.dispatcher._proto,
username=self.authenticator.selectedProfile.name username=self.authenticator.selectedProfile.name
) )
) )
@ -178,7 +180,7 @@ class MinecraftClient:
else: else:
self.logger.warning("Server gave an offline-mode serverId but still requested Encryption") self.logger.warning("Server gave an offline-mode serverId but still requested Encryption")
encryption_response = PacketEncryptionResponse( encryption_response = PacketEncryptionResponse(
self.dispatcher.proto, self.dispatcher._proto,
sharedSecret=encrypted_secret, sharedSecret=encrypted_secret,
verifyToken=token verifyToken=token
) )
@ -186,7 +188,7 @@ class MinecraftClient:
self.dispatcher.encrypt(secret) self.dispatcher.encrypt(secret)
elif isinstance(packet, PacketCompress): elif isinstance(packet, PacketCompress):
self.logger.info("Compression enabled") self.logger.info("Compression enabled")
self.dispatcher.compression = packet.threshold self.dispatcher._compression = packet.threshold
elif isinstance(packet, PacketLoginPluginRequest): elif isinstance(packet, PacketLoginPluginRequest):
self.logger.info("Ignoring plugin request") # TODO ? self.logger.info("Ignoring plugin request") # TODO ?
elif isinstance(packet, PacketSuccess): elif isinstance(packet, PacketSuccess):
@ -203,7 +205,7 @@ class MinecraftClient:
self.logger.debug("[ * ] Processing %s", packet.__class__.__name__) self.logger.debug("[ * ] Processing %s", packet.__class__.__name__)
if isinstance(packet, PacketSetCompression): if isinstance(packet, PacketSetCompression):
self.logger.info("Compression updated") self.logger.info("Compression updated")
self.dispatcher.compression = packet.threshold self.dispatcher._compression = packet.threshold
elif isinstance(packet, PacketKeepAlive): elif isinstance(packet, PacketKeepAlive):
keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId) keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId)
await self.dispatcher.write(keep_alive_packet) await self.dispatcher.write(keep_alive_packet)

View file

@ -9,6 +9,7 @@ from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
from types import ModuleType from types import ModuleType
from cryptography.hazmat.primitives.ciphers import CipherContext from cryptography.hazmat.primitives.ciphers import CipherContext
import asyncio_dgram
from .mc import proto as minecraft_protocol from .mc import proto as minecraft_protocol
from .mc.types import VarInt, Context from .mc.types import VarInt, Context
@ -18,6 +19,10 @@ from .util import encryption
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
class Transport(Enum):
TCP = 0
UDP = 1
class InvalidState(Exception): class InvalidState(Exception):
pass pass
@ -27,11 +32,11 @@ class ConnectionError(Exception):
class Dispatcher: class Dispatcher:
_is_server : bool # True when receiving packets from clients _is_server : bool # True when receiving packets from clients
_down : StreamReader _down : Union[StreamReader, asyncio_dgram.DatagramServer]
_reader : Optional[Task] _reader : Optional[Task]
_decryptor : CipherContext _decryptor : CipherContext
_up : StreamWriter _up : Union[StreamWriter, asyncio_dgram.DatagramClient]
_writer : Optional[Task] _writer : Optional[Task]
_encryptor : CipherContext _encryptor : CipherContext
@ -45,26 +50,53 @@ class Dispatcher:
_log_ignored_packets : bool _log_ignored_packets : bool
host : str _host : str
port : int _port : int
_transport : Transport
proto : int _proto : int
state : ConnectionState
encryption : bool
compression : Optional[int]
_encryption : bool
_compression : Optional[int]
state : ConnectionState # TODO make getter/setter ?
logger : logging.Logger logger : logging.Logger
def __init__(self, server:bool = False): def __init__(self, server:bool = False):
self.proto = 757 self._proto = 757
self._is_server = server self._is_server = server
self.host = "localhost" self._host = "localhost"
self.port = 25565 self._port = 25565
self._transport = Transport.TCP
self._dispatching = False self._dispatching = False
self._packet_whitelist = None self._packet_whitelist = None
self._packet_id_whitelist = None self._packet_id_whitelist = None
self._log_ignored_packets = False self._log_ignored_packets = False
@property
def proto(self) -> int:
return self._proto
@property
def host(self) -> str:
return self._host
@property
def port(self) -> int:
return self._port
@property
def transport(self) -> Transport:
return self._transport
@property
def encryption(self) -> bool:
return self._encryption
@property
def compression(self) -> Optional[int]:
return self._compression
@property @property
def is_server(self) -> bool: def is_server(self) -> bool:
return self._is_server return self._is_server
@ -95,10 +127,10 @@ class Dispatcher:
cipher = encryption.create_AES_cipher(secret) cipher = encryption.create_AES_cipher(secret)
self._encryptor = cipher.encryptor() self._encryptor = cipher.encryptor()
self._decryptor = cipher.decryptor() self._decryptor = cipher.decryptor()
self.encryption = True self._encryption = True
self.logger.info("Encryption enabled") self.logger.info("Encryption enabled")
else: else:
self.encryption = False self._encryption = False
self.logger.info("Encryption disabled") self.logger.info("Encryption disabled")
return self return self
@ -107,19 +139,25 @@ class Dispatcher:
if self._packet_whitelist: if self._packet_whitelist:
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive) self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect) self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None self._packet_id_whitelist = set((P(self._proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None
return self return self
def set_host(self, host:Optional[str]="localhost", port:Optional[int]=25565) -> 'Dispatcher': def set_host(
self.host = host or self.host self,
self.port = port or self.port host:Optional[str]="localhost",
self.logger = LOGGER.getChild(f"on({self.host}:{self.port})") port:Optional[int]=25565,
transport:Transport=Transport.TCP
) -> 'Dispatcher':
self._host = host or self._host
self._port = port or self._port
self._transport = transport
self.logger = LOGGER.getChild(f"on({self._host}:{self._port})")
return self return self
def set_proto(self, proto:Optional[int]=757) -> 'Dispatcher': def set_proto(self, proto:Optional[int]=757) -> 'Dispatcher':
self.proto = proto or self.proto self._proto = proto or self._proto
if self._packet_whitelist: if self._packet_whitelist:
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) self._packet_id_whitelist = set((P(self._proto).id for P in self._packet_whitelist))
return self return self
def set_state(self, state:Optional[ConnectionState]=ConnectionState.HANDSHAKING) -> 'Dispatcher': def set_state(self, state:Optional[ConnectionState]=ConnectionState.HANDSHAKING) -> 'Dispatcher':
@ -138,8 +176,8 @@ class Dispatcher:
if self.connected: if self.connected:
raise InvalidState("Dispatcher already connected") raise InvalidState("Dispatcher already connected")
self.encryption = False self._encryption = False
self.compression = None self._compression = None
self._incoming = Queue(queue_size) self._incoming = Queue(queue_size)
self._outgoing = Queue(queue_size) self._outgoing = Queue(queue_size)
self._dispatching = False self._dispatching = False
@ -149,11 +187,15 @@ class Dispatcher:
if reader and writer: if reader and writer:
self._down, self._up = reader, writer self._down, self._up = reader, writer
else: # TODO put a timeout here and throw exception else: # TODO put a timeout here and throw exception
self.logger.debug("Attempting to connect to %s:%d", self.host, self.port) self.logger.debug("Attempting to connect to %s:%d", self._host, self._port)
if self.transport == Transport.TCP:
self._down, self._up = await asyncio.open_connection( self._down, self._up = await asyncio.open_connection(
host=self.host, host=self._host,
port=self.port, port=self._port,
) )
else:
self._up = await asyncio_dgram.connect((self.host, self.port))
self._down = await asyncio_dgram.bind(("0.0.0.0", self.port))
self._dispatching = True self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker()) self._reader = asyncio.get_event_loop().create_task(self._down_worker())
@ -167,14 +209,14 @@ class Dispatcher:
await asyncio.gather(self._writer, self._reader) await asyncio.gather(self._writer, self._reader)
self.logger.debug("Net workers stopped") self.logger.debug("Net workers stopped")
if self._up: if self._up:
if not self._up.is_closing() and self._up.can_write_eof(): if isinstance(self._up, StreamWriter) and not self._up.is_closing() and self._up.can_write_eof():
try: try:
self._up.write_eof() self._up.write_eof()
self.logger.debug("Wrote EOF on socket") self.logger.debug("Wrote EOF on socket")
except OSError as e: except OSError as e:
self.logger.error("Could not write EOF : %s", str(e)) self.logger.error("Could not write EOF : %s", str(e))
self._up.close() self._up.close()
if block: if block and isinstance(self._up, StreamWriter):
await self._up.wait_closed() await self._up.wait_closed()
self.logger.debug("Socket closed") self.logger.debug("Socket closed")
if block: if block:
@ -200,19 +242,21 @@ class Dispatcher:
else: else:
reg = m.clientbound.REGISTRY reg = m.clientbound.REGISTRY
if not self.proto: if not self._proto:
raise InvalidState("Cannot access registries from invalid protocol") raise InvalidState("Cannot access registries from invalid protocol")
proto_reg = reg[self.proto] proto_reg = reg[self._proto]
return proto_reg[packet_id] return proto_reg[packet_id]
async def _read_varint(self) -> int: async def _read_varint_from_stream(self) -> int:
if not isinstance(self._down, StreamReader):
raise InvalidState("Requires a TCP connection")
numRead = 0 numRead = 0
result = 0 result = 0
while True: while True:
data = await self._down.readexactly(1) data = await self._down.readexactly(1)
if self.encryption: if self._encryption:
data = self._decryptor.update(data) data = self._decryptor.update(data)
buf = int.from_bytes(data, 'little') buf = int.from_bytes(data, 'little')
result |= (buf & 0b01111111) << (7 * numRead) result |= (buf & 0b01111111) << (7 * numRead)
@ -224,22 +268,43 @@ class Dispatcher:
return result return result
async def _read_packet(self) -> bytes: async def _read_packet(self) -> bytes:
length = await self._read_varint() if isinstance(self._down, StreamReader):
length = await self._read_varint_from_stream()
return await self._down.readexactly(length) return await self._down.readexactly(length)
elif isinstance(self._down, asyncio_dgram.DatagramServer):
data, source = await self._down.recv()
if source != self.host:
self.logger.warning("Host %s sent buffer '%s'", source, str(data))
return b''
# TODO do I need to discard size or maybe check it and merge with next packet?
return data
else:
self.logger.error("Unknown protocol, could not read from stream")
return b''
async def _write_packet(self, data:bytes):
if isinstance(self._up, StreamWriter):
self._up.write(data)
await self._up.drain() # TODO maybe no need to call drain?
elif isinstance(self._up, asyncio_dgram.DatagramClient):
await self._up.send(data)
else:
self.logger.error("Unknown protocol, could not send packet")
async def _down_worker(self, timeout:float=30): async def _down_worker(self, timeout:float=30):
while self._dispatching: while self._dispatching:
try: # Will timeout or raise EOFError if client gets disconnected try: # Will timeout or raise EOFError if client gets disconnected
data = await asyncio.wait_for(self._read_packet(), timeout=timeout) data = await asyncio.wait_for(self._read_packet(), timeout=timeout)
if not data:
continue
if self.encryption: if self._encryption:
data = self._decryptor.update(data) data = self._decryptor.update(data)
buffer = io.BytesIO(data) buffer = io.BytesIO(data)
if self.compression is not None: if self._compression is not None:
decompressed_size = VarInt.read(buffer, Context(_proto=self.proto)) decompressed_size = VarInt.read(buffer, Context(_proto=self._proto))
if decompressed_size > 0: if decompressed_size > 0:
decompressor = zlib.decompressobj() decompressor = zlib.decompressobj()
decompressed_data = decompressor.decompress(buffer.read()) decompressed_data = decompressor.decompress(buffer.read())
@ -247,14 +312,14 @@ class Dispatcher:
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}") raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
buffer = io.BytesIO(decompressed_data) buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer, Context(_proto=self.proto)) packet_id = VarInt.read(buffer, Context(_proto=self._proto))
if self.state == ConnectionState.PLAY and self._packet_id_whitelist \ if self.state == ConnectionState.PLAY and self._packet_id_whitelist \
and packet_id not in self._packet_id_whitelist: and packet_id not in self._packet_id_whitelist:
if self._log_ignored_packets: if self._log_ignored_packets:
self.logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id) self.logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
continue # ignore this packet, we rarely need them all, should improve performance continue # ignore this packet, we rarely need them all, should improve performance
cls = self._packet_type_from_registry(packet_id) cls = self._packet_type_from_registry(packet_id)
packet = cls.deserialize(self.proto, buffer) packet = cls.deserialize(self._proto, buffer)
self.logger.debug("[<--] Received | %s", repr(packet)) self.logger.debug("[<--] Received | %s", repr(packet))
await self._incoming.put(packet) await self._incoming.put(packet)
if self.state != ConnectionState.PLAY: if self.state != ConnectionState.PLAY:
@ -288,26 +353,23 @@ class Dispatcher:
buffer = packet.serialize() buffer = packet.serialize()
length = len(buffer.getvalue()) # ewww TODO length = len(buffer.getvalue()) # ewww TODO
if self.compression is not None: if self._compression is not None:
if length > self.compression: if length > self._compression:
new_buffer = io.BytesIO() new_buffer = io.BytesIO()
VarInt.write(length, new_buffer, Context(_proto=self.proto)) VarInt.write(length, new_buffer, Context(_proto=self._proto))
new_buffer.write(zlib.compress(buffer.read())) new_buffer.write(zlib.compress(buffer.read()))
buffer = new_buffer buffer = new_buffer
else: else:
new_buffer = io.BytesIO() new_buffer = io.BytesIO()
VarInt.write(0, new_buffer, Context(_proto=self.proto)) VarInt.write(0, new_buffer, Context(_proto=self._proto))
new_buffer.write(buffer.read()) new_buffer.write(buffer.read())
buffer = new_buffer buffer = new_buffer
length = len(buffer.getvalue()) length = len(buffer.getvalue())
data = VarInt.serialize(length) + buffer.getvalue() data = VarInt.serialize(length) + buffer.getvalue()
if self.encryption: if self._encryption:
data = self._encryptor.update(data) data = self._encryptor.update(data)
await self._write_packet(data)
self._up.write(data)
await self._up.drain() # TODO maybe no need to call drain?
self.logger.debug("[-->] Sent | %s", repr(packet)) self.logger.debug("[-->] Sent | %s", repr(packet))
except Exception: except Exception:
self.logger.exception("Exception dispatching packet %s", str(packet)) self.logger.exception("Exception dispatching packet %s", str(packet))