improved how parameters are managed, removed duplicates
This commit is contained in:
parent
7789888e03
commit
6052dc578b
2 changed files with 77 additions and 91 deletions
|
@ -28,11 +28,8 @@ from .util import encryption, helpers
|
|||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class MinecraftClient:
|
||||
host:str
|
||||
port:int
|
||||
username:str
|
||||
online_mode:bool
|
||||
authenticator:Optional[AuthInterface]
|
||||
authenticator:AuthInterface
|
||||
dispatcher : Dispatcher
|
||||
logger : logging.Logger
|
||||
_authenticated : bool
|
||||
|
@ -42,28 +39,26 @@ class MinecraftClient:
|
|||
def __init__(
|
||||
self,
|
||||
server:str,
|
||||
authenticator:AuthInterface,
|
||||
online_mode:bool = True,
|
||||
authenticator:AuthInterface=None,
|
||||
username:str = "",
|
||||
):
|
||||
super().__init__()
|
||||
if ":" in server:
|
||||
_host, _port = server.split(":", 1)
|
||||
self.host = _host.strip()
|
||||
self.port = int(_port)
|
||||
host = _host.strip()
|
||||
port = int(_port)
|
||||
else:
|
||||
self.host = server.strip()
|
||||
self.port = 25565
|
||||
host = server.strip()
|
||||
port = 25565
|
||||
|
||||
self.username = username
|
||||
self.online_mode = online_mode
|
||||
self.authenticator = authenticator
|
||||
self._authenticated = False
|
||||
|
||||
self.dispatcher = Dispatcher()
|
||||
self.dispatcher = Dispatcher().set_host(host, port)
|
||||
self._processing = False
|
||||
|
||||
self.logger = LOGGER.getChild(f"on({self.host}:{self.port})")
|
||||
self.logger = LOGGER.getChild(f"on({server})")
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
|
@ -88,27 +83,18 @@ class MinecraftClient:
|
|||
|
||||
async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]:
|
||||
"""Make a mini connection to asses server status and version"""
|
||||
self.host = host or self.host
|
||||
self.port = port or self.port
|
||||
try:
|
||||
await self.dispatcher.connect(self.host, self.port)
|
||||
await self.dispatcher.set_host(host, port).connect()
|
||||
await self._handshake(ConnectionState.STATUS)
|
||||
return await self._status(ping)
|
||||
finally:
|
||||
await self.dispatcher.disconnect()
|
||||
|
||||
async def join(self, host:str="", port:int=0, proto:int=0, packet_whitelist:Optional[Set[Type[Packet]]]=None): # jank packet_whitelist argument! TODO
|
||||
self.host = host or self.host
|
||||
self.port = port or self.port
|
||||
async def join(self, host:str="", port:int=0, proto:int=0):
|
||||
if self.online_mode:
|
||||
await self.authenticate()
|
||||
try:
|
||||
await self.dispatcher.connect(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
proto=proto,
|
||||
packet_whitelist=packet_whitelist
|
||||
)
|
||||
await self.dispatcher.set_host(host, port).set_proto(proto).connect()
|
||||
await self._handshake(ConnectionState.LOGIN)
|
||||
if await self._login():
|
||||
await self._play()
|
||||
|
@ -120,8 +106,8 @@ class MinecraftClient:
|
|||
PacketSetProtocol(
|
||||
self.dispatcher.proto,
|
||||
protocolVersion=self.dispatcher.proto,
|
||||
serverHost=self.host,
|
||||
serverPort=self.port,
|
||||
serverHost=self.dispatcher.host,
|
||||
serverPort=self.dispatcher.port,
|
||||
nextState=state.value
|
||||
)
|
||||
)
|
||||
|
@ -160,17 +146,14 @@ class MinecraftClient:
|
|||
await self.dispatcher.write(
|
||||
PacketLoginStart(
|
||||
self.dispatcher.proto,
|
||||
username=self.authenticator.selectedProfile.name if self.online_mode and self.authenticator else self.username
|
||||
username=self.authenticator.selectedProfile.name
|
||||
)
|
||||
)
|
||||
async for packet in self.dispatcher.packets():
|
||||
if isinstance(packet, PacketEncryptionBegin):
|
||||
if not self.online_mode:
|
||||
if not self.online_mode or not self.authenticator or not self.authenticator.accessToken: # overkill to check authenticator and accessToken but whatever
|
||||
self.logger.error("Cannot answer Encryption Request in offline mode")
|
||||
return False
|
||||
if not self.authenticator:
|
||||
self.logger.error("No available token to enable encryption")
|
||||
return False
|
||||
secret = encryption.generate_shared_secret()
|
||||
token, encrypted_secret = encryption.encrypt_token_and_secret(
|
||||
packet.publicKey,
|
||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
|||
from asyncio import StreamReader, StreamWriter, Queue, Task
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
|
||||
from types import ModuleType
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers import CipherContext
|
||||
|
||||
|
@ -39,8 +40,8 @@ class Dispatcher:
|
|||
_incoming : Queue
|
||||
_outgoing : Queue
|
||||
|
||||
_packet_whitelist : Set[Type[Packet]]
|
||||
_packet_id_whitelist : Set[int]
|
||||
_packet_whitelist : Optional[Set[Type[Packet]]]
|
||||
_packet_id_whitelist : Optional[Set[int]]
|
||||
|
||||
host : str
|
||||
port : int
|
||||
|
@ -50,14 +51,13 @@ class Dispatcher:
|
|||
encryption : bool
|
||||
compression : Optional[int]
|
||||
|
||||
_logger : logging.Logger
|
||||
logger : logging.Logger
|
||||
|
||||
def __init__(self, server:bool = False):
|
||||
self.proto = 757
|
||||
self._is_server = server
|
||||
self.host = "localhost"
|
||||
self.port = 25565
|
||||
self._prepare()
|
||||
|
||||
@property
|
||||
def is_server(self) -> bool:
|
||||
|
@ -84,57 +84,58 @@ class Dispatcher:
|
|||
except asyncio.TimeoutError:
|
||||
pass # so we recheck self.connected
|
||||
|
||||
def encrypt(self, secret:bytes):
|
||||
cipher = encryption.create_AES_cipher(secret)
|
||||
self._encryptor = cipher.encryptor()
|
||||
self._decryptor = cipher.decryptor()
|
||||
self.encryption = True
|
||||
self._logger.info("Encryption enabled")
|
||||
def encrypt(self, secret:Optional[bytes]=None) -> 'Dispatcher':
|
||||
if secret is not None:
|
||||
cipher = encryption.create_AES_cipher(secret)
|
||||
self._encryptor = cipher.encryptor()
|
||||
self._decryptor = cipher.decryptor()
|
||||
self.encryption = True
|
||||
self.logger.info("Encryption enabled")
|
||||
else:
|
||||
self.encryption = False
|
||||
self.logger.info("Encryption disabled")
|
||||
return self
|
||||
|
||||
def _prepare(self,
|
||||
host:Optional[str] = None,
|
||||
port:Optional[int] = None,
|
||||
proto:Optional[int] = None,
|
||||
queue_size:int = 100,
|
||||
packet_whitelist : Set[Type[Packet]] = None
|
||||
):
|
||||
self.proto = proto or self.proto or 757 # TODO not hardcode this?
|
||||
self.host = host or self.host or "localhost"
|
||||
self.port = port or self.port or 25565
|
||||
self._logger = LOGGER.getChild(f"on({self.host}:{self.port})")
|
||||
self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() # just in case make new set
|
||||
def whitelist(self, ids:Optional[List[Type[Packet]]]) -> 'Dispatcher':
|
||||
self._packet_whitelist = set(ids) if ids is not None else None
|
||||
if self._packet_whitelist:
|
||||
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
|
||||
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
|
||||
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None
|
||||
return self
|
||||
|
||||
def set_host(self, host:Optional[str]="localhost", port:Optional[int]=25565) -> 'Dispatcher':
|
||||
self.host = host or self.host
|
||||
self.port = port or self.port
|
||||
self.logger = LOGGER.getChild(f"on({self.host}:{self.port})")
|
||||
return self
|
||||
|
||||
def set_proto(self, proto:Optional[int]=757) -> 'Dispatcher':
|
||||
self.proto = proto or self.proto
|
||||
if self._packet_id_whitelist:
|
||||
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else set()
|
||||
return self
|
||||
|
||||
def set_state(self, state:Optional[ConnectionState]=ConnectionState.HANDSHAKING) -> 'Dispatcher':
|
||||
self.state = state or self.state
|
||||
return self
|
||||
|
||||
async def connect(self,
|
||||
reader : Optional[StreamReader] = None,
|
||||
writer : Optional[StreamWriter] = None,
|
||||
queue_size : int = 100,
|
||||
) -> 'Dispatcher':
|
||||
if self.connected:
|
||||
raise InvalidState("Dispatcher already connected")
|
||||
|
||||
self.encryption = False
|
||||
self.compression = None
|
||||
self.state = ConnectionState.HANDSHAKING
|
||||
|
||||
# This can only happen after we know the connection protocol
|
||||
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else set()
|
||||
|
||||
# Make new queues, do set a max size to sorta propagate back pressure
|
||||
self._incoming = Queue(queue_size)
|
||||
self._outgoing = Queue(queue_size)
|
||||
self._dispatching = False
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
|
||||
async def connect(self,
|
||||
host : Optional[str] = None,
|
||||
port : Optional[int] = None,
|
||||
proto : Optional[int] = None,
|
||||
reader : Optional[StreamReader] = None,
|
||||
writer : Optional[StreamWriter] = None,
|
||||
queue_size : int = 100,
|
||||
packet_whitelist : Set[Type[Packet]] = None,
|
||||
):
|
||||
if self.connected:
|
||||
raise InvalidState("Dispatcher already connected")
|
||||
|
||||
self._prepare(host, port, proto, queue_size, packet_whitelist)
|
||||
|
||||
if reader and writer:
|
||||
self._down, self._up = reader, writer
|
||||
else:
|
||||
|
@ -146,29 +147,31 @@ class Dispatcher:
|
|||
self._dispatching = True
|
||||
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
|
||||
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
|
||||
self._logger.info("Connected")
|
||||
self.logger.info("Connected")
|
||||
return self
|
||||
|
||||
async def disconnect(self, block:bool=True):
|
||||
async def disconnect(self, block:bool=True) -> 'Dispatcher':
|
||||
self._dispatching = False
|
||||
if block and self._writer and self._reader:
|
||||
await asyncio.gather(self._writer, self._reader)
|
||||
self._logger.debug("Net workers stopped")
|
||||
self.logger.debug("Net workers stopped")
|
||||
if self._up:
|
||||
if not self._up.is_closing() and self._up.can_write_eof():
|
||||
try:
|
||||
self._up.write_eof()
|
||||
except OSError as e:
|
||||
self._logger.error("Could not write EOF : %s", str(e))
|
||||
self.logger.error("Could not write EOF : %s", str(e))
|
||||
self._up.close()
|
||||
if block:
|
||||
await self._up.wait_closed()
|
||||
self._logger.debug("Socket closed")
|
||||
self.logger.debug("Socket closed")
|
||||
if block:
|
||||
self._logger.info("Disconnected")
|
||||
self.logger.info("Disconnected")
|
||||
return self
|
||||
|
||||
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
|
||||
# TODO de-jank this, language server gets kinda mad
|
||||
# m : Module
|
||||
m : ModuleType
|
||||
|
||||
if self.state == ConnectionState.HANDSHAKING:
|
||||
m = minecraft_protocol.handshaking
|
||||
elif self.state == ConnectionState.STATUS:
|
||||
|
@ -235,26 +238,26 @@ class Dispatcher:
|
|||
packet_id = VarInt.read(buffer, Context(_proto=self.proto))
|
||||
if self.state == ConnectionState.PLAY and self._packet_id_whitelist \
|
||||
and packet_id not in self._packet_id_whitelist:
|
||||
self._logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
|
||||
self.logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
|
||||
continue # ignore this packet, we rarely need them all, should improve performance
|
||||
cls = self._packet_type_from_registry(packet_id)
|
||||
packet = cls.deserialize(self.proto, buffer)
|
||||
self._logger.debug("[<--] Received | %s", repr(packet))
|
||||
self.logger.debug("[<--] Received | %s", repr(packet))
|
||||
await self._incoming.put(packet)
|
||||
if self.state != ConnectionState.PLAY:
|
||||
await self._incoming.join() # During play we can pre-process packets
|
||||
except (asyncio.TimeoutError, TimeoutError):
|
||||
self._logger.error("Connection timed out")
|
||||
self.logger.error("Connection timed out")
|
||||
await self.disconnect(block=False)
|
||||
except (ConnectionResetError, BrokenPipeError):
|
||||
self._logger.error("Connection reset while reading packet")
|
||||
self.logger.error("Connection reset while reading packet")
|
||||
await self.disconnect(block=False)
|
||||
except (asyncio.IncompleteReadError, EOFError):
|
||||
self._logger.error("Received EOF while reading packet")
|
||||
self.logger.error("Received EOF while reading packet")
|
||||
await self.disconnect(block=False)
|
||||
except Exception:
|
||||
self._logger.exception("Exception parsing packet %d", packet_id)
|
||||
self._logger.debug("%s", buffer.getvalue())
|
||||
self.logger.exception("Exception parsing packet %d", packet_id)
|
||||
self.logger.debug("%s", buffer.getvalue())
|
||||
await self.disconnect(block=False)
|
||||
|
||||
async def _up_worker(self, timeout=1):
|
||||
|
@ -288,9 +291,9 @@ class Dispatcher:
|
|||
self._up.write(data)
|
||||
await self._up.drain()
|
||||
|
||||
self._logger.debug("[-->] Sent | %s", repr(packet))
|
||||
self.logger.debug("[-->] Sent | %s", repr(packet))
|
||||
except Exception:
|
||||
self._logger.exception("Exception dispatching packet %s", str(packet))
|
||||
self.logger.exception("Exception dispatching packet %s", str(packet))
|
||||
|
||||
packet.processed.set() # Notify that packet has been processed
|
||||
|
||||
|
|
Loading…
Reference in a new issue