better logging, catch some disconnect errors
This commit is contained in:
parent
dadf9ec013
commit
d8dc04e663
2 changed files with 83 additions and 48 deletions
|
@ -11,7 +11,7 @@ from .mc.packet import Packet
|
|||
from .mc.identity import Token
|
||||
from .mc import proto, encryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Type[Packet]]]:
|
||||
if state == ConnectionState.HANDSHAKING:
|
||||
|
@ -34,6 +34,7 @@ _STATE_REGS = {
|
|||
class Client:
|
||||
host:str
|
||||
port:int
|
||||
|
||||
username:Optional[str]
|
||||
password:Optional[str]
|
||||
token:Optional[Token]
|
||||
|
@ -43,11 +44,12 @@ class Client:
|
|||
_worker : Task
|
||||
|
||||
_packet_callbacks : Dict[ConnectionState, Dict[Packet, List[Callable]]]
|
||||
_logger : logging.Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host:str,
|
||||
port:int = 25565,
|
||||
port:int,
|
||||
username:Optional[str] = None,
|
||||
password:Optional[str] = None,
|
||||
token:Optional[Token] = None,
|
||||
|
@ -59,11 +61,21 @@ class Client:
|
|||
self.username = username
|
||||
self.password = password
|
||||
|
||||
self.dispatcher = Dispatcher(host, port)
|
||||
self.dispatcher = Dispatcher()
|
||||
self._processing = False
|
||||
|
||||
self._packet_callbacks = {}
|
||||
|
||||
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
|
||||
|
||||
@property
|
||||
def started(self) -> bool:
|
||||
return self._processing
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
return self.started and self.dispatcher.connected
|
||||
|
||||
def on(self, hook):
|
||||
def wrapper(fun):
|
||||
pass # TODO
|
||||
|
@ -83,7 +95,7 @@ class Client:
|
|||
if not self.token:
|
||||
if self.username and self.password:
|
||||
self.token = await Token.authenticate(self.username, self.password)
|
||||
logger.info("Authenticated from credentials")
|
||||
self._logger.info("Authenticated from credentials")
|
||||
return True
|
||||
return False
|
||||
try:
|
||||
|
@ -91,11 +103,28 @@ class Client:
|
|||
except Exception: # idk TODO
|
||||
try:
|
||||
await self.token.refresh()
|
||||
logger.info("Refreshed Token")
|
||||
self._logger.warning("Refreshed Token")
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def change_server(self, server:str):
|
||||
restart = self.started
|
||||
if restart:
|
||||
await self.stop()
|
||||
|
||||
if ":" in server:
|
||||
_host, _port = server.split(":", 1)
|
||||
self.host = _host.strip()
|
||||
self.port = int(_port)
|
||||
else:
|
||||
self.host = server.strip()
|
||||
self.port = 25565
|
||||
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
|
||||
|
||||
if restart:
|
||||
await self.start()
|
||||
|
||||
async def run(self):
|
||||
await self.start()
|
||||
|
||||
|
@ -103,34 +132,36 @@ class Client:
|
|||
while True: # TODO don't busywait even if it doesn't matter much
|
||||
await asyncio.sleep(5)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received SIGINT, stopping...")
|
||||
self._logger.info("Received SIGINT, stopping...")
|
||||
|
||||
await self.stop()
|
||||
|
||||
async def start(self):
|
||||
self._processing = True
|
||||
self._worker = asyncio.get_event_loop().create_task(self._client_worker())
|
||||
logger.info("Minecraft client started")
|
||||
self._logger.info("Minecraft client started")
|
||||
|
||||
async def stop(self, block=True):
|
||||
self._processing = False
|
||||
await self.dispatcher.disconnect()
|
||||
if block:
|
||||
await self._worker
|
||||
logger.info("Minecraft client stopped")
|
||||
self._logger.info("Minecraft client stopped")
|
||||
|
||||
async def _client_worker(self):
|
||||
while self._processing:
|
||||
if not await self.authenticate():
|
||||
raise Exception("Token not refreshable or credentials invalid") # TODO!
|
||||
try:
|
||||
await self.dispatcher.connect()
|
||||
await self.dispatcher.connect(self.host, self.port)
|
||||
for packet in self._handshake():
|
||||
await self.dispatcher.write(packet)
|
||||
self.dispatcher.state = ConnectionState.LOGIN
|
||||
await self._process_packets()
|
||||
except ConnectionRefusedError:
|
||||
self._logger.error("Server rejected connection")
|
||||
except Exception:
|
||||
logger.exception("Connection terminated")
|
||||
self._logger.exception("Exception in Client connection")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
def _handshake(self, force:bool=False) -> Tuple[Packet, Packet]: # TODO make this fancier! poll for version and status first
|
||||
|
@ -151,7 +182,7 @@ class Client:
|
|||
while self.dispatcher.connected:
|
||||
try:
|
||||
packet = await asyncio.wait_for(self.dispatcher.incoming.get(), timeout=5)
|
||||
logger.debug("[ * ] Processing | %s", str(packet))
|
||||
self._logger.debug("[ * ] Processing | %s", str(packet))
|
||||
|
||||
if self.dispatcher.state == ConnectionState.LOGIN:
|
||||
await self.login_logic(packet)
|
||||
|
@ -169,7 +200,7 @@ class Client:
|
|||
except asyncio.TimeoutError:
|
||||
pass # need this to recheck self._processing periodically
|
||||
except Exception:
|
||||
logger.exception("Exception while processing packet %s", packet)
|
||||
self._logger.exception("Exception while processing packet %s", packet)
|
||||
|
||||
# TODO move these in separate module
|
||||
|
||||
|
@ -203,16 +234,16 @@ class Client:
|
|||
self.dispatcher.encrypt(secret)
|
||||
|
||||
elif isinstance(packet, proto.login.clientbound.PacketDisconnect):
|
||||
logger.error("Disconnected while logging in")
|
||||
self._logger.error("Kicked while logging in")
|
||||
await self.dispatcher.disconnect(block=False)
|
||||
# raise Exception("Disconnected while logging in") # TODO make a more specific one, do some shit
|
||||
|
||||
elif isinstance(packet, proto.login.clientbound.PacketCompress):
|
||||
logger.info("Compression enabled")
|
||||
self._logger.info("Compression enabled")
|
||||
self.dispatcher.compression = packet.threshold
|
||||
|
||||
elif isinstance(packet, proto.login.clientbound.PacketSuccess):
|
||||
logger.info("Login success, joining world...")
|
||||
self._logger.info("Login success, joining world...")
|
||||
self.dispatcher.state = ConnectionState.PLAY
|
||||
|
||||
elif isinstance(packet, proto.login.clientbound.PacketLoginPluginRequest):
|
||||
|
@ -220,7 +251,7 @@ class Client:
|
|||
|
||||
async def play_logic(self, packet:Packet):
|
||||
if isinstance(packet, proto.play.clientbound.PacketSetCompression):
|
||||
logger.info("Compression updated")
|
||||
self._logger.info("Compression updated")
|
||||
self.dispatcher.compression = packet.threshold
|
||||
|
||||
elif isinstance(packet, proto.play.clientbound.PacketKeepAlive):
|
||||
|
@ -228,7 +259,7 @@ class Client:
|
|||
await self.dispatcher.write(keep_alive_packet)
|
||||
|
||||
elif isinstance(packet, proto.play.clientbound.PacketPosition):
|
||||
logger.info("Position synchronized")
|
||||
self._logger.info("Position synchronized")
|
||||
await self.dispatcher.write(
|
||||
proto.play.serverbound.PacketTeleportConfirm(
|
||||
340,
|
||||
|
@ -238,12 +269,12 @@ class Client:
|
|||
|
||||
elif isinstance(packet, proto.play.clientbound.PacketUpdateHealth):
|
||||
if packet.health <= 0:
|
||||
logger.info("Dead, respawning...")
|
||||
self._logger.info("Dead, respawning...")
|
||||
await self.dispatcher.write(
|
||||
proto.play.serverbound.PacketClientCommand(self.dispatcher.proto, actionId=0) # respawn
|
||||
)
|
||||
|
||||
elif isinstance(packet, proto.play.clientbound.PacketKickDisconnect):
|
||||
logger.error("Disconnected")
|
||||
self._logger.error("Kicked while in game")
|
||||
await self.dispatcher.disconnect(block=False)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from .mc.mctypes import VarInt
|
|||
from .mc.packet import Packet
|
||||
from .mc import encryption
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class ConnectionState(Enum):
|
||||
NONE = -1
|
||||
|
@ -37,11 +37,11 @@ _STATE_REGS = {
|
|||
|
||||
class Dispatcher:
|
||||
_down : StreamReader
|
||||
_reader : Task
|
||||
_reader : Optional[Task]
|
||||
_decryptor : CipherContext
|
||||
|
||||
_up : StreamWriter
|
||||
_writer : Task
|
||||
_writer : Optional[Task]
|
||||
_encryptor : CipherContext
|
||||
|
||||
_dispatching : bool
|
||||
|
@ -49,18 +49,17 @@ class Dispatcher:
|
|||
incoming : Queue
|
||||
outgoing : Queue
|
||||
|
||||
host : str
|
||||
port : int
|
||||
_host : str
|
||||
_port : int
|
||||
|
||||
proto : int
|
||||
state : ConnectionState
|
||||
encryption : bool
|
||||
compression : Optional[int]
|
||||
|
||||
def __init__(self, host:str, port:int):
|
||||
self.host = host
|
||||
self.port = port
|
||||
_logger : logging.Logger
|
||||
|
||||
def __init__(self):
|
||||
self.proto = 340
|
||||
self._dispatching = False
|
||||
self.compression = None
|
||||
|
@ -69,6 +68,10 @@ class Dispatcher:
|
|||
self.outgoing = Queue()
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
self._host = "localhost"
|
||||
self._port = 25565
|
||||
|
||||
self._logger = LOGGER.getChild(f"{self._host}:{self._port}")
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
|
@ -89,43 +92,49 @@ class Dispatcher:
|
|||
self._up.close()
|
||||
if block:
|
||||
await self._up.wait_closed()
|
||||
logger.info("Disconnected")
|
||||
self._logger.info("Disconnected")
|
||||
|
||||
async def connect(self):
|
||||
async def connect(self, host:Optional[str] = None, port:Optional[int] = None):
|
||||
if self.connected:
|
||||
raise InvalidState("Dispatcher already connected")
|
||||
|
||||
if host is not None:
|
||||
self._host = host
|
||||
if port is not None:
|
||||
self._port = port
|
||||
self._logger = LOGGER.getChild(f"{self._host}:{self._port}")
|
||||
|
||||
self.encryption = False
|
||||
self.compression = None
|
||||
self.state = ConnectionState.HANDSHAKING
|
||||
# self.proto = 340 # TODO
|
||||
|
||||
# Make new queues
|
||||
self.incoming = Queue()
|
||||
self.outgoing = Queue()
|
||||
|
||||
self._down, self._up = await asyncio.open_connection(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
host=self._host,
|
||||
port=self._port,
|
||||
)
|
||||
|
||||
self._dispatching = True
|
||||
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
|
||||
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
|
||||
logger.info("Connected")
|
||||
self._logger.info("Connected")
|
||||
|
||||
def encrypt(self, secret:bytes):
|
||||
cipher = encryption.create_AES_cipher(secret)
|
||||
self._encryptor = cipher.encryptor()
|
||||
self._decryptor = cipher.decryptor()
|
||||
self.encryption = True
|
||||
logger.info("Encryption enabled")
|
||||
self._logger.info("Encryption enabled")
|
||||
|
||||
async def _read_varint(self) -> int:
|
||||
numRead = 0
|
||||
result = 0
|
||||
while True:
|
||||
data = await self._down.readexactly(1)
|
||||
if len(data) < 1:
|
||||
raise ConnectionError("Could not read data off socket")
|
||||
if self.encryption:
|
||||
data = self._decryptor.update(data)
|
||||
buf = int.from_bytes(data, 'little')
|
||||
|
@ -152,36 +161,31 @@ class Dispatcher:
|
|||
|
||||
if self.compression is not None:
|
||||
decompressed_size = VarInt.read(buffer)
|
||||
# logger.info("Decompressing packet to %d | %s", decompressed_size, buffer.getvalue())
|
||||
if decompressed_size > 0:
|
||||
decompressor = zlib.decompressobj()
|
||||
decompressed_data = decompressor.decompress(buffer.read())
|
||||
# logger.info("Obtained %s", decompressed_data)
|
||||
if len(decompressed_data) != decompressed_size:
|
||||
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
|
||||
buffer = io.BytesIO(decompressed_data)
|
||||
|
||||
packet_id = VarInt.read(buffer)
|
||||
# logger.info("Parsing packet '%d' [%s] | %s", packet_id, str(self.state), buffer.getvalue())
|
||||
cls = _STATE_REGS[self.state][self.proto][packet_id]
|
||||
packet = cls.deserialize(self.proto, buffer)
|
||||
logger.debug("[<--] Received | %s", str(packet))
|
||||
self._logger.debug("[<--] Received | %s", str(packet))
|
||||
await self.incoming.put(packet)
|
||||
except AttributeError:
|
||||
logger.debug("Received unimplemented packet %d", packet_id)
|
||||
except (ConnectionError, asyncio.IncompleteReadError):
|
||||
logger.exception("Connection error")
|
||||
await self.stop(block=False)
|
||||
self._logger.debug("Unimplemented packet %d", packet_id)
|
||||
except asyncio.IncompleteReadError:
|
||||
self._logger.error("EOF from server")
|
||||
await self.disconnect(block=False)
|
||||
except Exception:
|
||||
logger.exception("Error while processing packet %d | %s", packet_id, buffer.getvalue())
|
||||
self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue())
|
||||
|
||||
async def _up_worker(self):
|
||||
while self._dispatching:
|
||||
# logger.info("Up packet")
|
||||
try:
|
||||
packet = await asyncio.wait_for(self.outgoing.get(), timeout=5)
|
||||
buffer = packet.serialize()
|
||||
# logger.info("Sending packet %s [%s]", str(packet), buffer.getvalue())
|
||||
length = len(buffer.getvalue()) # ewww TODO
|
||||
|
||||
if self.compression is not None:
|
||||
|
@ -205,9 +209,9 @@ class Dispatcher:
|
|||
await self._up.drain()
|
||||
|
||||
packet.sent.set() # Notify
|
||||
logger.debug("[-->] Sent | %s", str(packet))
|
||||
self._logger.debug("[-->] Sent | %s", str(packet))
|
||||
except asyncio.TimeoutError:
|
||||
pass # need this to recheck self._dispatching periodically
|
||||
except Exception:
|
||||
logger.exception("Exception dispatching packet %s", str(packet))
|
||||
self._logger.exception("Exception dispatching packet %s", str(packet))
|
||||
|
||||
|
|
Loading…
Reference in a new issue