better logging, catch some disconnect errors

This commit is contained in:
əlemi 2021-11-11 12:25:07 +01:00
parent dadf9ec013
commit d8dc04e663
2 changed files with 83 additions and 48 deletions

View file

@ -11,7 +11,7 @@ from .mc.packet import Packet
from .mc.identity import Token
from .mc import proto, encryption
logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)
def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Type[Packet]]]:
if state == ConnectionState.HANDSHAKING:
@ -34,6 +34,7 @@ _STATE_REGS = {
class Client:
host:str
port:int
username:Optional[str]
password:Optional[str]
token:Optional[Token]
@ -43,11 +44,12 @@ class Client:
_worker : Task
_packet_callbacks : Dict[ConnectionState, Dict[Packet, List[Callable]]]
_logger : logging.Logger
def __init__(
self,
host:str,
port:int = 25565,
port:int,
username:Optional[str] = None,
password:Optional[str] = None,
token:Optional[Token] = None,
@ -59,11 +61,21 @@ class Client:
self.username = username
self.password = password
self.dispatcher = Dispatcher(host, port)
self.dispatcher = Dispatcher()
self._processing = False
self._packet_callbacks = {}
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
@property
def started(self) -> bool:
return self._processing
@property
def connected(self) -> bool:
return self.started and self.dispatcher.connected
def on(self, hook):
def wrapper(fun):
pass # TODO
@ -83,7 +95,7 @@ class Client:
if not self.token:
if self.username and self.password:
self.token = await Token.authenticate(self.username, self.password)
logger.info("Authenticated from credentials")
self._logger.info("Authenticated from credentials")
return True
return False
try:
@ -91,11 +103,28 @@ class Client:
except Exception: # idk TODO
try:
await self.token.refresh()
logger.info("Refreshed Token")
self._logger.warning("Refreshed Token")
except Exception:
return False
return True
async def change_server(self, server:str):
restart = self.started
if restart:
await self.stop()
if ":" in server:
_host, _port = server.split(":", 1)
self.host = _host.strip()
self.port = int(_port)
else:
self.host = server.strip()
self.port = 25565
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
if restart:
await self.start()
async def run(self):
await self.start()
@ -103,34 +132,36 @@ class Client:
while True: # TODO don't busywait even if it doesn't matter much
await asyncio.sleep(5)
except KeyboardInterrupt:
logger.info("Received SIGINT, stopping...")
self._logger.info("Received SIGINT, stopping...")
await self.stop()
async def start(self):
self._processing = True
self._worker = asyncio.get_event_loop().create_task(self._client_worker())
logger.info("Minecraft client started")
self._logger.info("Minecraft client started")
async def stop(self, block=True):
self._processing = False
await self.dispatcher.disconnect()
if block:
await self._worker
logger.info("Minecraft client stopped")
self._logger.info("Minecraft client stopped")
async def _client_worker(self):
while self._processing:
if not await self.authenticate():
raise Exception("Token not refreshable or credentials invalid") # TODO!
try:
await self.dispatcher.connect()
await self.dispatcher.connect(self.host, self.port)
for packet in self._handshake():
await self.dispatcher.write(packet)
self.dispatcher.state = ConnectionState.LOGIN
await self._process_packets()
except ConnectionRefusedError:
self._logger.error("Server rejected connection")
except Exception:
logger.exception("Connection terminated")
self._logger.exception("Exception in Client connection")
await asyncio.sleep(2)
def _handshake(self, force:bool=False) -> Tuple[Packet, Packet]: # TODO make this fancier! poll for version and status first
@ -151,7 +182,7 @@ class Client:
while self.dispatcher.connected:
try:
packet = await asyncio.wait_for(self.dispatcher.incoming.get(), timeout=5)
logger.debug("[ * ] Processing | %s", str(packet))
self._logger.debug("[ * ] Processing | %s", str(packet))
if self.dispatcher.state == ConnectionState.LOGIN:
await self.login_logic(packet)
@ -169,7 +200,7 @@ class Client:
except asyncio.TimeoutError:
pass # need this to recheck self._processing periodically
except Exception:
logger.exception("Exception while processing packet %s", packet)
self._logger.exception("Exception while processing packet %s", packet)
# TODO move these in separate module
@ -203,16 +234,16 @@ class Client:
self.dispatcher.encrypt(secret)
elif isinstance(packet, proto.login.clientbound.PacketDisconnect):
logger.error("Disconnected while logging in")
self._logger.error("Kicked while logging in")
await self.dispatcher.disconnect(block=False)
# raise Exception("Disconnected while logging in") # TODO make a more specific one, do some shit
elif isinstance(packet, proto.login.clientbound.PacketCompress):
logger.info("Compression enabled")
self._logger.info("Compression enabled")
self.dispatcher.compression = packet.threshold
elif isinstance(packet, proto.login.clientbound.PacketSuccess):
logger.info("Login success, joining world...")
self._logger.info("Login success, joining world...")
self.dispatcher.state = ConnectionState.PLAY
elif isinstance(packet, proto.login.clientbound.PacketLoginPluginRequest):
@ -220,7 +251,7 @@ class Client:
async def play_logic(self, packet:Packet):
if isinstance(packet, proto.play.clientbound.PacketSetCompression):
logger.info("Compression updated")
self._logger.info("Compression updated")
self.dispatcher.compression = packet.threshold
elif isinstance(packet, proto.play.clientbound.PacketKeepAlive):
@ -228,7 +259,7 @@ class Client:
await self.dispatcher.write(keep_alive_packet)
elif isinstance(packet, proto.play.clientbound.PacketPosition):
logger.info("Position synchronized")
self._logger.info("Position synchronized")
await self.dispatcher.write(
proto.play.serverbound.PacketTeleportConfirm(
340,
@ -238,12 +269,12 @@ class Client:
elif isinstance(packet, proto.play.clientbound.PacketUpdateHealth):
if packet.health <= 0:
logger.info("Dead, respawning...")
self._logger.info("Dead, respawning...")
await self.dispatcher.write(
proto.play.serverbound.PacketClientCommand(self.dispatcher.proto, actionId=0) # respawn
)
elif isinstance(packet, proto.play.clientbound.PacketKickDisconnect):
logger.error("Disconnected")
self._logger.error("Kicked while in game")
await self.dispatcher.disconnect(block=False)

View file

@ -13,7 +13,7 @@ from .mc.mctypes import VarInt
from .mc.packet import Packet
from .mc import encryption
logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)
class ConnectionState(Enum):
NONE = -1
@ -37,11 +37,11 @@ _STATE_REGS = {
class Dispatcher:
_down : StreamReader
_reader : Task
_reader : Optional[Task]
_decryptor : CipherContext
_up : StreamWriter
_writer : Task
_writer : Optional[Task]
_encryptor : CipherContext
_dispatching : bool
@ -49,18 +49,17 @@ class Dispatcher:
incoming : Queue
outgoing : Queue
host : str
port : int
_host : str
_port : int
proto : int
state : ConnectionState
encryption : bool
compression : Optional[int]
def __init__(self, host:str, port:int):
self.host = host
self.port = port
_logger : logging.Logger
def __init__(self):
self.proto = 340
self._dispatching = False
self.compression = None
@ -69,6 +68,10 @@ class Dispatcher:
self.outgoing = Queue()
self._reader = None
self._writer = None
self._host = "localhost"
self._port = 25565
self._logger = LOGGER.getChild(f"{self._host}:{self._port}")
@property
def connected(self) -> bool:
@ -89,43 +92,49 @@ class Dispatcher:
self._up.close()
if block:
await self._up.wait_closed()
logger.info("Disconnected")
self._logger.info("Disconnected")
async def connect(self):
async def connect(self, host:Optional[str] = None, port:Optional[int] = None):
if self.connected:
raise InvalidState("Dispatcher already connected")
if host is not None:
self._host = host
if port is not None:
self._port = port
self._logger = LOGGER.getChild(f"{self._host}:{self._port}")
self.encryption = False
self.compression = None
self.state = ConnectionState.HANDSHAKING
# self.proto = 340 # TODO
# Make new queues
self.incoming = Queue()
self.outgoing = Queue()
self._down, self._up = await asyncio.open_connection(
host=self.host,
port=self.port,
host=self._host,
port=self._port,
)
self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
logger.info("Connected")
self._logger.info("Connected")
def encrypt(self, secret:bytes):
cipher = encryption.create_AES_cipher(secret)
self._encryptor = cipher.encryptor()
self._decryptor = cipher.decryptor()
self.encryption = True
logger.info("Encryption enabled")
self._logger.info("Encryption enabled")
async def _read_varint(self) -> int:
numRead = 0
result = 0
while True:
data = await self._down.readexactly(1)
if len(data) < 1:
raise ConnectionError("Could not read data off socket")
if self.encryption:
data = self._decryptor.update(data)
buf = int.from_bytes(data, 'little')
@ -152,36 +161,31 @@ class Dispatcher:
if self.compression is not None:
decompressed_size = VarInt.read(buffer)
# logger.info("Decompressing packet to %d | %s", decompressed_size, buffer.getvalue())
if decompressed_size > 0:
decompressor = zlib.decompressobj()
decompressed_data = decompressor.decompress(buffer.read())
# logger.info("Obtained %s", decompressed_data)
if len(decompressed_data) != decompressed_size:
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer)
# logger.info("Parsing packet '%d' [%s] | %s", packet_id, str(self.state), buffer.getvalue())
cls = _STATE_REGS[self.state][self.proto][packet_id]
packet = cls.deserialize(self.proto, buffer)
logger.debug("[<--] Received | %s", str(packet))
self._logger.debug("[<--] Received | %s", str(packet))
await self.incoming.put(packet)
except AttributeError:
logger.debug("Received unimplemented packet %d", packet_id)
except (ConnectionError, asyncio.IncompleteReadError):
logger.exception("Connection error")
await self.stop(block=False)
self._logger.debug("Unimplemented packet %d", packet_id)
except asyncio.IncompleteReadError:
self._logger.error("EOF from server")
await self.disconnect(block=False)
except Exception:
logger.exception("Error while processing packet %d | %s", packet_id, buffer.getvalue())
self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue())
async def _up_worker(self):
while self._dispatching:
# logger.info("Up packet")
try:
packet = await asyncio.wait_for(self.outgoing.get(), timeout=5)
buffer = packet.serialize()
# logger.info("Sending packet %s [%s]", str(packet), buffer.getvalue())
length = len(buffer.getvalue()) # ewww TODO
if self.compression is not None:
@ -205,9 +209,9 @@ class Dispatcher:
await self._up.drain()
packet.sent.set() # Notify
logger.debug("[-->] Sent | %s", str(packet))
self._logger.debug("[-->] Sent | %s", str(packet))
except asyncio.TimeoutError:
pass # need this to recheck self._dispatching periodically
except Exception:
logger.exception("Exception dispatching packet %s", str(packet))
self._logger.exception("Exception dispatching packet %s", str(packet))