made properties private, added UDP

UDP is added with asyncio-dgram library for now, maybe will implement it
directly as asyncio.Protocol later?
This commit is contained in:
əlemi 2022-05-09 23:50:54 +02:00
parent aa5aed5fbc
commit 4e2f0d87c1
No known key found for this signature in database
GPG key ID: BBCBFE5D7244634E
3 changed files with 127 additions and 62 deletions

View file

@ -1,6 +1,6 @@
[metadata]
name = aiocraft
version = 0.0.8
version = 0.0.9
author = alemi
author_email = me@alemi.dev
description = asyncio-powered headless minecraft client library
@ -20,6 +20,7 @@ install_requires =
cryptography
aiohttp
termcolor
asyncio-dgram
package_dir =
= src
packages = find:

View file

@ -10,7 +10,7 @@ from time import time
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any, Set
from .dispatcher import Dispatcher
from .dispatcher import Dispatcher, Transport
from .mc.packet import Packet
from .mc.auth import AuthInterface, AuthException, MojangAuthenticator, MicrosoftAuthenticator
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
@ -40,6 +40,7 @@ class MinecraftClient:
self,
server:str,
authenticator:AuthInterface,
use_udp:bool=False,
online_mode:bool = True,
):
super().__init__()
@ -55,7 +56,8 @@ class MinecraftClient:
self.authenticator = authenticator
self._authenticated = False
self.dispatcher = Dispatcher().set_host(host, port)
_transp = Transport.UDP if use_udp else Transport.TCP
self.dispatcher = Dispatcher().set_host(host, port, transport=_transp)
self._processing = False
self.logger = LOGGER.getChild(f"on({server})")
@ -106,10 +108,10 @@ class MinecraftClient:
async def _handshake(self, state:ConnectionState):
await self.dispatcher.write(
PacketSetProtocol(
self.dispatcher.proto,
protocolVersion=self.dispatcher.proto,
serverHost=self.dispatcher.host,
serverPort=self.dispatcher.port,
self.dispatcher._proto,
protocolVersion=self.dispatcher._proto,
serverHost=self.dispatcher._host,
serverPort=self.dispatcher._port,
nextState=state.value
)
)
@ -117,7 +119,7 @@ class MinecraftClient:
async def _status(self, ping:bool=False) -> Dict[str, Any]:
self.dispatcher.state = ConnectionState.STATUS
await self.dispatcher.write(
PacketPingStart(self.dispatcher.proto) #empty packet
PacketPingStart(self.dispatcher._proto) #empty packet
)
#Response
data : Dict[str, Any] = {}
@ -133,7 +135,7 @@ class MinecraftClient:
ping_time = time()
await self.dispatcher.write(
PacketPing(
self.dispatcher.proto,
self.dispatcher._proto,
time=ping_id,
)
)
@ -147,7 +149,7 @@ class MinecraftClient:
self.dispatcher.state = ConnectionState.LOGIN
await self.dispatcher.write(
PacketLoginStart(
self.dispatcher.proto,
self.dispatcher._proto,
username=self.authenticator.selectedProfile.name
)
)
@ -178,7 +180,7 @@ class MinecraftClient:
else:
self.logger.warning("Server gave an offline-mode serverId but still requested Encryption")
encryption_response = PacketEncryptionResponse(
self.dispatcher.proto,
self.dispatcher._proto,
sharedSecret=encrypted_secret,
verifyToken=token
)
@ -186,7 +188,7 @@ class MinecraftClient:
self.dispatcher.encrypt(secret)
elif isinstance(packet, PacketCompress):
self.logger.info("Compression enabled")
self.dispatcher.compression = packet.threshold
self.dispatcher._compression = packet.threshold
elif isinstance(packet, PacketLoginPluginRequest):
self.logger.info("Ignoring plugin request") # TODO ?
elif isinstance(packet, PacketSuccess):
@ -203,7 +205,7 @@ class MinecraftClient:
self.logger.debug("[ * ] Processing %s", packet.__class__.__name__)
if isinstance(packet, PacketSetCompression):
self.logger.info("Compression updated")
self.dispatcher.compression = packet.threshold
self.dispatcher._compression = packet.threshold
elif isinstance(packet, PacketKeepAlive):
keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId)
await self.dispatcher.write(keep_alive_packet)

View file

@ -9,6 +9,7 @@ from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
from types import ModuleType
from cryptography.hazmat.primitives.ciphers import CipherContext
import asyncio_dgram
from .mc import proto as minecraft_protocol
from .mc.types import VarInt, Context
@ -18,6 +19,10 @@ from .util import encryption
LOGGER = logging.getLogger(__name__)
class Transport(Enum):
TCP = 0
UDP = 1
class InvalidState(Exception):
pass
@ -27,11 +32,11 @@ class ConnectionError(Exception):
class Dispatcher:
_is_server : bool # True when receiving packets from clients
_down : StreamReader
_down : Union[StreamReader, asyncio_dgram.DatagramServer]
_reader : Optional[Task]
_decryptor : CipherContext
_up : StreamWriter
_up : Union[StreamWriter, asyncio_dgram.DatagramClient]
_writer : Optional[Task]
_encryptor : CipherContext
@ -45,26 +50,53 @@ class Dispatcher:
_log_ignored_packets : bool
host : str
port : int
_host : str
_port : int
_transport : Transport
proto : int
state : ConnectionState
encryption : bool
compression : Optional[int]
_proto : int
_encryption : bool
_compression : Optional[int]
state : ConnectionState # TODO make getter/setter ?
logger : logging.Logger
def __init__(self, server:bool = False):
self.proto = 757
self._proto = 757
self._is_server = server
self.host = "localhost"
self.port = 25565
self._host = "localhost"
self._port = 25565
self._transport = Transport.TCP
self._dispatching = False
self._packet_whitelist = None
self._packet_id_whitelist = None
self._log_ignored_packets = False
@property
def proto(self) -> int:
return self._proto
@property
def host(self) -> str:
return self._host
@property
def port(self) -> int:
return self._port
@property
def transport(self) -> Transport:
return self._transport
@property
def encryption(self) -> bool:
return self._encryption
@property
def compression(self) -> Optional[int]:
return self._compression
@property
def is_server(self) -> bool:
return self._is_server
@ -95,10 +127,10 @@ class Dispatcher:
cipher = encryption.create_AES_cipher(secret)
self._encryptor = cipher.encryptor()
self._decryptor = cipher.decryptor()
self.encryption = True
self._encryption = True
self.logger.info("Encryption enabled")
else:
self.encryption = False
self._encryption = False
self.logger.info("Encryption disabled")
return self
@ -107,19 +139,25 @@ class Dispatcher:
if self._packet_whitelist:
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None
self._packet_id_whitelist = set((P(self._proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None
return self
def set_host(self, host:Optional[str]="localhost", port:Optional[int]=25565) -> 'Dispatcher':
self.host = host or self.host
self.port = port or self.port
self.logger = LOGGER.getChild(f"on({self.host}:{self.port})")
def set_host(
self,
host:Optional[str]="localhost",
port:Optional[int]=25565,
transport:Transport=Transport.TCP
) -> 'Dispatcher':
self._host = host or self._host
self._port = port or self._port
self._transport = transport
self.logger = LOGGER.getChild(f"on({self._host}:{self._port})")
return self
def set_proto(self, proto:Optional[int]=757) -> 'Dispatcher':
self.proto = proto or self.proto
self._proto = proto or self._proto
if self._packet_whitelist:
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist))
self._packet_id_whitelist = set((P(self._proto).id for P in self._packet_whitelist))
return self
def set_state(self, state:Optional[ConnectionState]=ConnectionState.HANDSHAKING) -> 'Dispatcher':
@ -138,8 +176,8 @@ class Dispatcher:
if self.connected:
raise InvalidState("Dispatcher already connected")
self.encryption = False
self.compression = None
self._encryption = False
self._compression = None
self._incoming = Queue(queue_size)
self._outgoing = Queue(queue_size)
self._dispatching = False
@ -149,11 +187,15 @@ class Dispatcher:
if reader and writer:
self._down, self._up = reader, writer
else: # TODO put a timeout here and throw exception
self.logger.debug("Attempting to connect to %s:%d", self.host, self.port)
self._down, self._up = await asyncio.open_connection(
host=self.host,
port=self.port,
)
self.logger.debug("Attempting to connect to %s:%d", self._host, self._port)
if self.transport == Transport.TCP:
self._down, self._up = await asyncio.open_connection(
host=self._host,
port=self._port,
)
else:
self._up = await asyncio_dgram.connect((self.host, self.port))
self._down = await asyncio_dgram.bind(("0.0.0.0", self.port))
self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
@ -167,14 +209,14 @@ class Dispatcher:
await asyncio.gather(self._writer, self._reader)
self.logger.debug("Net workers stopped")
if self._up:
if not self._up.is_closing() and self._up.can_write_eof():
if isinstance(self._up, StreamWriter) and not self._up.is_closing() and self._up.can_write_eof():
try:
self._up.write_eof()
self.logger.debug("Wrote EOF on socket")
except OSError as e:
self.logger.error("Could not write EOF : %s", str(e))
self._up.close()
if block:
if block and isinstance(self._up, StreamWriter):
await self._up.wait_closed()
self.logger.debug("Socket closed")
if block:
@ -200,19 +242,21 @@ class Dispatcher:
else:
reg = m.clientbound.REGISTRY
if not self.proto:
if not self._proto:
raise InvalidState("Cannot access registries from invalid protocol")
proto_reg = reg[self.proto]
proto_reg = reg[self._proto]
return proto_reg[packet_id]
async def _read_varint(self) -> int:
async def _read_varint_from_stream(self) -> int:
if not isinstance(self._down, StreamReader):
raise InvalidState("Requires a TCP connection")
numRead = 0
result = 0
while True:
data = await self._down.readexactly(1)
if self.encryption:
if self._encryption:
data = self._decryptor.update(data)
buf = int.from_bytes(data, 'little')
result |= (buf & 0b01111111) << (7 * numRead)
@ -224,22 +268,43 @@ class Dispatcher:
return result
async def _read_packet(self) -> bytes:
length = await self._read_varint()
return await self._down.readexactly(length)
if isinstance(self._down, StreamReader):
length = await self._read_varint_from_stream()
return await self._down.readexactly(length)
elif isinstance(self._down, asyncio_dgram.DatagramServer):
data, source = await self._down.recv()
if source != self.host:
self.logger.warning("Host %s sent buffer '%s'", source, str(data))
return b''
# TODO do I need to discard size or maybe check it and merge with next packet?
return data
else:
self.logger.error("Unknown protocol, could not read from stream")
return b''
async def _write_packet(self, data:bytes):
if isinstance(self._up, StreamWriter):
self._up.write(data)
await self._up.drain() # TODO maybe no need to call drain?
elif isinstance(self._up, asyncio_dgram.DatagramClient):
await self._up.send(data)
else:
self.logger.error("Unknown protocol, could not send packet")
async def _down_worker(self, timeout:float=30):
while self._dispatching:
try: # Will timeout or raise EOFError if client gets disconnected
data = await asyncio.wait_for(self._read_packet(), timeout=timeout)
if not data:
continue
if self.encryption:
if self._encryption:
data = self._decryptor.update(data)
buffer = io.BytesIO(data)
if self.compression is not None:
decompressed_size = VarInt.read(buffer, Context(_proto=self.proto))
if self._compression is not None:
decompressed_size = VarInt.read(buffer, Context(_proto=self._proto))
if decompressed_size > 0:
decompressor = zlib.decompressobj()
decompressed_data = decompressor.decompress(buffer.read())
@ -247,14 +312,14 @@ class Dispatcher:
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer, Context(_proto=self.proto))
packet_id = VarInt.read(buffer, Context(_proto=self._proto))
if self.state == ConnectionState.PLAY and self._packet_id_whitelist \
and packet_id not in self._packet_id_whitelist:
if self._log_ignored_packets:
self.logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
continue # ignore this packet, we rarely need them all, should improve performance
cls = self._packet_type_from_registry(packet_id)
packet = cls.deserialize(self.proto, buffer)
packet = cls.deserialize(self._proto, buffer)
self.logger.debug("[<--] Received | %s", repr(packet))
await self._incoming.put(packet)
if self.state != ConnectionState.PLAY:
@ -288,26 +353,23 @@ class Dispatcher:
buffer = packet.serialize()
length = len(buffer.getvalue()) # ewww TODO
if self.compression is not None:
if length > self.compression:
if self._compression is not None:
if length > self._compression:
new_buffer = io.BytesIO()
VarInt.write(length, new_buffer, Context(_proto=self.proto))
VarInt.write(length, new_buffer, Context(_proto=self._proto))
new_buffer.write(zlib.compress(buffer.read()))
buffer = new_buffer
else:
new_buffer = io.BytesIO()
VarInt.write(0, new_buffer, Context(_proto=self.proto))
VarInt.write(0, new_buffer, Context(_proto=self._proto))
new_buffer.write(buffer.read())
buffer = new_buffer
length = len(buffer.getvalue())
data = VarInt.serialize(length) + buffer.getvalue()
if self.encryption:
if self._encryption:
data = self._encryptor.update(data)
self._up.write(data)
await self._up.drain() # TODO maybe no need to call drain?
await self._write_packet(data)
self.logger.debug("[-->] Sent | %s", repr(packet))
except Exception:
self.logger.exception("Exception dispatching packet %s", str(packet))