moved common code into traits, added callbacks, made it cleaner?

This commit is contained in:
əlemi 2021-11-22 02:27:42 +01:00
parent c2e8e80806
commit 540dbd5d90
6 changed files with 185 additions and 122 deletions

View file

@ -5,19 +5,19 @@ import logging
from .mc.proto.play.clientbound import PacketChat
from .mc.token import Token
from .dispatcher import ConnectionState
from .client import Client
from .client import MinecraftClient
from .server import MinecraftServer
from .util.helpers import parse_chat
async def idle():
while True:
await asyncio.sleep(1)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
if sys.argv[1] == "--server":
serv = MinecraftServer("0.0.0.0", 25565)
host = sys.argv[2] if len(sys.argv) > 2 else "localhost"
port = sys.argv[3] if len(sys.argv) > 3 else 25565
serv = MinecraftServer(host, port)
serv.run() # will block and start asyncio event loop
else:
username = sys.argv[1]
@ -33,7 +33,7 @@ if __name__ == "__main__":
host = server.strip()
port = 25565
client = Client(host, port, username=username, password=pwd)
client = MinecraftClient(host, port, username=username, password=pwd)
@client.on_packet(PacketChat, ConnectionState.PLAY)
async def print_chat(packet: PacketChat):

View file

@ -2,12 +2,14 @@ import asyncio
import logging
import uuid
from dataclasses import dataclass
from asyncio import Task
from enum import Enum
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator
from .dispatcher import Dispatcher
from .traits import CallbacksHolder, Runnable
from .mc.packet import Packet
from .mc.token import Token, AuthException
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
@ -22,10 +24,21 @@ from .util import encryption
LOGGER = logging.getLogger(__name__)
class Client:
@dataclass
class ClientOptions:
reconnect : bool
reconnect_delay : float
keep_alive : bool
poll_interval : float
class ClientEvent(Enum):
CONNECTED = 0
DISCONNECTED = 1
class MinecraftClient(CallbacksHolder, Runnable):
host:str
port:int
options:dict
options:ClientOptions
username:Optional[str]
password:Optional[str]
@ -37,27 +50,30 @@ class Client:
_worker : Task
_callbacks = Dict[str, Task]
_packet_callbacks : Dict[Type[Packet], List[Callable]]
_logger : logging.Logger
def __init__(
self,
host:str,
port:int = 25565,
options:dict = None,
username:Optional[str] = None,
password:Optional[str] = None,
token:Optional[Token] = None,
reconnect:bool = True,
reconnect_delay:float = 10.0,
keep_alive:bool = True,
poll_interval:float = 1.0,
):
self.host = host
self.port = port
self.options = options or {
"reconnect" : True,
"rctime" : 5.0,
"keep-alive" : True,
"poll-timeout" : 1,
}
self.options = ClientOptions(
reconnect=reconnect,
reconnect_delay=reconnect_delay,
keep_alive=keep_alive,
poll_interval=poll_interval
)
self.token = token
self.username = username
@ -67,9 +83,6 @@ class Client:
self._processing = False
self._authenticated = False
self._packet_callbacks = {}
self._callbacks = {}
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
@property
@ -80,23 +93,21 @@ class Client:
def connected(self) -> bool:
return self.started and self.dispatcher.connected
def _run_async(self, func, pkt:Packet):
key = str(uuid.uuid4()) # ugly!
async def wrapper(packet:Packet):
try:
await func(packet)
except Exception as e:
self._logger.error("Exception in callback %s for packet %s | %s", func.__name__, packet, str(e))
self._callbacks.pop(key, None)
self._callbacks[key] = asyncio.get_event_loop().create_task(wrapper(pkt))
def on_packet(self, packet:Type[Packet], *args) -> Callable: # receive *args for retro compatibility
def on_connected(self) -> Callable:
def wrapper(fun):
if packet not in self._packet_callbacks:
self._packet_callbacks[packet] = []
self._packet_callbacks[packet].append(fun)
self.register(ClientEvent.CONNECTED, fun)
return fun
return wrapper
def on_disconnected(self) -> Callable:
def wrapper(fun):
self.register(ClientEvent.DISCONNECTED, fun)
return fun
return wrapper
def on_packet(self, packet:Type[Packet]) -> Callable:
def wrapper(fun):
self.register(packet, fun)
return fun
return wrapper
@ -136,39 +147,22 @@ class Client:
if restart:
await self.start()
def run(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(self.start())
async def idle():
while True: # TODO don't busywait even if it doesn't matter much
await asyncio.sleep(self.options["poll-timeout"])
try:
loop.run_until_complete(idle())
except KeyboardInterrupt:
self._logger.info("Received SIGINT, stopping...")
try:
loop.run_until_complete(self.stop())
except KeyboardInterrupt:
self._logger.info("Received SIGINT, stopping for real")
loop.run_until_complete(self.stop(wait_tasks=False))
async def start(self):
if self.started:
return
self._processing = True
self._worker = asyncio.get_event_loop().create_task(self._client_worker())
self._logger.info("Minecraft client started")
async def stop(self, block=True, wait_tasks=True):
async def stop(self, force:bool=False):
self._processing = False
if self.dispatcher.connected:
await self.dispatcher.disconnect(block=block)
if block:
await self.dispatcher.disconnect(block=not force)
if not force:
await self._worker
self._logger.info("Minecraft client stopped")
if block and wait_tasks:
await asyncio.gather(*list(self._callbacks.values()))
if not force:
await self.join_callbacks()
async def _client_worker(self):
while self._processing:
@ -188,9 +182,9 @@ class Client:
self._logger.exception("Exception in Client connection")
if self.dispatcher.connected:
await self.dispatcher.disconnect()
if not self.options["reconnect"]:
if not self.options.reconnect:
break
await asyncio.sleep(self.options["rctime"])
await asyncio.sleep(self.options.reconnect_delay)
await self.stop(block=False)
async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first
@ -255,20 +249,20 @@ class Client:
async def _play(self) -> bool:
self.dispatcher.state = ConnectionState.PLAY
self.run_callbacks(ClientEvent.CONNECTED)
async for packet in self.dispatcher.packets():
self._logger.debug("[ * ] Processing | %s", str(packet))
if isinstance(packet, PacketSetCompression):
self._logger.info("Compression updated")
self.dispatcher.compression = packet.threshold
elif isinstance(packet, PacketKeepAlive):
if self.options["keep-alive"]:
if self.options.keep_alive:
keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId)
await self.dispatcher.write(keep_alive_packet)
elif isinstance(packet, PacketKickDisconnect):
self._logger.error("Kicked while in game")
break
for packet_type in (Packet, type(packet)): # check both callbacks for base class and instance class
if packet_type in self._packet_callbacks:
for cb in self._packet_callbacks[packet_type]:
self._run_async(cb, packet)
self.run_callbacks(Packet, packet)
self.run_callbacks(type(packet), packet)
self.run_callbacks(ClientEvent.DISCONNECTED)
return False

View file

@ -2,6 +2,7 @@ import asyncio
import logging
import uuid
from dataclasses import dataclass
from asyncio import Task, StreamReader, StreamWriter
from asyncio.base_events import Server # just for typing
from enum import Enum
@ -9,6 +10,7 @@ from enum import Enum
from typing import Dict, List, Callable, Coroutine, Type, Optional, Tuple, AsyncIterator
from .dispatcher import Dispatcher
from .traits import CallbacksHolder, Runnable
from .mc.packet import Packet
from .mc.token import Token, AuthException
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
@ -23,89 +25,80 @@ from .util import encryption
LOGGER = logging.getLogger(__name__)
class MinecraftServer:
@dataclass
class ServerOptions:
online_mode : bool
spawn_player : bool
poll_interval : float
class ServerEvent(Enum):
CLIENT_CONNECTED = 0
CLIENT_DISCONNECTED = 1
class MinecraftServer(CallbacksHolder, Runnable):
host:str
port:int
options:dict
options:ServerOptions
_dispatcher_pool : List[Dispatcher]
_processing : bool
_server : Server
_worker : Task
_callbacks = Dict[str, Task]
_logger : logging.Logger
_disconnect_handlers : List[Callable]
_connect_handlers : List[Callable]
_packet_handlers : List[Callable]
def __init__(
self,
host:str,
port:int = 25565,
options:dict = None,
online_mode:bool = False,
spawn_player:bool = True,
poll_interval:float = 1.0,
):
super().__init__()
self.host = host
self.port = port
self.options = options or {
"poll-timeout" : 1,
"online-mode" : False,
"spawn-player" : True,
}
self.options = ServerOptions(
online_mode=online_mode,
spawn_player=spawn_player,
poll_interval=poll_interval,
)
self._dispatcher_pool = []
self._processing = False
self._logger = LOGGER.getChild(f"@({self.host}:{self.port})")
self._disconnect_handlers = []
self._connect_handlers = []
self._packet_handlers = []
@property
def started(self) -> bool:
return self._processing
def on_client_connect(self, *args):
@property
def connected(self) -> int:
return len(self._dispatcher_pool)
def on_connect(self):
def wrapper(fun):
self._connect_handlers.append(fun)
self.register(ServerEvent.CLIENT_CONNECTED, fun)
return fun
return wrapper
def on_client_disconnect(self, *args):
def on_disconnect(self):
def wrapper(fun):
self._disconnect_handlers.append(fun)
self.register(ServerEvent.CLIENT_DISCONNECTED, fun)
return fun
return wrapper
def on_client_packet(self, *args):
def on_packet(self, packet:Type[Packet]):
def wrapper(fun):
self._packet_handlers.append(fun)
self.register(packet, fun)
return fun
return wrapper
def run(self):
loop = asyncio.get_event_loop()
loop.run_until_complete(self.start())
async def idle():
while True: # TODO don't busywait even if it doesn't matter much
await asyncio.sleep(self.options["poll-timeout"])
try:
loop.run_until_complete(idle())
except KeyboardInterrupt:
self._logger.info("Received SIGINT, stopping...")
try:
loop.run_until_complete(self.stop())
except KeyboardInterrupt:
self._logger.info("Received SIGINT, stopping for real")
loop.run_until_complete(self.stop(wait_tasks=False))
async def start(self):
if self.started:
return
self._server = await asyncio.start_server(
self._server_worker, self.host, self.port
)
@ -114,14 +107,15 @@ class MinecraftServer:
await self._server.start_serving()
self._logger.info("Minecraft server started")
async def stop(self, block=True, wait_tasks=True):
async def stop(self, force:bool = False):
self._processing = False
self._server.close()
await asyncio.gather(*[d.disconnect(block=block) for d in self._dispatcher_pool])
if block:
await self._server.wait_closed()
# if block and wait_tasks: # TODO wait for client workers
# await asyncio.gather(*list(self._callbacks.values()))
await asyncio.gather(*[d.disconnect(block=not force) for d in self._dispatcher_pool])
if not force:
await asyncio.gather(
self._server.wait_closed(),
self.join_callbacks(),
)
async def _disconnect_client(self, dispatcher):
if dispatcher.state == ConnectionState.LOGIN:
@ -178,7 +172,7 @@ class MinecraftServer:
self._logger.info("Logging in player")
async for packet in dispatcher.packets():
if isinstance(packet, PacketLoginStart):
if self.options["online-mode"]:
if self.options.online_mode:
# await dispatcher.write(
# PacketEncryptionBegin(
# dispatcher.proto,
@ -207,7 +201,7 @@ class MinecraftServer:
async def _play(self, dispatcher:Dispatcher) -> bool:
self._logger.info("Player connected")
if self.options["spawn-player"]:
if self.options.spawn_player:
await dispatcher.write(
PacketLogin(
dispatcher.proto,
@ -245,14 +239,14 @@ class MinecraftServer:
)
)
await asyncio.gather(*[cb(dispatcher) for cb in self._connect_handlers])
self.run_callbacks(ServerEvent.CLIENT_CONNECTED)
async for packet in dispatcher.packets():
for cb in self._packet_handlers:
asyncio.get_event_loop().create_task(cb(dispatcher, packet))
pass # TODO handle play
# TODO handle play
self.run_callbacks(Packet, packet)
self.run_callbacks(type(packet), packet)
await asyncio.gather(*[cb(dispatcher) for cb in self._disconnect_handlers])
self.run_callbacks(ServerEvent.CLIENT_DISCONNECTED)
return False

View file

@ -0,0 +1,2 @@
from .callbacks import CallbacksHolder
from .runnable import Runnable

View file

@ -0,0 +1,38 @@
import asyncio
from uuid import uuid4
from typing import Dict, List, Any, Callable
class CallbacksHolder:
_callbacks : Dict[Any, List[Callable]]
_tasks : Dict[str, asyncio.Task]
def __init__(self):
self._callbacks = {}
def register(self, key:Any, callback:Callable):
if key not in self._callbacks:
self._callbacks[key] = []
self._callbacks[key].append(callback)
def trigger(self, key:Any) -> List[Callable]:
if key not in self._callbacks:
return []
return self._callbacks[key]
def run_callbacks(self, key:Any, *args) -> None:
for cb in self.trigger(key):
task_id = str(uuid4())
async def wrapper(*args):
await cb(*args)
self._tasks.pop(task_id)
loop = asyncio.get_event_loop()
self._tasks[task_id] = loop.create_task(wrapper(*args))
async def join_callbacks(self):
await asyncio.gather(*list(self._tasks.values()))
self._tasks.clear()

View file

@ -0,0 +1,35 @@
import asyncio
import logging
class Runnable:
async def start(self):
raise NotImplementedError
async def stop(self, force:bool=False):
raise NotImplementedError
def run(self):
loop = asyncio.get_event_loop()
logging.info("Starting")
loop.run_until_complete(self.start())
async def idle():
never = asyncio.Event()
logging.info("Idling")
await never.wait()
try:
loop.run_until_complete(idle())
except KeyboardInterrupt:
logging.info("Received SIGINT, stopping...")
try:
loop.run_until_complete(self.stop())
except KeyboardInterrupt:
logging.info("Received SIGINT, stopping for real")
loop.run_until_complete(self.stop(force=True))
logging.info("Done")