made things more generic for a mini server impl

This commit is contained in:
əlemi 2021-11-20 15:07:34 +01:00 committed by alemidev
parent f7ec83c8d5
commit a5b34b4d78
2 changed files with 90 additions and 67 deletions

View file

@ -5,7 +5,7 @@ import zlib
import logging import logging
from asyncio import StreamReader, StreamWriter, Queue, Task from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum from enum import Enum
from typing import Dict, Optional, AsyncIterator from typing import List, Dict, Optional, AsyncIterator
from cryptography.hazmat.primitives.ciphers import CipherContext from cryptography.hazmat.primitives.ciphers import CipherContext
@ -31,19 +31,6 @@ _STATE_REGS = {
ConnectionState.PLAY : proto.play.clientbound.REGISTRY, ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
} }
class PacketFrame:
_packet : Packet
_queue : Queue
def __init__(self, packet:Packet, queue:Queue):
self._packet = packet
self._queue = queue
def __enter__(self):
return self._packet
def __exit__(self):
self.queue.task_done()
class Dispatcher: class Dispatcher:
_down : StreamReader _down : StreamReader
_reader : Optional[Task] _reader : Optional[Task]
@ -69,18 +56,7 @@ class Dispatcher:
_logger : logging.Logger _logger : logging.Logger
def __init__(self): def __init__(self):
self.proto = 340 self._prepare()
self._dispatching = False
self.compression = None
self.encryption = False
self._incoming = Queue()
self._outgoing = Queue()
self._reader = None
self._writer = None
self._host = "localhost"
self._port = 25565
self._logger = LOGGER.getChild(f"{self._host}:{self._port}")
@property @property
def connected(self) -> bool: def connected(self) -> bool:
@ -103,6 +79,71 @@ class Dispatcher:
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass # so we recheck self.connected pass # so we recheck self.connected
def encrypt(self, secret:bytes):
cipher = encryption.create_AES_cipher(secret)
self._encryptor = cipher.encryptor()
self._decryptor = cipher.decryptor()
self.encryption = True
self._logger.info("Encryption enabled")
def _prepare(self, host:Optional[str] = None, port:Optional[int] = None, queue_timeout:int = 1, queue_size:int = 100):
self._host = host or self._host or "localhost"
self._port = port or self._port or 25565
self._logger = LOGGER.getChild(f"@({self._host}:{self._port})")
self.encryption = False
self.compression = None
self.state = ConnectionState.HANDSHAKING
self.proto = 340 # TODO
# Make new queues, do set a max size to sorta propagate back pressure
self._incoming = Queue(queue_size)
self._outgoing = Queue(queue_size)
self._dispatching = False
self._reader = None
self._writer = None
async def connect(self,
host : Optional[str] = None,
port : Optional[int] = None,
queue_timeout : int = 1,
queue_size : int = 100
):
if self.connected:
raise InvalidState("Dispatcher already connected")
self._prepare(host, port, queue_timeout, queue_size)
self._down, self._up = await asyncio.open_connection(
host=self._host,
port=self._port,
)
self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker(timeout=queue_timeout))
self._logger.info("Connected")
@classmethod
def serve(cls,
container : List[Dispatcher],
host : Optional[str] = None,
port : Optional[int] = None,
queue_timeout : int = 1,
queue_size : int = 100
):
async def _client_connected(reader:StreamReader, writer:StreamWriter):
dispatcher = cls()
container.append(dispatcher)
dispatcher._prepare(host, port, queue_timeout, queue_size)
dispatcher._down, dispatcher._up = reader, writer
dispatcher._dispatching = True
dispatcher._reader = asyncio.get_event_loop().create_task(dispatcher._down_worker())
dispatcher._writer = asyncio.get_event_loop().create_task(dispatcher._up_worker(timeout=queue_timeout))
dispatcher._logger.info("Serving client")
return _client_connected
async def disconnect(self, block:bool=True): async def disconnect(self, block:bool=True):
self._dispatching = False self._dispatching = False
if block and self._writer and self._reader: if block and self._writer and self._reader:
@ -120,42 +161,6 @@ class Dispatcher:
self._logger.debug("Socket closed") self._logger.debug("Socket closed")
self._logger.info("Disconnected") self._logger.info("Disconnected")
async def connect(self, host:Optional[str] = None, port:Optional[int] = None, queue_timeout:int = 1, queue_size:int = 100):
if self.connected:
raise InvalidState("Dispatcher already connected")
if host is not None:
self._host = host
if port is not None:
self._port = port
self._logger = LOGGER.getChild(f"{self._host}:{self._port}")
self.encryption = False
self.compression = None
self.state = ConnectionState.HANDSHAKING
# self.proto = 340 # TODO
# Make new queues, do set a max size to sorta propagate back pressure
self._incoming = Queue(queue_size)
self._outgoing = Queue(queue_size)
self._down, self._up = await asyncio.open_connection(
host=self._host,
port=self._port,
)
self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker(timeout=queue_timeout))
self._logger.info("Connected")
def encrypt(self, secret:bytes):
cipher = encryption.create_AES_cipher(secret)
self._encryptor = cipher.encryptor()
self._decryptor = cipher.decryptor()
self.encryption = True
self._logger.info("Encryption enabled")
async def _read_varint(self) -> int: async def _read_varint(self) -> int:
numRead = 0 numRead = 0
result = 0 result = 0

View file

@ -11,6 +11,12 @@ import aiohttp
class AuthException(Exception): class AuthException(Exception):
pass pass
def _raise_from_json(endpoint:str, data:dict):
err_type = data["error"] if data and "error" in data else "Unknown Error"
err_msg = data["errorMessage"] if data and "errorMessage" in data else "Credentials invalid or token not refreshable anymore"
action = endpoint.rsplit('/',1)[1]
raise AuthException(f"[{action}] {err_type} : {err_msg}")
@dataclass @dataclass
class Profile: class Profile:
id : str id : str
@ -96,15 +102,27 @@ class Token:
"selectedProfile": self.profile.dict() "selectedProfile": self.profile.dict()
}) })
@classmethod
async def server_join(cls, username:str, serverId:str, ip:Optional[str] = None):
params = {"username":username, "serverId":serverId}
if ip:
params["ip"] = ip
return await cls._get(cls.SESSION_SERVER + "/hasJoined", params)
@classmethod @classmethod
async def _post(cls, endpoint:str, data:dict) -> dict: async def _post(cls, endpoint:str, data:dict) -> dict:
async with aiohttp.ClientSession() as sess: async with aiohttp.ClientSession() as sess:
async with sess.post(endpoint, headers=cls.HEADERS, data=json.dumps(data).encode('utf-8')) as res: async with sess.post(endpoint, headers=cls.HEADERS, json=data) as res:
data = await res.json(content_type=None) data = await res.json(content_type=None)
if res.status >= 400: if res.status >= 400:
err_type = data["error"] if data and "error" in data else "Unknown Error" _raise_from_json(endpoint, data)
err_msg = data["errorMessage"] if data and "errorMessage" in data else "Credentials invalid or token not refreshable anymore"
action = endpoint.rsplit('/',1)[1]
raise AuthException(f"[{action}] {err_type} : {err_msg}")
return data return data
@classmethod
async def _get(cls, endpoint:str, data:dict) -> dict:
async with aiohttp.ClientSession() as sess:
async with sess.get(endpoint, headers=cls.HEADERS, params=data) as res:
data = await res.json(content_type=None)
if res.status >= 400:
_raise_from_json(endpoint, data)
return data