moved features out of aiocraft, cleaned up client states routines

This commit is contained in:
əlemi 2022-03-08 01:39:03 +01:00
parent 34def8b6cc
commit d3587f65ae
No known key found for this signature in database
GPG key ID: BBCBFE5D7244634E
5 changed files with 59 additions and 269 deletions

View file

@ -8,10 +8,9 @@ from asyncio import Task
from enum import Enum from enum import Enum
from time import time from time import time
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any, Set
from .dispatcher import Dispatcher from .dispatcher import Dispatcher
from .traits import CallbacksHolder, Runnable
from .mc.packet import Packet from .mc.packet import Packet
from .mc.auth import AuthInterface, AuthException, MojangToken, MicrosoftAuthenticator from .mc.auth import AuthInterface, AuthException, MojangToken, MicrosoftAuthenticator
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
@ -36,11 +35,7 @@ class ClientOptions:
poll_interval : float = 1.0 poll_interval : float = 1.0
use_packet_whitelist : bool = True use_packet_whitelist : bool = True
class ClientEvent(Enum): class MinecraftClient:
CONNECTED = 0
DISCONNECTED = 1
class MinecraftClient(CallbacksHolder, Runnable):
host:str host:str
port:int port:int
options:ClientOptions options:ClientOptions
@ -53,7 +48,6 @@ class MinecraftClient(CallbacksHolder, Runnable):
_processing : bool _processing : bool
_authenticated : bool _authenticated : bool
_worker : Task _worker : Task
_callbacks = Dict[str, Task]
_logger : logging.Logger _logger : logging.Logger
@ -90,35 +84,13 @@ class MinecraftClient(CallbacksHolder, Runnable):
self._logger = LOGGER.getChild(f"on({self.host}:{self.port})") self._logger = LOGGER.getChild(f"on({self.host}:{self.port})")
@property
def started(self) -> bool:
return self._processing
@property @property
def connected(self) -> bool: def connected(self) -> bool:
return self.started and self.dispatcher.connected return self.dispatcher.connected
async def write(self, packet:Packet, wait:bool=False): async def write(self, packet:Packet, wait:bool=False):
await self.dispatcher.write(packet, wait) await self.dispatcher.write(packet, wait)
def on_connected(self) -> Callable:
def wrapper(fun):
self.register(ClientEvent.CONNECTED, fun)
return fun
return wrapper
def on_disconnected(self) -> Callable:
def wrapper(fun):
self.register(ClientEvent.DISCONNECTED, fun)
return fun
return wrapper
def on_packet(self, packet:Type[Packet]) -> Callable:
def wrapper(fun):
self.register(packet, fun)
return fun
return wrapper
async def authenticate(self): async def authenticate(self):
if self._authenticated: if self._authenticated:
return # Don't spam Auth endpoint! return # Don't spam Auth endpoint!
@ -136,60 +108,49 @@ class MinecraftClient(CallbacksHolder, Runnable):
raise ValueError("No refreshable auth or code to login") raise ValueError("No refreshable auth or code to login")
self._authenticated = True self._authenticated = True
async def change_server(self, server:str): async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]:
restart = self.started
if restart:
await self.stop()
if ":" in server:
_host, _port = server.split(":", 1)
self.host = _host.strip()
self.port = int(_port)
else:
self.host = server.strip()
self.port = 25565
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
if restart:
await self.start()
async def start(self):
await super().start()
if self.started:
return
self._processing = True
self._worker = asyncio.get_event_loop().create_task(self._client_worker())
self._logger.info("Minecraft client started")
async def stop(self, force:bool=False):
self._processing = False
if self.dispatcher.connected:
await self.dispatcher.disconnect(block=not force)
if not force:
await self._worker
self._logger.info("Minecraft client stopped")
if not force:
await self.join_callbacks()
await super().stop(force)
async def info(self, host:str="", port:int=0, ping:bool=False) -> Dict[str, Any]:
"""Make a mini connection to asses server status and version""" """Make a mini connection to asses server status and version"""
self.host = host or self.host
self.port = port or self.port
try:
await self.dispatcher.connect(self.host, self.port)
await self._handshake(ConnectionState.STATUS)
return await self._status(ping)
finally:
await self.dispatcher.disconnect()
async def join(self, host:str="", port:int=0, proto:int=0, packet_whitelist:Optional[Set[Type[Packet]]]=None): # jank packet_whitelist argument! TODO
self.host = host or self.host
self.port = port or self.port
if self.online_mode:
await self.authenticate()
try:
await self.dispatcher.connect( await self.dispatcher.connect(
host or self.host, host=self.host,
port or self.port, port=self.port,
proto=proto,
queue_timeout=self.options.poll_interval,
packet_whitelist=packet_whitelist
) )
#Handshake await self._handshake(ConnectionState.LOGIN)
if await self._login():
await self._play()
finally:
await self.dispatcher.disconnect()
async def _handshake(self, state:ConnectionState):
await self.dispatcher.write( await self.dispatcher.write(
PacketSetProtocol( PacketSetProtocol(
self.dispatcher.proto, self.dispatcher.proto,
protocolVersion=self.dispatcher.proto, protocolVersion=self.dispatcher.proto,
serverHost=host or self.host, serverHost=self.host,
serverPort=port or self.port, serverPort=self.port,
nextState=ConnectionState.STATUS.value, nextState=state.value
) )
) )
async def _status(self, ping:bool=False) -> Dict[str, Any]:
self.dispatcher.state = ConnectionState.STATUS self.dispatcher.state = ConnectionState.STATUS
#Request
await self.dispatcher.write( await self.dispatcher.write(
PacketPingStart(self.dispatcher.proto) #empty packet PacketPingStart(self.dispatcher.proto) #empty packet
) )
@ -215,80 +176,16 @@ class MinecraftClient(CallbacksHolder, Runnable):
if packet.time == ping_id: if packet.time == ping_id:
data['ping'] = int(time() - ping_time) data['ping'] = int(time() - ping_time)
break break
await self.dispatcher.disconnect()
return data return data
async def _client_worker(self): async def _login(self) -> bool:
try: self.dispatcher.state = ConnectionState.LOGIN
self._logger.info("Pinging server")
server_data = await self.info()
self._logger.info(
"Connecting to: %s (%d/%d)",
server_data['version']['name'],
server_data['players']['online'],
server_data['players']['max']
)
except Exception:
self._logger.exception("Exception checking server stats")
return
while self._processing:
if self.online_mode:
try:
await self.authenticate()
except AuthException as e:
self._logger.error(str(e))
break
except Exception as e:
self._logger.exception("Unexpected error while authenticating")
break
try:
packet_whitelist = self.callback_keys(filter=Packet) if self.options.use_packet_whitelist else set()
await self.dispatcher.connect(
host=self.host,
port=self.port,
proto=server_data['version']['protocol'],
queue_timeout=self.options.poll_interval,
packet_whitelist=packet_whitelist
)
self.dispatcher.proto = server_data['version']['protocol'] # TODO maybe check if it's supported?
await self._handshake()
if await self._login():
await self._play()
except ConnectionRefusedError:
self._logger.error("Server rejected connection")
except OSError as e:
self._logger.error("Connection error : %s", str(e))
except Exception:
self._logger.exception("Exception in Client connection")
if self.dispatcher.connected:
await self.dispatcher.disconnect()
if not self.options.reconnect:
break
if self._processing: # if client was stopped exit immediately
await asyncio.sleep(self.options.reconnect_delay)
if self._processing:
await self.stop(force=True)
async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first
await self.dispatcher.write(
PacketSetProtocol(
self.dispatcher.proto,
protocolVersion=self.dispatcher.proto,
serverHost=self.host,
serverPort=self.port,
nextState=2, # play
)
)
await self.dispatcher.write( await self.dispatcher.write(
PacketLoginStart( PacketLoginStart(
self.dispatcher.proto, self.dispatcher.proto,
username=self._authenticator.selectedProfile.name if self.online_mode else self._username username=self._authenticator.selectedProfile.name if self.online_mode else self._username
) )
) )
return True
async def _login(self) -> bool:
self.dispatcher.state = ConnectionState.LOGIN
async for packet in self.dispatcher.packets(): async for packet in self.dispatcher.packets():
if isinstance(packet, PacketEncryptionBegin): if isinstance(packet, PacketEncryptionBegin):
if not self.online_mode: if not self.online_mode:
@ -338,9 +235,8 @@ class MinecraftClient(CallbacksHolder, Runnable):
return False return False
return False return False
async def _play(self) -> bool: async def _play(self):
self.dispatcher.state = ConnectionState.PLAY self.dispatcher.state = ConnectionState.PLAY
self.run_callbacks(ClientEvent.CONNECTED)
async for packet in self.dispatcher.packets(): async for packet in self.dispatcher.packets():
self._logger.debug("[ * ] Processing %s", packet.__class__.__name__) self._logger.debug("[ * ] Processing %s", packet.__class__.__name__)
if isinstance(packet, PacketSetCompression): if isinstance(packet, PacketSetCompression):
@ -353,7 +249,3 @@ class MinecraftClient(CallbacksHolder, Runnable):
elif isinstance(packet, PacketKickDisconnect): elif isinstance(packet, PacketKickDisconnect):
self._logger.error("Kicked while in game : %s", helpers.parse_chat(packet.reason)) self._logger.error("Kicked while in game : %s", helpers.parse_chat(packet.reason))
break break
self.run_callbacks(Packet, packet)
self.run_callbacks(type(packet), packet)
self.run_callbacks(ClientEvent.DISCONNECTED)
return False

View file

@ -5,7 +5,7 @@ import zlib
import logging import logging
from asyncio import StreamReader, StreamWriter, Queue, Task from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum from enum import Enum
from typing import List, Dict, Set, Optional, AsyncIterator, Type from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
from cryptography.hazmat.primitives.ciphers import CipherContext from cryptography.hazmat.primitives.ciphers import CipherContext
@ -39,7 +39,7 @@ class Dispatcher:
_incoming : Queue _incoming : Queue
_outgoing : Queue _outgoing : Queue
_packet_whitelist : Set[Packet] _packet_whitelist : Set[Type[Packet]]
_packet_id_whitelist : Set[int] _packet_id_whitelist : Set[int]
host : str host : str
@ -95,15 +95,15 @@ class Dispatcher:
host:Optional[str] = None, host:Optional[str] = None,
port:Optional[int] = None, port:Optional[int] = None,
proto:Optional[int] = None, proto:Optional[int] = None,
queue_timeout:int = 1, queue_timeout:float = 1,
queue_size:int = 100, queue_size:int = 100,
packet_whitelist : List[Packet] = None packet_whitelist : Set[Type[Packet]] = None
): ):
self.proto = proto or self.proto or 757 # TODO not hardcode this? self.proto = proto or self.proto or 757 # TODO not hardcode this?
self.host = host or self.host or "localhost" self.host = host or self.host or "localhost"
self.port = port or self.port or 25565 self.port = port or self.port or 25565
self._logger = LOGGER.getChild(f"on({self.host}:{self.port})") self._logger = LOGGER.getChild(f"on({self.host}:{self.port})")
self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() # just in case make new set
if self._packet_whitelist: if self._packet_whitelist:
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive) self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect) self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
@ -128,9 +128,9 @@ class Dispatcher:
proto : Optional[int] = None, proto : Optional[int] = None,
reader : Optional[StreamReader] = None, reader : Optional[StreamReader] = None,
writer : Optional[StreamWriter] = None, writer : Optional[StreamWriter] = None,
queue_timeout : int = 1, queue_timeout : float = 1,
queue_size : int = 100, queue_size : int = 100,
packet_whitelist : Set[Packet] = None, packet_whitelist : Set[Type[Packet]] = None,
): ):
if self.connected: if self.connected:
raise InvalidState("Dispatcher already connected") raise InvalidState("Dispatcher already connected")
@ -165,33 +165,34 @@ class Dispatcher:
if block: if block:
await self._up.wait_closed() await self._up.wait_closed()
self._logger.debug("Socket closed") self._logger.debug("Socket closed")
if block:
self._logger.info("Disconnected") self._logger.info("Disconnected")
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]: def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
# TODO de-jank this, language server gets kinda mad # TODO de-jank this, language server gets kinda mad
reg = None # m : Module
if self.state == ConnectionState.HANDSHAKING: if self.state == ConnectionState.HANDSHAKING:
reg = minecraft_protocol.handshaking m = minecraft_protocol.handshaking
elif self.state == ConnectionState.STATUS: elif self.state == ConnectionState.STATUS:
reg = minecraft_protocol.status m = minecraft_protocol.status
elif self.state == ConnectionState.LOGIN: elif self.state == ConnectionState.LOGIN:
reg = minecraft_protocol.login m = minecraft_protocol.login
elif self.state == ConnectionState.PLAY: elif self.state == ConnectionState.PLAY:
reg = minecraft_protocol.play m = minecraft_protocol.play
else: else:
raise InvalidState("Cannot access registries from invalid state") raise InvalidState("Cannot access registries from invalid state")
if self.is_server: if self.is_server:
reg = reg.serverbound.REGISTRY reg = m.serverbound.REGISTRY
else: else:
reg = reg.clientbound.REGISTRY reg = m.clientbound.REGISTRY
if not self.proto: if not self.proto:
raise InvalidState("Cannot access registries from invalid protocol") raise InvalidState("Cannot access registries from invalid protocol")
reg = reg[self.proto] proto_reg = reg[self.proto]
return reg[packet_id] return proto_reg[packet_id]
async def _read_varint(self) -> int: async def _read_varint(self) -> int:
numRead = 0 numRead = 0

View file

@ -1,2 +0,0 @@
from .callbacks import CallbacksHolder
from .runnable import Runnable

View file

@ -1,56 +0,0 @@
import asyncio
import uuid
import logging
from inspect import isclass
from typing import Dict, List, Set, Any, Callable, Type
class CallbacksHolder:
_callbacks : Dict[Any, List[Callable]]
_tasks : Dict[uuid.UUID, asyncio.Event]
_logger : logging.Logger
def __init__(self):
super().__init__()
self._callbacks = {}
self._tasks = {}
def callback_keys(self, filter:Type = None) -> Set[Any]:
return set(x for x in self._callbacks.keys() if not filter or (isclass(x) and issubclass(x, filter)))
def register(self, key:Any, callback:Callable):
if key not in self._callbacks:
self._callbacks[key] = []
self._callbacks[key].append(callback)
return callback
def trigger(self, key:Any) -> List[Callable]:
if key not in self._callbacks:
return []
return self._callbacks[key]
def _wrap(self, cb:Callable, uid:uuid.UUID) -> Callable:
async def wrapper(*args):
try:
ret = await cb(*args)
except Exception:
logging.exception("Exception processing callback")
ret = None
self._tasks[uid].set()
self._tasks.pop(uid)
return ret
return wrapper
def run_callbacks(self, key:Any, *args) -> None:
for cb in self.trigger(key):
task_id = uuid.uuid4()
self._tasks[task_id] = asyncio.Event()
asyncio.get_event_loop().create_task(self._wrap(cb, task_id)(*args))
async def join_callbacks(self):
await asyncio.gather(*list(t.wait() for t in self._tasks.values()))
self._tasks.clear()

View file

@ -1,45 +0,0 @@
import asyncio
import logging
from typing import Optional
from signal import signal, SIGINT, SIGTERM, SIGABRT
class Runnable:
_is_running : bool
_stop_task : Optional[asyncio.Task]
def __init__(self):
self._is_running = False
self._stop_task = None
async def start(self):
self._is_running = True
async def stop(self, force:bool=False):
self._is_running = False
def run(self):
logging.info("Starting process")
def signal_handler(signum, __):
if signum == SIGINT:
if self._stop_task:
self._stop_task.cancel()
logging.info("Received SIGINT, terminating")
else:
logging.info("Received SIGINT, stopping gracefully...")
self._stop_task = asyncio.get_event_loop().create_task(self.stop(force=self._stop_task is not None))
signal(SIGINT, signal_handler)
loop = asyncio.get_event_loop()
async def main():
await self.start()
while self._is_running:
await asyncio.sleep(1)
loop.run_until_complete(main())
logging.info("Process finished")