moved common code into traits, added callbacks, made it cleaner?
This commit is contained in:
parent
c2e8e80806
commit
540dbd5d90
6 changed files with 185 additions and 122 deletions
|
@ -5,19 +5,19 @@ import logging
|
|||
from .mc.proto.play.clientbound import PacketChat
|
||||
from .mc.token import Token
|
||||
from .dispatcher import ConnectionState
|
||||
from .client import Client
|
||||
from .client import MinecraftClient
|
||||
from .server import MinecraftServer
|
||||
from .util.helpers import parse_chat
|
||||
|
||||
async def idle():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
if sys.argv[1] == "--server":
|
||||
serv = MinecraftServer("0.0.0.0", 25565)
|
||||
host = sys.argv[2] if len(sys.argv) > 2 else "localhost"
|
||||
port = sys.argv[3] if len(sys.argv) > 3 else 25565
|
||||
|
||||
serv = MinecraftServer(host, port)
|
||||
|
||||
serv.run() # will block and start asyncio event loop
|
||||
else:
|
||||
username = sys.argv[1]
|
||||
|
@ -33,7 +33,7 @@ if __name__ == "__main__":
|
|||
host = server.strip()
|
||||
port = 25565
|
||||
|
||||
client = Client(host, port, username=username, password=pwd)
|
||||
client = MinecraftClient(host, port, username=username, password=pwd)
|
||||
|
||||
@client.on_packet(PacketChat, ConnectionState.PLAY)
|
||||
async def print_chat(packet: PacketChat):
|
||||
|
|
|
@ -2,12 +2,14 @@ import asyncio
|
|||
import logging
|
||||
import uuid
|
||||
|
||||
from dataclasses import dataclass
|
||||
from asyncio import Task
|
||||
from enum import Enum
|
||||
|
||||
from typing import Dict, List, Callable, Type, Optional, Tuple, AsyncIterator
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from .traits import CallbacksHolder, Runnable
|
||||
from .mc.packet import Packet
|
||||
from .mc.token import Token, AuthException
|
||||
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
|
||||
|
@ -22,10 +24,21 @@ from .util import encryption
|
|||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class Client:
|
||||
@dataclass
|
||||
class ClientOptions:
|
||||
reconnect : bool
|
||||
reconnect_delay : float
|
||||
keep_alive : bool
|
||||
poll_interval : float
|
||||
|
||||
class ClientEvent(Enum):
|
||||
CONNECTED = 0
|
||||
DISCONNECTED = 1
|
||||
|
||||
class MinecraftClient(CallbacksHolder, Runnable):
|
||||
host:str
|
||||
port:int
|
||||
options:dict
|
||||
options:ClientOptions
|
||||
|
||||
username:Optional[str]
|
||||
password:Optional[str]
|
||||
|
@ -37,27 +50,30 @@ class Client:
|
|||
_worker : Task
|
||||
_callbacks = Dict[str, Task]
|
||||
|
||||
_packet_callbacks : Dict[Type[Packet], List[Callable]]
|
||||
_logger : logging.Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host:str,
|
||||
port:int = 25565,
|
||||
options:dict = None,
|
||||
username:Optional[str] = None,
|
||||
password:Optional[str] = None,
|
||||
token:Optional[Token] = None,
|
||||
reconnect:bool = True,
|
||||
reconnect_delay:float = 10.0,
|
||||
keep_alive:bool = True,
|
||||
poll_interval:float = 1.0,
|
||||
|
||||
):
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.options = options or {
|
||||
"reconnect" : True,
|
||||
"rctime" : 5.0,
|
||||
"keep-alive" : True,
|
||||
"poll-timeout" : 1,
|
||||
}
|
||||
self.options = ClientOptions(
|
||||
reconnect=reconnect,
|
||||
reconnect_delay=reconnect_delay,
|
||||
keep_alive=keep_alive,
|
||||
poll_interval=poll_interval
|
||||
)
|
||||
|
||||
self.token = token
|
||||
self.username = username
|
||||
|
@ -67,9 +83,6 @@ class Client:
|
|||
self._processing = False
|
||||
self._authenticated = False
|
||||
|
||||
self._packet_callbacks = {}
|
||||
self._callbacks = {}
|
||||
|
||||
self._logger = LOGGER.getChild(f"{self.host}:{self.port}")
|
||||
|
||||
@property
|
||||
|
@ -80,23 +93,21 @@ class Client:
|
|||
def connected(self) -> bool:
|
||||
return self.started and self.dispatcher.connected
|
||||
|
||||
def _run_async(self, func, pkt:Packet):
|
||||
key = str(uuid.uuid4()) # ugly!
|
||||
|
||||
async def wrapper(packet:Packet):
|
||||
try:
|
||||
await func(packet)
|
||||
except Exception as e:
|
||||
self._logger.error("Exception in callback %s for packet %s | %s", func.__name__, packet, str(e))
|
||||
self._callbacks.pop(key, None)
|
||||
|
||||
self._callbacks[key] = asyncio.get_event_loop().create_task(wrapper(pkt))
|
||||
|
||||
def on_packet(self, packet:Type[Packet], *args) -> Callable: # receive *args for retro compatibility
|
||||
def on_connected(self) -> Callable:
|
||||
def wrapper(fun):
|
||||
if packet not in self._packet_callbacks:
|
||||
self._packet_callbacks[packet] = []
|
||||
self._packet_callbacks[packet].append(fun)
|
||||
self.register(ClientEvent.CONNECTED, fun)
|
||||
return fun
|
||||
return wrapper
|
||||
|
||||
def on_disconnected(self) -> Callable:
|
||||
def wrapper(fun):
|
||||
self.register(ClientEvent.DISCONNECTED, fun)
|
||||
return fun
|
||||
return wrapper
|
||||
|
||||
def on_packet(self, packet:Type[Packet]) -> Callable:
|
||||
def wrapper(fun):
|
||||
self.register(packet, fun)
|
||||
return fun
|
||||
return wrapper
|
||||
|
||||
|
@ -136,39 +147,22 @@ class Client:
|
|||
if restart:
|
||||
await self.start()
|
||||
|
||||
def run(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
loop.run_until_complete(self.start())
|
||||
|
||||
async def idle():
|
||||
while True: # TODO don't busywait even if it doesn't matter much
|
||||
await asyncio.sleep(self.options["poll-timeout"])
|
||||
|
||||
try:
|
||||
loop.run_until_complete(idle())
|
||||
except KeyboardInterrupt:
|
||||
self._logger.info("Received SIGINT, stopping...")
|
||||
try:
|
||||
loop.run_until_complete(self.stop())
|
||||
except KeyboardInterrupt:
|
||||
self._logger.info("Received SIGINT, stopping for real")
|
||||
loop.run_until_complete(self.stop(wait_tasks=False))
|
||||
|
||||
async def start(self):
|
||||
if self.started:
|
||||
return
|
||||
self._processing = True
|
||||
self._worker = asyncio.get_event_loop().create_task(self._client_worker())
|
||||
self._logger.info("Minecraft client started")
|
||||
|
||||
async def stop(self, block=True, wait_tasks=True):
|
||||
async def stop(self, force:bool=False):
|
||||
self._processing = False
|
||||
if self.dispatcher.connected:
|
||||
await self.dispatcher.disconnect(block=block)
|
||||
if block:
|
||||
await self.dispatcher.disconnect(block=not force)
|
||||
if not force:
|
||||
await self._worker
|
||||
self._logger.info("Minecraft client stopped")
|
||||
if block and wait_tasks:
|
||||
await asyncio.gather(*list(self._callbacks.values()))
|
||||
if not force:
|
||||
await self.join_callbacks()
|
||||
|
||||
async def _client_worker(self):
|
||||
while self._processing:
|
||||
|
@ -188,9 +182,9 @@ class Client:
|
|||
self._logger.exception("Exception in Client connection")
|
||||
if self.dispatcher.connected:
|
||||
await self.dispatcher.disconnect()
|
||||
if not self.options["reconnect"]:
|
||||
if not self.options.reconnect:
|
||||
break
|
||||
await asyncio.sleep(self.options["rctime"])
|
||||
await asyncio.sleep(self.options.reconnect_delay)
|
||||
await self.stop(block=False)
|
||||
|
||||
async def _handshake(self) -> bool: # TODO make this fancier! poll for version and status first
|
||||
|
@ -255,20 +249,20 @@ class Client:
|
|||
|
||||
async def _play(self) -> bool:
|
||||
self.dispatcher.state = ConnectionState.PLAY
|
||||
self.run_callbacks(ClientEvent.CONNECTED)
|
||||
async for packet in self.dispatcher.packets():
|
||||
self._logger.debug("[ * ] Processing | %s", str(packet))
|
||||
if isinstance(packet, PacketSetCompression):
|
||||
self._logger.info("Compression updated")
|
||||
self.dispatcher.compression = packet.threshold
|
||||
elif isinstance(packet, PacketKeepAlive):
|
||||
if self.options["keep-alive"]:
|
||||
if self.options.keep_alive:
|
||||
keep_alive_packet = PacketKeepAliveResponse(340, keepAliveId=packet.keepAliveId)
|
||||
await self.dispatcher.write(keep_alive_packet)
|
||||
elif isinstance(packet, PacketKickDisconnect):
|
||||
self._logger.error("Kicked while in game")
|
||||
break
|
||||
for packet_type in (Packet, type(packet)): # check both callbacks for base class and instance class
|
||||
if packet_type in self._packet_callbacks:
|
||||
for cb in self._packet_callbacks[packet_type]:
|
||||
self._run_async(cb, packet)
|
||||
self.run_callbacks(Packet, packet)
|
||||
self.run_callbacks(type(packet), packet)
|
||||
self.run_callbacks(ClientEvent.DISCONNECTED)
|
||||
return False
|
||||
|
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
import logging
|
||||
import uuid
|
||||
|
||||
from dataclasses import dataclass
|
||||
from asyncio import Task, StreamReader, StreamWriter
|
||||
from asyncio.base_events import Server # just for typing
|
||||
from enum import Enum
|
||||
|
@ -9,6 +10,7 @@ from enum import Enum
|
|||
from typing import Dict, List, Callable, Coroutine, Type, Optional, Tuple, AsyncIterator
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from .traits import CallbacksHolder, Runnable
|
||||
from .mc.packet import Packet
|
||||
from .mc.token import Token, AuthException
|
||||
from .mc.definitions import Dimension, Difficulty, Gamemode, ConnectionState
|
||||
|
@ -23,89 +25,80 @@ from .util import encryption
|
|||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class MinecraftServer:
|
||||
@dataclass
|
||||
class ServerOptions:
|
||||
online_mode : bool
|
||||
spawn_player : bool
|
||||
poll_interval : float
|
||||
|
||||
class ServerEvent(Enum):
|
||||
CLIENT_CONNECTED = 0
|
||||
CLIENT_DISCONNECTED = 1
|
||||
|
||||
class MinecraftServer(CallbacksHolder, Runnable):
|
||||
host:str
|
||||
port:int
|
||||
options:dict
|
||||
options:ServerOptions
|
||||
|
||||
_dispatcher_pool : List[Dispatcher]
|
||||
_processing : bool
|
||||
_server : Server
|
||||
_worker : Task
|
||||
_callbacks = Dict[str, Task]
|
||||
|
||||
_logger : logging.Logger
|
||||
|
||||
_disconnect_handlers : List[Callable]
|
||||
_connect_handlers : List[Callable]
|
||||
_packet_handlers : List[Callable]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host:str,
|
||||
port:int = 25565,
|
||||
options:dict = None,
|
||||
online_mode:bool = False,
|
||||
spawn_player:bool = True,
|
||||
poll_interval:float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.options = options or {
|
||||
"poll-timeout" : 1,
|
||||
"online-mode" : False,
|
||||
"spawn-player" : True,
|
||||
}
|
||||
self.options = ServerOptions(
|
||||
online_mode=online_mode,
|
||||
spawn_player=spawn_player,
|
||||
poll_interval=poll_interval,
|
||||
)
|
||||
|
||||
self._dispatcher_pool = []
|
||||
self._processing = False
|
||||
|
||||
self._logger = LOGGER.getChild(f"@({self.host}:{self.port})")
|
||||
|
||||
self._disconnect_handlers = []
|
||||
self._connect_handlers = []
|
||||
self._packet_handlers = []
|
||||
|
||||
@property
|
||||
def started(self) -> bool:
|
||||
return self._processing
|
||||
|
||||
def on_client_connect(self, *args):
|
||||
@property
|
||||
def connected(self) -> int:
|
||||
return len(self._dispatcher_pool)
|
||||
|
||||
def on_connect(self):
|
||||
def wrapper(fun):
|
||||
self._connect_handlers.append(fun)
|
||||
self.register(ServerEvent.CLIENT_CONNECTED, fun)
|
||||
return fun
|
||||
return wrapper
|
||||
|
||||
def on_client_disconnect(self, *args):
|
||||
def on_disconnect(self):
|
||||
def wrapper(fun):
|
||||
self._disconnect_handlers.append(fun)
|
||||
self.register(ServerEvent.CLIENT_DISCONNECTED, fun)
|
||||
return fun
|
||||
return wrapper
|
||||
|
||||
def on_client_packet(self, *args):
|
||||
def on_packet(self, packet:Type[Packet]):
|
||||
def wrapper(fun):
|
||||
self._packet_handlers.append(fun)
|
||||
self.register(packet, fun)
|
||||
return fun
|
||||
return wrapper
|
||||
|
||||
def run(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
loop.run_until_complete(self.start())
|
||||
|
||||
async def idle():
|
||||
while True: # TODO don't busywait even if it doesn't matter much
|
||||
await asyncio.sleep(self.options["poll-timeout"])
|
||||
|
||||
try:
|
||||
loop.run_until_complete(idle())
|
||||
except KeyboardInterrupt:
|
||||
self._logger.info("Received SIGINT, stopping...")
|
||||
try:
|
||||
loop.run_until_complete(self.stop())
|
||||
except KeyboardInterrupt:
|
||||
self._logger.info("Received SIGINT, stopping for real")
|
||||
loop.run_until_complete(self.stop(wait_tasks=False))
|
||||
|
||||
async def start(self):
|
||||
if self.started:
|
||||
return
|
||||
self._server = await asyncio.start_server(
|
||||
self._server_worker, self.host, self.port
|
||||
)
|
||||
|
@ -114,14 +107,15 @@ class MinecraftServer:
|
|||
await self._server.start_serving()
|
||||
self._logger.info("Minecraft server started")
|
||||
|
||||
async def stop(self, block=True, wait_tasks=True):
|
||||
async def stop(self, force:bool = False):
|
||||
self._processing = False
|
||||
self._server.close()
|
||||
await asyncio.gather(*[d.disconnect(block=block) for d in self._dispatcher_pool])
|
||||
if block:
|
||||
await self._server.wait_closed()
|
||||
# if block and wait_tasks: # TODO wait for client workers
|
||||
# await asyncio.gather(*list(self._callbacks.values()))
|
||||
await asyncio.gather(*[d.disconnect(block=not force) for d in self._dispatcher_pool])
|
||||
if not force:
|
||||
await asyncio.gather(
|
||||
self._server.wait_closed(),
|
||||
self.join_callbacks(),
|
||||
)
|
||||
|
||||
async def _disconnect_client(self, dispatcher):
|
||||
if dispatcher.state == ConnectionState.LOGIN:
|
||||
|
@ -178,7 +172,7 @@ class MinecraftServer:
|
|||
self._logger.info("Logging in player")
|
||||
async for packet in dispatcher.packets():
|
||||
if isinstance(packet, PacketLoginStart):
|
||||
if self.options["online-mode"]:
|
||||
if self.options.online_mode:
|
||||
# await dispatcher.write(
|
||||
# PacketEncryptionBegin(
|
||||
# dispatcher.proto,
|
||||
|
@ -207,7 +201,7 @@ class MinecraftServer:
|
|||
async def _play(self, dispatcher:Dispatcher) -> bool:
|
||||
self._logger.info("Player connected")
|
||||
|
||||
if self.options["spawn-player"]:
|
||||
if self.options.spawn_player:
|
||||
await dispatcher.write(
|
||||
PacketLogin(
|
||||
dispatcher.proto,
|
||||
|
@ -245,14 +239,14 @@ class MinecraftServer:
|
|||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*[cb(dispatcher) for cb in self._connect_handlers])
|
||||
self.run_callbacks(ServerEvent.CLIENT_CONNECTED)
|
||||
|
||||
async for packet in dispatcher.packets():
|
||||
for cb in self._packet_handlers:
|
||||
asyncio.get_event_loop().create_task(cb(dispatcher, packet))
|
||||
pass # TODO handle play
|
||||
# TODO handle play
|
||||
self.run_callbacks(Packet, packet)
|
||||
self.run_callbacks(type(packet), packet)
|
||||
|
||||
await asyncio.gather(*[cb(dispatcher) for cb in self._disconnect_handlers])
|
||||
self.run_callbacks(ServerEvent.CLIENT_DISCONNECTED)
|
||||
|
||||
return False
|
||||
|
||||
|
|
2
aiocraft/traits/__init__.py
Normal file
2
aiocraft/traits/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
|||
from .callbacks import CallbacksHolder
|
||||
from .runnable import Runnable
|
38
aiocraft/traits/callbacks.py
Normal file
38
aiocraft/traits/callbacks.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
from typing import Dict, List, Any, Callable
|
||||
|
||||
class CallbacksHolder:
|
||||
|
||||
_callbacks : Dict[Any, List[Callable]]
|
||||
_tasks : Dict[str, asyncio.Task]
|
||||
|
||||
def __init__(self):
|
||||
self._callbacks = {}
|
||||
|
||||
def register(self, key:Any, callback:Callable):
|
||||
if key not in self._callbacks:
|
||||
self._callbacks[key] = []
|
||||
self._callbacks[key].append(callback)
|
||||
|
||||
def trigger(self, key:Any) -> List[Callable]:
|
||||
if key not in self._callbacks:
|
||||
return []
|
||||
return self._callbacks[key]
|
||||
|
||||
def run_callbacks(self, key:Any, *args) -> None:
|
||||
for cb in self.trigger(key):
|
||||
task_id = str(uuid4())
|
||||
|
||||
async def wrapper(*args):
|
||||
await cb(*args)
|
||||
self._tasks.pop(task_id)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
self._tasks[task_id] = loop.create_task(wrapper(*args))
|
||||
|
||||
async def join_callbacks(self):
|
||||
await asyncio.gather(*list(self._tasks.values()))
|
||||
self._tasks.clear()
|
||||
|
35
aiocraft/traits/runnable.py
Normal file
35
aiocraft/traits/runnable.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import asyncio
|
||||
import logging
|
||||
|
||||
class Runnable:
|
||||
|
||||
async def start(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def stop(self, force:bool=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
logging.info("Starting")
|
||||
|
||||
loop.run_until_complete(self.start())
|
||||
|
||||
async def idle():
|
||||
never = asyncio.Event()
|
||||
logging.info("Idling")
|
||||
await never.wait()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(idle())
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Received SIGINT, stopping...")
|
||||
try:
|
||||
loop.run_until_complete(self.stop())
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Received SIGINT, stopping for real")
|
||||
loop.run_until_complete(self.stop(force=True))
|
||||
|
||||
logging.info("Done")
|
||||
|
Loading…
Reference in a new issue