moved features out of aiocraft, cleaned up client states routines
This commit is contained in:
parent
34def8b6cc
commit
d3587f65ae
5 changed files with 59 additions and 269 deletions
|
@ -8,10 +8,9 @@ from asyncio import Task
|
|||
from enum import Enum
|
||||
from time import time
|
||||
|
||||
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any
|
||||
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator, Any, Set
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from .traits import CallbacksHolder, Runnable
|
||||
from .mc.packet import Packet
|
||||
from .mc.auth import AuthInterface, AuthException, MojangToken, MicrosoftAuthenticator
|
||||
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
|
||||
|
@ -36,11 +35,7 @@ class ClientOptions:
|
|||
poll_interval : float = 1.0
|
||||
use_packet_whitelist : bool = True
|
||||
|
||||
class ClientEvent(Enum):
|
||||
CONNECTED = 0
|
||||
DISCONNECTED = 1
|
||||
|
||||
class MinecraftClient(CallbacksHolder, Runnable):
|
||||
class MinecraftClient:
|
||||
host:str
|
||||
port:int
|
||||
options:ClientOptions
|
||||
|
@ -53,7 +48,6 @@ class MinecraftClient(CallbacksHolder, Runnable):
|
|||
_processing : bool
|
||||
_authenticated : bool
|
||||
_worker : Task
|
||||
_callbacks = Dict[str, Task]
|
||||
|
||||
_logger : logging.Logger
|
||||
|
||||
|
@ -90,35 +84,13 @@ class MinecraftClient(CallbacksHolder, Runnable):
|
|||
|
||||
self._logger = LOGGER.getChild(f"on({self.host}:{self.port})")
|
||||
|
||||
@property
|
||||
def started(self) -> bool:
|
||||
return self._processing
|
||||
|
||||
@property
|
||||
def connected(self) -> bool:
|
||||
return self.started and self.dispatcher.connected
|
||||
return self.dispatcher.connected
|
||||
|
||||
async def write(self, packet:Packet, wait:bool=False):
|
||||
await self.dispatcher.write(packet, wait)
|
||||
|
||||
def on_connected(self) -> Callable:
|
||||
def wrapper(fun):
|
||||
self.register(ClientEvent.CONNECTED, fun)
|
||||
return fun
|
||||
return wrapper
|
||||
|
||||
def on_disconnected(self) -> Callable:
|
||||
def wrapper(fun):
|
||||
self.register(ClientEvent.DISCONNECTED, fun)
|
||||
return fun
|
||||
return wrapper
|
||||
|
||||
def on_packet(self, packet:Type[Packet]) -> Callable:
|
||||
def wrapper(fun):
|
||||
self.register(packet, fun)
|
||||
return fun
|
||||
return wrapper
|
||||
|
||||
async def authenticate(self):
|
||||
if self._authenticated:
|
||||
return # Don't spam Auth endpoint!
|
||||
|
@ -136,60 +108,49 @@ class MinecraftClient(CallbacksHolder, Runnable):
|
|||
raise ValueError("No refreshable auth or code to login")
|
||||
self._authenticated = True
|
||||
|
||||
async def change_server(self, server:str):
|
||||
restart = self.started
|
||||
if restart:
|
||||
await self.stop()
|
||||
|
||||
if ":" in server:
|
||||
_host, _port = server.split(":", 1)
|
||||
self.host = _host.strip()
|
||||
self.port = int(_port)
|
||||
else:
|
||||
self.host = server.strip()
|
||||
self.port = 25565
|
||||
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
|
||||
|
||||
if restart:
|
||||
await self.start()
|
||||
|
||||
async def start(self):
|
||||
await super().start()
|
||||
if self.started:
|
||||
return
|
||||
self._processing = True
|
||||
self._worker = asyncio.get_event_loop().create_task(self._client_worker())
|
||||
self._logger.info("Minecraft client started")
|
||||
|
||||
async def stop(self, force:bool=False):
|
||||
self._processing = False
|
||||
if self.dispatcher.connected:
|
||||
await self.dispatcher.disconnect(block=not force)
|
||||
if not force:
|
||||
await self._worker
|
||||
self._logger.info("Minecraft client stopped")
|
||||
if not force:
|
||||
await self.join_callbacks()
|
||||
await super().stop(force)
|
||||
|
||||
async def info(self, host:str="", port:int=0, ping:bool=False) -> Dict[str, Any]:
|
||||
async def info(self, host:str="", port:int=0, proto:int=0, ping:bool=False) -> Dict[str, Any]:
|
||||
"""Make a mini connection to asses server status and version"""
|
||||
await self.dispatcher.connect(
|
||||
host or self.host,
|
||||
port or self.port,
|
||||
)
|
||||
#Handshake
|
||||
self.host = host or self.host
|
||||
self.port = port or self.port
|
||||
try:
|
||||
await self.dispatcher.connect(self.host, self.port)
|
||||
await self._handshake(ConnectionState.STATUS)
|
||||
return await self._status(ping)
|
||||
finally:
|
||||
await self.dispatcher.disconnect()
|
||||
|
||||
async def join(self, host:str="", port:int=0, proto:int=0, packet_whitelist:Optional[Set[Type[Packet]]]=None): # jank packet_whitelist argument! TODO
|
||||
self.host = host or self.host
|
||||
self.port = port or self.port
|
||||
if self.online_mode:
|
||||
await self.authenticate()
|
||||
try:
|
||||
await self.dispatcher.connect(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
proto=proto,
|
||||
queue_timeout=self.options.poll_interval,
|
||||
packet_whitelist=packet_whitelist
|
||||
)
|
||||
await self._handshake(ConnectionState.LOGIN)
|
||||
if await self._login():
|
||||
await self._play()
|
||||
finally:
|
||||
await self.dispatcher.disconnect()
|
||||
|
||||
async def _handshake(self, state:ConnectionState):
|
||||
await self.dispatcher.write(
|
||||
PacketSetProtocol(
|
||||
self.dispatcher.proto,
|
||||
protocolVersion=self.dispatcher.proto,
|
||||
serverHost=host or self.host,
|
||||
serverPort=port or self.port,
|
||||
nextState=ConnectionState.STATUS.value,
|
||||
serverHost=self.host,
|
||||
serverPort=self.port,
|
||||
nextState=state.value
|
||||
)
|
||||
)
|
||||
|
||||
async def _status(self, ping:bool=False) -> Dict[str, Any]:
|
||||
self.dispatcher.state = ConnectionState.STATUS
|
||||
#Request
|
||||
await self.dispatcher.write(
|
||||
PacketPingStart(self.dispatcher.proto) #empty packet
|
||||
)
|
||||
|
@ -215,80 +176,16 @@ class MinecraftClient(CallbacksHolder, Runnable):
|
|||
if packet.time == ping_id:
|
||||
data['ping'] = int(time() - ping_time)
|
||||
break
|
||||
await self.dispatcher.disconnect()
|
||||
return data
|
||||
|
||||
async def _client_worker(self):
|
||||
try:
|
||||
self._logger.info("Pinging server")
|
||||
server_data = await self.info()
|
||||
self._logger.info(
|
||||
"Connecting to: %s (%d/%d)",
|
||||
server_data['version']['name'],
|
||||
server_data['players']['online'],
|
||||
server_data['players']['max']
|
||||
)
|
||||
except Exception:
|
||||
self._logger.exception("Exception checking server stats")
|
||||
return
|
||||
while self._processing:
|
||||
if self.online_mode:
|
||||
try:
|
||||
await self.authenticate()
|
||||
except AuthException as e:
|
||||
self._logger.error(str(e))
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.exception("Unexpected error while authenticating")
|
||||
break
|
||||
try:
|
||||
packet_whitelist = self.callback_keys(filter=Packet) if self.options.use_packet_whitelist else set()
|
||||
await self.dispatcher.connect(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
proto=server_data['version']['protocol'],
|
||||
queue_timeout=self.options.poll_interval,
|
||||
packet_whitelist=packet_whitelist
|
||||
)
|
||||
self.dispatcher.proto = server_data['version']['protocol'] # TODO maybe check if it's supported?
|
||||
await self._handshake()
|
||||
if await self._login():
|
||||
await self._play()
|
||||
except ConnectionRefusedError:
|
||||
self._logger.error("Server rejected connection")
|
||||
except OSError as e:
|
||||
self._logger.error("Connection error : %s", str(e))
|
||||
except Exception:
|
||||
self._logger.exception("Exception in Client connection")
|
||||
if self.dispatcher.connected:
|
||||
await self.dispatcher.disconnect()
|
||||
if not self.options.reconnect:
|
||||
break
|
||||
if self._processing: # if client was stopped exit immediately
|
||||
await asyncio.sleep(self.options.reconnect_delay)
|
||||
if self._processing:
|
||||
await self.stop(force=True)
|
||||
|
||||
async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first
|
||||
await self.dispatcher.write(
|
||||
PacketSetProtocol(
|
||||
self.dispatcher.proto,
|
||||
protocolVersion=self.dispatcher.proto,
|
||||
serverHost=self.host,
|
||||
serverPort=self.port,
|
||||
nextState=2, # play
|
||||
)
|
||||
)
|
||||
async def _login(self) -> bool:
|
||||
self.dispatcher.state = ConnectionState.LOGIN
|
||||
await self.dispatcher.write(
|
||||
PacketLoginStart(
|
||||
self.dispatcher.proto,
|
||||
username=self._authenticator.selectedProfile.name if self.online_mode else self._username
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
async def _login(self) -> bool:
|
||||
self.dispatcher.state = ConnectionState.LOGIN
|
||||
async for packet in self.dispatcher.packets():
|
||||
if isinstance(packet, PacketEncryptionBegin):
|
||||
if not self.online_mode:
|
||||
|
@ -338,9 +235,8 @@ class MinecraftClient(CallbacksHolder, Runnable):
|
|||
return False
|
||||
return False
|
||||
|
||||
async def _play(self) -> bool:
|
||||
async def _play(self):
|
||||
self.dispatcher.state = ConnectionState.PLAY
|
||||
self.run_callbacks(ClientEvent.CONNECTED)
|
||||
async for packet in self.dispatcher.packets():
|
||||
self._logger.debug("[ * ] Processing %s", packet.__class__.__name__)
|
||||
if isinstance(packet, PacketSetCompression):
|
||||
|
@ -353,7 +249,3 @@ class MinecraftClient(CallbacksHolder, Runnable):
|
|||
elif isinstance(packet, PacketKickDisconnect):
|
||||
self._logger.error("Kicked while in game : %s", helpers.parse_chat(packet.reason))
|
||||
break
|
||||
self.run_callbacks(Packet, packet)
|
||||
self.run_callbacks(type(packet), packet)
|
||||
self.run_callbacks(ClientEvent.DISCONNECTED)
|
||||
return False
|
||||
|
|
|
@ -5,7 +5,7 @@ import zlib
|
|||
import logging
|
||||
from asyncio import StreamReader, StreamWriter, Queue, Task
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Set, Optional, AsyncIterator, Type
|
||||
from typing import List, Dict, Set, Optional, AsyncIterator, Type, Union
|
||||
|
||||
from cryptography.hazmat.primitives.ciphers import CipherContext
|
||||
|
||||
|
@ -39,7 +39,7 @@ class Dispatcher:
|
|||
_incoming : Queue
|
||||
_outgoing : Queue
|
||||
|
||||
_packet_whitelist : Set[Packet]
|
||||
_packet_whitelist : Set[Type[Packet]]
|
||||
_packet_id_whitelist : Set[int]
|
||||
|
||||
host : str
|
||||
|
@ -95,15 +95,15 @@ class Dispatcher:
|
|||
host:Optional[str] = None,
|
||||
port:Optional[int] = None,
|
||||
proto:Optional[int] = None,
|
||||
queue_timeout:int = 1,
|
||||
queue_timeout:float = 1,
|
||||
queue_size:int = 100,
|
||||
packet_whitelist : List[Packet] = None
|
||||
packet_whitelist : Set[Type[Packet]] = None
|
||||
):
|
||||
self.proto = proto or self.proto or 757 # TODO not hardcode this?
|
||||
self.host = host or self.host or "localhost"
|
||||
self.port = port or self.port or 25565
|
||||
self._logger = LOGGER.getChild(f"on({self.host}:{self.port})")
|
||||
self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set()
|
||||
self._packet_whitelist = set(packet_whitelist) if packet_whitelist else set() # just in case make new set
|
||||
if self._packet_whitelist:
|
||||
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKeepAlive)
|
||||
self._packet_whitelist.add(minecraft_protocol.play.clientbound.PacketKickDisconnect)
|
||||
|
@ -128,9 +128,9 @@ class Dispatcher:
|
|||
proto : Optional[int] = None,
|
||||
reader : Optional[StreamReader] = None,
|
||||
writer : Optional[StreamWriter] = None,
|
||||
queue_timeout : int = 1,
|
||||
queue_timeout : float = 1,
|
||||
queue_size : int = 100,
|
||||
packet_whitelist : Set[Packet] = None,
|
||||
packet_whitelist : Set[Type[Packet]] = None,
|
||||
):
|
||||
if self.connected:
|
||||
raise InvalidState("Dispatcher already connected")
|
||||
|
@ -165,33 +165,34 @@ class Dispatcher:
|
|||
if block:
|
||||
await self._up.wait_closed()
|
||||
self._logger.debug("Socket closed")
|
||||
self._logger.info("Disconnected")
|
||||
if block:
|
||||
self._logger.info("Disconnected")
|
||||
|
||||
def _packet_type_from_registry(self, packet_id:int) -> Type[Packet]:
|
||||
# TODO de-jank this, language server gets kinda mad
|
||||
reg = None
|
||||
# m : Module
|
||||
if self.state == ConnectionState.HANDSHAKING:
|
||||
reg = minecraft_protocol.handshaking
|
||||
m = minecraft_protocol.handshaking
|
||||
elif self.state == ConnectionState.STATUS:
|
||||
reg = minecraft_protocol.status
|
||||
m = minecraft_protocol.status
|
||||
elif self.state == ConnectionState.LOGIN:
|
||||
reg = minecraft_protocol.login
|
||||
m = minecraft_protocol.login
|
||||
elif self.state == ConnectionState.PLAY:
|
||||
reg = minecraft_protocol.play
|
||||
m = minecraft_protocol.play
|
||||
else:
|
||||
raise InvalidState("Cannot access registries from invalid state")
|
||||
|
||||
if self.is_server:
|
||||
reg = reg.serverbound.REGISTRY
|
||||
reg = m.serverbound.REGISTRY
|
||||
else:
|
||||
reg = reg.clientbound.REGISTRY
|
||||
reg = m.clientbound.REGISTRY
|
||||
|
||||
if not self.proto:
|
||||
raise InvalidState("Cannot access registries from invalid protocol")
|
||||
|
||||
reg = reg[self.proto]
|
||||
proto_reg = reg[self.proto]
|
||||
|
||||
return reg[packet_id]
|
||||
return proto_reg[packet_id]
|
||||
|
||||
async def _read_varint(self) -> int:
|
||||
numRead = 0
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
from .callbacks import CallbacksHolder
|
||||
from .runnable import Runnable
|
|
@ -1,56 +0,0 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from inspect import isclass
|
||||
from typing import Dict, List, Set, Any, Callable, Type
|
||||
|
||||
class CallbacksHolder:
|
||||
|
||||
_callbacks : Dict[Any, List[Callable]]
|
||||
_tasks : Dict[uuid.UUID, asyncio.Event]
|
||||
|
||||
_logger : logging.Logger
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._callbacks = {}
|
||||
self._tasks = {}
|
||||
|
||||
def callback_keys(self, filter:Type = None) -> Set[Any]:
|
||||
return set(x for x in self._callbacks.keys() if not filter or (isclass(x) and issubclass(x, filter)))
|
||||
|
||||
def register(self, key:Any, callback:Callable):
|
||||
if key not in self._callbacks:
|
||||
self._callbacks[key] = []
|
||||
self._callbacks[key].append(callback)
|
||||
return callback
|
||||
|
||||
def trigger(self, key:Any) -> List[Callable]:
|
||||
if key not in self._callbacks:
|
||||
return []
|
||||
return self._callbacks[key]
|
||||
|
||||
def _wrap(self, cb:Callable, uid:uuid.UUID) -> Callable:
|
||||
async def wrapper(*args):
|
||||
try:
|
||||
ret = await cb(*args)
|
||||
except Exception:
|
||||
logging.exception("Exception processing callback")
|
||||
ret = None
|
||||
self._tasks[uid].set()
|
||||
self._tasks.pop(uid)
|
||||
return ret
|
||||
return wrapper
|
||||
|
||||
def run_callbacks(self, key:Any, *args) -> None:
|
||||
for cb in self.trigger(key):
|
||||
task_id = uuid.uuid4()
|
||||
self._tasks[task_id] = asyncio.Event()
|
||||
|
||||
asyncio.get_event_loop().create_task(self._wrap(cb, task_id)(*args))
|
||||
|
||||
async def join_callbacks(self):
|
||||
await asyncio.gather(*list(t.wait() for t in self._tasks.values()))
|
||||
self._tasks.clear()
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
from typing import Optional
|
||||
from signal import signal, SIGINT, SIGTERM, SIGABRT
|
||||
|
||||
class Runnable:
|
||||
_is_running : bool
|
||||
_stop_task : Optional[asyncio.Task]
|
||||
|
||||
def __init__(self):
|
||||
self._is_running = False
|
||||
self._stop_task = None
|
||||
|
||||
async def start(self):
|
||||
self._is_running = True
|
||||
|
||||
async def stop(self, force:bool=False):
|
||||
self._is_running = False
|
||||
|
||||
def run(self):
|
||||
logging.info("Starting process")
|
||||
|
||||
def signal_handler(signum, __):
|
||||
if signum == SIGINT:
|
||||
if self._stop_task:
|
||||
self._stop_task.cancel()
|
||||
logging.info("Received SIGINT, terminating")
|
||||
else:
|
||||
logging.info("Received SIGINT, stopping gracefully...")
|
||||
self._stop_task = asyncio.get_event_loop().create_task(self.stop(force=self._stop_task is not None))
|
||||
|
||||
signal(SIGINT, signal_handler)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def main():
|
||||
await self.start()
|
||||
while self._is_running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
loop.run_until_complete(main())
|
||||
|
||||
logging.info("Process finished")
|
||||
|
Loading…
Reference in a new issue