improved how parameters are managed, removed duplicates
This commit is contained in:
parent
7789888e03
commit
6052dc578b
2 changed files with 77 additions and 91 deletions
|
@ -28,11 +28,8 @@ from .util import encryption, helpers
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
class MinecraftClient:
|
class MinecraftClient:
|
||||||
host:str
|
|
||||||
port:int
|
|
||||||
username:str
|
|
||||||
online_mode:bool
|
online_mode:bool
|
||||||
authenticator:Optional[AuthInterface]
|
authenticator:AuthInterface
|
||||||
dispatcher : Dispatcher
|
dispatcher : Dispatcher
|
||||||
logger : logging.Logger
|
logger : logging.Logger
|
||||||
_authenticated : bool
|
_authenticated : bool
|
||||||
|
@ -42,28 +39,26 @@ class MinecraftClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server:str,
|
server:str,
|
||||||
|
authenticator:AuthInterface,
|
||||||
online_mode:bool = True,
|
online_mode:bool = True,
|
||||||
authenticator:AuthInterface=None,
|
|
||||||
username:str = "",
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if ":" in server:
|
if ":" in server:
|
||||||
_host, _port = server.split(":", 1)
|
_host, _port = server.split(":", 1)
|
||||||
self.host = _host.strip()
|
host = _host.strip()
|
||||||
self.port = int(_port)
|
port = int(_port)
|
||||||
else:
|
else:
|
||||||
self.host = server.strip()
|
host = server.strip()
|
||||||
self.port = 25565
|
port = 25565
|
||||||
|
|
||||||
self.username = username
|
|
||||||
self.online_mode = online_mode
|
self.online_mode = online_mode
|
||||||
self.authenticator = authenticator
|
self.authenticator = authenticator
|
||||||
self._authenticated = False
|
self._authenticated = False
|
||||||
|
|
||||||
self.dispatcher = Dispatcher()
|
self.dispatcher = Dispatcher().set_host(host, port)
|
||||||
self._processing = False
|
self._processing = False
|
||||||
|
|
||||||
self.logger = LOGGER.getChild(f"on({self.host}:{self.port})")
|
self.logger = LOGGER.getChild(f"on({server})")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connected(self) -> bool:
|
def connected(self) -> bool:
|
||||||
|
@ -88,27 +83,18 @@ class MinecraftClient:
|
||||||
|
|
||||||
async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]:
|
async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]:
|
||||||
"""Make a mini connection to asses server status and version"""
|
"""Make a mini connection to asses server status and version"""
|
||||||
self.host = host or self.host
|
|
||||||
self.port = port or self.port
|
|
||||||
try:
|
try:
|
||||||
await self.dispatcher.connect(self.host, self.port)
|
await self.dispatcher.set_host(host, port).connect()
|
||||||
await self._handshake(ConnectionState.STATUS)
|
await self._handshake(ConnectionState.STATUS)
|
||||||
return await self._status(ping)
|
return await self._status(ping)
|
||||||
finally:
|
finally:
|
||||||
await self.dispatcher.disconnect()
|
await self.dispatcher.disconnect()
|
||||||
|
|
||||||
async def join(self, host:str="", port:int=0, proto:int=0, packet_whitelist:Optional[Set[Type[Packet]]]=None): # jank packet_whitelist argument! TODO
|
async def join(self, host:str="", port:int=0, proto:int=0):
|
||||||
self.host = host or self.host
|
|
||||||
self.port = port or self.port
|
|
||||||
if self.online_mode:
|
if self.online_mode:
|
||||||
await self.authenticate()
|
await self.authenticate()
|
||||||
try:
|
try:
|
||||||
await self.dispatcher.connect(
|
await self.dispatcher.set_host(host, port).set_proto(proto).connect()
|
||||||
host=self.host,
|
|
||||||
port=self.port,
|
|
||||||
proto=proto,
|
|
||||||
packet_whitelist=packet_whitelist
|
|
||||||
)
|
|
||||||
await self._handshake(ConnectionState.LOGIN)
|
await self._handshake(ConnectionState.LOGIN)
|
||||||
if await self._login():
|
if await self._login():
|
||||||
await self._play()
|
await self._play()
|
||||||
|
@ -120,8 +106,8 @@ class MinecraftClient:
|
||||||
PacketSetProtocol(
|
PacketSetProtocol(
|
||||||
self.dispatcher.proto,
|
self.dispatcher.proto,
|
||||||
protocolVersion=self.dispatcher.proto,
|
protocolVersion=self.dispatcher.proto,
|
||||||
serverHost=self.host,
|
serverHost=self.dispatcher.host,
|
||||||
serverPort=self.port,
|
serverPort=self.dispatcher.port,
|
||||||
nextState=state.value
|
nextState=state.value
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -160,17 +146,14 @@ class MinecraftClient:
|
||||||
await self.dispatcher.write(
|
await self.dispatcher.write(
|
||||||
PacketLoginStart(
|
PacketLoginStart(
|
||||||
self.dispatcher.proto,
|
self.dispatcher.proto,
|
||||||
username=self.authenticator.selectedProfile.name if self.online_mode and self.authenticator else self.username
|
username=self.authenticator.selectedProfile.name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
async for packet in self.dispatcher.packets():
|
async for packet in self.dispatcher.packets():
|
||||||
if isinstance(packet, PacketEncryptionBegin):
|
if isinstance(packet, PacketEncryptionBegin):
|
||||||
if not self.online_mode:
|
if not self.online_mode or not self.authenticator or not self.authenticator.accessToken: # overkill to check authenticator and accessToken but whatever
|
||||||
self.logger.error("Cannot answer Encryption Request in offline mode")
|
self.logger.error("Cannot answer Encryption Request in offline mode")
|
||||||
return False
|
return False
|
||||||
if not self.authenticator:
|
|
||||||
self.logger.error("No available token to enable encryption")
|
|
||||||
return False
|
|
||||||
secret = encryption.generate_shared_secret()
|
secret = encryption.generate_shared_secret()
|
||||||
token, encrypted_secret = encryption.encrypt_token_and_secret(
|
token, encrypted_secret = encryption.encrypt_token_and_secret(
|
||||||
packet.publicKey,
|
packet.publicKey,
|
||||||
|
|
|
@ -6,6 +6,7 @@ import logging
|
||||||
from asyncio import StreamReader, StreamWriter, Queue, Task
|
from asyncio import StreamReader, StreamWriter, Queue, Task
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
|
from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
|
||||||
|
from types import ModuleType
|
||||||
|
|
||||||
from cryptography.hazmat.primitives.ciphers import CipherContext
|
from cryptography.hazmat.primitives.ciphers import CipherContext
|
||||||
|
|
||||||
|
@ -39,8 +40,8 @@ class Dispatcher:
|
||||||
_incoming : Queue
|
_incoming : Queue
|
||||||
_outgoing : Queue
|
_outgoing : Queue
|
||||||
|
|
||||||
_packet_whitelist : Set[Type[Packet]]
|
_packet_whitelist : Optional[Set[Type[Packet]]]
|
||||||
_packet_id_whitelist : Set[int]
|
_packet_id_whitelist : Optional[Set[int]]
|
||||||
|
|
||||||
host : str
|
host : str
|
||||||
port : int
|
port : int
|
||||||
|
@ -50,14 +51,13 @@ class Dispatcher:
|
||||||
encryption : bool
|
encryption : bool
|
||||||
compression : Optional[int]
|
compression : Optional[int]
|
||||||
|
|
||||||
_logger : logging.Logger
|
logger : logging.Logger
|
||||||
|
|
||||||
def __init__(self, server:bool = False):
|
def __init__(self, server:bool = False):
|
||||||
self.proto = 757
|
self.proto = 757
|
||||||
self._is_server = server
|
self._is_server = server
|
||||||
self.host = "localhost"
|
self.host = "localhost"
|
||||||
self.port = 25565
|
self.port = 25565
|
||||||
self._prepare()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_server(self) -> bool:
|
def is_server(self) -> bool:
|
||||||
|
@ -84,57 +84,58 @@ class Dispatcher:
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass # so we recheck self.connected
|
pass # so we recheck self.connected
|
||||||
|
|
||||||
def encrypt(self, secret:bytes):
|
def encrypt(self, secret:Optional[bytes]=None) -> 'Dispatcher':
|
||||||
cipher = encryption.create_AES_cipher(secret)
|
if secret is not None:
|
||||||
self._encryptor = cipher.encryptor()
|
cipher = encryption.create_AES_cipher(secret)
|
||||||
self._decryptor = cipher.decryptor()
|
self._encryptor = cipher.encryptor()
|
||||||
self.encryption = True
|
self._decryptor = cipher.decryptor()
|
||||||
self._logger.info("Encryption enabled")
|
self.encryption = True
|
||||||
|
self.logger.info("Encryption enabled")
|
||||||
|
else:
|
||||||
|
self.encryption = False
|
||||||
|
self.logger.info("Encryption disabled")
|
||||||
|
return self
|
||||||
|
|
||||||
def _prepare(self,
|
def whitelist(self, ids:Optional[List[Type[Packet]]]) -> 'Dispatcher':
|
||||||
host:Optional[str] = None,
|
self._packet_whitelist = set(ids) if ids is not None else None
|
||||||
port:Optional[int] = None,
|
|
||||||
proto:Optional[int] = None,
|
|
||||||
queue_size:int = 100,
|
|
||||||
packet_whitelist : Set[Type[Packet]] = None
|
|
||||||
):
|
|
||||||
self.proto = proto or self.proto or 757 # TODO not hardcode this?
|
|
||||||
self.host = host or self.host or "localhost"
|
|
||||||
self.port = port or self.port or 25565
|
|
||||||
self._logger = LOGGER.getChild(f"on({self.host}:{self.port})")
|
|
||||||
self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() # just in case make new set
|
|
||||||
if self._packet_whitelist:
|
if self._packet_whitelist:
|
||||||
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
|
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
|
||||||
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
|
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
|
||||||
|
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else None
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_host(self, host:Optional[str]="localhost", port:Optional[int]=25565) -> 'Dispatcher':
|
||||||
|
self.host = host or self.host
|
||||||
|
self.port = port or self.port
|
||||||
|
self.logger = LOGGER.getChild(f"on({self.host}:{self.port})")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_proto(self, proto:Optional[int]=757) -> 'Dispatcher':
|
||||||
|
self.proto = proto or self.proto
|
||||||
|
if self._packet_id_whitelist:
|
||||||
|
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else set()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def set_state(self, state:Optional[ConnectionState]=ConnectionState.HANDSHAKING) -> 'Dispatcher':
|
||||||
|
self.state = state or self.state
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def connect(self,
|
||||||
|
reader : Optional[StreamReader] = None,
|
||||||
|
writer : Optional[StreamWriter] = None,
|
||||||
|
queue_size : int = 100,
|
||||||
|
) -> 'Dispatcher':
|
||||||
|
if self.connected:
|
||||||
|
raise InvalidState("Dispatcher already connected")
|
||||||
|
|
||||||
self.encryption = False
|
self.encryption = False
|
||||||
self.compression = None
|
self.compression = None
|
||||||
self.state = ConnectionState.HANDSHAKING
|
|
||||||
|
|
||||||
# This can only happen after we know the connection protocol
|
|
||||||
self._packet_id_whitelist = set((P(self.proto).id for P in self._packet_whitelist)) if self._packet_whitelist else set()
|
|
||||||
|
|
||||||
# Make new queues, do set a max size to sorta propagate back pressure
|
|
||||||
self._incoming = Queue(queue_size)
|
self._incoming = Queue(queue_size)
|
||||||
self._outgoing = Queue(queue_size)
|
self._outgoing = Queue(queue_size)
|
||||||
self._dispatching = False
|
self._dispatching = False
|
||||||
self._reader = None
|
self._reader = None
|
||||||
self._writer = None
|
self._writer = None
|
||||||
|
|
||||||
async def connect(self,
|
|
||||||
host : Optional[str] = None,
|
|
||||||
port : Optional[int] = None,
|
|
||||||
proto : Optional[int] = None,
|
|
||||||
reader : Optional[StreamReader] = None,
|
|
||||||
writer : Optional[StreamWriter] = None,
|
|
||||||
queue_size : int = 100,
|
|
||||||
packet_whitelist : Set[Type[Packet]] = None,
|
|
||||||
):
|
|
||||||
if self.connected:
|
|
||||||
raise InvalidState("Dispatcher already connected")
|
|
||||||
|
|
||||||
self._prepare(host, port, proto, queue_size, packet_whitelist)
|
|
||||||
|
|
||||||
if reader and writer:
|
if reader and writer:
|
||||||
self._down, self._up = reader, writer
|
self._down, self._up = reader, writer
|
||||||
else:
|
else:
|
||||||
|
@ -146,29 +147,31 @@ class Dispatcher:
|
||||||
self._dispatching = True
|
self._dispatching = True
|
||||||
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
|
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
|
||||||
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
|
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
|
||||||
self._logger.info("Connected")
|
self.logger.info("Connected")
|
||||||
|
return self
|
||||||
|
|
||||||
async def disconnect(self, block:bool=True):
|
async def disconnect(self, block:bool=True) -> 'Dispatcher':
|
||||||
self._dispatching = False
|
self._dispatching = False
|
||||||
if block and self._writer and self._reader:
|
if block and self._writer and self._reader:
|
||||||
await asyncio.gather(self._writer, self._reader)
|
await asyncio.gather(self._writer, self._reader)
|
||||||
self._logger.debug("Net workers stopped")
|
self.logger.debug("Net workers stopped")
|
||||||
if self._up:
|
if self._up:
|
||||||
if not self._up.is_closing() and self._up.can_write_eof():
|
if not self._up.is_closing() and self._up.can_write_eof():
|
||||||
try:
|
try:
|
||||||
self._up.write_eof()
|
self._up.write_eof()
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
self._logger.error("Could not write EOF : %s", str(e))
|
self.logger.error("Could not write EOF : %s", str(e))
|
||||||
self._up.close()
|
self._up.close()
|
||||||
if block:
|
if block:
|
||||||
await self._up.wait_closed()
|
await self._up.wait_closed()
|
||||||
self._logger.debug("Socket closed")
|
self.logger.debug("Socket closed")
|
||||||
if block:
|
if block:
|
||||||
self._logger.info("Disconnected")
|
self.logger.info("Disconnected")
|
||||||
|
return self
|
||||||
|
|
||||||
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
|
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
|
||||||
# TODO de-jank this, language server gets kinda mad
|
m : ModuleType
|
||||||
# m : Module
|
|
||||||
if self.state == ConnectionState.HANDSHAKING:
|
if self.state == ConnectionState.HANDSHAKING:
|
||||||
m = minecraft_protocol.handshaking
|
m = minecraft_protocol.handshaking
|
||||||
elif self.state == ConnectionState.STATUS:
|
elif self.state == ConnectionState.STATUS:
|
||||||
|
@ -235,26 +238,26 @@ class Dispatcher:
|
||||||
packet_id = VarInt.read(buffer, Context(_proto=self.proto))
|
packet_id = VarInt.read(buffer, Context(_proto=self.proto))
|
||||||
if self.state == ConnectionState.PLAY and self._packet_id_whitelist \
|
if self.state == ConnectionState.PLAY and self._packet_id_whitelist \
|
||||||
and packet_id not in self._packet_id_whitelist:
|
and packet_id not in self._packet_id_whitelist:
|
||||||
self._logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
|
self.logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
|
||||||
continue # ignore this packet, we rarely need them all, should improve performance
|
continue # ignore this packet, we rarely need them all, should improve performance
|
||||||
cls = self._packet_type_from_registry(packet_id)
|
cls = self._packet_type_from_registry(packet_id)
|
||||||
packet = cls.deserialize(self.proto, buffer)
|
packet = cls.deserialize(self.proto, buffer)
|
||||||
self._logger.debug("[<--] Received | %s", repr(packet))
|
self.logger.debug("[<--] Received | %s", repr(packet))
|
||||||
await self._incoming.put(packet)
|
await self._incoming.put(packet)
|
||||||
if self.state != ConnectionState.PLAY:
|
if self.state != ConnectionState.PLAY:
|
||||||
await self._incoming.join() # During play we can pre-process packets
|
await self._incoming.join() # During play we can pre-process packets
|
||||||
except (asyncio.TimeoutError, TimeoutError):
|
except (asyncio.TimeoutError, TimeoutError):
|
||||||
self._logger.error("Connection timed out")
|
self.logger.error("Connection timed out")
|
||||||
await self.disconnect(block=False)
|
await self.disconnect(block=False)
|
||||||
except (ConnectionResetError, BrokenPipeError):
|
except (ConnectionResetError, BrokenPipeError):
|
||||||
self._logger.error("Connection reset while reading packet")
|
self.logger.error("Connection reset while reading packet")
|
||||||
await self.disconnect(block=False)
|
await self.disconnect(block=False)
|
||||||
except (asyncio.IncompleteReadError, EOFError):
|
except (asyncio.IncompleteReadError, EOFError):
|
||||||
self._logger.error("Received EOF while reading packet")
|
self.logger.error("Received EOF while reading packet")
|
||||||
await self.disconnect(block=False)
|
await self.disconnect(block=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
self._logger.exception("Exception parsing packet %d", packet_id)
|
self.logger.exception("Exception parsing packet %d", packet_id)
|
||||||
self._logger.debug("%s", buffer.getvalue())
|
self.logger.debug("%s", buffer.getvalue())
|
||||||
await self.disconnect(block=False)
|
await self.disconnect(block=False)
|
||||||
|
|
||||||
async def _up_worker(self, timeout=1):
|
async def _up_worker(self, timeout=1):
|
||||||
|
@ -288,9 +291,9 @@ class Dispatcher:
|
||||||
self._up.write(data)
|
self._up.write(data)
|
||||||
await self._up.drain()
|
await self._up.drain()
|
||||||
|
|
||||||
self._logger.debug("[-->] Sent | %s", repr(packet))
|
self.logger.debug("[-->] Sent | %s", repr(packet))
|
||||||
except Exception:
|
except Exception:
|
||||||
self._logger.exception("Exception dispatching packet %s", str(packet))
|
self.logger.exception("Exception dispatching packet %s", str(packet))
|
||||||
|
|
||||||
packet.processed.set() # Notify that packet has been processed
|
packet.processed.set() # Notify that packet has been processed
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue