restructured project, improved modularity

This commit is contained in:
əlemi 2021-11-17 16:57:02 +01:00
parent f282ed6665
commit cd402f668a
12 changed files with 175 additions and 207 deletions

View file

@ -4,10 +4,10 @@ import asyncio
import logging import logging
from .mc.proto.play.clientbound import PacketChat from .mc.proto.play.clientbound import PacketChat
from .mc.identity import Token from .mc.token import Token
from .dispatcher import ConnectionState from .dispatcher import ConnectionState
from .client import Client from .client import Client
from .helpers import parse_chat from .util.helpers import parse_chat
async def idle(): async def idle():
while True: while True:

View file

@ -3,35 +3,23 @@ import logging
from asyncio import Task from asyncio import Task
from enum import Enum from enum import Enum
from typing import Dict, List, Callable, Type, Optional, Tuple from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator
from .dispatcher import Dispatcher, ConnectionState from .dispatcher import Dispatcher
from .mc.mctypes import VarInt
from .mc.packet import Packet from .mc.packet import Packet
from .mc.identity import Token, AuthException from .mc.token import Token, AuthException
from .mc.definitions import Dimension, Difficulty, Gamemode from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
from .mc import proto, encryption from .mc.proto.handshaking.serverbound import PacketSetProtocol
from .mc.proto.play.serverbound import PacketKeepAlive as PacketKeepAliveResponse
from .mc.proto.play.clientbound import PacketKeepAlive, PacketSetCompression, PacketKickDisconnect
from .mc.proto.login.serverbound import PacketLoginStart, PacketEncryptionBegin as PacketEncryptionResponse
from .mc.proto.login.clientbound import (
PacketCompress, PacketDisconnect, PacketEncryptionBegin, PacketLoginPluginRequest, PacketSuccess
)
from .util import encryption
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Type[Packet]]]:
if state == ConnectionState.HANDSHAKING:
return proto.handshaking.clientbound.REGISTRY
if state == ConnectionState.STATUS:
return proto.status.clientbound.REGISTRY
if state == ConnectionState.LOGIN:
return proto.login.clientbound.REGISTRY
if state == ConnectionState.PLAY:
return proto.play.clientbound.REGISTRY
return {}
_STATE_REGS = {
ConnectionState.HANDSHAKING : proto.handshaking,
ConnectionState.STATUS : proto.status,
ConnectionState.LOGIN : proto.login,
ConnectionState.PLAY : proto.play,
}
class Client: class Client:
host:str host:str
port:int port:int
@ -46,7 +34,7 @@ class Client:
_authenticated : bool _authenticated : bool
_worker : Task _worker : Task
_packet_callbacks : Dict[ConnectionState, Dict[Packet, List[Callable]]] _packet_callbacks : Dict[Type[Packet], List[Callable]]
_logger : logging.Logger _logger : logging.Logger
def __init__( def __init__(
@ -64,6 +52,8 @@ class Client:
self.options = options or { self.options = options or {
"reconnect" : True, "reconnect" : True,
"rctime" : 5.0, "rctime" : 5.0,
"keep-alive" : True,
"poll-timeout" : 1,
} }
self.token = token self.token = token
@ -142,7 +132,7 @@ class Client:
try: try:
while self._processing: # TODO don't busywait even if it doesn't matter much while self._processing: # TODO don't busywait even if it doesn't matter much
await asyncio.sleep(5) await asyncio.sleep(self.options["poll-timeout"])
except KeyboardInterrupt: except KeyboardInterrupt:
self._logger.info("Received SIGINT, stopping...") self._logger.info("Received SIGINT, stopping...")
else: else:
@ -157,171 +147,114 @@ class Client:
async def stop(self, block=True): async def stop(self, block=True):
self._processing = False self._processing = False
await self.dispatcher.disconnect() if self.dispatcher.connected:
await self.dispatcher.disconnect(block=block)
if block: if block:
await self._worker await self._worker
self._logger.info("Minecraft client stopped") self._logger.info("Minecraft client stopped")
async def _client_worker(self): async def _client_worker(self):
while self._processing: while self._processing:
try: try:
await self.authenticate() await self.authenticate()
except AuthException: except AuthException as e:
self._logger.error("Token not refreshable or credentials invalid") self._logger.error(str(e))
await self.stop(block=False) break
try: try:
await self.dispatcher.connect(self.host, self.port) await self.dispatcher.connect(self.host, self.port)
for packet in self._handshake(): await self._handshake()
await self.dispatcher.write(packet) if await self._login():
self.dispatcher.state = ConnectionState.LOGIN await self._play()
await self._process_packets()
except ConnectionRefusedError: except ConnectionRefusedError:
self._logger.error("Server rejected connection") self._logger.error("Server rejected connection")
except Exception: except Exception:
self._logger.exception("Exception in Client connection") self._logger.exception("Exception in Client connection")
if self.dispatcher.connected:
await self.dispatcher.disconnect()
if not self.options["reconnect"]: if not self.options["reconnect"]:
await self.stop(block=False)
break break
await asyncio.sleep(self.options["rctime"]) await asyncio.sleep(self.options["rctime"])
await self.stop(block=False)
def _handshake(self, force:bool=False) -> Tuple[Packet, Packet]: # TODO make this fancier! poll for version and status first async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first
return ( proto.handshaking.serverbound.PacketSetProtocol( await self.dispatcher.write(
PacketSetProtocol(
340, # TODO!!!! 340, # TODO!!!!
protocolVersion=340, protocolVersion=340,
serverHost=self.host, serverHost=self.host,
serverPort=self.port, serverPort=self.port,
nextState=2, # play nextState=2, # play
), )
proto.login.serverbound.PacketLoginStart( )
await self.dispatcher.write(
PacketLoginStart(
340, 340,
username=self.token.profile.name if self.token else self.username username=self.token.profile.name if self.token else self.username
) )
) )
return True
async def _process_packets(self): async def _login(self) -> bool:
while self.dispatcher.connected: self.dispatcher.state = ConnectionState.LOGIN
try: async for packet in self.dispatcher.packets():
packet = await asyncio.wait_for(self.dispatcher.incoming.get(), timeout=5) if isinstance(packet, PacketEncryptionBegin):
self._logger.debug("[ * ] Processing | %s", str(packet)) secret = encryption.generate_shared_secret()
token, encrypted_secret = encryption.encrypt_token_and_secret(
if self.dispatcher.state == ConnectionState.LOGIN: packet.publicKey,
await self.login_logic(packet) packet.verifyToken,
elif self.dispatcher.state == ConnectionState.PLAY: secret
await self.play_logic(packet)
if self.dispatcher.state in self._packet_callbacks:
if Packet in self._packet_callbacks[self.dispatcher.state]: # callback for any packet
for cb in self._packet_callbacks[self.dispatcher.state][Packet]:
try:
await cb(packet)
except Exception as e:
self._logger.exception("Exception while handling callback")
if packet.__class__ in self._packet_callbacks[self.dispatcher.state]: # callback for this packet
for cb in self._packet_callbacks[self.dispatcher.state][packet.__class__]:
try:
await cb(packet)
except Exception as e:
self._logger.exception("Exception while handling callback")
self.dispatcher.incoming.task_done()
except asyncio.TimeoutError:
pass # need this to recheck self._processing periodically
except AuthException:
self._authenticated = False
self._logger.error("Authentication exception")
await self.dispatcher.disconnect(block=False)
except Exception:
self._logger.exception("Exception while processing packet %s", packet)
# TODO move these in separate module
async def login_logic(self, packet:Packet):
if isinstance(packet, proto.login.clientbound.PacketEncryptionBegin):
secret = encryption.generate_shared_secret()
token, encrypted_secret = encryption.encrypt_token_and_secret(
packet.publicKey,
packet.verifyToken,
secret
)
if packet.serverId != '-' and self.token:
await self.token.join(
encryption.generate_verification_hash(
packet.serverId,
secret,
packet.publicKey
)
) )
if packet.serverId != '-' and self.token:
encryption_response = proto.login.serverbound.PacketEncryptionBegin( try:
340, # TODO!!!! await self.token.join(
sharedSecret=encrypted_secret, encryption.generate_verification_hash(
verifyToken=token packet.serverId,
) secret,
packet.publicKey
await self.dispatcher.write(encryption_response, wait=True) )
)
self.dispatcher.encrypt(secret) except AuthException:
self._logger.error("Could not authenticate with Mojang")
elif isinstance(packet, proto.login.clientbound.PacketDisconnect): break
self._logger.error("Kicked while logging in") encryption_response = PacketEncryptionResponse(
await self.dispatcher.disconnect(block=False) 340, # TODO!!!!
# raise Exception("Disconnected while logging in") # TODO make a more specific one, do some shit sharedSecret=encrypted_secret,
verifyToken=token
elif isinstance(packet, proto.login.clientbound.PacketCompress):
self._logger.info("Compression enabled")
self.dispatcher.compression = packet.threshold
elif isinstance(packet, proto.login.clientbound.PacketSuccess):
self._logger.info("Login success, joining world...")
self.dispatcher.state = ConnectionState.PLAY
elif isinstance(packet, proto.login.clientbound.PacketLoginPluginRequest):
pass # TODO ?
async def play_logic(self, packet:Packet):
if isinstance(packet, proto.play.clientbound.PacketSetCompression):
self._logger.info("Compression updated")
self.dispatcher.compression = packet.threshold
elif isinstance(packet, proto.play.clientbound.PacketKeepAlive):
keep_alive_packet = proto.play.serverbound.packet_keep_alive.PacketKeepAlive(340, keepAliveId=packet.keepAliveId)
await self.dispatcher.write(keep_alive_packet)
elif isinstance(packet, proto.play.clientbound.PacketRespawn):
self._logger.info(
"Reloading world: %s (%s) in %s",
Dimension(packet.dimension).name,
Difficulty(packet.difficulty).name,
Gamemode(packet.gamemode).name
)
elif isinstance(packet, proto.play.clientbound.PacketLogin):
self._logger.info(
"Joined world: %s (%s) in %s",
Dimension(packet.dimension).name,
Difficulty(packet.difficulty).name,
Gamemode(packet.gameMode).name
)
elif isinstance(packet, proto.play.clientbound.PacketPosition):
self._logger.info("Position synchronized")
await self.dispatcher.write(
proto.play.serverbound.PacketTeleportConfirm(
340,
teleportId=packet.teleportId
) )
) await self.dispatcher.write(encryption_response, wait=True)
self.dispatcher.encrypt(secret)
elif isinstance(packet, proto.play.clientbound.PacketUpdateHealth): elif isinstance(packet, PacketCompress):
if packet.health <= 0: self._logger.info("Compression enabled")
self._logger.info("Dead, respawning...") self.dispatcher.compression = packet.threshold
await self.dispatcher.write( elif isinstance(packet, PacketLoginPluginRequest):
proto.play.serverbound.PacketClientCommand(self.dispatcher.proto, actionId=0) # respawn self._logger.info("Ignoring plugin request") # TODO ?
) elif isinstance(packet, PacketSuccess):
self._logger.info("Login success, joining world...")
elif isinstance(packet, proto.play.clientbound.PacketKickDisconnect): return True
self._logger.error("Kicked while in game") elif isinstance(packet, PacketDisconnect):
await self.dispatcher.disconnect(block=False) self._logger.error("Kicked while logging in")
break
return False
async def _play(self) -> bool:
self.dispatcher.state = ConnectionState.PLAY
async for packet in self.dispatcher.packets():
self._logger.debug("[ * ] Processing | %s", str(packet))
if isinstance(packet, PacketSetCompression):
self._logger.info("Compression updated")
self.dispatcher.compression = packet.threshold
elif isinstance(packet, PacketKeepAlive):
if self.options["keep-alive"]:
keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId)
await self.dispatcher.write(keep_alive_packet)
elif isinstance(packet, PacketKickDisconnect):
self._logger.error("Kicked while in game")
break
for packet_type in (Packet, packet.__class__): # check both callbacks for base class and instance class
if packet_type in self._packet_callbacks:
for cb in self._packet_callbacks[packet_type]:
try: # TODO run in executor to not block
await cb(packet)
except Exception as e:
self._logger.exception("Exception while handling callback")
return False

View file

@ -1,26 +1,22 @@
import io import io
import asyncio import asyncio
import contextlib
import zlib import zlib
import logging import logging
from asyncio import StreamReader, StreamWriter, Queue, Task from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional, AsyncIterator
from cryptography.hazmat.primitives.ciphers import CipherContext from cryptography.hazmat.primitives.ciphers import CipherContext
from .mc import proto from .mc import proto
from .mc.mctypes import VarInt from .mc.types import VarInt
from .mc.packet import Packet from .mc.packet import Packet
from .mc import encryption from .mc.definitions import ConnectionState
from .util import encryption
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
class ConnectionState(Enum):
NONE = -1
HANDSHAKING = 0
STATUS = 1
LOGIN = 2
PLAY = 3
class InvalidState(Exception): class InvalidState(Exception):
pass pass
@ -35,6 +31,19 @@ _STATE_REGS = {
ConnectionState.PLAY : proto.play.clientbound.REGISTRY, ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
} }
class PacketFrame:
_packet : Packet
_queue : Queue
def __init__(self, packet:Packet, queue:Queue):
self._packet = packet
self._queue = queue
def __enter__(self):
return self._packet
def __exit__(self):
self.queue.task_done()
class Dispatcher: class Dispatcher:
_down : StreamReader _down : StreamReader
_reader : Optional[Task] _reader : Optional[Task]
@ -46,8 +55,8 @@ class Dispatcher:
_dispatching : bool _dispatching : bool
incoming : Queue _incoming : Queue
outgoing : Queue _outgoing : Queue
_host : str _host : str
_port : int _port : int
@ -64,8 +73,8 @@ class Dispatcher:
self._dispatching = False self._dispatching = False
self.compression = None self.compression = None
self.encryption = False self.encryption = False
self.incoming = Queue() self._incoming = Queue()
self.outgoing = Queue() self._outgoing = Queue()
self._reader = None self._reader = None
self._writer = None self._writer = None
self._host = "localhost" self._host = "localhost"
@ -78,24 +87,37 @@ class Dispatcher:
return self._dispatching return self._dispatching
async def write(self, packet:Packet, wait:bool=False) -> int: async def write(self, packet:Packet, wait:bool=False) -> int:
await self.outgoing.put(packet) await self._outgoing.put(packet)
if wait: if wait:
await packet.sent.wait() await packet.processed.wait()
return self.outgoing.qsize() return self._outgoing.qsize()
async def packets(self, timeout=1) -> AsyncIterator[Packet]:
while self.connected:
try: # TODO replace this timed busy-wait with an event which resolves upon disconnection, and await both
packet = await asyncio.wait_for(self._incoming.get(), timeout=timeout)
try:
yield packet
finally:
self._incoming.task_done()
except asyncio.TimeoutError:
pass # so we recheck self.connected
async def disconnect(self, block:bool=True): async def disconnect(self, block:bool=True):
self._dispatching = False self._dispatching = False
if block and self._writer and self._reader: if block and self._writer and self._reader:
await asyncio.gather(self._writer, self._reader) await asyncio.gather(self._writer, self._reader)
self._logger.debug("Net workers stopped")
if self._up: if self._up:
if self._up.can_write_eof(): if self._up.can_write_eof():
self._up.write_eof() self._up.write_eof()
self._up.close() self._up.close()
if block: if block:
await self._up.wait_closed() await self._up.wait_closed()
self._logger.debug("Socket closed")
self._logger.info("Disconnected") self._logger.info("Disconnected")
async def connect(self, host:Optional[str] = None, port:Optional[int] = None): async def connect(self, host:Optional[str] = None, port:Optional[int] = None, queue_timeout:int = 1, queue_size:int = 100):
if self.connected: if self.connected:
raise InvalidState("Dispatcher already connected") raise InvalidState("Dispatcher already connected")
@ -110,9 +132,9 @@ class Dispatcher:
self.state = ConnectionState.HANDSHAKING self.state = ConnectionState.HANDSHAKING
# self.proto = 340 # TODO # self.proto = 340 # TODO
# Make new queues # Make new queues, do set a max size to sorta propagate back pressure
self.incoming = Queue() self._incoming = Queue(queue_size)
self.outgoing = Queue() self._outgoing = Queue(queue_size)
self._down, self._up = await asyncio.open_connection( self._down, self._up = await asyncio.open_connection(
host=self._host, host=self._host,
@ -121,7 +143,7 @@ class Dispatcher:
self._dispatching = True self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker()) self._reader = asyncio.get_event_loop().create_task(self._down_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker()) self._writer = asyncio.get_event_loop().create_task(self._up_worker(timeout=queue_timeout))
self._logger.info("Connected") self._logger.info("Connected")
def encrypt(self, secret:bytes): def encrypt(self, secret:bytes):
@ -149,8 +171,6 @@ class Dispatcher:
async def _down_worker(self): async def _down_worker(self):
while self._dispatching: while self._dispatching:
if self.state != ConnectionState.PLAY:
await self.incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression
try: # these 2 will timeout or raise EOFError if client gets disconnected try: # these 2 will timeout or raise EOFError if client gets disconnected
length = await self._read_varint() length = await self._read_varint()
data = await self._down.readexactly(length) data = await self._down.readexactly(length)
@ -173,7 +193,9 @@ class Dispatcher:
cls = _STATE_REGS[self.state][self.proto][packet_id] cls = _STATE_REGS[self.state][self.proto][packet_id]
packet = cls.deserialize(self.proto, buffer) packet = cls.deserialize(self.proto, buffer)
self._logger.debug("[<--] Received | %s", str(packet)) self._logger.debug("[<--] Received | %s", str(packet))
await self.incoming.put(packet) await self._incoming.put(packet)
if self.state == ConnectionState.LOGIN:
await self._incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression
except AttributeError: except AttributeError:
self._logger.debug("Unimplemented packet %d", packet_id) self._logger.debug("Unimplemented packet %d", packet_id)
except asyncio.IncompleteReadError: except asyncio.IncompleteReadError:
@ -182,10 +204,14 @@ class Dispatcher:
except Exception: except Exception:
self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue()) self._logger.exception("Exception parsing packet %d | %s", packet_id, buffer.getvalue())
async def _up_worker(self): async def _up_worker(self, timeout=1):
while self._dispatching: while self._dispatching:
try: try:
packet = await asyncio.wait_for(self.outgoing.get(), timeout=5) packet = await asyncio.wait_for(self._outgoing.get(), timeout=timeout)
except asyncio.TimeoutError:
continue # check again self._dispatching
try:
buffer = packet.serialize() buffer = packet.serialize()
length = len(buffer.getvalue()) # ewww TODO length = len(buffer.getvalue()) # ewww TODO
@ -209,10 +235,9 @@ class Dispatcher:
self._up.write(data) self._up.write(data)
await self._up.drain() await self._up.drain()
packet.sent.set() # Notify
self._logger.debug("[-->] Sent | %s", str(packet)) self._logger.debug("[-->] Sent | %s", str(packet))
except asyncio.TimeoutError:
pass # need this to recheck self._dispatching periodically
except Exception: except Exception:
self._logger.exception("Exception dispatching packet %s", str(packet)) self._logger.exception("Exception dispatching packet %s", str(packet))
packet.processed.set() # Notify that packet has been processed

View file

@ -1,4 +0,0 @@
"""Minecraft definitions"""
from .packet import Packet
from .mctypes import *
from .proto import *

View file

@ -16,3 +16,10 @@ class Gamemode(Enum):
CREATIVE = 1 CREATIVE = 1
ADVENTURE = 2 ADVENTURE = 2
SPECTATOR = 3 SPECTATOR = 3
class ConnectionState(Enum):
NONE = -1
HANDSHAKING = 0
STATUS = 1
LOGIN = 2
PLAY = 3

View file

@ -3,14 +3,14 @@ import json
from asyncio import Event from asyncio import Event
from typing import Tuple, Dict, Any from typing import Tuple, Dict, Any
from .mctypes import Type, VarInt from .types import Type, VarInt
class Packet: class Packet:
__slots__ = 'id', 'definition', 'sent', '_protocol', '_state' __slots__ = 'id', 'definition', '_processed', '_protocol', '_state'
id : int id : int
definition : Tuple[Tuple[str, Type]] definition : Tuple[Tuple[str, Type]]
sent : Event _processed : Event
_protocol : int _protocol : int
_state : int _state : int
@ -19,12 +19,17 @@ class Packet:
def __init__(self, proto:int, **kwargs): def __init__(self, proto:int, **kwargs):
self._protocol = proto self._protocol = proto
self._processed = Event()
self.definition = self._definitions[proto] self.definition = self._definitions[proto]
self.sent = Event()
self.id = self._ids[proto] self.id = self._ids[proto]
for name, t in self.definition: for name, t in self.definition:
setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None) setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None)
@property
def processed(self) -> Event:
"""Returns an event which will be set only after the packet has been processed (either sent or raised exc)"""
return self._processed
@classmethod @classmethod
def deserialize(cls, proto:int, buffer:io.BytesIO): def deserialize(cls, proto:int, buffer:io.BytesIO):
return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._definitions[proto] }) return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._definitions[proto] })

View file

@ -101,8 +101,10 @@ class Token:
async with aiohttp.ClientSession() as sess: async with aiohttp.ClientSession() as sess:
async with sess.post(endpoint, headers=cls.HEADERS, data=json.dumps(data).encode('utf-8')) as res: async with sess.post(endpoint, headers=cls.HEADERS, data=json.dumps(data).encode('utf-8')) as res:
data = await res.json(content_type=None) data = await res.json(content_type=None)
logging.info(f"Auth request | {res.status} | {data}")
if res.status >= 400: if res.status >= 400:
raise AuthException(f"Action '{endpoint.rsplit('/',1)[1]}' did not succeed") err_type = data["error"] if data and "error" in data else "Unknown Error"
err_msg = data["errorMessage"] if data and "errorMessage" in data else "Credentials invalid or token not refreshable anymore"
action = endpoint.rsplit('/',1)[1]
raise AuthException(f"[{action}] {err_type} : {err_msg}")
return data return data

View file

View file

@ -7,13 +7,13 @@ import logging
from pathlib import Path from pathlib import Path
from typing import List, Dict, Union from typing import List, Dict, Union
from aiocraft.mc.mctypes import * from aiocraft.mc.types import *
DIR_MAP = {"toClient": "clientbound", "toServer": "serverbound"} DIR_MAP = {"toClient": "clientbound", "toServer": "serverbound"}
PREFACE = """\"\"\"[!] This file is autogenerated\"\"\"\n\n""" PREFACE = """\"\"\"[!] This file is autogenerated\"\"\"\n\n"""
IMPORTS = """from typing import Tuple, List, Dict IMPORTS = """from typing import Tuple, List, Dict
from ....packet import Packet from ....packet import Packet
from ....mctypes import *\n""" from ....types import *\n"""
IMPORT_ALL = """__all__ = [\n\t{all}\n]\n""" IMPORT_ALL = """__all__ = [\n\t{all}\n]\n"""
REGISTRY_ENTRY = """ REGISTRY_ENTRY = """
REGISTRY = {entries}\n""" REGISTRY = {entries}\n"""