restructured project, improved modularity

This commit is contained in:
əlemi 2021-11-17 16:57:02 +01:00
parent f282ed6665
commit cd402f668a
12 changed files with 175 additions and 207 deletions

View file

@ -4,10 +4,10 @@ import asyncio
import logging
from .mc.proto.play.clientbound import PacketChat
from .mc.identity import Token
from .mc.token import Token
from .dispatcher import ConnectionState
from .client import Client
from .helpers import parse_chat
from .util.helpers import parse_chat
async def idle():
while True:

View file

@ -3,35 +3,23 @@ import logging
from asyncio import Task
from enum import Enum
from typing import Dict, List, Callable, Type, Optional, Tuple
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator
from .dispatcher import Dispatcher, ConnectionState
from .mc.mctypes import VarInt
from .dispatcher import Dispatcher
from .mc.packet import Packet
from .mc.identity import Token, AuthException
from .mc.definitions import Dimension, Difficulty, Gamemode
from .mc import proto, encryption
from .mc.token import Token, AuthException
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
from .mc.proto.handshaking.serverbound import PacketSetProtocol
from .mc.proto.play.serverbound import PacketKeepAlive as PacketKeepAliveResponse
from .mc.proto.play.clientbound import PacketKeepAlive, PacketSetCompression, PacketKickDisconnect
from .mc.proto.login.serverbound import PacketLoginStart, PacketEncryptionBegin as PacketEncryptionResponse
from .mc.proto.login.clientbound import (
PacketCompress, PacketDisconnect, PacketEncryptionBegin, PacketLoginPluginRequest, PacketSuccess
)
from .util import encryption
LOGGER = logging.getLogger(__name__)
def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Type[Packet]]]:
if state == ConnectionState.HANDSHAKING:
return proto.handshaking.clientbound.REGISTRY
if state == ConnectionState.STATUS:
return proto.status.clientbound.REGISTRY
if state == ConnectionState.LOGIN:
return proto.login.clientbound.REGISTRY
if state == ConnectionState.PLAY:
return proto.play.clientbound.REGISTRY
return {}
_STATE_REGS = {
ConnectionState.HANDSHAKING : proto.handshaking,
ConnectionState.STATUS : proto.status,
ConnectionState.LOGIN : proto.login,
ConnectionState.PLAY : proto.play,
}
class Client:
host:str
port:int
@ -46,7 +34,7 @@ class Client:
_authenticated : bool
_worker : Task
_packet_callbacks : Dict[ConnectionState, Dict[Packet, List[Callable]]]
_packet_callbacks : Dict[Type[Packet], List[Callable]]
_logger : logging.Logger
def __init__(
@ -64,6 +52,8 @@ class Client:
self.options = options or {
"reconnect" : True,
"rctime" : 5.0,
"keep-alive" : True,
"poll-timeout" : 1,
}
self.token = token
@ -142,7 +132,7 @@ class Client:
try:
while self._processing: # TODO don't busywait even if it doesn't matter much
await asyncio.sleep(5)
await asyncio.sleep(self.options["poll-timeout"])
except KeyboardInterrupt:
self._logger.info("Received SIGINT, stopping...")
else:
@ -157,171 +147,114 @@ class Client:
async def stop(self, block=True):
self._processing = False
await self.dispatcher.disconnect()
if self.dispatcher.connected:
await self.dispatcher.disconnect(block=block)
if block:
await self._worker
self._logger.info("Minecraft client stopped")
self._logger.info("Minecraft client stopped")
async def _client_worker(self):
while self._processing:
try:
await self.authenticate()
except AuthException:
self._logger.error("Token not refreshable or credentials invalid")
await self.stop(block=False)
except AuthException as e:
self._logger.error(str(e))
break
try:
await self.dispatcher.connect(self.host, self.port)
for packet in self._handshake():
await self.dispatcher.write(packet)
self.dispatcher.state = ConnectionState.LOGIN
await self._process_packets()
await self._handshake()
if await self._login():
await self._play()
except ConnectionRefusedError:
self._logger.error("Server rejected connection")
except Exception:
self._logger.exception("Exception in Client connection")
if self.dispatcher.connected:
await self.dispatcher.disconnect()
if not self.options["reconnect"]:
await self.stop(block=False)
break
await asyncio.sleep(self.options["rctime"])
await self.stop(block=False)
def _handshake(self, force:bool=False) -> Tuple[Packet, Packet]: # TODO make this fancier! poll for version and status first
return ( proto.handshaking.serverbound.PacketSetProtocol(
async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first
await self.dispatcher.write(
PacketSetProtocol(
340, # TODO!!!!
protocolVersion=340,
serverHost=self.host,
serverPort=self.port,
nextState=2, # play
),
proto.login.serverbound.PacketLoginStart(
)
)
await self.dispatcher.write(
PacketLoginStart(
340,
username=self.token.profile.name if self.token else self.username
)
)
return True
async def _process_packets(self):
while self.dispatcher.connected:
try:
packet = await asyncio.wait_for(self.dispatcher.incoming.get(), timeout=5)
self._logger.debug("[ * ] Processing | %s", str(packet))
if self.dispatcher.state == ConnectionState.LOGIN:
await self.login_logic(packet)
elif self.dispatcher.state == ConnectionState.PLAY:
await self.play_logic(packet)
if self.dispatcher.state in self._packet_callbacks:
if Packet in self._packet_callbacks[self.dispatcher.state]: # callback for any packet
for cb in self._packet_callbacks[self.dispatcher.state][Packet]:
try:
await cb(packet)
except Exception as e:
self._logger.exception("Exception while handling callback")
if packet.__class__ in self._packet_callbacks[self.dispatcher.state]: # callback for this packet
for cb in self._packet_callbacks[self.dispatcher.state][packet.__class__]:
try:
await cb(packet)
except Exception as e:
self._logger.exception("Exception while handling callback")
self.dispatcher.incoming.task_done()
except asyncio.TimeoutError:
pass # need this to recheck self._processing periodically
except AuthException:
self._authenticated = False
self._logger.error("Authentication exception")
await self.dispatcher.disconnect(block=False)
except Exception:
self._logger.exception("Exception while processing packet %s", packet)
# TODO move these in separate module
async def login_logic(self, packet:Packet):
if isinstance(packet, proto.login.clientbound.PacketEncryptionBegin):
secret = encryption.generate_shared_secret()
token, encrypted_secret = encryption.encrypt_token_and_secret(
packet.publicKey,
packet.verifyToken,
secret
)
if packet.serverId != '-' and self.token:
await self.token.join(
encryption.generate_verification_hash(
packet.serverId,
secret,
packet.publicKey
)
async def _login(self) -> bool:
self.dispatcher.state = ConnectionState.LOGIN
async for packet in self.dispatcher.packets():
if isinstance(packet, PacketEncryptionBegin):
secret = encryption.generate_shared_secret()
token, encrypted_secret = encryption.encrypt_token_and_secret(
packet.publicKey,
packet.verifyToken,
secret
)
encryption_response = proto.login.serverbound.PacketEncryptionBegin(
340, # TODO!!!!
sharedSecret=encrypted_secret,
verifyToken=token
)
await self.dispatcher.write(encryption_response, wait=True)
self.dispatcher.encrypt(secret)
elif isinstance(packet, proto.login.clientbound.PacketDisconnect):
self._logger.error("Kicked while logging in")
await self.dispatcher.disconnect(block=False)
# raise Exception("Disconnected while logging in") # TODO make a more specific one, do some shit
elif isinstance(packet, proto.login.clientbound.PacketCompress):
self._logger.info("Compression enabled")
self.dispatcher.compression = packet.threshold
elif isinstance(packet, proto.login.clientbound.PacketSuccess):
self._logger.info("Login success, joining world...")
self.dispatcher.state = ConnectionState.PLAY
elif isinstance(packet, proto.login.clientbound.PacketLoginPluginRequest):
pass # TODO ?
async def play_logic(self, packet:Packet):
if isinstance(packet, proto.play.clientbound.PacketSetCompression):
self._logger.info("Compression updated")
self.dispatcher.compression = packet.threshold
elif isinstance(packet, proto.play.clientbound.PacketKeepAlive):
keep_alive_packet = proto.play.serverbound.packet_keep_alive.PacketKeepAlive(340, keepAliveId=packet.keepAliveId)
await self.dispatcher.write(keep_alive_packet)
elif isinstance(packet, proto.play.clientbound.PacketRespawn):
self._logger.info(
"Reloading world: %s (%s) in %s",
Dimension(packet.dimension).name,
Difficulty(packet.difficulty).name,
Gamemode(packet.gamemode).name
)
elif isinstance(packet, proto.play.clientbound.PacketLogin):
self._logger.info(
"Joined world: %s (%s) in %s",
Dimension(packet.dimension).name,
Difficulty(packet.difficulty).name,
Gamemode(packet.gameMode).name
)
elif isinstance(packet, proto.play.clientbound.PacketPosition):
self._logger.info("Position synchronized")
await self.dispatcher.write(
proto.play.serverbound.PacketTeleportConfirm(
340,
teleportId=packet.teleportId
if packet.serverId != '-' and self.token:
try:
await self.token.join(
encryption.generate_verification_hash(
packet.serverId,
secret,
packet.publicKey
)
)
except AuthException:
self._logger.error("Could not authenticate with Mojang")
break
encryption_response = PacketEncryptionResponse(
340, # TODO!!!!
sharedSecret=encrypted_secret,
verifyToken=token
)
)
elif isinstance(packet, proto.play.clientbound.PacketUpdateHealth):
if packet.health <= 0:
self._logger.info("Dead, respawning...")
await self.dispatcher.write(
proto.play.serverbound.PacketClientCommand(self.dispatcher.proto, actionId=0) # respawn
)
elif isinstance(packet, proto.play.clientbound.PacketKickDisconnect):
self._logger.error("Kicked while in game")
await self.dispatcher.disconnect(block=False)
await self.dispatcher.write(encryption_response, wait=True)
self.dispatcher.encrypt(secret)
elif isinstance(packet, PacketCompress):
self._logger.info("Compression enabled")
self.dispatcher.compression = packet.threshold
elif isinstance(packet, PacketLoginPluginRequest):
self._logger.info("Ignoring plugin request") # TODO ?
elif isinstance(packet, PacketSuccess):
self._logger.info("Login success, joining world...")
return True
elif isinstance(packet, PacketDisconnect):
self._logger.error("Kicked while logging in")
break
return False
async def _play(self) -> bool:
self.dispatcher.state = ConnectionState.PLAY
async for packet in self.dispatcher.packets():
self._logger.debug("[ * ] Processing | %s", str(packet))
if isinstance(packet, PacketSetCompression):
self._logger.info("Compression updated")
self.dispatcher.compression = packet.threshold
elif isinstance(packet, PacketKeepAlive):
if self.options["keep-alive"]:
keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId)
await self.dispatcher.write(keep_alive_packet)
elif isinstance(packet, PacketKickDisconnect):
self._logger.error("Kicked while in game")
break
for packet_type in (Packet, packet.__class__): # check both callbacks for base class and instance class
if packet_type in self._packet_callbacks:
for cb in self._packet_callbacks[packet_type]:
try: # TODO run in executor to not block
await cb(packet)
except Exception as e:
self._logger.exception("Exception while handling callback")
return False

View file

@ -1,26 +1,22 @@
import io
import asyncio
import contextlib
import zlib
import logging
from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum
from typing import Dict, Optional
from typing import Dict, Optional, AsyncIterator
from cryptography.hazmat.primitives.ciphers import CipherContext
from .mc import proto
from .mc.mctypes import VarInt
from .mc.types import VarInt
from .mc.packet import Packet
from .mc import encryption
from .mc.definitions import ConnectionState
from .util import encryption
LOGGER = logging.getLogger(__name__)
class ConnectionState(Enum):
NONE = -1
HANDSHAKING = 0
STATUS = 1
LOGIN = 2
PLAY = 3
class InvalidState(Exception):
pass
@ -35,6 +31,19 @@ _STATE_REGS = {
ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
}
class PacketFrame:
_packet : Packet
_queue : Queue
def __init__(self, packet:Packet, queue:Queue):
self._packet = packet
self._queue = queue
def __enter__(self):
return self._packet
def __exit__(self):
self.queue.task_done()
class Dispatcher:
_down : StreamReader
_reader : Optional[Task]
@ -46,8 +55,8 @@ class Dispatcher:
_dispatching : bool
incoming : Queue
outgoing : Queue
_incoming : Queue
_outgoing : Queue
_host : str
_port : int
@ -64,8 +73,8 @@ class Dispatcher:
self._dispatching = False
self.compression = None
self.encryption = False
self.incoming = Queue()
self.outgoing = Queue()
self._incoming = Queue()
self._outgoing = Queue()
self._reader = None
self._writer = None
self._host = "localhost"
@ -78,24 +87,37 @@ class Dispatcher:
return self._dispatching
async def write(self, packet:Packet, wait:bool=False) -> int:
await self.outgoing.put(packet)
await self._outgoing.put(packet)
if wait:
await packet.sent.wait()
return self.outgoing.qsize()
await packet.processed.wait()
return self._outgoing.qsize()
async def packets(self, timeout=1) -> AsyncIterator[Packet]:
while self.connected:
try: # TODO replace this timed busy-wait with an event which resolves upon disconnection, and await both
packet = await asyncio.wait_for(self._incoming.get(), timeout=timeout)
try:
yield packet
finally:
self._incoming.task_done()
except asyncio.TimeoutError:
pass # so we recheck self.connected
async def disconnect(self, block:bool=True):
self._dispatching = False
if block and self._writer and self._reader:
await asyncio.gather(self._writer, self._reader)
self._logger.debug("Net workers stopped")
if self._up:
if self._up.can_write_eof():
self._up.write_eof()
self._up.close()
if block:
await self._up.wait_closed()
self._logger.debug("Socket closed")
self._logger.info("Disconnected")
async def connect(self, host:Optional[str] = None, port:Optional[int] = None):
async def connect(self, host:Optional[str] = None, port:Optional[int] = None, queue_timeout:int = 1, queue_size:int = 100):
if self.connected:
raise InvalidState("Dispatcher already connected")
@ -110,9 +132,9 @@ class Dispatcher:
self.state = ConnectionState.HANDSHAKING
# self.proto = 340 # TODO
# Make new queues
self.incoming = Queue()
self.outgoing = Queue()
# Make new queues, do set a max size to sorta propagate back pressure
self._incoming = Queue(queue_size)
self._outgoing = Queue(queue_size)
self._down, self._up = await asyncio.open_connection(
host=self._host,
@ -121,7 +143,7 @@ class Dispatcher:
self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker(timeout=queue_timeout))
self._logger.info("Connected")
def encrypt(self, secret:bytes):
@ -149,8 +171,6 @@ class Dispatcher:
async def _down_worker(self):
while self._dispatching:
if self.state != ConnectionState.PLAY:
await self.incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression
try: # these 2 will timeout or raise EOFError if client gets disconnected
length = await self._read_varint()
data = await self._down.readexactly(length)
@ -173,7 +193,9 @@ class Dispatcher:
cls = _STATE_REGS[self.state][self.proto][packet_id]
packet = cls.deserialize(self.proto, buffer)
self._logger.debug("[<--] Received | %s", str(packet))
await self.incoming.put(packet)
await self._incoming.put(packet)
if self.state == ConnectionState.LOGIN:
await self._incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression
except AttributeError:
self._logger.debug("Unimplemented packet %d", packet_id)
except asyncio.IncompleteReadError:
@ -182,10 +204,14 @@ class Dispatcher:
except Exception:
self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue())
async def _up_worker(self):
async def _up_worker(self, timeout=1):
while self._dispatching:
try:
packet = await asyncio.wait_for(self.outgoing.get(), timeout=5)
packet = await asyncio.wait_for(self._outgoing.get(), timeout=timeout)
except asyncio.TimeoutError:
continue # check again self._dispatching
try:
buffer = packet.serialize()
length = len(buffer.getvalue()) # ewww TODO
@ -209,10 +235,9 @@ class Dispatcher:
self._up.write(data)
await self._up.drain()
packet.sent.set() # Notify
self._logger.debug("[-->] Sent | %s", str(packet))
except asyncio.TimeoutError:
pass # need this to recheck self._dispatching periodically
except Exception:
self._logger.exception("Exception dispatching packet %s", str(packet))
packet.processed.set() # Notify that packet has been processed

View file

@ -1,4 +0,0 @@
"""Minecraft definitions"""
from .packet import Packet
from .mctypes import *
from .proto import *

View file

@ -16,3 +16,10 @@ class Gamemode(Enum):
CREATIVE = 1
ADVENTURE = 2
SPECTATOR = 3
class ConnectionState(Enum):
NONE = -1
HANDSHAKING = 0
STATUS = 1
LOGIN = 2
PLAY = 3

View file

@ -3,14 +3,14 @@ import json
from asyncio import Event
from typing import Tuple, Dict, Any
from .mctypes import Type, VarInt
from .types import Type, VarInt
class Packet:
__slots__ = 'id', 'definition', 'sent', '_protocol', '_state'
__slots__ = 'id', 'definition', '_processed', '_protocol', '_state'
id : int
definition : Tuple[Tuple[str, Type]]
sent : Event
_processed : Event
_protocol : int
_state : int
@ -19,12 +19,17 @@ class Packet:
def __init__(self, proto:int, **kwargs):
self._protocol = proto
self._processed = Event()
self.definition = self._definitions[proto]
self.sent = Event()
self.id = self._ids[proto]
for name, t in self.definition:
setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None)
@property
def processed(self) -> Event:
"""Returns an event which will be set only after the packet has been processed (either sent or raised exc)"""
return self._processed
@classmethod
def deserialize(cls, proto:int, buffer:io.BytesIO):
return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._definitions[proto] })

View file

@ -101,8 +101,10 @@ class Token:
async with aiohttp.ClientSession() as sess:
async with sess.post(endpoint, headers=cls.HEADERS, data=json.dumps(data).encode('utf-8')) as res:
data = await res.json(content_type=None)
logging.info(f"Auth request | {res.status} | {data}")
if res.status >= 400:
raise AuthException(f"Action '{endpoint.rsplit('/',1)[1]}' did not succeed")
err_type = data["error"] if data and "error" in data else "Unknown Error"
err_msg = data["errorMessage"] if data and "errorMessage" in data else "Credentials invalid or token not refreshable anymore"
action = endpoint.rsplit('/',1)[1]
raise AuthException(f"[{action}] {err_type} : {err_msg}")
return data

View file

View file

@ -7,13 +7,13 @@ import logging
from pathlib import Path
from typing import List, Dict, Union
from aiocraft.mc.mctypes import *
from aiocraft.mc.types import *
DIR_MAP = {"toClient": "clientbound", "toServer": "serverbound"}
PREFACE = """\"\"\"[!] This file is autogenerated\"\"\"\n\n"""
IMPORTS = """from typing import Tuple, List, Dict
from ....packet import Packet
from ....mctypes import *\n"""
from ....types import *\n"""
IMPORT_ALL = """__all__ = [\n\t{all}\n]\n"""
REGISTRY_ENTRY = """
REGISTRY = {entries}\n"""