moved common code into traits, added callbacks, made it cleaner?

This commit is contained in:
əlemi 2021-11-22 02:27:42 +01:00
parent c2e8e80806
commit 540dbd5d90
6 changed files with 185 additions and 122 deletions

View file

@ -5,19 +5,19 @@ import logging
from .mc.proto.play.clientbound import PacketChat from .mc.proto.play.clientbound import PacketChat
from .mc.token import Token from .mc.token import Token
from .dispatcher import ConnectionState from .dispatcher import ConnectionState
from .client import Client from .client import MinecraftClient
from .server import MinecraftServer from .server import MinecraftServer
from .util.helpers import parse_chat from .util.helpers import parse_chat
async def idle():
while True:
await asyncio.sleep(1)
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
if sys.argv[1] == "--server": if sys.argv[1] == "--server":
serv = MinecraftServer("0.0.0.0", 25565) host = sys.argv[2] if len(sys.argv) > 2 else "localhost"
port = sys.argv[3] if len(sys.argv) > 3 else 25565
serv = MinecraftServer(host, port)
serv.run() # will block and start asyncio event loop serv.run() # will block and start asyncio event loop
else: else:
username = sys.argv[1] username = sys.argv[1]
@ -33,7 +33,7 @@ if __name__ == "__main__":
host = server.strip() host = server.strip()
port = 25565 port = 25565
client = Client(host, port, username=username, password=pwd) client = MinecraftClient(host, port, username=username, password=pwd)
@client.on_packet(PacketChat, ConnectionState.PLAY) @client.on_packet(PacketChat, ConnectionState.PLAY)
async def print_chat(packet: PacketChat): async def print_chat(packet: PacketChat):

View file

@ -2,12 +2,14 @@ import asyncio
import logging import logging
import uuid import uuid
from dataclasses import dataclass
from asyncio import Task from asyncio import Task
from enum import Enum from enum import Enum
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator
from .dispatcher import Dispatcher from .dispatcher import Dispatcher
from .traits import CallbacksHolder, Runnable
from .mc.packet import Packet from .mc.packet import Packet
from .mc.token import Token, AuthException from .mc.token import Token, AuthException
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
@ -22,10 +24,21 @@ from .util import encryption
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
class Client: @dataclass
class ClientOptions:
reconnect : bool
reconnect_delay : float
keep_alive : bool
poll_interval : float
class ClientEvent(Enum):
CONNECTED = 0
DISCONNECTED = 1
class MinecraftClient(CallbacksHolder, Runnable):
host:str host:str
port:int port:int
options:dict options:ClientOptions
username:Optional[str] username:Optional[str]
password:Optional[str] password:Optional[str]
@ -37,27 +50,30 @@ class Client:
_worker : Task _worker : Task
_callbacks = Dict[str, Task] _callbacks = Dict[str, Task]
_packet_callbacks : Dict[Type[Packet], List[Callable]]
_logger : logging.Logger _logger : logging.Logger
def __init__( def __init__(
self, self,
host:str, host:str,
port:int = 25565, port:int = 25565,
options:dict = None,
username:Optional[str] = None, username:Optional[str] = None,
password:Optional[str] = None, password:Optional[str] = None,
token:Optional[Token] = None, token:Optional[Token] = None,
reconnect:bool = True,
reconnect_delay:float = 10.0,
keep_alive:bool = True,
poll_interval:float = 1.0,
): ):
self.host = host self.host = host
self.port = port self.port = port
self.options = options or { self.options = ClientOptions(
"reconnect" : True, reconnect=reconnect,
"rctime" : 5.0, reconnect_delay=reconnect_delay,
"keep-alive" : True, keep_alive=keep_alive,
"poll-timeout" : 1, poll_interval=poll_interval
} )
self.token = token self.token = token
self.username = username self.username = username
@ -67,9 +83,6 @@ class Client:
self._processing = False self._processing = False
self._authenticated = False self._authenticated = False
self._packet_callbacks = {}
self._callbacks = {}
self._logger = LOGGER.getChild(f"{self.host}:{self.port}") self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
@property @property
@ -80,23 +93,21 @@ class Client:
def connected(self) -> bool: def connected(self) -> bool:
return self.started and self.dispatcher.connected return self.started and self.dispatcher.connected
def _run_async(self, func, pkt:Packet): def on_connected(self) -> Callable:
key = str(uuid.uuid4()) # ugly!
async def wrapper(packet:Packet):
try:
await func(packet)
except Exception as e:
self._logger.error("Exception in callback %s for packet %s | %s", func.__name__, packet, str(e))
self._callbacks.pop(key, None)
self._callbacks[key] = asyncio.get_event_loop().create_task(wrapper(pkt))
def on_packet(self, packet:Type[Packet], *args) -> Callable: # receive *args for retro compatibility
def wrapper(fun): def wrapper(fun):
if packet not in self._packet_callbacks: self.register(ClientEvent.CONNECTED, fun)
self._packet_callbacks[packet] = [] return fun
self._packet_callbacks[packet].append(fun) return wrapper
def on_disconnected(self) -> Callable:
def wrapper(fun):
self.register(ClientEvent.DISCONNECTED, fun)
return fun
return wrapper
def on_packet(self, packet:Type[Packet]) -> Callable:
def wrapper(fun):
self.register(packet, fun)
return fun return fun
return wrapper return wrapper
@ -136,39 +147,22 @@ class Client:
if restart: if restart:
await self.start() await self.start()
def run(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(self.start())
async def idle():
while True: # TODO don't busywait even if it doesn't matter much
await asyncio.sleep(self.options["poll-timeout"])
try:
loop.run_until_complete(idle())
except KeyboardInterrupt:
self._logger.info("Received SIGINT, stopping...")
try:
loop.run_until_complete(self.stop())
except KeyboardInterrupt:
self._logger.info("Received SIGINT, stopping for real")
loop.run_until_complete(self.stop(wait_tasks=False))
async def start(self): async def start(self):
if self.started:
return
self._processing = True self._processing = True
self._worker = asyncio.get_event_loop().create_task(self._client_worker()) self._worker = asyncio.get_event_loop().create_task(self._client_worker())
self._logger.info("Minecraft client started") self._logger.info("Minecraft client started")
async def stop(self, block=True, wait_tasks=True): async def stop(self, force:bool=False):
self._processing = False self._processing = False
if self.dispatcher.connected: if self.dispatcher.connected:
await self.dispatcher.disconnect(block=block) await self.dispatcher.disconnect(block=not force)
if block: if not force:
await self._worker await self._worker
self._logger.info("Minecraft client stopped") self._logger.info("Minecraft client stopped")
if block and wait_tasks: if not force:
await asyncio.gather(*list(self._callbacks.values())) await self.join_callbacks()
async def _client_worker(self): async def _client_worker(self):
while self._processing: while self._processing:
@ -188,9 +182,9 @@ class Client:
self._logger.exception("Exception in Client connection") self._logger.exception("Exception in Client connection")
if self.dispatcher.connected: if self.dispatcher.connected:
await self.dispatcher.disconnect() await self.dispatcher.disconnect()
if not self.options["reconnect"]: if not self.options.reconnect:
break break
await asyncio.sleep(self.options["rctime"]) await asyncio.sleep(self.options.reconnect_delay)
await self.stop(block=False) await self.stop(block=False)
async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first
@ -255,20 +249,20 @@ class Client:
async def _play(self) -> bool: async def _play(self) -> bool:
self.dispatcher.state = ConnectionState.PLAY self.dispatcher.state = ConnectionState.PLAY
self.run_callbacks(ClientEvent.CONNECTED)
async for packet in self.dispatcher.packets(): async for packet in self.dispatcher.packets():
self._logger.debug("[ * ] Processing | %s", str(packet)) self._logger.debug("[ * ] Processing | %s", str(packet))
if isinstance(packet, PacketSetCompression): if isinstance(packet, PacketSetCompression):
self._logger.info("Compression updated") self._logger.info("Compression updated")
self.dispatcher.compression = packet.threshold self.dispatcher.compression = packet.threshold
elif isinstance(packet, PacketKeepAlive): elif isinstance(packet, PacketKeepAlive):
if self.options["keep-alive"]: if self.options.keep_alive:
keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId) keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId)
await self.dispatcher.write(keep_alive_packet) await self.dispatcher.write(keep_alive_packet)
elif isinstance(packet, PacketKickDisconnect): elif isinstance(packet, PacketKickDisconnect):
self._logger.error("Kicked while in game") self._logger.error("Kicked while in game")
break break
for packet_type in (Packet, type(packet)): # check both callbacks for base class and instance class self.run_callbacks(Packet, packet)
if packet_type in self._packet_callbacks: self.run_callbacks(type(packet), packet)
for cb in self._packet_callbacks[packet_type]: self.run_callbacks(ClientEvent.DISCONNECTED)
self._run_async(cb, packet)
return False return False

View file

@ -2,6 +2,7 @@ import asyncio
import logging import logging
import uuid import uuid
from dataclasses import dataclass
from asyncio import Task, StreamReader, StreamWriter from asyncio import Task, StreamReader, StreamWriter
from asyncio.base_events import Server # just for typing from asyncio.base_events import Server # just for typing
from enum import Enum from enum import Enum
@ -9,6 +10,7 @@ from enum import Enum
from typing import Dict, List, Callable, Coroutine, Type, Optional, Tuple, AsyncIterator from typing import Dict, List, Callable, Coroutine, Type, Optional, Tuple, AsyncIterator
from .dispatcher import Dispatcher from .dispatcher import Dispatcher
from .traits import CallbacksHolder, Runnable
from .mc.packet import Packet from .mc.packet import Packet
from .mc.token import Token, AuthException from .mc.token import Token, AuthException
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
@ -23,89 +25,80 @@ from .util import encryption
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
class MinecraftServer: @dataclass
class ServerOptions:
online_mode : bool
spawn_player : bool
poll_interval : float
class ServerEvent(Enum):
CLIENT_CONNECTED = 0
CLIENT_DISCONNECTED = 1
class MinecraftServer(CallbacksHolder, Runnable):
host:str host:str
port:int port:int
options:dict options:ServerOptions
_dispatcher_pool : List[Dispatcher] _dispatcher_pool : List[Dispatcher]
_processing : bool _processing : bool
_server : Server _server : Server
_worker : Task _worker : Task
_callbacks = Dict[str, Task]
_logger : logging.Logger _logger : logging.Logger
_disconnect_handlers : List[Callable]
_connect_handlers : List[Callable]
_packet_handlers : List[Callable]
def __init__( def __init__(
self, self,
host:str, host:str,
port:int = 25565, port:int = 25565,
options:dict = None, online_mode:bool = False,
spawn_player:bool = True,
poll_interval:float = 1.0,
): ):
super().__init__()
self.host = host self.host = host
self.port = port self.port = port
self.options = options or { self.options = ServerOptions(
"poll-timeout" : 1, online_mode=online_mode,
"online-mode" : False, spawn_player=spawn_player,
"spawn-player" : True, poll_interval=poll_interval,
} )
self._dispatcher_pool = [] self._dispatcher_pool = []
self._processing = False self._processing = False
self._logger = LOGGER.getChild(f"@({self.host}:{self.port})") self._logger = LOGGER.getChild(f"@({self.host}:{self.port})")
self._disconnect_handlers = []
self._connect_handlers = []
self._packet_handlers = []
@property @property
def started(self) -> bool: def started(self) -> bool:
return self._processing return self._processing
def on_client_connect(self, *args): @property
def connected(self) -> int:
return len(self._dispatcher_pool)
def on_connect(self):
def wrapper(fun): def wrapper(fun):
self._connect_handlers.append(fun) self.register(ServerEvent.CLIENT_CONNECTED, fun)
return fun return fun
return wrapper return wrapper
def on_client_disconnect(self, *args): def on_disconnect(self):
def wrapper(fun): def wrapper(fun):
self._disconnect_handlers.append(fun) self.register(ServerEvent.CLIENT_DISCONNECTED, fun)
return fun return fun
return wrapper return wrapper
def on_client_packet(self, *args): def on_packet(self, packet:Type[Packet]):
def wrapper(fun): def wrapper(fun):
self._packet_handlers.append(fun) self.register(packet, fun)
return fun return fun
return wrapper return wrapper
def run(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(self.start())
async def idle():
while True: # TODO don't busywait even if it doesn't matter much
await asyncio.sleep(self.options["poll-timeout"])
try:
loop.run_until_complete(idle())
except KeyboardInterrupt:
self._logger.info("Received SIGINT, stopping...")
try:
loop.run_until_complete(self.stop())
except KeyboardInterrupt:
self._logger.info("Received SIGINT, stopping for real")
loop.run_until_complete(self.stop(wait_tasks=False))
async def start(self): async def start(self):
if self.started:
return
self._server = await asyncio.start_server( self._server = await asyncio.start_server(
self._server_worker, self.host, self.port self._server_worker, self.host, self.port
) )
@ -114,14 +107,15 @@ class MinecraftServer:
await self._server.start_serving() await self._server.start_serving()
self._logger.info("Minecraft server started") self._logger.info("Minecraft server started")
async def stop(self, block=True, wait_tasks=True): async def stop(self, force:bool = False):
self._processing = False self._processing = False
self._server.close() self._server.close()
await asyncio.gather(*[d.disconnect(block=block) for d in self._dispatcher_pool]) await asyncio.gather(*[d.disconnect(block=not force) for d in self._dispatcher_pool])
if block: if not force:
await self._server.wait_closed() await asyncio.gather(
# if block and wait_tasks: # TODO wait for client workers self._server.wait_closed(),
# await asyncio.gather(*list(self._callbacks.values())) self.join_callbacks(),
)
async def _disconnect_client(self, dispatcher): async def _disconnect_client(self, dispatcher):
if dispatcher.state == ConnectionState.LOGIN: if dispatcher.state == ConnectionState.LOGIN:
@ -178,7 +172,7 @@ class MinecraftServer:
self._logger.info("Logging in player") self._logger.info("Logging in player")
async for packet in dispatcher.packets(): async for packet in dispatcher.packets():
if isinstance(packet, PacketLoginStart): if isinstance(packet, PacketLoginStart):
if self.options["online-mode"]: if self.options.online_mode:
# await dispatcher.write( # await dispatcher.write(
# PacketEncryptionBegin( # PacketEncryptionBegin(
# dispatcher.proto, # dispatcher.proto,
@ -207,7 +201,7 @@ class MinecraftServer:
async def _play(self, dispatcher:Dispatcher) -> bool: async def _play(self, dispatcher:Dispatcher) -> bool:
self._logger.info("Player connected") self._logger.info("Player connected")
if self.options["spawn-player"]: if self.options.spawn_player:
await dispatcher.write( await dispatcher.write(
PacketLogin( PacketLogin(
dispatcher.proto, dispatcher.proto,
@ -245,14 +239,14 @@ class MinecraftServer:
) )
) )
await asyncio.gather(*[cb(dispatcher) for cb in self._connect_handlers]) self.run_callbacks(ServerEvent.CLIENT_CONNECTED)
async for packet in dispatcher.packets(): async for packet in dispatcher.packets():
for cb in self._packet_handlers: # TODO handle play
asyncio.get_event_loop().create_task(cb(dispatcher, packet)) self.run_callbacks(Packet, packet)
pass # TODO handle play self.run_callbacks(type(packet), packet)
await asyncio.gather(*[cb(dispatcher) for cb in self._disconnect_handlers]) self.run_callbacks(ServerEvent.CLIENT_DISCONNECTED)
return False return False

View file

@ -0,0 +1,2 @@
from .callbacks import CallbacksHolder
from .runnable import Runnable

View file

@ -0,0 +1,38 @@
import asyncio
from uuid import uuid4
from typing import Dict, List, Any, Callable
class CallbacksHolder:
_callbacks : Dict[Any, List[Callable]]
_tasks : Dict[str, asyncio.Task]
def __init__(self):
self._callbacks = {}
def register(self, key:Any, callback:Callable):
if key not in self._callbacks:
self._callbacks[key] = []
self._callbacks[key].append(callback)
def trigger(self, key:Any) -> List[Callable]:
if key not in self._callbacks:
return []
return self._callbacks[key]
def run_callbacks(self, key:Any, *args) -> None:
for cb in self.trigger(key):
task_id = str(uuid4())
async def wrapper(*args):
await cb(*args)
self._tasks.pop(task_id)
loop = asyncio.get_event_loop()
self._tasks[task_id] = loop.create_task(wrapper(*args))
async def join_callbacks(self):
await asyncio.gather(*list(self._tasks.values()))
self._tasks.clear()

View file

@ -0,0 +1,35 @@
import asyncio
import logging
class Runnable:
async def start(self):
raise NotImplementedError
async def stop(self, force:bool=False):
raise NotImplementedError
def run(self):
loop = asyncio.get_event_loop()
logging.info("Starting")
loop.run_until_complete(self.start())
async def idle():
never = asyncio.Event()
logging.info("Idling")
await never.wait()
try:
loop.run_until_complete(idle())
except KeyboardInterrupt:
logging.info("Received SIGINT, stopping...")
try:
loop.run_until_complete(self.stop())
except KeyboardInterrupt:
logging.info("Received SIGINT, stopping for real")
loop.run_until_complete(self.stop(force=True))
logging.info("Done")