diff --git a/aiocraft/dispatcher.py b/aiocraft/dispatcher.py index 1117d2c..3c6432a 100644 --- a/aiocraft/dispatcher.py +++ b/aiocraft/dispatcher.py @@ -5,7 +5,7 @@ import zlib import logging from asyncio import StreamReader, StreamWriter, Queue, Task from enum import Enum -from typing import Dict, Optional, AsyncIterator +from typing import List, Dict, Optional, AsyncIterator from cryptography.hazmat.primitives.ciphers import CipherContext @@ -31,19 +31,6 @@ _STATE_REGS = { ConnectionState.PLAY : proto.play.clientbound.REGISTRY, } -class PacketFrame: - _packet : Packet - _queue : Queue - - def __init__(self, packet:Packet, queue:Queue): - self._packet = packet - self._queue = queue - - def __enter__(self): - return self._packet - def __exit__(self): - self.queue.task_done() - class Dispatcher: _down : StreamReader _reader : Optional[Task] @@ -69,18 +56,7 @@ class Dispatcher: _logger : logging.Logger def __init__(self): - self.proto = 340 - self._dispatching = False - self.compression = None - self.encryption = False - self._incoming = Queue() - self._outgoing = Queue() - self._reader = None - self._writer = None - self._host = "localhost" - self._port = 25565 - - self._logger = LOGGER.getChild(f"{self._host}:{self._port}") + self._prepare() @property def connected(self) -> bool: @@ -103,6 +79,71 @@ class Dispatcher: except asyncio.TimeoutError: pass # so we recheck self.connected + def encrypt(self, secret:bytes): + cipher = encryption.create_AES_cipher(secret) + self._encryptor = cipher.encryptor() + self._decryptor = cipher.decryptor() + self.encryption = True + self._logger.info("Encryption enabled") + + def _prepare(self, host:Optional[str] = None, port:Optional[int] = None, queue_timeout:int = 1, queue_size:int = 100): + self._host = host or self._host or "localhost" + self._port = port or self._port or 25565 + self._logger = LOGGER.getChild(f"@({self._host}:{self._port})") + + self.encryption = False + self.compression = None + self.state = ConnectionState.HANDSHAKING + self.proto = 340 # TODO + + # Make new queues, do set a max size to sorta propagate back pressure + self._incoming = Queue(queue_size) + self._outgoing = Queue(queue_size) + self._dispatching = False + self._reader = None + self._writer = None + + async def connect(self, + host : Optional[str] = None, + port : Optional[int] = None, + queue_timeout : int = 1, + queue_size : int = 100 + ): + if self.connected: + raise InvalidState("Dispatcher already connected") + + self._prepare(host, port, queue_timeout, queue_size) + + self._down, self._up = await asyncio.open_connection( + host=self._host, + port=self._port, + ) + + self._dispatching = True + self._reader = asyncio.get_event_loop().create_task(self._down_worker()) + self._writer = asyncio.get_event_loop().create_task(self._up_worker(timeout=queue_timeout)) + self._logger.info("Connected") + + @classmethod + def serve(cls, + container : List[Dispatcher], + host : Optional[str] = None, + port : Optional[int] = None, + queue_timeout : int = 1, + queue_size : int = 100 + ): + async def _client_connected(reader:StreamReader, writer:StreamWriter): + dispatcher = cls() + container.append(dispatcher) + dispatcher._prepare(host, port, queue_timeout, queue_size) + + dispatcher._down, dispatcher._up = reader, writer + dispatcher._dispatching = True + dispatcher._reader = asyncio.get_event_loop().create_task(dispatcher._down_worker()) + dispatcher._writer = asyncio.get_event_loop().create_task(dispatcher._up_worker(timeout=queue_timeout)) + dispatcher._logger.info("Serving client") + return _client_connected + async def disconnect(self, block:bool=True): self._dispatching = False if block and self._writer and self._reader: @@ -120,42 +161,6 @@ class Dispatcher: self._logger.debug("Socket closed") self._logger.info("Disconnected") - async def connect(self, host:Optional[str] = None, port:Optional[int] = None, queue_timeout:int = 1, queue_size:int = 100): - if self.connected: - raise InvalidState("Dispatcher already connected") - - if host is not None: - self._host = host - if port is not None: - self._port = port - self._logger = LOGGER.getChild(f"{self._host}:{self._port}") - - self.encryption = False - self.compression = None - self.state = ConnectionState.HANDSHAKING - # self.proto = 340 # TODO - - # Make new queues, do set a max size to sorta propagate back pressure - self._incoming = Queue(queue_size) - self._outgoing = Queue(queue_size) - - self._down, self._up = await asyncio.open_connection( - host=self._host, - port=self._port, - ) - - self._dispatching = True - self._reader = asyncio.get_event_loop().create_task(self._down_worker()) - self._writer = asyncio.get_event_loop().create_task(self._up_worker(timeout=queue_timeout)) - self._logger.info("Connected") - - def encrypt(self, secret:bytes): - cipher = encryption.create_AES_cipher(secret) - self._encryptor = cipher.encryptor() - self._decryptor = cipher.decryptor() - self.encryption = True - self._logger.info("Encryption enabled") - async def _read_varint(self) -> int: numRead = 0 result = 0 diff --git a/aiocraft/mc/token.py b/aiocraft/mc/token.py index beafe20..1cc9363 100644 --- a/aiocraft/mc/token.py +++ b/aiocraft/mc/token.py @@ -11,6 +11,12 @@ import aiohttp class AuthException(Exception): pass +def _raise_from_json(endpoint:str, data:dict): + err_type = data["error"] if data and "error" in data else "Unknown Error" + err_msg = data["errorMessage"] if data and "errorMessage" in data else "Credentials invalid or token not refreshable anymore" + action = endpoint.rsplit('/',1)[1] + raise AuthException(f"[{action}] {err_type} : {err_msg}") + @dataclass class Profile: id : str @@ -96,15 +102,27 @@ class Token: "selectedProfile": self.profile.dict() }) + @classmethod + async def server_join(cls, username:str, serverId:str, ip:Optional[str] = None): + params = {"username":username, "serverId":serverId} + if ip: + params["ip"] = ip + return await cls._get(cls.SESSION_SERVER + "/hasJoined", params) + @classmethod async def _post(cls, endpoint:str, data:dict) -> dict: async with aiohttp.ClientSession() as sess: - async with sess.post(endpoint, headers=cls.HEADERS, data=json.dumps(data).encode('utf-8')) as res: + async with sess.post(endpoint, headers=cls.HEADERS, json=data) as res: data = await res.json(content_type=None) if res.status >= 400: - err_type = data["error"] if data and "error" in data else "Unknown Error" - err_msg = data["errorMessage"] if data and "errorMessage" in data else "Credentials invalid or token not refreshable anymore" - action = endpoint.rsplit('/',1)[1] - raise AuthException(f"[{action}] {err_type} : {err_msg}") + _raise_from_json(endpoint, data) return data + @classmethod + async def _get(cls, endpoint:str, data:dict) -> dict: + async with aiohttp.ClientSession() as sess: + async with sess.get(endpoint, headers=cls.HEADERS, params=data) as res: + data = await res.json(content_type=None) + if res.status >= 400: + _raise_from_json(endpoint, data) + return data