no more optional object Context, made a dedicated object

This commit is contained in:
əlemi 2022-01-14 20:07:16 +01:00
parent a3f03756b6
commit ccadefc2f5
2 changed files with 68 additions and 52 deletions

View file

@ -10,7 +10,7 @@ from typing import List, Dict, Set, Optional, AsyncIterator, Type
from cryptography.hazmat.primitives.ciphers import CipherContext
from .mc import proto as minecraft_protocol
from .mc.types import VarInt
from .mc.types import VarInt, Context
from .mc.packet import Packet
from .mc.definitions import ConnectionState
from .util import encryption
@ -222,7 +222,7 @@ class Dispatcher:
buffer = io.BytesIO(data)
if self.compression is not None:
decompressed_size = VarInt.read(buffer)
decompressed_size = VarInt.read(buffer, Context())
if decompressed_size > 0:
decompressor = zlib.decompressobj()
decompressed_data = decompressor.decompress(buffer.read())
@ -230,7 +230,7 @@ class Dispatcher:
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer)
packet_id = VarInt.read(buffer, Context())
if self.state == ConnectionState.PLAY and self._packet_id_whitelist \
and packet_id not in self._packet_id_whitelist:
self._logger.debug("[<--] Received | Packet(0x%02x) (ignored)", packet_id)

View file

@ -1,6 +1,7 @@
import io
import struct
import asyncio
import json
import uuid
import logging
@ -10,27 +11,42 @@ from typing import List, Tuple, Dict, Any, Union, Optional, Callable, Type as Cl
from .definitions import Item
class Context(object):
def __init__(self, **kwargs):
for k, v in kwargs:
setattr(self, k, v)
def __getattr__(self, name) -> Any:
return None # return None rather than raising an exc
def __str__(self) -> str:
return json.dumps(vars(self), indent=2, default=str, sort_keys=True)
def __repr__(self) -> str:
values = ( f"{k}={repr(v)}" for k,v in vars(self).items() )
return f"Context({', '.join(values)})"
class Type(object):
pytype : Union[type, Callable] = lambda x : x
def write(self, data:Any, buffer:io.BytesIO, ctx:object=None) -> None:
def write(self, data:Any, buffer:io.BytesIO, ctx:Context) -> None:
"""Write data to a packet buffer"""
raise NotImplementedError
def read(self, buffer:io.BytesIO, ctx:object=None) -> Any:
def read(self, buffer:io.BytesIO, ctx:Context) -> Any:
"""Read data off a packet buffer"""
raise NotImplementedError
def check(self, ctx:object) -> bool:
def check(self, ctx:Context) -> bool:
"""Check if this type exists in this context"""
return True
class VoidType(Type):
def write(self, v:None, buffer:io.BytesIO, ctx:object=None):
def write(self, v:None, buffer:io.BytesIO, ctx:Context):
pass
def read(self, buffer:io.BytesIO, ctx:object=None) -> None:
def read(self, buffer:io.BytesIO, ctx:Context) -> None:
return None
Void = VoidType()
@ -38,11 +54,11 @@ Void = VoidType()
class UnimplementedDataType(Type):
pytype : type = bytes
def write(self, data:bytes, buffer:io.BytesIO, ctx:object=None):
def write(self, data:bytes, buffer:io.BytesIO, ctx:Context):
if data:
buffer.write(data)
def read(self, buffer:io.BytesIO, ctx:object=None) -> bytes:
def read(self, buffer:io.BytesIO, ctx:Context) -> bytes:
return buffer.read()
TrailingData = UnimplementedDataType()
@ -56,10 +72,10 @@ class PrimitiveType(Type):
self.fmt = fmt
self.size = size
def write(self, data:Any, buffer:io.BytesIO, ctx:object=None):
def write(self, data:Any, buffer:io.BytesIO, ctx:Context):
buffer.write(struct.pack(self.fmt, data))
def read(self, buffer:io.BytesIO, ctx:object=None) -> Any:
def read(self, buffer:io.BytesIO, ctx:Context) -> Any:
return struct.unpack(self.fmt, buffer.read(self.size))[0]
Boolean = PrimitiveType(bool, ">?", 1)
@ -87,14 +103,14 @@ def nbt_to_py(item:pynbt.BaseTag) -> Any:
class NBTType(Type):
pytype : type = dict
def write(self, data:Optional[dict], buffer:io.BytesIO, ctx:object=None):
def write(self, data:Optional[dict], buffer:io.BytesIO, ctx:Context):
if data is None:
buffer.write(b'\x00')
else:
pynbt.NBTFile(value=data).save(buffer)
def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[dict]:
head = Byte.read(buffer)
def read(self, buffer:io.BytesIO, ctx:Context) -> Optional[dict]:
head = Byte.read(buffer, ctx)
if head == 0x0:
return None
buffer.seek(-1,1) # go back 1 byte
@ -109,7 +125,7 @@ class VarLenPrimitive(Type):
def __init__(self, max_bytes:int):
self.max_bytes = max_bytes
def write(self, data:int, buffer:io.BytesIO, ctx:object=None):
def write(self, data:int, buffer:io.BytesIO, ctx:Context):
count = 0 # TODO raise exceptions
while count < self.max_bytes:
byte = data & 0b01111111
@ -121,7 +137,7 @@ class VarLenPrimitive(Type):
if not data:
break
def read(self, buffer:io.BytesIO, ctx:object=None) -> int:
def read(self, buffer:io.BytesIO, ctx:Context) -> int:
numRead = 0
result = 0
while True:
@ -141,13 +157,13 @@ class VarLenPrimitive(Type):
def serialize(self, data:int) -> bytes:
buf = io.BytesIO()
self.write(data, buf)
self.write(data, buf, Context())
buf.seek(0)
return buf.read()
def deserialize(self, data:bytes) -> int:
buf = io.BytesIO(data)
return self.read(buf)
return self.read(buf, Context())
VarInt = VarLenPrimitive(5)
VarLong = VarLenPrimitive(10)
@ -155,12 +171,12 @@ VarLong = VarLenPrimitive(10)
class StringType(Type):
pytype : type = str
def write(self, data:str, buffer:io.BytesIO, ctx:object=None):
def write(self, data:str, buffer:io.BytesIO, ctx:Context):
encoded = data.encode('utf-8')
VarInt.write(len(encoded), buffer, ctx=ctx)
buffer.write(encoded)
def read(self, buffer:io.BytesIO, ctx:object=None) -> str:
def read(self, buffer:io.BytesIO, ctx:Context) -> str:
length = VarInt.read(buffer, ctx=ctx)
return buffer.read(length).decode('utf-8')
@ -175,11 +191,11 @@ class BufferType(Type):
def __init__(self, count:Type = VarInt):
self.count = count
def write(self, data:bytes, buffer:io.BytesIO, ctx:object=None):
def write(self, data:bytes, buffer:io.BytesIO, ctx:Context):
self.count.write(len(data), buffer, ctx=ctx)
buffer.write(data)
def read(self, buffer:io.BytesIO, ctx:object=None) -> bytes:
def read(self, buffer:io.BytesIO, ctx:Context) -> bytes:
length = self.count.read(buffer, ctx=ctx)
return buffer.read(length)
@ -199,14 +215,14 @@ class PositionType(Type):
# TODO THIS IS FOR 1.12.2!!! Make a generic version-less?
def write(self, data:tuple, buffer:io.BytesIO, ctx:object=None):
def write(self, data:tuple, buffer:io.BytesIO, ctx:Context):
packed = ((0x3FFFFFF & data[0]) << 38) \
| ((0xFFF & data[1]) << 26) \
| (0x3FFFFFF & data[2])
UnsignedLong.write(packed, buffer, ctx=ctx)
def read(self, buffer:io.BytesIO, ctx:object=None) -> tuple:
packed = UnsignedLong.read(buffer)
def read(self, buffer:io.BytesIO, ctx:Context) -> tuple:
packed = UnsignedLong.read(buffer, ctx)
x = twos_comp(packed >> 38, 26)
y = (packed >> 26) & 0xFFF
z = twos_comp(packed & 0x3FFFFFF, 26)
@ -218,10 +234,10 @@ class UUIDType(Type):
pytype : type = uuid.UUID
MAX_SIZE : int = 16
def write(self, data:uuid.UUID, buffer:io.BytesIO, ctx:object=None):
def write(self, data:uuid.UUID, buffer:io.BytesIO, ctx:Context):
buffer.write(int(data).to_bytes(self.MAX_SIZE, 'big'))
def read(self, buffer:io.BytesIO, ctx:object=None) -> uuid.UUID:
def read(self, buffer:io.BytesIO, ctx:Context) -> uuid.UUID:
return uuid.UUID(int=int.from_bytes(buffer.read(self.MAX_SIZE), 'big'))
UUID = UUIDType()
@ -235,7 +251,7 @@ class ArrayType(Type):
self.content = content
self.counter = counter
def write(self, data:List[Any], buffer:io.BytesIO, ctx:object=None):
def write(self, data:List[Any], buffer:io.BytesIO, ctx:Context):
if isinstance(self.counter, Type):
self.counter.write(len(data), buffer, ctx=ctx)
for i, el in enumerate(data):
@ -243,7 +259,7 @@ class ArrayType(Type):
if isinstance(self.counter, int) and i >= self.counter:
break # jank but should do
def read(self, buffer:io.BytesIO, ctx:object=None) -> List[Any]:
def read(self, buffer:io.BytesIO, ctx:Context) -> List[Any]:
length = self.counter if isinstance(self.counter, int) else self.counter.read(buffer, ctx=ctx)
out = []
for _ in range(length):
@ -257,12 +273,12 @@ class OptionalType(Type):
self.t = t
self.pytype = t.pytype
def write(self, data:Optional[Any], buffer:io.BytesIO, ctx:object=None):
def write(self, data:Optional[Any], buffer:io.BytesIO, ctx:Context):
Boolean.write(bool(data), buffer, ctx=ctx)
if data:
self.t.write(data, buffer, ctx=ctx)
def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[Any]:
def read(self, buffer:io.BytesIO, ctx:Context) -> Optional[Any]:
if Boolean.read(buffer, ctx=ctx):
return self.t.read(buffer, ctx=ctx)
return None
@ -276,14 +292,14 @@ class SwitchType(Type):
self.mappings = mappings
self.default = default
def write(self, data:Any, buffer:io.BytesIO, ctx:object=None):
def write(self, data:Any, buffer:io.BytesIO, ctx:Context):
watched = getattr(ctx, self.field, None)
if watched is not None and watched in self.mappings:
return self.mappings[watched].write(data, buffer, ctx=ctx)
elif self.default:
return self.default.write(data, buffer, ctx=ctx)
def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[Any]:
def read(self, buffer:io.BytesIO, ctx:Context) -> Optional[Any]:
watched = getattr(ctx, self.field, None)
if watched is not None and watched in self.mappings:
return self.mappings[watched].read(buffer, ctx=ctx)
@ -298,44 +314,44 @@ class StructType(Type): # TODO sub objects
def __init__(self, *args:Tuple[str, Type]):
self.fields = args
def write(self, data:Dict[str, Any], buffer:io.BytesIO, ctx:object=None):
def write(self, data:Dict[str, Any], buffer:io.BytesIO, ctx:Context):
for k, t in self.fields:
t.write(data[k], buffer, ctx=ctx)
def read(self, buffer:io.BytesIO, ctx:object=None) -> Dict[str, Any]:
def read(self, buffer:io.BytesIO, ctx:Context) -> Dict[str, Any]:
return { k : t.read(buffer, ctx=ctx) for k, t in self.fields }
class SlotType(Type):
pytype : type = Item
def write(self, data:Item, buffer:io.BytesIO, ctx:object=None):
def write(self, data:Item, buffer:io.BytesIO, ctx:Context):
new_way = ctx._proto > 340
check_type = Boolean if new_way else Short
if data:
check_type.write(True if new_way else data.id, buffer)
check_type.write(True if new_way else data.id, buffer, ctx)
if new_way:
VarInt.write(data.id, buffer)
Byte.write(data.count, buffer)
VarInt.write(data.id, buffer, ctx)
Byte.write(data.count, buffer, ctx)
if not new_way:
Short.write(data.damage, buffer)
NBTTag.write(data.nbt, buffer) # TODO handle None maybe?
Short.write(data.damage, buffer, ctx)
NBTTag.write(data.nbt, buffer, ctx) # TODO handle None maybe?
else:
check_type.write(False if new_way else -1, buffer)
check_type.write(False if new_way else -1, buffer, ctx)
def read(self, buffer:io.BytesIO, ctx:object=None) -> Any:
def read(self, buffer:io.BytesIO, ctx:Context) -> Any:
slot : Dict[Any, Any] = {}
new_way = ctx._proto > 340
check_type = Boolean if new_way else Short
val = check_type.read(buffer)
val = check_type.read(buffer, ctx)
if (new_way and val) or val != -1:
if new_way:
slot["id"] = VarInt.read(buffer)
slot["id"] = VarInt.read(buffer, ctx)
else:
slot["id"] = val
slot["count"] = Byte.read(buffer)
slot["count"] = Byte.read(buffer, ctx)
if not new_way:
slot["damage"] = Short.read(buffer)
slot["nbt"] = NBTTag.read(buffer)
slot["damage"] = Short.read(buffer, ctx)
slot["nbt"] = NBTTag.read(buffer, ctx)
return Item(**slot)
Slot = SlotType()
@ -383,11 +399,11 @@ _ENTITY_METADATA_TYPES_NEW = {
class EntityMetadataType(Type):
pytype : type = dict
def write(self, data:Dict[int, Any], buffer:io.BytesIO, ctx:object=None):
def write(self, data:Dict[int, Any], buffer:io.BytesIO, ctx:Context):
logging.error("Sending entity metadata isn't implemented yet") # TODO
buffer.write(0xFF)
buffer.write(b'\xFF')
def read(self, buffer:io.BytesIO, ctx:object=None) -> Dict[int, Any]:
def read(self, buffer:io.BytesIO, ctx:Context) -> Dict[int, Any]:
types_map = _ENTITY_METADATA_TYPES_NEW if ctx._proto > 340 else _ENTITY_METADATA_TYPES
out : Dict[int, Any] = {}
while True: