improved how parameters are managed, removed duplicates

This commit is contained in:
əlemi 2022-04-18 19:32:31 +02:00
parent 7789888e03
commit 6052dc578b
No known key found for this signature in database
GPG key ID: BBCBFE5D7244634E
2 changed files with 77 additions and 91 deletions

View file

@ -28,11 +28,8 @@ from .util import encryption, helpers
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
class MinecraftClient: class MinecraftClient:
host:str
port:int
username:str
online_mode:bool online_mode:bool
authenticator:Optional[AuthInterface] authenticator:AuthInterface
dispatcher : Dispatcher dispatcher : Dispatcher
logger : logging.Logger logger : logging.Logger
_authenticated : bool _authenticated : bool
@ -42,28 +39,26 @@ class MinecraftClient:
def __init__( def __init__(
self, self,
server:str, server:str,
authenticator:AuthInterface,
online_mode:bool = True, online_mode:bool = True,
authenticator:AuthInterface=None,
username:str = "",
): ):
super().__init__() super().__init__()
if ":" in server: if ":" in server:
_host, _port = server.split(":", 1) _host, _port = server.split(":", 1)
self.host = _host.strip() host = _host.strip()
self.port = int(_port) port = int(_port)
else: else:
self.host = server.strip() host = server.strip()
self.port = 25565 port = 25565
self.username = username
self.online_mode = online_mode self.online_mode = online_mode
self.authenticator = authenticator self.authenticator = authenticator
self._authenticated = False self._authenticated = False
self.dispatcher = Dispatcher() self.dispatcher = Dispatcher().set_host(host, port)
self._processing = False self._processing = False
self.logger = LOGGER.getChild(f"on({self.host}:{self.port})") self.logger = LOGGER.getChild(f"on({server})")
@property @property
def connected(self) -> bool: def connected(self) -> bool:
@ -88,27 +83,18 @@ class MinecraftClient:
async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]: async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]:
"""Make a mini connection to asses server status and version""" """Make a mini connection to asses server status and version"""
self.host = host or self.host
self.port = port or self.port
try: try:
await self.dispatcher.connect(self.host, self.port) await self.dispatcher.set_host(host, port).connect()
await self._handshake(ConnectionState.STATUS) await self._handshake(ConnectionState.STATUS)
return await self._status(ping) return await self._status(ping)
finally: finally:
await self.dispatcher.disconnect() await self.dispatcher.disconnect()
async def join(self, host:str="", port:int=0, proto:int=0, packet_whitelist:Optional[Set[Type[Packet]]]=None): # jank packet_whitelist argument! TODO async def join(self, host:str="", port:int=0, proto:int=0):
self.host = host or self.host
self.port = port or self.port
if self.online_mode: if self.online_mode:
await self.authenticate() await self.authenticate()
try: try:
await self.dispatcher.connect( await self.dispatcher.set_host(host, port).set_proto(proto).connect()
host=self.host,
port=self.port,
proto=proto,
packet_whitelist=packet_whitelist
)
await self._handshake(ConnectionState.LOGIN) await self._handshake(ConnectionState.LOGIN)
if await self._login(): if await self._login():
await self._play() await self._play()
@ -120,8 +106,8 @@ class MinecraftClient:
PacketSetProtocol( PacketSetProtocol(
self.dispatcher.proto, self.dispatcher.proto,
protocolVersion=self.dispatcher.proto, protocolVersion=self.dispatcher.proto,
serverHost=self.host, serverHost=self.dispatcher.host,
serverPort=self.port, serverPort=self.dispatcher.port,
nextState=state.value nextState=state.value
) )
) )
@ -160,17 +146,14 @@ class MinecraftClient:
await self.dispatcher.write( await self.dispatcher.write(
PacketLoginStart( PacketLoginStart(
self.dispatcher.proto, self.dispatcher.proto,
username=self.authenticator.selectedProfile.name if self.online_mode and self.authenticator else self.username username=self.authenticator.selectedProfile.name
) )
) )
async for packet in self.dispatcher.packets(): async for packet in self.dispatcher.packets():
if isinstance(packet, PacketEncryptionBegin): if isinstance(packet, PacketEncryptionBegin):
if not self.online_mode: if not self.online_mode or not self.authenticator or not self.authenticator.accessToken: # overkill to check authenticator and accessToken but whatever
self.logger.error("Cannot answer Encryption Request in offline mode") self.logger.error("Cannot answer Encryption Request in offline mode")
return False return False
if not self.authenticator:
self.logger.error("No available token to enable encryption")
return False
secret = encryption.generate_shared_secret() secret = encryption.generate_shared_secret()
token, encrypted_secret = encryption.encrypt_token_and_secret( token, encrypted_secret = encryption.encrypt_token_and_secret(
packet.publicKey, packet.publicKey,

View file

@ -6,6 +6,7 @@ import logging
from asyncio import StreamReader, StreamWriter, Queue, Task from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum from enum import Enum
from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
from types import ModuleType
from cryptography.hazmat.primitives.ciphers import CipherContext from cryptography.hazmat.primitives.ciphers import CipherContext
@ -39,8 +40,8 @@ class Dispatcher:
_incoming : Queue _incoming : Queue
_outgoing : Queue _outgoing : Queue
_packet_whitelist : Set[Type[Packet]] _packet_whitelist : Optional[Set[Type[Packet]]]
_packet_id_whitelist : Set[int] _packet_id_whitelist : Optional[Set[int]]
host : str host : str
port : int port : int
@ -50,14 +51,13 @@ class Dispatcher:
encryption : bool encryption : bool
compression : Optional[int] compression : Optional[int]
_logger : logging.Logger logger : logging.Logger
def __init__(self, server:bool = False): def __init__(self, server:bool = False):
self.proto = 757 self.proto = 757
self._is_server = server self._is_server = server
self.host = "localhost" self.host = "localhost"
self.port = 25565 self.port = 25565
self._prepare()
@property @property
def is_server(self) -> bool: def is_server(self) -> bool:
@ -84,57 +84,58 @@ class Dispatcher:
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass # so we recheck self.connected pass # so we recheck self.connected
def encrypt(self, secret:bytes): def encrypt(self, secret:Optional[bytes]=None) -> 'Dispatcher':
cipher = encryption.create_AES_cipher(secret) if secret is not None:
self._encryptor = cipher.encryptor() cipher = encryption.create_AES_cipher(secret)
self._decryptor = cipher.decryptor() self._encryptor = cipher.encryptor()
self.encryption = True self._decryptor = cipher.decryptor()
self._logger.info("Encryption enabled") self.encryption = True
self.logger.info("Encryption enabled")
else:
self.encryption = False
self.logger.info("Encryption disabled")
return self
def _prepare(self, def whitelist(self, ids:Optional[List[Type[Packet]]]) -> 'Dispatcher':
host:Optional[str] = None, self._packet_whitelist = set(ids) if ids is not None else None
port:Optional[int] = None,
proto:Optional[int] = None,
queue_size:int = 100,
packet_whitelist : Set[Type[Packet]] = None
):
self.proto = proto or self.proto or 757 # TODO not hardcode this?
self.host = host or self.host or "localhost"
self.port = port or self.port or 25565
self._logger = LOGGER.getChild(f"on({self.host}:{self.port})")
self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() # just in case make new set
if self._packet_whitelist: if self._packet_whitelist:
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive) self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect) self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None
return self
def set_host(self, host:Optional[str]="localhost", port:Optional[int]=25565) -> 'Dispatcher':
self.host = host or self.host
self.port = port or self.port
self.logger = LOGGER.getChild(f"on({self.host}:{self.port})")
return self
def set_proto(self, proto:Optional[int]=757) -> 'Dispatcher':
self.proto = proto or self.proto
if self._packet_id_whitelist:
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else set()
return self
def set_state(self, state:Optional[ConnectionState]=ConnectionState.HANDSHAKING) -> 'Dispatcher':
self.state = state or self.state
return self
async def connect(self,
reader : Optional[StreamReader] = None,
writer : Optional[StreamWriter] = None,
queue_size : int = 100,
) -> 'Dispatcher':
if self.connected:
raise InvalidState("Dispatcher already connected")
self.encryption = False self.encryption = False
self.compression = None self.compression = None
self.state = ConnectionState.HANDSHAKING
# This can only happen after we know the connection protocol
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else set()
# Make new queues, do set a max size to sorta propagate back pressure
self._incoming = Queue(queue_size) self._incoming = Queue(queue_size)
self._outgoing = Queue(queue_size) self._outgoing = Queue(queue_size)
self._dispatching = False self._dispatching = False
self._reader = None self._reader = None
self._writer = None self._writer = None
async def connect(self,
host : Optional[str] = None,
port : Optional[int] = None,
proto : Optional[int] = None,
reader : Optional[StreamReader] = None,
writer : Optional[StreamWriter] = None,
queue_size : int = 100,
packet_whitelist : Set[Type[Packet]] = None,
):
if self.connected:
raise InvalidState("Dispatcher already connected")
self._prepare(host, port, proto, queue_size, packet_whitelist)
if reader and writer: if reader and writer:
self._down, self._up = reader, writer self._down, self._up = reader, writer
else: else:
@ -146,29 +147,31 @@ class Dispatcher:
self._dispatching = True self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker()) self._reader = asyncio.get_event_loop().create_task(self._down_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker()) self._writer = asyncio.get_event_loop().create_task(self._up_worker())
self._logger.info("Connected") self.logger.info("Connected")
return self
async def disconnect(self, block:bool=True): async def disconnect(self, block:bool=True) -> 'Dispatcher':
self._dispatching = False self._dispatching = False
if block and self._writer and self._reader: if block and self._writer and self._reader:
await asyncio.gather(self._writer, self._reader) await asyncio.gather(self._writer, self._reader)
self._logger.debug("Net workers stopped") self.logger.debug("Net workers stopped")
if self._up: if self._up:
if not self._up.is_closing() and self._up.can_write_eof(): if not self._up.is_closing() and self._up.can_write_eof():
try: try:
self._up.write_eof() self._up.write_eof()
except OSError as e: except OSError as e:
self._logger.error("Could not write EOF : %s", str(e)) self.logger.error("Could not write EOF : %s", str(e))
self._up.close() self._up.close()
if block: if block:
await self._up.wait_closed() await self._up.wait_closed()
self._logger.debug("Socket closed") self.logger.debug("Socket closed")
if block: if block:
self._logger.info("Disconnected") self.logger.info("Disconnected")
return self
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]: def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
# TODO de-jank this, language server gets kinda mad m : ModuleType
# m : Module
if self.state == ConnectionState.HANDSHAKING: if self.state == ConnectionState.HANDSHAKING:
m = minecraft_protocol.handshaking m = minecraft_protocol.handshaking
elif self.state == ConnectionState.STATUS: elif self.state == ConnectionState.STATUS:
@ -235,26 +238,26 @@ class Dispatcher:
packet_id = VarInt.read(buffer, Context(_proto=self.proto)) packet_id = VarInt.read(buffer, Context(_proto=self.proto))
if self.state == ConnectionState.PLAY and self._packet_id_whitelist \ if self.state == ConnectionState.PLAY and self._packet_id_whitelist \
and packet_id not in self._packet_id_whitelist: and packet_id not in self._packet_id_whitelist:
self._logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id) self.logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
continue # ignore this packet, we rarely need them all, should improve performance continue # ignore this packet, we rarely need them all, should improve performance
cls = self._packet_type_from_registry(packet_id) cls = self._packet_type_from_registry(packet_id)
packet = cls.deserialize(self.proto, buffer) packet = cls.deserialize(self.proto, buffer)
self._logger.debug("[<--] Received | %s", repr(packet)) self.logger.debug("[<--] Received | %s", repr(packet))
await self._incoming.put(packet) await self._incoming.put(packet)
if self.state != ConnectionState.PLAY: if self.state != ConnectionState.PLAY:
await self._incoming.join() # During play we can pre-process packets await self._incoming.join() # During play we can pre-process packets
except (asyncio.TimeoutError, TimeoutError): except (asyncio.TimeoutError, TimeoutError):
self._logger.error("Connection timed out") self.logger.error("Connection timed out")
await self.disconnect(block=False) await self.disconnect(block=False)
except (ConnectionResetError, BrokenPipeError): except (ConnectionResetError, BrokenPipeError):
self._logger.error("Connection reset while reading packet") self.logger.error("Connection reset while reading packet")
await self.disconnect(block=False) await self.disconnect(block=False)
except (asyncio.IncompleteReadError, EOFError): except (asyncio.IncompleteReadError, EOFError):
self._logger.error("Received EOF while reading packet") self.logger.error("Received EOF while reading packet")
await self.disconnect(block=False) await self.disconnect(block=False)
except Exception: except Exception:
self._logger.exception("Exception parsing packet %d", packet_id) self.logger.exception("Exception parsing packet %d", packet_id)
self._logger.debug("%s", buffer.getvalue()) self.logger.debug("%s", buffer.getvalue())
await self.disconnect(block=False) await self.disconnect(block=False)
async def _up_worker(self, timeout=1): async def _up_worker(self, timeout=1):
@ -288,9 +291,9 @@ class Dispatcher:
self._up.write(data) self._up.write(data)
await self._up.drain() await self._up.drain()
self._logger.debug("[-->] Sent | %s", repr(packet)) self.logger.debug("[-->] Sent | %s", repr(packet))
except Exception: except Exception:
self._logger.exception("Exception dispatching packet %s", str(packet)) self.logger.exception("Exception dispatching packet %s", str(packet))
packet.processed.set() # Notify that packet has been processed packet.processed.set() # Notify that packet has been processed