moved features out of aiocraft, cleaned up client states routines

This commit is contained in:
əlemi 2022-03-08 01:39:03 +01:00
parent 34def8b6cc
commit d3587f65ae
No known key found for this signature in database
GPG key ID: BBCBFE5D7244634E
5 changed files with 59 additions and 269 deletions

View file

@ -8,10 +8,9 @@ from asyncio import Task
from enum import Enum
from time import time
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any, Set
from .dispatcher import Dispatcher
from .traits import CallbacksHolder, Runnable
from .mc.packet import Packet
from .mc.auth import AuthInterface, AuthException, MojangToken, MicrosoftAuthenticator
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
@ -36,11 +35,7 @@ class ClientOptions:
poll_interval : float = 1.0
use_packet_whitelist : bool = True
class ClientEvent(Enum):
CONNECTED = 0
DISCONNECTED = 1
class MinecraftClient(CallbacksHolder, Runnable):
class MinecraftClient:
host:str
port:int
options:ClientOptions
@ -53,7 +48,6 @@ class MinecraftClient(CallbacksHolder, Runnable):
_processing : bool
_authenticated : bool
_worker : Task
_callbacks = Dict[str, Task]
_logger : logging.Logger
@ -90,35 +84,13 @@ class MinecraftClient(CallbacksHolder, Runnable):
self._logger = LOGGER.getChild(f"on({self.host}:{self.port})")
@property
def started(self) -> bool:
return self._processing
@property
def connected(self) -> bool:
return self.started and self.dispatcher.connected
return self.dispatcher.connected
async def write(self, packet:Packet, wait:bool=False):
await self.dispatcher.write(packet, wait)
def on_connected(self) -> Callable:
def wrapper(fun):
self.register(ClientEvent.CONNECTED, fun)
return fun
return wrapper
def on_disconnected(self) -> Callable:
def wrapper(fun):
self.register(ClientEvent.DISCONNECTED, fun)
return fun
return wrapper
def on_packet(self, packet:Type[Packet]) -> Callable:
def wrapper(fun):
self.register(packet, fun)
return fun
return wrapper
async def authenticate(self):
if self._authenticated:
return # Don't spam Auth endpoint!
@ -136,60 +108,49 @@ class MinecraftClient(CallbacksHolder, Runnable):
raise ValueError("No refreshable auth or code to login")
self._authenticated = True
async def change_server(self, server:str):
restart = self.started
if restart:
await self.stop()
if ":" in server:
_host, _port = server.split(":", 1)
self.host = _host.strip()
self.port = int(_port)
else:
self.host = server.strip()
self.port = 25565
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
if restart:
await self.start()
async def start(self):
await super().start()
if self.started:
return
self._processing = True
self._worker = asyncio.get_event_loop().create_task(self._client_worker())
self._logger.info("Minecraft client started")
async def stop(self, force:bool=False):
self._processing = False
if self.dispatcher.connected:
await self.dispatcher.disconnect(block=not force)
if not force:
await self._worker
self._logger.info("Minecraft client stopped")
if not force:
await self.join_callbacks()
await super().stop(force)
async def info(self, host:str="", port:int=0, ping:bool=False) -> Dict[str, Any]:
async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]:
"""Make a mini connection to asses server status and version"""
await self.dispatcher.connect(
host or self.host,
port or self.port,
)
#Handshake
self.host = host or self.host
self.port = port or self.port
try:
await self.dispatcher.connect(self.host, self.port)
await self._handshake(ConnectionState.STATUS)
return await self._status(ping)
finally:
await self.dispatcher.disconnect()
async def join(self, host:str="", port:int=0, proto:int=0, packet_whitelist:Optional[Set[Type[Packet]]]=None): # jank packet_whitelist argument! TODO
self.host = host or self.host
self.port = port or self.port
if self.online_mode:
await self.authenticate()
try:
await self.dispatcher.connect(
host=self.host,
port=self.port,
proto=proto,
queue_timeout=self.options.poll_interval,
packet_whitelist=packet_whitelist
)
await self._handshake(ConnectionState.LOGIN)
if await self._login():
await self._play()
finally:
await self.dispatcher.disconnect()
async def _handshake(self, state:ConnectionState):
await self.dispatcher.write(
PacketSetProtocol(
self.dispatcher.proto,
protocolVersion=self.dispatcher.proto,
serverHost=host or self.host,
serverPort=port or self.port,
nextState=ConnectionState.STATUS.value,
serverHost=self.host,
serverPort=self.port,
nextState=state.value
)
)
async def _status(self, ping:bool=False) -> Dict[str, Any]:
self.dispatcher.state = ConnectionState.STATUS
#Request
await self.dispatcher.write(
PacketPingStart(self.dispatcher.proto) #empty packet
)
@ -215,80 +176,16 @@ class MinecraftClient(CallbacksHolder, Runnable):
if packet.time == ping_id:
data['ping'] = int(time() - ping_time)
break
await self.dispatcher.disconnect()
return data
async def _client_worker(self):
try:
self._logger.info("Pinging server")
server_data = await self.info()
self._logger.info(
"Connecting to: %s (%d/%d)",
server_data['version']['name'],
server_data['players']['online'],
server_data['players']['max']
)
except Exception:
self._logger.exception("Exception checking server stats")
return
while self._processing:
if self.online_mode:
try:
await self.authenticate()
except AuthException as e:
self._logger.error(str(e))
break
except Exception as e:
self._logger.exception("Unexpected error while authenticating")
break
try:
packet_whitelist = self.callback_keys(filter=Packet) if self.options.use_packet_whitelist else set()
await self.dispatcher.connect(
host=self.host,
port=self.port,
proto=server_data['version']['protocol'],
queue_timeout=self.options.poll_interval,
packet_whitelist=packet_whitelist
)
self.dispatcher.proto = server_data['version']['protocol'] # TODO maybe check if it's supported?
await self._handshake()
if await self._login():
await self._play()
except ConnectionRefusedError:
self._logger.error("Server rejected connection")
except OSError as e:
self._logger.error("Connection error : %s", str(e))
except Exception:
self._logger.exception("Exception in Client connection")
if self.dispatcher.connected:
await self.dispatcher.disconnect()
if not self.options.reconnect:
break
if self._processing: # if client was stopped exit immediately
await asyncio.sleep(self.options.reconnect_delay)
if self._processing:
await self.stop(force=True)
async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first
await self.dispatcher.write(
PacketSetProtocol(
self.dispatcher.proto,
protocolVersion=self.dispatcher.proto,
serverHost=self.host,
serverPort=self.port,
nextState=2, # play
)
)
async def _login(self) -> bool:
self.dispatcher.state = ConnectionState.LOGIN
await self.dispatcher.write(
PacketLoginStart(
self.dispatcher.proto,
username=self._authenticator.selectedProfile.name if self.online_mode else self._username
)
)
return True
async def _login(self) -> bool:
self.dispatcher.state = ConnectionState.LOGIN
async for packet in self.dispatcher.packets():
if isinstance(packet, PacketEncryptionBegin):
if not self.online_mode:
@ -338,9 +235,8 @@ class MinecraftClient(CallbacksHolder, Runnable):
return False
return False
async def _play(self) -> bool:
async def _play(self):
self.dispatcher.state = ConnectionState.PLAY
self.run_callbacks(ClientEvent.CONNECTED)
async for packet in self.dispatcher.packets():
self._logger.debug("[ * ] Processing %s", packet.__class__.__name__)
if isinstance(packet, PacketSetCompression):
@ -353,7 +249,3 @@ class MinecraftClient(CallbacksHolder, Runnable):
elif isinstance(packet, PacketKickDisconnect):
self._logger.error("Kicked while in game : %s", helpers.parse_chat(packet.reason))
break
self.run_callbacks(Packet, packet)
self.run_callbacks(type(packet), packet)
self.run_callbacks(ClientEvent.DISCONNECTED)
return False

View file

@ -5,7 +5,7 @@ import zlib
import logging
from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum
from typing import List, Dict, Set, Optional, AsyncIterator, Type
from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
from cryptography.hazmat.primitives.ciphers import CipherContext
@ -39,7 +39,7 @@ class Dispatcher:
_incoming : Queue
_outgoing : Queue
_packet_whitelist : Set[Packet]
_packet_whitelist : Set[Type[Packet]]
_packet_id_whitelist : Set[int]
host : str
@ -95,15 +95,15 @@ class Dispatcher:
host:Optional[str] = None,
port:Optional[int] = None,
proto:Optional[int] = None,
queue_timeout:int = 1,
queue_timeout:float = 1,
queue_size:int = 100,
packet_whitelist : List[Packet] = None
packet_whitelist : Set[Type[Packet]] = None
):
self.proto = proto or self.proto or 757 # TODO not hardcode this?
self.host = host or self.host or "localhost"
self.port = port or self.port or 25565
self._logger = LOGGER.getChild(f"on({self.host}:{self.port})")
self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set()
self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() # just in case make new set
if self._packet_whitelist:
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
@ -128,9 +128,9 @@ class Dispatcher:
proto : Optional[int] = None,
reader : Optional[StreamReader] = None,
writer : Optional[StreamWriter] = None,
queue_timeout : int = 1,
queue_timeout : float = 1,
queue_size : int = 100,
packet_whitelist : Set[Packet] = None,
packet_whitelist : Set[Type[Packet]] = None,
):
if self.connected:
raise InvalidState("Dispatcher already connected")
@ -165,33 +165,34 @@ class Dispatcher:
if block:
await self._up.wait_closed()
self._logger.debug("Socket closed")
self._logger.info("Disconnected")
if block:
self._logger.info("Disconnected")
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
# TODO de-jank this, language server gets kinda mad
reg = None
# m : Module
if self.state == ConnectionState.HANDSHAKING:
reg = minecraft_protocol.handshaking
m = minecraft_protocol.handshaking
elif self.state == ConnectionState.STATUS:
reg = minecraft_protocol.status
m = minecraft_protocol.status
elif self.state == ConnectionState.LOGIN:
reg = minecraft_protocol.login
m = minecraft_protocol.login
elif self.state == ConnectionState.PLAY:
reg = minecraft_protocol.play
m = minecraft_protocol.play
else:
raise InvalidState("Cannot access registries from invalid state")
if self.is_server:
reg = reg.serverbound.REGISTRY
reg = m.serverbound.REGISTRY
else:
reg = reg.clientbound.REGISTRY
reg = m.clientbound.REGISTRY
if not self.proto:
raise InvalidState("Cannot access registries from invalid protocol")
reg = reg[self.proto]
proto_reg = reg[self.proto]
return reg[packet_id]
return proto_reg[packet_id]
async def _read_varint(self) -> int:
numRead = 0

View file

@ -1,2 +0,0 @@
from .callbacks import CallbacksHolder
from .runnable import Runnable

View file

@ -1,56 +0,0 @@
import asyncio
import uuid
import logging
from inspect import isclass
from typing import Dict, List, Set, Any, Callable, Type
class CallbacksHolder:
_callbacks : Dict[Any, List[Callable]]
_tasks : Dict[uuid.UUID, asyncio.Event]
_logger : logging.Logger
def __init__(self):
super().__init__()
self._callbacks = {}
self._tasks = {}
def callback_keys(self, filter:Type = None) -> Set[Any]:
return set(x for x in self._callbacks.keys() if not filter or (isclass(x) and issubclass(x, filter)))
def register(self, key:Any, callback:Callable):
if key not in self._callbacks:
self._callbacks[key] = []
self._callbacks[key].append(callback)
return callback
def trigger(self, key:Any) -> List[Callable]:
if key not in self._callbacks:
return []
return self._callbacks[key]
def _wrap(self, cb:Callable, uid:uuid.UUID) -> Callable:
async def wrapper(*args):
try:
ret = await cb(*args)
except Exception:
logging.exception("Exception processing callback")
ret = None
self._tasks[uid].set()
self._tasks.pop(uid)
return ret
return wrapper
def run_callbacks(self, key:Any, *args) -> None:
for cb in self.trigger(key):
task_id = uuid.uuid4()
self._tasks[task_id] = asyncio.Event()
asyncio.get_event_loop().create_task(self._wrap(cb, task_id)(*args))
async def join_callbacks(self):
await asyncio.gather(*list(t.wait() for t in self._tasks.values()))
self._tasks.clear()

View file

@ -1,45 +0,0 @@
import asyncio
import logging
from typing import Optional
from signal import signal, SIGINT, SIGTERM, SIGABRT
class Runnable:
_is_running : bool
_stop_task : Optional[asyncio.Task]
def __init__(self):
self._is_running = False
self._stop_task = None
async def start(self):
self._is_running = True
async def stop(self, force:bool=False):
self._is_running = False
def run(self):
logging.info("Starting process")
def signal_handler(signum, __):
if signum == SIGINT:
if self._stop_task:
self._stop_task.cancel()
logging.info("Received SIGINT, terminating")
else:
logging.info("Received SIGINT, stopping gracefully...")
self._stop_task = asyncio.get_event_loop().create_task(self.stop(force=self._stop_task is not None))
signal(SIGINT, signal_handler)
loop = asyncio.get_event_loop()
async def main():
await self.start()
while self._is_running:
await asyncio.sleep(1)
loop.run_until_complete(main())
logging.info("Process finished")