implemented packet whitelist to hopefully improve performance
This commit is contained in:
parent
7621021b83
commit
93e7859304
3 changed files with 39 additions and 25 deletions
|
@ -27,10 +27,11 @@ LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClientOptions:
|
class ClientOptions:
|
||||||
reconnect : bool
|
reconnect : bool = True
|
||||||
reconnect_delay : float
|
reconnect_delay : float = 10.0
|
||||||
keep_alive : bool
|
keep_alive : bool = True
|
||||||
poll_interval : float
|
poll_interval : float = 1.0
|
||||||
|
use_packet_whitelist : bool = True
|
||||||
|
|
||||||
class ClientEvent(Enum):
|
class ClientEvent(Enum):
|
||||||
CONNECTED = 0
|
CONNECTED = 0
|
||||||
|
@ -61,22 +62,13 @@ class MinecraftClient(CallbacksHolder, Runnable):
|
||||||
password:Optional[str] = None,
|
password:Optional[str] = None,
|
||||||
token:Optional[Token] = None,
|
token:Optional[Token] = None,
|
||||||
online_mode:bool = True,
|
online_mode:bool = True,
|
||||||
reconnect:bool = True,
|
**kwargs
|
||||||
reconnect_delay:float = 10.0,
|
|
||||||
keep_alive:bool = True,
|
|
||||||
poll_interval:float = 1.0,
|
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
|
||||||
self.options = ClientOptions(
|
self.options = ClientOptions(**kwargs)
|
||||||
reconnect=reconnect,
|
|
||||||
reconnect_delay=reconnect_delay,
|
|
||||||
keep_alive=keep_alive,
|
|
||||||
poll_interval=poll_interval
|
|
||||||
)
|
|
||||||
|
|
||||||
self.token = token
|
self.token = token
|
||||||
self.username = username
|
self.username = username
|
||||||
|
@ -186,7 +178,13 @@ class MinecraftClient(CallbacksHolder, Runnable):
|
||||||
self._logger.error(str(e))
|
self._logger.error(str(e))
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
await self.dispatcher.connect(self.host, self.port)
|
packet_whitelist = self.callback_keys(filter=Packet) if self.options.use_packet_whitelist else set()
|
||||||
|
await self.dispatcher.connect(
|
||||||
|
self.host,
|
||||||
|
self.port,
|
||||||
|
queue_timeout=self.options.poll_interval,
|
||||||
|
packet_whitelist=packet_whitelist
|
||||||
|
)
|
||||||
await self._handshake()
|
await self._handshake()
|
||||||
if await self._login():
|
if await self._login():
|
||||||
await self._play()
|
await self._play()
|
||||||
|
|
|
@ -5,7 +5,7 @@ import zlib
|
||||||
import logging
|
import logging
|
||||||
from asyncio import StreamReader, StreamWriter, Queue, Task
|
from asyncio import StreamReader, StreamWriter, Queue, Task
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Dict, Optional, AsyncIterator, Type
|
from typing import List, Dict, Set, Optional, AsyncIterator, Type
|
||||||
|
|
||||||
from cryptography.hazmat.primitives.ciphers import CipherContext
|
from cryptography.hazmat.primitives.ciphers import CipherContext
|
||||||
|
|
||||||
|
@ -23,8 +23,6 @@ class InvalidState(Exception):
|
||||||
class ConnectionError(Exception):
|
class ConnectionError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
BROKEN_PACKETS = (77, ) # These packets are still not parseable due to missing data type
|
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher:
|
||||||
_is_server : bool # True when receiving packets from clients
|
_is_server : bool # True when receiving packets from clients
|
||||||
|
|
||||||
|
@ -41,6 +39,9 @@ class Dispatcher:
|
||||||
_incoming : Queue
|
_incoming : Queue
|
||||||
_outgoing : Queue
|
_outgoing : Queue
|
||||||
|
|
||||||
|
_packet_whitelist : Set[Packet]
|
||||||
|
_packet_id_whitelist : Set[int]
|
||||||
|
|
||||||
_host : str
|
_host : str
|
||||||
_port : int
|
_port : int
|
||||||
|
|
||||||
|
@ -89,16 +90,26 @@ class Dispatcher:
|
||||||
self.encryption = True
|
self.encryption = True
|
||||||
self._logger.info("Encryption enabled")
|
self._logger.info("Encryption enabled")
|
||||||
|
|
||||||
def _prepare(self, host:Optional[str] = None, port:Optional[int] = None, queue_timeout:int = 1, queue_size:int = 100):
|
def _prepare(self,
|
||||||
|
host:Optional[str] = None,
|
||||||
|
port:Optional[int] = None,
|
||||||
|
queue_timeout:int = 1,
|
||||||
|
queue_size:int = 100,
|
||||||
|
packet_whitelist : List[Packet] = None
|
||||||
|
):
|
||||||
self._host = host or self._host or "localhost"
|
self._host = host or self._host or "localhost"
|
||||||
self._port = port or self._port or 25565
|
self._port = port or self._port or 25565
|
||||||
self._logger = LOGGER.getChild(f"on({self._host}:{self._port})")
|
self._logger = LOGGER.getChild(f"on({self._host}:{self._port})")
|
||||||
|
self._packet_whitelist = packet_whitelist or set()
|
||||||
|
|
||||||
self.encryption = False
|
self.encryption = False
|
||||||
self.compression = None
|
self.compression = None
|
||||||
self.state = ConnectionState.HANDSHAKING
|
self.state = ConnectionState.HANDSHAKING
|
||||||
self.proto = 340 # TODO
|
self.proto = 340 # TODO
|
||||||
|
|
||||||
|
# This can only happen after we know the connection protocol
|
||||||
|
self._packet_id_whitelist = set((P(self.proto).id for P in packet_whitelist)) if packet_whitelist else set()
|
||||||
|
|
||||||
# Make new queues, do set a max size to sorta propagate back pressure
|
# Make new queues, do set a max size to sorta propagate back pressure
|
||||||
self._incoming = Queue(queue_size)
|
self._incoming = Queue(queue_size)
|
||||||
self._outgoing = Queue(queue_size)
|
self._outgoing = Queue(queue_size)
|
||||||
|
@ -112,12 +123,13 @@ class Dispatcher:
|
||||||
reader : Optional[StreamReader] = None,
|
reader : Optional[StreamReader] = None,
|
||||||
writer : Optional[StreamWriter] = None,
|
writer : Optional[StreamWriter] = None,
|
||||||
queue_timeout : int = 1,
|
queue_timeout : int = 1,
|
||||||
queue_size : int = 100
|
queue_size : int = 100,
|
||||||
|
packet_whitelist : Set[Packet] = None,
|
||||||
):
|
):
|
||||||
if self.connected:
|
if self.connected:
|
||||||
raise InvalidState("Dispatcher already connected")
|
raise InvalidState("Dispatcher already connected")
|
||||||
|
|
||||||
self._prepare(host, port, queue_timeout, queue_size)
|
self._prepare(host, port, queue_timeout, queue_size, packet_whitelist)
|
||||||
|
|
||||||
if reader and writer:
|
if reader and writer:
|
||||||
self._down, self._up = reader, writer
|
self._down, self._up = reader, writer
|
||||||
|
@ -216,8 +228,9 @@ class Dispatcher:
|
||||||
buffer = io.BytesIO(decompressed_data)
|
buffer = io.BytesIO(decompressed_data)
|
||||||
|
|
||||||
packet_id = VarInt.read(buffer)
|
packet_id = VarInt.read(buffer)
|
||||||
if packet_id in BROKEN_PACKETS:
|
if self._packet_id_whitelist and packet_id in self._packet_id_whitelist:
|
||||||
continue # cheap fix, still need to implement NBT, Slot and EntityMetadata...
|
self._logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)
|
||||||
|
continue # ignore this packet, we rarely need them all, should improve performance
|
||||||
cls = self._packet_type_from_registry(packet_id)
|
cls = self._packet_type_from_registry(packet_id)
|
||||||
packet = cls.deserialize(self.proto, buffer)
|
packet = cls.deserialize(self.proto, buffer)
|
||||||
self._logger.debug("[<--] Received | %s", repr(packet))
|
self._logger.debug("[<--] Received | %s", repr(packet))
|
||||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Dict, List, Any, Callable
|
from typing import Dict, List, Set, Any, Callable, Type
|
||||||
|
|
||||||
class CallbacksHolder:
|
class CallbacksHolder:
|
||||||
|
|
||||||
|
@ -16,6 +16,9 @@ class CallbacksHolder:
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
self._tasks = {}
|
self._tasks = {}
|
||||||
|
|
||||||
|
def callback_keys(self, filter:Type = None) -> Set[Any]:
|
||||||
|
return set(x for x in self._callbacks.keys() if not filter or isinstance(x, filter))
|
||||||
|
|
||||||
def register(self, key:Any, callback:Callable):
|
def register(self, key:Any, callback:Callable):
|
||||||
if key not in self._callbacks:
|
if key not in self._callbacks:
|
||||||
self._callbacks[key] = []
|
self._callbacks[key] = []
|
||||||
|
|
Loading…
Reference in a new issue