no more optional object Context, made a dedicated object

This commit is contained in:
əlemi 2022-01-14 20:07:16 +01:00
parent a3f03756b6
commit ccadefc2f5
2 changed files with 68 additions and 52 deletions

View file

@ -10,7 +10,7 @@ from typing import List, Dict, Set, Optional, AsyncIterator, Type
from cryptography.hazmat.primitives.ciphers import CipherContext from cryptography.hazmat.primitives.ciphers import CipherContext
from .mc import proto as minecraft_protocol from .mc import proto as minecraft_protocol
from .mc.types import VarInt from .mc.types import VarInt, Context
from .mc.packet import Packet from .mc.packet import Packet
from .mc.definitions import ConnectionState from .mc.definitions import ConnectionState
from .util import encryption from .util import encryption
@ -222,7 +222,7 @@ class Dispatcher:
buffer = io.BytesIO(data) buffer = io.BytesIO(data)
if self.compression is not None: if self.compression is not None:
decompressed_size = VarInt.read(buffer) decompressed_size = VarInt.read(buffer, Context())
if decompressed_size > 0: if decompressed_size > 0:
decompressor = zlib.decompressobj() decompressor = zlib.decompressobj()
decompressed_data = decompressor.decompress(buffer.read()) decompressed_data = decompressor.decompress(buffer.read())
@ -230,7 +230,7 @@ class Dispatcher:
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}") raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
buffer = io.BytesIO(decompressed_data) buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer) packet_id = VarInt.read(buffer, Context())
if self.state == ConnectionState.PLAY and self._packet_id_whitelist \ if self.state == ConnectionState.PLAY and self._packet_id_whitelist \
and packet_id not in self._packet_id_whitelist: and packet_id not in self._packet_id_whitelist:
self._logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id) self._logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)

View file

@ -1,6 +1,7 @@
import io import io
import struct import struct
import asyncio import asyncio
import json
import uuid import uuid
import logging import logging
@ -10,27 +11,42 @@ from typing import List, Tuple, Dict, Any, Union, Optional, Callable, Type as Cl
from .definitions import Item from .definitions import Item
class Context(object):
def __init__(self, **kwargs):
for k, v in kwargs:
setattr(self, k, v)
def __getattr__(self, name) -> Any:
return None # return None rather than raising an exc
def __str__(self) -> str:
return json.dumps(vars(self), indent=2, default=str, sort_keys=True)
def __repr__(self) -> str:
values = ( f"{k}={repr(v)}" for k,v in vars(self).items() )
return f"Context({', '.join(values)})"
class Type(object): class Type(object):
pytype : Union[type, Callable] = lambda x : x pytype : Union[type, Callable] = lambda x : x
def write(self, data:Any, buffer:io.BytesIO, ctx:object=None) -> None: def write(self, data:Any, buffer:io.BytesIO, ctx:Context) -> None:
"""Write data to a packet buffer""" """Write data to a packet buffer"""
raise NotImplementedError raise NotImplementedError
def read(self, buffer:io.BytesIO, ctx:object=None) -> Any: def read(self, buffer:io.BytesIO, ctx:Context) -> Any:
"""Read data off a packet buffer""" """Read data off a packet buffer"""
raise NotImplementedError raise NotImplementedError
def check(self, ctx:object) -> bool: def check(self, ctx:Context) -> bool:
"""Check if this type exists in this context""" """Check if this type exists in this context"""
return True return True
class VoidType(Type): class VoidType(Type):
def write(self, v:None, buffer:io.BytesIO, ctx:object=None): def write(self, v:None, buffer:io.BytesIO, ctx:Context):
pass pass
def read(self, buffer:io.BytesIO, ctx:object=None) -> None: def read(self, buffer:io.BytesIO, ctx:Context) -> None:
return None return None
Void = VoidType() Void = VoidType()
@ -38,11 +54,11 @@ Void = VoidType()
class UnimplementedDataType(Type): class UnimplementedDataType(Type):
pytype : type = bytes pytype : type = bytes
def write(self, data:bytes, buffer:io.BytesIO, ctx:object=None): def write(self, data:bytes, buffer:io.BytesIO, ctx:Context):
if data: if data:
buffer.write(data) buffer.write(data)
def read(self, buffer:io.BytesIO, ctx:object=None) -> bytes: def read(self, buffer:io.BytesIO, ctx:Context) -> bytes:
return buffer.read() return buffer.read()
TrailingData = UnimplementedDataType() TrailingData = UnimplementedDataType()
@ -56,10 +72,10 @@ class PrimitiveType(Type):
self.fmt = fmt self.fmt = fmt
self.size = size self.size = size
def write(self, data:Any, buffer:io.BytesIO, ctx:object=None): def write(self, data:Any, buffer:io.BytesIO, ctx:Context):
buffer.write(struct.pack(self.fmt, data)) buffer.write(struct.pack(self.fmt, data))
def read(self, buffer:io.BytesIO, ctx:object=None) -> Any: def read(self, buffer:io.BytesIO, ctx:Context) -> Any:
return struct.unpack(self.fmt, buffer.read(self.size))[0] return struct.unpack(self.fmt, buffer.read(self.size))[0]
Boolean = PrimitiveType(bool, ">?", 1) Boolean = PrimitiveType(bool, ">?", 1)
@ -87,14 +103,14 @@ def nbt_to_py(item:pynbt.BaseTag) -> Any:
class NBTType(Type): class NBTType(Type):
pytype : type = dict pytype : type = dict
def write(self, data:Optional[dict], buffer:io.BytesIO, ctx:object=None): def write(self, data:Optional[dict], buffer:io.BytesIO, ctx:Context):
if data is None: if data is None:
buffer.write(b'\x00') buffer.write(b'\x00')
else: else:
pynbt.NBTFile(value=data).save(buffer) pynbt.NBTFile(value=data).save(buffer)
def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[dict]: def read(self, buffer:io.BytesIO, ctx:Context) -> Optional[dict]:
head = Byte.read(buffer) head = Byte.read(buffer, ctx)
if head == 0x0: if head == 0x0:
return None return None
buffer.seek(-1,1) # go back 1 byte buffer.seek(-1,1) # go back 1 byte
@ -109,7 +125,7 @@ class VarLenPrimitive(Type):
def __init__(self, max_bytes:int): def __init__(self, max_bytes:int):
self.max_bytes = max_bytes self.max_bytes = max_bytes
def write(self, data:int, buffer:io.BytesIO, ctx:object=None): def write(self, data:int, buffer:io.BytesIO, ctx:Context):
count = 0 # TODO raise exceptions count = 0 # TODO raise exceptions
while count < self.max_bytes: while count < self.max_bytes:
byte = data & 0b01111111 byte = data & 0b01111111
@ -121,7 +137,7 @@ class VarLenPrimitive(Type):
if not data: if not data:
break break
def read(self, buffer:io.BytesIO, ctx:object=None) -> int: def read(self, buffer:io.BytesIO, ctx:Context) -> int:
numRead = 0 numRead = 0
result = 0 result = 0
while True: while True:
@ -141,13 +157,13 @@ class VarLenPrimitive(Type):
def serialize(self, data:int) -> bytes: def serialize(self, data:int) -> bytes:
buf = io.BytesIO() buf = io.BytesIO()
self.write(data, buf) self.write(data, buf, Context())
buf.seek(0) buf.seek(0)
return buf.read() return buf.read()
def deserialize(self, data:bytes) -> int: def deserialize(self, data:bytes) -> int:
buf = io.BytesIO(data) buf = io.BytesIO(data)
return self.read(buf) return self.read(buf, Context())
VarInt = VarLenPrimitive(5) VarInt = VarLenPrimitive(5)
VarLong = VarLenPrimitive(10) VarLong = VarLenPrimitive(10)
@ -155,12 +171,12 @@ VarLong = VarLenPrimitive(10)
class StringType(Type): class StringType(Type):
pytype : type = str pytype : type = str
def write(self, data:str, buffer:io.BytesIO, ctx:object=None): def write(self, data:str, buffer:io.BytesIO, ctx:Context):
encoded = data.encode('utf-8') encoded = data.encode('utf-8')
VarInt.write(len(encoded), buffer, ctx=ctx) VarInt.write(len(encoded), buffer, ctx=ctx)
buffer.write(encoded) buffer.write(encoded)
def read(self, buffer:io.BytesIO, ctx:object=None) -> str: def read(self, buffer:io.BytesIO, ctx:Context) -> str:
length = VarInt.read(buffer, ctx=ctx) length = VarInt.read(buffer, ctx=ctx)
return buffer.read(length).decode('utf-8') return buffer.read(length).decode('utf-8')
@ -175,11 +191,11 @@ class BufferType(Type):
def __init__(self, count:Type = VarInt): def __init__(self, count:Type = VarInt):
self.count = count self.count = count
def write(self, data:bytes, buffer:io.BytesIO, ctx:object=None): def write(self, data:bytes, buffer:io.BytesIO, ctx:Context):
self.count.write(len(data), buffer, ctx=ctx) self.count.write(len(data), buffer, ctx=ctx)
buffer.write(data) buffer.write(data)
def read(self, buffer:io.BytesIO, ctx:object=None) -> bytes: def read(self, buffer:io.BytesIO, ctx:Context) -> bytes:
length = self.count.read(buffer, ctx=ctx) length = self.count.read(buffer, ctx=ctx)
return buffer.read(length) return buffer.read(length)
@ -199,14 +215,14 @@ class PositionType(Type):
# TODO THIS IS FOR 1.12.2!!! Make a generic version-less? # TODO THIS IS FOR 1.12.2!!! Make a generic version-less?
def write(self, data:tuple, buffer:io.BytesIO, ctx:object=None): def write(self, data:tuple, buffer:io.BytesIO, ctx:Context):
packed = ((0x3FFFFFF & data[0]) << 38) \ packed = ((0x3FFFFFF & data[0]) << 38) \
| ((0xFFF & data[1]) << 26) \ | ((0xFFF & data[1]) << 26) \
| (0x3FFFFFF & data[2]) | (0x3FFFFFF & data[2])
UnsignedLong.write(packed, buffer, ctx=ctx) UnsignedLong.write(packed, buffer, ctx=ctx)
def read(self, buffer:io.BytesIO, ctx:object=None) -> tuple: def read(self, buffer:io.BytesIO, ctx:Context) -> tuple:
packed = UnsignedLong.read(buffer) packed = UnsignedLong.read(buffer, ctx)
x = twos_comp(packed >> 38, 26) x = twos_comp(packed >> 38, 26)
y = (packed >> 26) & 0xFFF y = (packed >> 26) & 0xFFF
z = twos_comp(packed & 0x3FFFFFF, 26) z = twos_comp(packed & 0x3FFFFFF, 26)
@ -218,10 +234,10 @@ class UUIDType(Type):
pytype : type = uuid.UUID pytype : type = uuid.UUID
MAX_SIZE : int = 16 MAX_SIZE : int = 16
def write(self, data:uuid.UUID, buffer:io.BytesIO, ctx:object=None): def write(self, data:uuid.UUID, buffer:io.BytesIO, ctx:Context):
buffer.write(int(data).to_bytes(self.MAX_SIZE, 'big')) buffer.write(int(data).to_bytes(self.MAX_SIZE, 'big'))
def read(self, buffer:io.BytesIO, ctx:object=None) -> uuid.UUID: def read(self, buffer:io.BytesIO, ctx:Context) -> uuid.UUID:
return uuid.UUID(int=int.from_bytes(buffer.read(self.MAX_SIZE), 'big')) return uuid.UUID(int=int.from_bytes(buffer.read(self.MAX_SIZE), 'big'))
UUID = UUIDType() UUID = UUIDType()
@ -235,7 +251,7 @@ class ArrayType(Type):
self.content = content self.content = content
self.counter = counter self.counter = counter
def write(self, data:List[Any], buffer:io.BytesIO, ctx:object=None): def write(self, data:List[Any], buffer:io.BytesIO, ctx:Context):
if isinstance(self.counter, Type): if isinstance(self.counter, Type):
self.counter.write(len(data), buffer, ctx=ctx) self.counter.write(len(data), buffer, ctx=ctx)
for i, el in enumerate(data): for i, el in enumerate(data):
@ -243,7 +259,7 @@ class ArrayType(Type):
if isinstance(self.counter, int) and i >= self.counter: if isinstance(self.counter, int) and i >= self.counter:
break # jank but should do break # jank but should do
def read(self, buffer:io.BytesIO, ctx:object=None) -> List[Any]: def read(self, buffer:io.BytesIO, ctx:Context) -> List[Any]:
length = self.counter if isinstance(self.counter, int) else self.counter.read(buffer, ctx=ctx) length = self.counter if isinstance(self.counter, int) else self.counter.read(buffer, ctx=ctx)
out = [] out = []
for _ in range(length): for _ in range(length):
@ -257,12 +273,12 @@ class OptionalType(Type):
self.t = t self.t = t
self.pytype = t.pytype self.pytype = t.pytype
def write(self, data:Optional[Any], buffer:io.BytesIO, ctx:object=None): def write(self, data:Optional[Any], buffer:io.BytesIO, ctx:Context):
Boolean.write(bool(data), buffer, ctx=ctx) Boolean.write(bool(data), buffer, ctx=ctx)
if data: if data:
self.t.write(data, buffer, ctx=ctx) self.t.write(data, buffer, ctx=ctx)
def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[Any]: def read(self, buffer:io.BytesIO, ctx:Context) -> Optional[Any]:
if Boolean.read(buffer, ctx=ctx): if Boolean.read(buffer, ctx=ctx):
return self.t.read(buffer, ctx=ctx) return self.t.read(buffer, ctx=ctx)
return None return None
@ -276,14 +292,14 @@ class SwitchType(Type):
self.mappings = mappings self.mappings = mappings
self.default = default self.default = default
def write(self, data:Any, buffer:io.BytesIO, ctx:object=None): def write(self, data:Any, buffer:io.BytesIO, ctx:Context):
watched = getattr(ctx, self.field, None) watched = getattr(ctx, self.field, None)
if watched is not None and watched in self.mappings: if watched is not None and watched in self.mappings:
return self.mappings[watched].write(data, buffer, ctx=ctx) return self.mappings[watched].write(data, buffer, ctx=ctx)
elif self.default: elif self.default:
return self.default.write(data, buffer, ctx=ctx) return self.default.write(data, buffer, ctx=ctx)
def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[Any]: def read(self, buffer:io.BytesIO, ctx:Context) -> Optional[Any]:
watched = getattr(ctx, self.field, None) watched = getattr(ctx, self.field, None)
if watched is not None and watched in self.mappings: if watched is not None and watched in self.mappings:
return self.mappings[watched].read(buffer, ctx=ctx) return self.mappings[watched].read(buffer, ctx=ctx)
@ -298,44 +314,44 @@ class StructType(Type): # TODO sub objects
def __init__(self, *args:Tuple[str, Type]): def __init__(self, *args:Tuple[str, Type]):
self.fields = args self.fields = args
def write(self, data:Dict[str, Any], buffer:io.BytesIO, ctx:object=None): def write(self, data:Dict[str, Any], buffer:io.BytesIO, ctx:Context):
for k, t in self.fields: for k, t in self.fields:
t.write(data[k], buffer, ctx=ctx) t.write(data[k], buffer, ctx=ctx)
def read(self, buffer:io.BytesIO, ctx:object=None) -> Dict[str, Any]: def read(self, buffer:io.BytesIO, ctx:Context) -> Dict[str, Any]:
return { k : t.read(buffer, ctx=ctx) for k, t in self.fields } return { k : t.read(buffer, ctx=ctx) for k, t in self.fields }
class SlotType(Type): class SlotType(Type):
pytype : type = Item pytype : type = Item
def write(self, data:Item, buffer:io.BytesIO, ctx:object=None): def write(self, data:Item, buffer:io.BytesIO, ctx:Context):
new_way = ctx._proto > 340 new_way = ctx._proto > 340
check_type = Boolean if new_way else Short check_type = Boolean if new_way else Short
if data: if data:
check_type.write(True if new_way else data.id, buffer) check_type.write(True if new_way else data.id, buffer, ctx)
if new_way: if new_way:
VarInt.write(data.id, buffer) VarInt.write(data.id, buffer, ctx)
Byte.write(data.count, buffer) Byte.write(data.count, buffer, ctx)
if not new_way: if not new_way:
Short.write(data.damage, buffer) Short.write(data.damage, buffer, ctx)
NBTTag.write(data.nbt, buffer) # TODO handle None maybe? NBTTag.write(data.nbt, buffer, ctx) # TODO handle None maybe?
else: else:
check_type.write(False if new_way else -1, buffer) check_type.write(False if new_way else -1, buffer, ctx)
def read(self, buffer:io.BytesIO, ctx:object=None) -> Any: def read(self, buffer:io.BytesIO, ctx:Context) -> Any:
slot : Dict[Any, Any] = {} slot : Dict[Any, Any] = {}
new_way = ctx._proto > 340 new_way = ctx._proto > 340
check_type = Boolean if new_way else Short check_type = Boolean if new_way else Short
val = check_type.read(buffer) val = check_type.read(buffer, ctx)
if (new_way and val) or val != -1: if (new_way and val) or val != -1:
if new_way: if new_way:
slot["id"] = VarInt.read(buffer) slot["id"] = VarInt.read(buffer, ctx)
else: else:
slot["id"] = val slot["id"] = val
slot["count"] = Byte.read(buffer) slot["count"] = Byte.read(buffer, ctx)
if not new_way: if not new_way:
slot["damage"] = Short.read(buffer) slot["damage"] = Short.read(buffer, ctx)
slot["nbt"] = NBTTag.read(buffer) slot["nbt"] = NBTTag.read(buffer, ctx)
return Item(**slot) return Item(**slot)
Slot = SlotType() Slot = SlotType()
@ -383,11 +399,11 @@ _ENTITY_METADATA_TYPES_NEW = {
class EntityMetadataType(Type): class EntityMetadataType(Type):
pytype : type = dict pytype : type = dict
def write(self, data:Dict[int, Any], buffer:io.BytesIO, ctx:object=None): def write(self, data:Dict[int, Any], buffer:io.BytesIO, ctx:Context):
logging.error("Sending entity metadata isn't implemented yet") # TODO logging.error("Sending entity metadata isn't implemented yet") # TODO
buffer.write(0xFF) buffer.write(b'\xFF')
def read(self, buffer:io.BytesIO, ctx:object=None) -> Dict[int, Any]: def read(self, buffer:io.BytesIO, ctx:Context) -> Dict[int, Any]:
types_map = _ENTITY_METADATA_TYPES_NEW if ctx._proto > 340 else _ENTITY_METADATA_TYPES types_map = _ENTITY_METADATA_TYPES_NEW if ctx._proto > 340 else _ENTITY_METADATA_TYPES
out : Dict[int, Any] = {} out : Dict[int, Any] = {}
while True: while True: