improved how parameters are managed, removed duplicates

This commit is contained in:
əlemi 2022-04-18 19:32:31 +02:00
parent 7789888e03
commit 6052dc578b
No known key found for this signature in database
GPG key ID: BBCBFE5D7244634E
2 changed files with 77 additions and 91 deletions

View file

@ -28,11 +28,8 @@ from .util import encryption, helpers
LOGGER = logging.getLogger(__name__)
class MinecraftClient:
host:str
port:int
username:str
online_mode:bool
authenticator:Optional[AuthInterface]
authenticator:AuthInterface
dispatcher : Dispatcher
logger : logging.Logger
_authenticated : bool
@ -42,28 +39,26 @@ class MinecraftClient:
def __init__(
self,
server:str,
authenticator:AuthInterface,
online_mode:bool = True,
authenticator:AuthInterface=None,
username:str = "",
):
super().__init__()
if ":" in server:
_host, _port = server.split(":", 1)
self.host = _host.strip()
self.port = int(_port)
host = _host.strip()
port = int(_port)
else:
self.host = server.strip()
self.port = 25565
host = server.strip()
port = 25565
self.username = username
self.online_mode = online_mode
self.authenticator = authenticator
self._authenticated = False
self.dispatcher = Dispatcher()
self.dispatcher = Dispatcher().set_host(host, port)
self._processing = False
self.logger = LOGGER.getChild(f"on({self.host}:{self.port})")
self.logger = LOGGER.getChild(f"on({server})")
@property
def connected(self) -> bool:
@ -88,27 +83,18 @@ class MinecraftClient:
async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]:
"""Make a mini connection to asses server status and version"""
self.host = host or self.host
self.port = port or self.port
try:
await self.dispatcher.connect(self.host, self.port)
await self.dispatcher.set_host(host, port).connect()
await self._handshake(ConnectionState.STATUS)
return await self._status(ping)
finally:
await self.dispatcher.disconnect()
async def join(self, host:str="", port:int=0, proto:int=0, packet_whitelist:Optional[Set[Type[Packet]]]=None): # jank packet_whitelist argument! TODO
self.host = host or self.host
self.port = port or self.port
async def join(self, host:str="", port:int=0, proto:int=0):
if self.online_mode:
await self.authenticate()
try:
await self.dispatcher.connect(
host=self.host,
port=self.port,
proto=proto,
packet_whitelist=packet_whitelist
)
await self.dispatcher.set_host(host, port).set_proto(proto).connect()
await self._handshake(ConnectionState.LOGIN)
if await self._login():
await self._play()
@ -120,8 +106,8 @@ class MinecraftClient:
PacketSetProtocol(
self.dispatcher.proto,
protocolVersion=self.dispatcher.proto,
serverHost=self.host,
serverPort=self.port,
serverHost=self.dispatcher.host,
serverPort=self.dispatcher.port,
nextState=state.value
)
)
@ -160,17 +146,14 @@ class MinecraftClient:
await self.dispatcher.write(
PacketLoginStart(
self.dispatcher.proto,
username=self.authenticator.selectedProfile.name if self.online_mode and self.authenticator else self.username
username=self.authenticator.selectedProfile.name
)
)
async for packet in self.dispatcher.packets():
if isinstance(packet, PacketEncryptionBegin):
if not self.online_mode:
if not self.online_mode or not self.authenticator or not self.authenticator.accessToken: # overkill to check authenticator and accessToken but whatever
self.logger.error("Cannot answer Encryption Request in offline mode")
return False
if not self.authenticator:
self.logger.error("No available token to enable encryption")
return False
secret = encryption.generate_shared_secret()
token, encrypted_secret = encryption.encrypt_token_and_secret(
packet.publicKey,

View file

@ -6,6 +6,7 @@ import logging
from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum
from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
from types import ModuleType
from cryptography.hazmat.primitives.ciphers import CipherContext
@ -39,8 +40,8 @@ class Dispatcher:
_incoming : Queue
_outgoing : Queue
_packet_whitelist : Set[Type[Packet]]
_packet_id_whitelist : Set[int]
_packet_whitelist : Optional[Set[Type[Packet]]]
_packet_id_whitelist : Optional[Set[int]]
host : str
port : int
@ -50,14 +51,13 @@ class Dispatcher:
encryption : bool
compression : Optional[int]
_logger : logging.Logger
logger : logging.Logger
def __init__(self, server:bool = False):
self.proto = 757
self._is_server = server
self.host = "localhost"
self.port = 25565
self._prepare()
@property
def is_server(self) -> bool:
@ -84,57 +84,58 @@ class Dispatcher:
except asyncio.TimeoutError:
pass # so we recheck self.connected
def encrypt(self, secret:bytes):
def encrypt(self, secret:Optional[bytes]=None) -> 'Dispatcher':
if secret is not None:
cipher = encryption.create_AES_cipher(secret)
self._encryptor = cipher.encryptor()
self._decryptor = cipher.decryptor()
self.encryption = True
self._logger.info("Encryption enabled")
self.logger.info("Encryption enabled")
else:
self.encryption = False
self.logger.info("Encryption disabled")
return self
def _prepare(self,
host:Optional[str] = None,
port:Optional[int] = None,
proto:Optional[int] = None,
queue_size:int = 100,
packet_whitelist : Set[Type[Packet]] = None
):
self.proto = proto or self.proto or 757 # TODO not hardcode this?
self.host = host or self.host or "localhost"
self.port = port or self.port or 25565
self._logger = LOGGER.getChild(f"on({self.host}:{self.port})")
self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() # just in case make new set
def whitelist(self, ids:Optional[List[Type[Packet]]]) -> 'Dispatcher':
self._packet_whitelist = set(ids) if ids is not None else None
if self._packet_whitelist:
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None
return self
def set_host(self, host:Optional[str]="localhost", port:Optional[int]=25565) -> 'Dispatcher':
self.host = host or self.host
self.port = port or self.port
self.logger = LOGGER.getChild(f"on({self.host}:{self.port})")
return self
def set_proto(self, proto:Optional[int]=757) -> 'Dispatcher':
self.proto = proto or self.proto
if self._packet_id_whitelist:
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else set()
return self
def set_state(self, state:Optional[ConnectionState]=ConnectionState.HANDSHAKING) -> 'Dispatcher':
self.state = state or self.state
return self
async def connect(self,
reader : Optional[StreamReader] = None,
writer : Optional[StreamWriter] = None,
queue_size : int = 100,
) -> 'Dispatcher':
if self.connected:
raise InvalidState("Dispatcher already connected")
self.encryption = False
self.compression = None
self.state = ConnectionState.HANDSHAKING
# This can only happen after we know the connection protocol
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else set()
# Make new queues, do set a max size to sorta propagate back pressure
self._incoming = Queue(queue_size)
self._outgoing = Queue(queue_size)
self._dispatching = False
self._reader = None
self._writer = None
async def connect(self,
host : Optional[str] = None,
port : Optional[int] = None,
proto : Optional[int] = None,
reader : Optional[StreamReader] = None,
writer : Optional[StreamWriter] = None,
queue_size : int = 100,
packet_whitelist : Set[Type[Packet]] = None,
):
if self.connected:
raise InvalidState("Dispatcher already connected")
self._prepare(host, port, proto, queue_size, packet_whitelist)
if reader and writer:
self._down, self._up = reader, writer
else:
@ -146,29 +147,31 @@ class Dispatcher:
self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
self._logger.info("Connected")
self.logger.info("Connected")
return self
async def disconnect(self, block:bool=True):
async def disconnect(self, block:bool=True) -> 'Dispatcher':
self._dispatching = False
if block and self._writer and self._reader:
await asyncio.gather(self._writer, self._reader)
self._logger.debug("Net workers stopped")
self.logger.debug("Net workers stopped")
if self._up:
if not self._up.is_closing() and self._up.can_write_eof():
try:
self._up.write_eof()
except OSError as e:
self._logger.error("Could not write EOF : %s", str(e))
self.logger.error("Could not write EOF : %s", str(e))
self._up.close()
if block:
await self._up.wait_closed()
self._logger.debug("Socket closed")
self.logger.debug("Socket closed")
if block:
self._logger.info("Disconnected")
self.logger.info("Disconnected")
return self
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
# TODO de-jank this, language server gets kinda mad
# m : Module
m : ModuleType
if self.state == ConnectionState.HANDSHAKING:
m = minecraft_protocol.handshaking
elif self.state == ConnectionState.STATUS:
@ -235,26 +238,26 @@ class Dispatcher:
packet_id = VarInt.read(buffer, Context(_proto=self.proto))
if self.state == ConnectionState.PLAY and self._packet_id_whitelist \
and packet_id not in self._packet_id_whitelist:
self._logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
self.logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
continue # ignore this packet, we rarely need them all, should improve performance
cls = self._packet_type_from_registry(packet_id)
packet = cls.deserialize(self.proto, buffer)
self._logger.debug("[<--] Received | %s", repr(packet))
self.logger.debug("[<--] Received | %s", repr(packet))
await self._incoming.put(packet)
if self.state != ConnectionState.PLAY:
await self._incoming.join() # During play we can pre-process packets
except (asyncio.TimeoutError, TimeoutError):
self._logger.error("Connection timed out")
self.logger.error("Connection timed out")
await self.disconnect(block=False)
except (ConnectionResetError, BrokenPipeError):
self._logger.error("Connection reset while reading packet")
self.logger.error("Connection reset while reading packet")
await self.disconnect(block=False)
except (asyncio.IncompleteReadError, EOFError):
self._logger.error("Received EOF while reading packet")
self.logger.error("Received EOF while reading packet")
await self.disconnect(block=False)
except Exception:
self._logger.exception("Exception parsing packet %d", packet_id)
self._logger.debug("%s", buffer.getvalue())
self.logger.exception("Exception parsing packet %d", packet_id)
self.logger.debug("%s", buffer.getvalue())
await self.disconnect(block=False)
async def _up_worker(self, timeout=1):
@ -288,9 +291,9 @@ class Dispatcher:
self._up.write(data)
await self._up.drain()
self._logger.debug("[-->] Sent | %s", repr(packet))
self.logger.debug("[-->] Sent | %s", repr(packet))
except Exception:
self._logger.exception("Exception dispatching packet %s", str(packet))
self.logger.exception("Exception dispatching packet %s", str(packet))
packet.processed.set() # Notify that packet has been processed