From ccadefc2f5a89bc6bb778c96c24c662ed111e341 Mon Sep 17 00:00:00 2001 From: alemidev Date: Fri, 14 Jan 2022 20:07:16 +0100 Subject: [PATCH] no more optional object Context, made a dedicated object --- aiocraft/dispatcher.py | 6 +-- aiocraft/mc/types.py | 114 +++++++++++++++++++++++------------------ 2 files changed, 68 insertions(+), 52 deletions(-) diff --git a/aiocraft/dispatcher.py b/aiocraft/dispatcher.py index beef21a..c235e9d 100644 --- a/aiocraft/dispatcher.py +++ b/aiocraft/dispatcher.py @@ -10,7 +10,7 @@ from typing import List, Dict, Set, Optional, AsyncIterator, Type from cryptography.hazmat.primitives.ciphers import CipherContext from .mc import proto as minecraft_protocol -from .mc.types import VarInt +from .mc.types import VarInt, Context from .mc.packet import Packet from .mc.definitions import ConnectionState from .util import encryption @@ -222,7 +222,7 @@ class Dispatcher: buffer = io.BytesIO(data) if self.compression is not None: - decompressed_size = VarInt.read(buffer) + decompressed_size = VarInt.read(buffer, Context()) if decompressed_size > 0: decompressor = zlib.decompressobj() decompressed_data = decompressor.decompress(buffer.read()) @@ -230,7 +230,7 @@ class Dispatcher: raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}") buffer = io.BytesIO(decompressed_data) - packet_id = VarInt.read(buffer) + packet_id = VarInt.read(buffer, Context()) if self.state == ConnectionState.PLAY and self._packet_id_whitelist \ and packet_id not in self._packet_id_whitelist: self._logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id) diff --git a/aiocraft/mc/types.py b/aiocraft/mc/types.py index 4170ab3..d104e07 100644 --- a/aiocraft/mc/types.py +++ b/aiocraft/mc/types.py @@ -1,6 +1,7 @@ import io import struct import asyncio +import json import uuid import logging @@ -10,27 +11,42 @@ from typing import List, Tuple, Dict, Any, Union, Optional, Callable, Type as Cl from .definitions import Item +class Context(object): + def __init__(self, **kwargs): + for k, v in kwargs: + setattr(self, k, v) + + def __getattr__(self, name) -> Any: + return None # return None rather than raising an exc + + def __str__(self) -> str: + return json.dumps(vars(self), indent=2, default=str, sort_keys=True) + + def __repr__(self) -> str: + values = ( f"{k}={repr(v)}" for k,v in vars(self).items() ) + return f"Context({', '.join(values)})" + class Type(object): pytype : Union[type, Callable] = lambda x : x - def write(self, data:Any, buffer:io.BytesIO, ctx:object=None) -> None: + def write(self, data:Any, buffer:io.BytesIO, ctx:Context) -> None: """Write data to a packet buffer""" raise NotImplementedError - def read(self, buffer:io.BytesIO, ctx:object=None) -> Any: + def read(self, buffer:io.BytesIO, ctx:Context) -> Any: """Read data off a packet buffer""" raise NotImplementedError - def check(self, ctx:object) -> bool: + def check(self, ctx:Context) -> bool: """Check if this type exists in this context""" return True class VoidType(Type): - def write(self, v:None, buffer:io.BytesIO, ctx:object=None): + def write(self, v:None, buffer:io.BytesIO, ctx:Context): pass - def read(self, buffer:io.BytesIO, ctx:object=None) -> None: + def read(self, buffer:io.BytesIO, ctx:Context) -> None: return None Void = VoidType() @@ -38,11 +54,11 @@ Void = VoidType() class UnimplementedDataType(Type): pytype : type = bytes - def write(self, data:bytes, buffer:io.BytesIO, ctx:object=None): + def write(self, data:bytes, buffer:io.BytesIO, ctx:Context): if data: buffer.write(data) - def read(self, buffer:io.BytesIO, ctx:object=None) -> bytes: + def read(self, buffer:io.BytesIO, ctx:Context) -> bytes: return buffer.read() TrailingData = UnimplementedDataType() @@ -56,10 +72,10 @@ class PrimitiveType(Type): self.fmt = fmt self.size = size - def write(self, data:Any, buffer:io.BytesIO, ctx:object=None): + def write(self, data:Any, buffer:io.BytesIO, ctx:Context): buffer.write(struct.pack(self.fmt, data)) - def read(self, buffer:io.BytesIO, ctx:object=None) -> Any: + def read(self, buffer:io.BytesIO, ctx:Context) -> Any: return struct.unpack(self.fmt, buffer.read(self.size))[0] Boolean = PrimitiveType(bool, ">?", 1) @@ -87,14 +103,14 @@ def nbt_to_py(item:pynbt.BaseTag) -> Any: class NBTType(Type): pytype : type = dict - def write(self, data:Optional[dict], buffer:io.BytesIO, ctx:object=None): + def write(self, data:Optional[dict], buffer:io.BytesIO, ctx:Context): if data is None: buffer.write(b'\x00') else: pynbt.NBTFile(value=data).save(buffer) - def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[dict]: - head = Byte.read(buffer) + def read(self, buffer:io.BytesIO, ctx:Context) -> Optional[dict]: + head = Byte.read(buffer, ctx) if head == 0x0: return None buffer.seek(-1,1) # go back 1 byte @@ -109,7 +125,7 @@ class VarLenPrimitive(Type): def __init__(self, max_bytes:int): self.max_bytes = max_bytes - def write(self, data:int, buffer:io.BytesIO, ctx:object=None): + def write(self, data:int, buffer:io.BytesIO, ctx:Context): count = 0 # TODO raise exceptions while count < self.max_bytes: byte = data & 0b01111111 @@ -121,7 +137,7 @@ class VarLenPrimitive(Type): if not data: break - def read(self, buffer:io.BytesIO, ctx:object=None) -> int: + def read(self, buffer:io.BytesIO, ctx:Context) -> int: numRead = 0 result = 0 while True: @@ -141,13 +157,13 @@ class VarLenPrimitive(Type): def serialize(self, data:int) -> bytes: buf = io.BytesIO() - self.write(data, buf) + self.write(data, buf, Context()) buf.seek(0) return buf.read() def deserialize(self, data:bytes) -> int: buf = io.BytesIO(data) - return self.read(buf) + return self.read(buf, Context()) VarInt = VarLenPrimitive(5) VarLong = VarLenPrimitive(10) @@ -155,12 +171,12 @@ VarLong = VarLenPrimitive(10) class StringType(Type): pytype : type = str - def write(self, data:str, buffer:io.BytesIO, ctx:object=None): + def write(self, data:str, buffer:io.BytesIO, ctx:Context): encoded = data.encode('utf-8') VarInt.write(len(encoded), buffer, ctx=ctx) buffer.write(encoded) - def read(self, buffer:io.BytesIO, ctx:object=None) -> str: + def read(self, buffer:io.BytesIO, ctx:Context) -> str: length = VarInt.read(buffer, ctx=ctx) return buffer.read(length).decode('utf-8') @@ -175,11 +191,11 @@ class BufferType(Type): def __init__(self, count:Type = VarInt): self.count = count - def write(self, data:bytes, buffer:io.BytesIO, ctx:object=None): + def write(self, data:bytes, buffer:io.BytesIO, ctx:Context): self.count.write(len(data), buffer, ctx=ctx) buffer.write(data) - def read(self, buffer:io.BytesIO, ctx:object=None) -> bytes: + def read(self, buffer:io.BytesIO, ctx:Context) -> bytes: length = self.count.read(buffer, ctx=ctx) return buffer.read(length) @@ -199,14 +215,14 @@ class PositionType(Type): # TODO THIS IS FOR 1.12.2!!! Make a generic version-less? - def write(self, data:tuple, buffer:io.BytesIO, ctx:object=None): + def write(self, data:tuple, buffer:io.BytesIO, ctx:Context): packed = ((0x3FFFFFF & data[0]) << 38) \ | ((0xFFF & data[1]) << 26) \ | (0x3FFFFFF & data[2]) UnsignedLong.write(packed, buffer, ctx=ctx) - def read(self, buffer:io.BytesIO, ctx:object=None) -> tuple: - packed = UnsignedLong.read(buffer) + def read(self, buffer:io.BytesIO, ctx:Context) -> tuple: + packed = UnsignedLong.read(buffer, ctx) x = twos_comp(packed >> 38, 26) y = (packed >> 26) & 0xFFF z = twos_comp(packed & 0x3FFFFFF, 26) @@ -218,10 +234,10 @@ class UUIDType(Type): pytype : type = uuid.UUID MAX_SIZE : int = 16 - def write(self, data:uuid.UUID, buffer:io.BytesIO, ctx:object=None): + def write(self, data:uuid.UUID, buffer:io.BytesIO, ctx:Context): buffer.write(int(data).to_bytes(self.MAX_SIZE, 'big')) - def read(self, buffer:io.BytesIO, ctx:object=None) -> uuid.UUID: + def read(self, buffer:io.BytesIO, ctx:Context) -> uuid.UUID: return uuid.UUID(int=int.from_bytes(buffer.read(self.MAX_SIZE), 'big')) UUID = UUIDType() @@ -235,7 +251,7 @@ class ArrayType(Type): self.content = content self.counter = counter - def write(self, data:List[Any], buffer:io.BytesIO, ctx:object=None): + def write(self, data:List[Any], buffer:io.BytesIO, ctx:Context): if isinstance(self.counter, Type): self.counter.write(len(data), buffer, ctx=ctx) for i, el in enumerate(data): @@ -243,7 +259,7 @@ class ArrayType(Type): if isinstance(self.counter, int) and i >= self.counter: break # jank but should do - def read(self, buffer:io.BytesIO, ctx:object=None) -> List[Any]: + def read(self, buffer:io.BytesIO, ctx:Context) -> List[Any]: length = self.counter if isinstance(self.counter, int) else self.counter.read(buffer, ctx=ctx) out = [] for _ in range(length): @@ -257,12 +273,12 @@ class OptionalType(Type): self.t = t self.pytype = t.pytype - def write(self, data:Optional[Any], buffer:io.BytesIO, ctx:object=None): + def write(self, data:Optional[Any], buffer:io.BytesIO, ctx:Context): Boolean.write(bool(data), buffer, ctx=ctx) if data: self.t.write(data, buffer, ctx=ctx) - def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[Any]: + def read(self, buffer:io.BytesIO, ctx:Context) -> Optional[Any]: if Boolean.read(buffer, ctx=ctx): return self.t.read(buffer, ctx=ctx) return None @@ -276,14 +292,14 @@ class SwitchType(Type): self.mappings = mappings self.default = default - def write(self, data:Any, buffer:io.BytesIO, ctx:object=None): + def write(self, data:Any, buffer:io.BytesIO, ctx:Context): watched = getattr(ctx, self.field, None) if watched is not None and watched in self.mappings: return self.mappings[watched].write(data, buffer, ctx=ctx) elif self.default: return self.default.write(data, buffer, ctx=ctx) - def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[Any]: + def read(self, buffer:io.BytesIO, ctx:Context) -> Optional[Any]: watched = getattr(ctx, self.field, None) if watched is not None and watched in self.mappings: return self.mappings[watched].read(buffer, ctx=ctx) @@ -298,44 +314,44 @@ class StructType(Type): # TODO sub objects def __init__(self, *args:Tuple[str, Type]): self.fields = args - def write(self, data:Dict[str, Any], buffer:io.BytesIO, ctx:object=None): + def write(self, data:Dict[str, Any], buffer:io.BytesIO, ctx:Context): for k, t in self.fields: t.write(data[k], buffer, ctx=ctx) - def read(self, buffer:io.BytesIO, ctx:object=None) -> Dict[str, Any]: + def read(self, buffer:io.BytesIO, ctx:Context) -> Dict[str, Any]: return { k : t.read(buffer, ctx=ctx) for k, t in self.fields } class SlotType(Type): pytype : type = Item - def write(self, data:Item, buffer:io.BytesIO, ctx:object=None): + def write(self, data:Item, buffer:io.BytesIO, ctx:Context): new_way = ctx._proto > 340 check_type = Boolean if new_way else Short if data: - check_type.write(True if new_way else data.id, buffer) + check_type.write(True if new_way else data.id, buffer, ctx) if new_way: - VarInt.write(data.id, buffer) - Byte.write(data.count, buffer) + VarInt.write(data.id, buffer, ctx) + Byte.write(data.count, buffer, ctx) if not new_way: - Short.write(data.damage, buffer) - NBTTag.write(data.nbt, buffer) # TODO handle None maybe? + Short.write(data.damage, buffer, ctx) + NBTTag.write(data.nbt, buffer, ctx) # TODO handle None maybe? else: - check_type.write(False if new_way else -1, buffer) + check_type.write(False if new_way else -1, buffer, ctx) - def read(self, buffer:io.BytesIO, ctx:object=None) -> Any: + def read(self, buffer:io.BytesIO, ctx:Context) -> Any: slot : Dict[Any, Any] = {} new_way = ctx._proto > 340 check_type = Boolean if new_way else Short - val = check_type.read(buffer) + val = check_type.read(buffer, ctx) if (new_way and val) or val != -1: if new_way: - slot["id"] = VarInt.read(buffer) + slot["id"] = VarInt.read(buffer, ctx) else: slot["id"] = val - slot["count"] = Byte.read(buffer) + slot["count"] = Byte.read(buffer, ctx) if not new_way: - slot["damage"] = Short.read(buffer) - slot["nbt"] = NBTTag.read(buffer) + slot["damage"] = Short.read(buffer, ctx) + slot["nbt"] = NBTTag.read(buffer, ctx) return Item(**slot) Slot = SlotType() @@ -383,11 +399,11 @@ _ENTITY_METADATA_TYPES_NEW = { class EntityMetadataType(Type): pytype : type = dict - def write(self, data:Dict[int, Any], buffer:io.BytesIO, ctx:object=None): + def write(self, data:Dict[int, Any], buffer:io.BytesIO, ctx:Context): logging.error("Sending entity metadata isn't implemented yet") # TODO - buffer.write(0xFF) + buffer.write(b'\xFF') - def read(self, buffer:io.BytesIO, ctx:object=None) -> Dict[int, Any]: + def read(self, buffer:io.BytesIO, ctx:Context) -> Dict[int, Any]: types_map = _ENTITY_METADATA_TYPES_NEW if ctx._proto > 340 else _ENTITY_METADATA_TYPES out : Dict[int, Any] = {} while True: