made things more generic for a mini server impl
This commit is contained in:
parent
f7ec83c8d5
commit
a5b34b4d78
2 changed files with 90 additions and 67 deletions
|
@ -5,7 +5,7 @@ import zlib
|
||||||
import logging
|
import logging
|
||||||
from asyncio import StreamReader, StreamWriter, Queue, Task
|
from asyncio import StreamReader, StreamWriter, Queue, Task
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, AsyncIterator
|
from typing import List, Dict, Optional, AsyncIterator
|
||||||
|
|
||||||
from cryptography.hazmat.primitives.ciphers import CipherContext
|
from cryptography.hazmat.primitives.ciphers import CipherContext
|
||||||
|
|
||||||
|
@ -31,19 +31,6 @@ _STATE_REGS = {
|
||||||
ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
|
ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
|
||||||
}
|
}
|
||||||
|
|
||||||
class PacketFrame:
|
|
||||||
_packet : Packet
|
|
||||||
_queue : Queue
|
|
||||||
|
|
||||||
def __init__(self, packet:Packet, queue:Queue):
|
|
||||||
self._packet = packet
|
|
||||||
self._queue = queue
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self._packet
|
|
||||||
def __exit__(self):
|
|
||||||
self.queue.task_done()
|
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher:
|
||||||
_down : StreamReader
|
_down : StreamReader
|
||||||
_reader : Optional[Task]
|
_reader : Optional[Task]
|
||||||
|
@ -69,18 +56,7 @@ class Dispatcher:
|
||||||
_logger : logging.Logger
|
_logger : logging.Logger
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.proto = 340
|
self._prepare()
|
||||||
self._dispatching = False
|
|
||||||
self.compression = None
|
|
||||||
self.encryption = False
|
|
||||||
self._incoming = Queue()
|
|
||||||
self._outgoing = Queue()
|
|
||||||
self._reader = None
|
|
||||||
self._writer = None
|
|
||||||
self._host = "localhost"
|
|
||||||
self._port = 25565
|
|
||||||
|
|
||||||
self._logger = LOGGER.getChild(f"{self._host}:{self._port}")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connected(self) -> bool:
|
def connected(self) -> bool:
|
||||||
|
@ -103,6 +79,71 @@ class Dispatcher:
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
pass # so we recheck self.connected
|
pass # so we recheck self.connected
|
||||||
|
|
||||||
|
def encrypt(self, secret:bytes):
|
||||||
|
cipher = encryption.create_AES_cipher(secret)
|
||||||
|
self._encryptor = cipher.encryptor()
|
||||||
|
self._decryptor = cipher.decryptor()
|
||||||
|
self.encryption = True
|
||||||
|
self._logger.info("Encryption enabled")
|
||||||
|
|
||||||
|
def _prepare(self, host:Optional[str] = None, port:Optional[int] = None, queue_timeout:int = 1, queue_size:int = 100):
|
||||||
|
self._host = host or self._host or "localhost"
|
||||||
|
self._port = port or self._port or 25565
|
||||||
|
self._logger = LOGGER.getChild(f"@({self._host}:{self._port})")
|
||||||
|
|
||||||
|
self.encryption = False
|
||||||
|
self.compression = None
|
||||||
|
self.state = ConnectionState.HANDSHAKING
|
||||||
|
self.proto = 340 # TODO
|
||||||
|
|
||||||
|
# Make new queues, do set a max size to sorta propagate back pressure
|
||||||
|
self._incoming = Queue(queue_size)
|
||||||
|
self._outgoing = Queue(queue_size)
|
||||||
|
self._dispatching = False
|
||||||
|
self._reader = None
|
||||||
|
self._writer = None
|
||||||
|
|
||||||
|
async def connect(self,
|
||||||
|
host : Optional[str] = None,
|
||||||
|
port : Optional[int] = None,
|
||||||
|
queue_timeout : int = 1,
|
||||||
|
queue_size : int = 100
|
||||||
|
):
|
||||||
|
if self.connected:
|
||||||
|
raise InvalidState("Dispatcher already connected")
|
||||||
|
|
||||||
|
self._prepare(host, port, queue_timeout, queue_size)
|
||||||
|
|
||||||
|
self._down, self._up = await asyncio.open_connection(
|
||||||
|
host=self._host,
|
||||||
|
port=self._port,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._dispatching = True
|
||||||
|
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
|
||||||
|
self._writer = asyncio.get_event_loop().create_task(self._up_worker(timeout=queue_timeout))
|
||||||
|
self._logger.info("Connected")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serve(cls,
|
||||||
|
container : List[Dispatcher],
|
||||||
|
host : Optional[str] = None,
|
||||||
|
port : Optional[int] = None,
|
||||||
|
queue_timeout : int = 1,
|
||||||
|
queue_size : int = 100
|
||||||
|
):
|
||||||
|
async def _client_connected(reader:StreamReader, writer:StreamWriter):
|
||||||
|
dispatcher = cls()
|
||||||
|
container.append(dispatcher)
|
||||||
|
dispatcher._prepare(host, port, queue_timeout, queue_size)
|
||||||
|
|
||||||
|
dispatcher._down, dispatcher._up = reader, writer
|
||||||
|
dispatcher._dispatching = True
|
||||||
|
dispatcher._reader = asyncio.get_event_loop().create_task(dispatcher._down_worker())
|
||||||
|
dispatcher._writer = asyncio.get_event_loop().create_task(dispatcher._up_worker(timeout=queue_timeout))
|
||||||
|
dispatcher._logger.info("Serving client")
|
||||||
|
return _client_connected
|
||||||
|
|
||||||
async def disconnect(self, block:bool=True):
|
async def disconnect(self, block:bool=True):
|
||||||
self._dispatching = False
|
self._dispatching = False
|
||||||
if block and self._writer and self._reader:
|
if block and self._writer and self._reader:
|
||||||
|
@ -120,42 +161,6 @@ class Dispatcher:
|
||||||
self._logger.debug("Socket closed")
|
self._logger.debug("Socket closed")
|
||||||
self._logger.info("Disconnected")
|
self._logger.info("Disconnected")
|
||||||
|
|
||||||
async def connect(self, host:Optional[str] = None, port:Optional[int] = None, queue_timeout:int = 1, queue_size:int = 100):
|
|
||||||
if self.connected:
|
|
||||||
raise InvalidState("Dispatcher already connected")
|
|
||||||
|
|
||||||
if host is not None:
|
|
||||||
self._host = host
|
|
||||||
if port is not None:
|
|
||||||
self._port = port
|
|
||||||
self._logger = LOGGER.getChild(f"{self._host}:{self._port}")
|
|
||||||
|
|
||||||
self.encryption = False
|
|
||||||
self.compression = None
|
|
||||||
self.state = ConnectionState.HANDSHAKING
|
|
||||||
# self.proto = 340 # TODO
|
|
||||||
|
|
||||||
# Make new queues, do set a max size to sorta propagate back pressure
|
|
||||||
self._incoming = Queue(queue_size)
|
|
||||||
self._outgoing = Queue(queue_size)
|
|
||||||
|
|
||||||
self._down, self._up = await asyncio.open_connection(
|
|
||||||
host=self._host,
|
|
||||||
port=self._port,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._dispatching = True
|
|
||||||
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
|
|
||||||
self._writer = asyncio.get_event_loop().create_task(self._up_worker(timeout=queue_timeout))
|
|
||||||
self._logger.info("Connected")
|
|
||||||
|
|
||||||
def encrypt(self, secret:bytes):
|
|
||||||
cipher = encryption.create_AES_cipher(secret)
|
|
||||||
self._encryptor = cipher.encryptor()
|
|
||||||
self._decryptor = cipher.decryptor()
|
|
||||||
self.encryption = True
|
|
||||||
self._logger.info("Encryption enabled")
|
|
||||||
|
|
||||||
async def _read_varint(self) -> int:
|
async def _read_varint(self) -> int:
|
||||||
numRead = 0
|
numRead = 0
|
||||||
result = 0
|
result = 0
|
||||||
|
|
|
@ -11,6 +11,12 @@ import aiohttp
|
||||||
class AuthException(Exception):
|
class AuthException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _raise_from_json(endpoint:str, data:dict):
|
||||||
|
err_type = data["error"] if data and "error" in data else "Unknown Error"
|
||||||
|
err_msg = data["errorMessage"] if data and "errorMessage" in data else "Credentials invalid or token not refreshable anymore"
|
||||||
|
action = endpoint.rsplit('/',1)[1]
|
||||||
|
raise AuthException(f"[{action}] {err_type} : {err_msg}")
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Profile:
|
class Profile:
|
||||||
id : str
|
id : str
|
||||||
|
@ -96,15 +102,27 @@ class Token:
|
||||||
"selectedProfile": self.profile.dict()
|
"selectedProfile": self.profile.dict()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def server_join(cls, username:str, serverId:str, ip:Optional[str] = None):
|
||||||
|
params = {"username":username, "serverId":serverId}
|
||||||
|
if ip:
|
||||||
|
params["ip"] = ip
|
||||||
|
return await cls._get(cls.SESSION_SERVER + "/hasJoined", params)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _post(cls, endpoint:str, data:dict) -> dict:
|
async def _post(cls, endpoint:str, data:dict) -> dict:
|
||||||
async with aiohttp.ClientSession() as sess:
|
async with aiohttp.ClientSession() as sess:
|
||||||
async with sess.post(endpoint, headers=cls.HEADERS, data=json.dumps(data).encode('utf-8')) as res:
|
async with sess.post(endpoint, headers=cls.HEADERS, json=data) as res:
|
||||||
data = await res.json(content_type=None)
|
data = await res.json(content_type=None)
|
||||||
if res.status >= 400:
|
if res.status >= 400:
|
||||||
err_type = data["error"] if data and "error" in data else "Unknown Error"
|
_raise_from_json(endpoint, data)
|
||||||
err_msg = data["errorMessage"] if data and "errorMessage" in data else "Credentials invalid or token not refreshable anymore"
|
|
||||||
action = endpoint.rsplit('/',1)[1]
|
|
||||||
raise AuthException(f"[{action}] {err_type} : {err_msg}")
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _get(cls, endpoint:str, data:dict) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as sess:
|
||||||
|
async with sess.get(endpoint, headers=cls.HEADERS, params=data) as res:
|
||||||
|
data = await res.json(content_type=None)
|
||||||
|
if res.status >= 400:
|
||||||
|
_raise_from_json(endpoint, data)
|
||||||
|
return data
|
||||||
|
|
Loading…
Reference in a new issue