diff --git a/aiocraft/mc/packet.py b/aiocraft/mc/packet.py index 1c5ee98..01b0061 100644 --- a/aiocraft/mc/packet.py +++ b/aiocraft/mc/packet.py @@ -23,7 +23,7 @@ class Packet: self.definition = self._definitions[proto] self.id = self._ids[proto] for name, t in self.definition: - setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None) + setattr(self, name, t.pytype(kwargs[name]) if name in kwargs else None) @property def processed(self) -> Event: @@ -32,14 +32,18 @@ class Packet: @classmethod def deserialize(cls, proto:int, buffer:io.BytesIO): - return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._definitions[proto] }) + pkt = cls(proto) + for k, t in cls._definitions[proto]: + setattr(pkt, k, t.read(buffer, ctx=pkt)) + return pkt + # return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._definitions[proto] }) def serialize(self) -> io.BytesIO: buf = io.BytesIO() VarInt.write(self.id, buf) for name, t in self.definition: if getattr(self, name, None) is not None: # minecraft proto has no null type: this is an optional field left unset - t.write(getattr(self, name, None), buf) + t.write(getattr(self, name), buf, ctx=self) buf.seek(0) return buf diff --git a/aiocraft/mc/types.py b/aiocraft/mc/types.py index 844a51d..edfd21b 100644 --- a/aiocraft/mc/types.py +++ b/aiocraft/mc/types.py @@ -3,99 +3,77 @@ import struct import asyncio import uuid -from typing import List, Any, Optional, Type as Class +from typing import List, Tuple, Dict, Any, Optional, Type as Class class Type(object): - _pytype : type - _size : int - _fmt : str + pytype : type - # These methods will work only for fixed size data, if _size and _fmt are defined. - # For anything variabile in size, define custom read() and write() classmethods + def write(self, data:Any, buffer:io.BytesIO, ctx:object=None) -> None: + """Write data to a packet buffer""" + raise NotImplementedError + + def read(self, buffer:io.BytesIO, ctx:object=None) -> Any: + """Read data off a packet buffer""" + raise NotImplementedError - @classmethod - def write(cls, data:Any, buffer:io.BytesIO): - buffer.write(struct.pack(cls._fmt, data)) + def check(self, ctx:object) -> bool: + """Check if this type exists in this context""" + return True - @classmethod - def read(cls, buffer:io.BytesIO) -> Any: - return struct.unpack(cls._fmt, buffer.read(cls._size))[0] +class UnimplementedDataType(Type): + pytype : type = bytes -class TrailingByteArray(Type): - _pytype : type = bytes - - @classmethod - def write(cls, data:bytes, buffer:io.BytesIO): + def write(self, data:bytes, buffer:io.BytesIO, ctx:object=None): if data: buffer.write(data) - @classmethod - def read(cls, buffer:io.BytesIO) -> bytes: + def read(self, buffer:io.BytesIO, ctx:object=None) -> bytes: return buffer.read() -class Boolean(Type): - _pytype : type = bool - _size : int = 1 - _fmt : str = ">?" +TrailingData = UnimplementedDataType() +EntityMetadata = UnimplementedDataType() +EntityMetadataItem = UnimplementedDataType() +NBTTag = UnimplementedDataType() +Slot = UnimplementedDataType() -class Byte(Type): - _pytype : type = int - _size : int = 1 - _fmt : str = ">b" +class PrimitiveType(Type): + size : int + fmt : str -class UnsignedByte(Type): - _pytype : type = int - _size : int = 1 - _fmt : str = ">B" + def __init__(self, pytype:type, fmt:str, size:int): + self.pytype = pytype + self.fmt = fmt + self.size = size -class Short(Type): - _pytype : type = int - _size : int = 2 - _fmt : str = ">h" + def write(self, data:Any, buffer:io.BytesIO, ctx:object=None): + buffer.write(struct.pack(self.fmt, data)) -class UnsignedShort(Type): - _pytype : type = int - _size : int = 2 - _fmt : str = ">H" + def read(self, buffer:io.BytesIO, ctx:object=None) -> Any: + return struct.unpack(self.fmt, buffer.read(self.size))[0] -class Int(Type): - _pytype : type = int - _size : int = 4 - _fmt : str = ">i" +Boolean = PrimitiveType(bool, ">?", 1) +Byte = PrimitiveType(int, ">b", 1) +UnsignedByte = PrimitiveType(int, ">B", 1) +Short = PrimitiveType(int, ">h", 2) +UnsignedShort = PrimitiveType(int, ">H", 2) +Int = PrimitiveType(int, ">i", 4) +UnsignedInt = PrimitiveType(int, ">I", 4) +Long = PrimitiveType(int, ">q", 8) +UnsignedLong = PrimitiveType(int, ">Q", 8) +Float = PrimitiveType(float, ">f", 4) +Double = PrimitiveType(float, ">d", 8) +Angle = PrimitiveType(int, ">b", 1) -class UnsignedInt(Type): - _pytype : type = int - _size : int = 4 - _fmt : str = ">I" +class VarLenPrimitive(Type): + pytype : type = int + max_bytes : int -class Long(Type): - _pytype : type = int - _size : int = 8 - _fmt : str = ">q" + def __init__(self, max_bytes:int): + self.max_bytes = max_bytes -class UnsignedLong(Type): - _pytype : type = int - _size : int = 8 - _fmt : str = ">Q" - -class Float(Type): - _pytype : type = float - _size : int = 4 - _fmt : str = ">f" - -class Double(Type): - _pytype : type = float - _size : int = 8 - _fmt : str = ">d" - -class VarInt(Type): - _pytype : type = int - _size = 5 - - @classmethod - def write(cls, data:int, buffer:io.BytesIO): + def write(self, data:int, buffer:io.BytesIO, ctx:object=None): count = 0 - while count < cls._size: + while count < self.max_bytes: byte = data & 0b01111111 data >>= 7 if data > 0: @@ -105,179 +83,173 @@ class VarInt(Type): if not data: break - @classmethod - def read(cls, buffer:io.BytesIO) -> int: + def read(self, buffer:io.BytesIO, ctx:object=None) -> int: numRead = 0 result = 0 while True: data = buffer.read(1) if len(data) < 1: - raise ValueError("VarInt is too short") + raise ValueError("VarInt/VarLong is too short") buf = int.from_bytes(data, 'little') result |= (buf & 0b01111111) << (7 * numRead) numRead +=1 - if numRead > cls._size: - raise ValueError("VarInt is too big") + if numRead > self.max_bytes: + raise ValueError("VarInt/VarLong is too big") if buf & 0b10000000 == 0: break return result - @classmethod - def serialize(cls, data:int) -> bytes: + # utility methods since VarInt is super used + + def serialize(self, data:int) -> bytes: buf = io.BytesIO() - cls.write(data, buf) + self.write(data, buf) buf.seek(0) return buf.read() - @classmethod - def deserialize(cls, data:bytes) -> int: + def deserialize(self, data:bytes) -> int: buf = io.BytesIO(data) - return cls.read(buf) + return self.read(buf, ctx=ctx) -class VarLong(VarInt): - _pytype : type = int - _size = 10 +VarInt = VarLenPrimitive(5) +VarLong = VarLenPrimitive(10) -class EntityMetadata(TrailingByteArray): - # TODO - pass +class StringType(Type): + pytype : type = str -class Slot(TrailingByteArray): - _pytype : type = bytes - # TODO - pass - - -class Maybe(Type): # TODO better name without - _t : Class[Type] = TrailingByteArray - _pytype : type = bytes - - def __init__(self, t:Class[Type]): - self._t = t - self._pytype = t._pytype - self._size = Boolean._size + t._size - - @classmethod - def write(cls, data:Optional[Any], buffer:io.BytesIO): - Boolean.write(bool(data), buffer) - if data: - cls._t.write(data, buffer) - - @classmethod - def read(cls, buffer:io.BytesIO) -> Optional[Any]: - if Boolean.read(buffer): - return cls._t.read(buffer) - return None - -class Array(Type): - _counter : Class[Type] = VarInt - _content : Class[Type] = Byte - _pytype : type = bytes - - def __init__(self, content:Class[Type] = Byte, counter:Class[Type] = VarInt): - self._content = content - self._counter = counter - - @classmethod - def write(cls, data:List[Any], buffer:io.BytesIO): - cls._counter.write(len(data), buffer) - for el in data: - cls._content.write(el, buffer) - - @classmethod - def read(cls, buffer:io.BytesIO) -> List[Any]: - length = cls._counter.read(buffer) - return [ cls._content.read(buffer) for _ in range(length) ] - -class String(Type): - _pytype : type = str - - @classmethod - def write(cls, data:str, buffer:io.BytesIO): + def write(self, data:str, buffer:io.BytesIO, ctx:object=None): encoded = data.encode('utf-8') - VarInt.write(len(encoded), buffer) + VarInt.write(len(encoded), buffer, ctx=ctx) buffer.write(encoded) - @classmethod - def read(cls, buffer:io.BytesIO) -> str: - length = VarInt.read(buffer) + def read(self, buffer:io.BytesIO, ctx:object=None) -> str: + length = VarInt.read(buffer, ctx=ctx) return buffer.read(length).decode('utf-8') -class ByteArray(Type): - _pytype : type = bytes +String = StringType() +Chat = StringType() +Identifier = StringType() - @classmethod - def write(cls, data:bytes, buffer:io.BytesIO): - VarInt.write(len(data), buffer) +class BufferType(Type): + pytype : type = bytes + count : Type + + def __init__(self, count:Type = VarInt): + self.count = count + + def write(self, data:bytes, buffer:io.BytesIO, ctx:object=None): + self.count.write(len(data), buffer, ctx=ctx) buffer.write(data) - @classmethod - def read(cls, buffer:io.BytesIO) -> bytes: - length = VarInt.read(buffer) + def read(self, buffer:io.BytesIO, ctx:object=None) -> bytes: + length = self.count.read(buffer, ctx=ctx) return buffer.read(length) -class IntegerByteArray(Type): - _pytype : type = bytes +ByteArray = BufferType() +IntegerByteArray = BufferType(Int) - @classmethod - def write(cls, data:bytes, buffer:io.BytesIO): - Int.write(len(data), buffer) - buffer.write(data) +class PositionType(Type): + pytype : type = tuple + MAX_SIZE : int = 8 - @classmethod - def read(cls, buffer:io.BytesIO) -> bytes: - length = Int.read(buffer) - return buffer.read(length) + # TODO THIS IS FOR 1.12.2!!! Make a generic version-less? -class Chat(String): - _pytype : type = str - -class Identifier(String): - _pytype : type = str - -class Angle(Type): - _pytype : type = int - _size : int = 1 - _fmt : str = ">b" - -class EntityMetadataItem(Type): - _pytype : type = bytes - # TODO - pass - -class NBTTag(Type): - _pytype : type = bytes - # TODO - pass - -class Position(Type): - _pytype : type = tuple - _size = 8 - - # TODO THIS IS FOR 1.12.2!!! - - @classmethod - def write(cls, data:tuple, buffer:io.BytesIO): + def write(self, data:tuple, buffer:io.BytesIO, ctx:object=None): packed = ((0x3FFFFFF & data[0]) << 38) | ((0xFFF & data[1]) << 26) | (0x3FFFFFF & data[2]) - UnsignedLong.write(packed, buffer) + UnsignedLong.write(packed, buffer, ctx=ctx) - @classmethod - def read(cls, buffer:io.BytesIO) -> tuple: + def read(self, buffer:io.BytesIO, ctx:object=None) -> tuple: packed = UnsignedLong.read(buffer) x = packed >> 38 y = (packed >> 24) & 0xFFF z = packed & 0x3FFFFFF return (x, y, z) -class UUID(Type): - _pytype : type = str - _size = 16 +Position = PositionType() - @classmethod - def write(cls, data:uuid.UUID, buffer:io.BytesIO): - buffer.write(int(data).to_bytes(cls._size, 'big')) +class UUIDType(Type): + pytype : type = str # TODO maybe use partial with uuid constructor? + MAX_SIZE : int = 16 + + def write(self, data:uuid.UUID, buffer:io.BytesIO, ctx:object=None): + buffer.write(int(data).to_bytes(self.MAX_SIZE, 'big')) + + def read(self, buffer:io.BytesIO, ctx:object=None) -> uuid.UUID: + return uuid.UUID(int=int.from_bytes(buffer.read(self.MAX_SIZE), 'big')) + +UUID = UUIDType() + +class ArrayType(Type): + pytype : type = list + counter : Type + content : Type + + def __init__(self, content:Type, counter:Type = VarInt): + self.content = content + self.counter = counter + + def write(self, data:List[Any], buffer:io.BytesIO, ctx:object=None): + self.counter.write(len(data), buffer, ctx=ctx) + for el in data: + self.content.write(el, buffer, ctx=ctx) + + def read(self, buffer:io.BytesIO, ctx:object=None) -> List[Any]: + length = self.counter.read(buffer, ctx=ctx) + return [ self.content.read(buffer, ctx=ctx) for _ in range(length) ] + +class OptionalType(Type): + t : Type + + def __init__(self, t:Type): + self.t = t + self.pytype = t.pytype + + def write(self, data:Optional[Any], buffer:io.BytesIO, ctx:object=None): + Boolean.write(bool(data), buffer, ctx=ctx) + if data: + self.t.write(data, buffer, ctx=ctx) + + def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[Any]: + if Boolean.read(buffer, ctx=ctx): + return self.t.read(buffer, ctx=ctx) + return None + +class SwitchType(Type): + field : str + mappings : Dict[Any, Type] + + def __init__(self, watch:str, mappings:Dict[Any, Type], default:Type = None): + self.field = watch + self.mappings = mappings + self.default = default + + def write(self, data:Any, buffer:io.BytesIO, ctx:object=None): + watched = getattr(ctx, self.field) + if watched in self.mappings: + return self.mappings[watched].write(data, buffer, ctx=ctx) + elif self.default: + return self.default.write(data, buffer, ctx=ctx) + + def read(self, buffer:io.BytesIO, ctx:object=None) -> Dict[str, Any]: + watched = getattr(ctx, self.field) + if watched in self.mappings: + return self.mappings[watched].read(buffer, ctx=ctx) + elif self.default: + return self.default.read(buffer, ctx=ctx) + return {} + +class StructType(Type): + pytype : type = dict + fields : Tuple[Tuple[str, Type], ...] + + def __init__(self, *args:Tuple[str, Type]): + self.fields = args + + def write(self, data:Dict[str, Any], buffer:io.BytesIO, ctx:object=None): + for k, t in self.fields: + t.write(data[k], buffer, ctx=ctx) + + def read(self, buffer:io.BytesIO, ctx:object=None) -> Dict[str, Any]: + return { k : t.read(buffer, ctx=ctx) for k, t in self.fields } - @classmethod - def read(cls, buffer:io.BytesIO) -> uuid.UUID: - return uuid.UUID(int=int.from_bytes(buffer.read(cls._size), 'big')) diff --git a/compiler/proto.py b/compiler/proto.py index c707672..c7faadc 100644 --- a/compiler/proto.py +++ b/compiler/proto.py @@ -5,7 +5,7 @@ import keyword import logging from pathlib import Path -from typing import List, Dict, Union, Type as Class +from typing import Tuple, List, Dict, Union, Set, Type as Class from aiocraft.mc.types import * @@ -13,7 +13,7 @@ from aiocraft.mc.types import * DIR_MAP = {"toClient": "clientbound", "toServer": "serverbound"} PREFACE = """\"\"\"[!] This file is autogenerated\"\"\"\n\n""" -IMPORTS = """from typing import Tuple, List, Dict +IMPORTS = """from typing import Tuple, List, Dict, Union from ....packet import Packet from ....types import *\n""" IMPORT_ALL = """__all__ = [\n\t{all}\n]\n""" @@ -30,91 +30,226 @@ class {name}(Packet): _definitions : Dict[int, List[Tuple[str, Type]]] = {definitions} """ +class Ref: + name : str + args : tuple + + def __equals__(self, other) -> bool: + if self.args: + return self.name == other.name and self.args == other.args + return self.name == other.name + + def __init__(self, name:str, *args): + self.name = name or "anon" + self.args = args + + def __repr__(self) -> str: + if self.args: + out = self.name + "(" + for arg in self.args: + out += repr(arg) + ", " + out += ")" + return out + return self.name + TYPE_MAP = { - "varint": VarInt, - "u8": Byte, - "i8": Byte, - "u16": UnsignedShort, - "u32": UnsignedInt, - "u64": UnsignedLong, - "i16": Short, - "i32": Int, - "i64": Long, - "f32": Float, - "f64": Double, - "bool": Boolean, - "UUID": UUID, - "string": String, - "nbt": NBTTag, - "slot": Slot, - "position": Position, - "entityMetadataItem": EntityMetadataItem, - "entityMetadata": EntityMetadata, + "varint": Ref('VarInt'), + "u8": Ref('Byte'), + "i8": Ref('Byte'), + "u16": Ref('UnsignedShort'), + "u32": Ref('UnsignedInt'), + "u64": Ref('UnsignedLong'), + "i16": Ref('Short'), + "i32": Ref('Int'), + "i64": Ref('Long'), + "f32": Ref('Float'), + "f64": Ref('Double'), + "bool": Ref('Boolean'), + "UUID": Ref('UUID'), + "string": Ref('String'), + "nbt": Ref('NBTTag'), + "slot": Ref('Slot'), + "position": Ref('Position'), + "entityMetadataItem": Ref('EntityMetadataItem'), + "entityMetadata": Ref('EntityMetadata'), } -def mctype(slot_type:Any) -> Class[Type]: +HINT_MAP = { + "varint": 'int', + "u8": 'int', + "i8": 'int', + "u16": 'int', + "u32": 'int', + "u64": 'int', + "i16": 'int', + "i32": 'int', + "i64": 'int', + "f32": 'float', + "f64": 'float', + "bool": 'bool', + "UUID": 'str', + "string": 'str', + "nbt": 'bytes', + "slot": 'dict', + "position": 'tuple', + "entityMetadataItem": 'bytes', + "entityMetadata": 'bytes', +} + +def _format_line(i, depth:int=0) -> str: + nl = ('\n' if depth > 0 else " ") + tab = '\t' * depth + return nl + tab + \ + f",{nl}{tab}".join(f"{repr(e)}" for e in i) + \ + nl + ('\t' * (depth-1)) + +def format_dict(d:dict, depth:int=1) -> str: + return "{" + _format_line((Ref(f"{k} : {v}") for k,v in sorted(d.items())), depth) + "}" + +def format_list(l:list, depth:int=0) -> str: + return "[" + _format_line(l, depth) + "]" + +def format_tuple(l:list, depth:int=0) -> str: + return "(" + _format_line(l, depth) + ")" + +def mctype(slot_type:Any) -> Ref: if isinstance(slot_type, str) and slot_type in TYPE_MAP: return TYPE_MAP[slot_type] if isinstance(slot_type, list): - name = slot_type[0] - if name == "buffer": - if "countType" in slot_type[1] and slot_type[1]["countType"] == "integer": - return IntegerByteArray - return ByteArray - # TODO composite data types - return TrailingByteArray + t = slot_type[0] + v = slot_type[1] + if t == "buffer": # Array of bytes + if "countType" in v and v["countType"] == "integer": + return Ref('IntegerByteArray') + return Ref('ByteArray') + elif t == "array": # Generic array + return Ref('ArrayType', mctype(v["type"]), (mctype(v["countType"]) if "countType" in v else 'VarInt')) + elif t == "container": # Struct + return Ref('StructType', Ref(", ".join(format_tuple((p["name"], mctype(p["type"]))) for p in v if "name" in p))) # some fields are anonymous??? + elif t == "option": # Optional + return Ref('OptionalType', mctype(v)) + elif t == "switch": # Union + return Ref('SwitchType', + v["compareTo"].split('/')[-1], + Ref(format_dict({int(k) if k.isnumeric() else repr(k):mctype(x) for k,x in v["fields"].items()}, depth=0)), + v["default"] if "default" in v and v['default'] != 'void' else None, + ) + # return SwitchType(mctype(v)) # TODO + # elif t == "mapper": # ???? + # return TrailingData + else: + logging.error("Encountered unknown composite data type : %s", t) + return Ref('TrailingData') + +def mchint(slot_type:Any) -> Ref: + if isinstance(slot_type, str) and slot_type in HINT_MAP: + return Ref(HINT_MAP[slot_type]) + if isinstance(slot_type, list): + t = slot_type[0] + if t == "buffer": # Array of bytes + return Ref('bytes') + elif t == "array": # Generic array + return Ref('list') + elif t == "container": # Struct + return Ref('dict') + elif t == "option": # Optional + return Ref('tuple') + elif t == "switch": # Union + return Ref('bytes') + # return SwitchType(mctype(v)) # TODO + # elif t == "mapper": # ???? + # return TrailingData + return Ref('bytes') + +def pytype(t:list) -> str: + vals = set(str(x) for x in t) + if len(vals) <= 1: + return next(iter(vals)) + return 'Union[' + ','.join(x for x in vals) + ']' def snake_to_camel(name:str) -> str: return "".join(x.capitalize() for x in name.split("_")) -def parse_slot(slot: dict) -> str: - name = slot["name"] if "name" in slot else "anon" - if keyword.iskeyword(name): - name = "is_" + name - t = mctype(slot["type"] if "type" in slot else "restBuffer") - return f"(\"{name}\", {t.__name__})" - -def parse_field(slot: dict) -> str: - name = slot["name"] if "name" in slot else "anon" - if keyword.iskeyword(name): - name = "is_" + name - t = mctype(slot["type"] if "type" in slot else "restBuffer") - return f"{name} : {t._pytype.__name__}" +# def parse_slot(slot: dict) -> str: +# name = slot["name"] if "name" in slot else "anon" +# if keyword.iskeyword(name): +# name = "is_" + name +# t = mctype(slot["type"] if "type" in slot else "restBuffer") +# return f"(\"{name}\", {t.__name__})" +# +# def parse_field(slot: dict) -> str: +# name = slot["name"] if "name" in slot else "anon" +# if keyword.iskeyword(name): +# name = "is_" + name +# t = mctype(slot["type"] if "type" in slot else "restBuffer") +# return f"{name} : {t._pytype.__name__}" class PacketClassWriter: - title : str - ids : str - attrs : List[str] - slots : str - fields: str + name : str + attrs : Set[str] + types : Dict[str, List[Type]] + hints : Dict[str, List[Type]] + ids : Dict[int, int] + definitions : Dict[int, List[Tuple[str, Type]]] state : int - - def __init__(self, title:str, ids:str, attrs:List[str], slots:str, fields:str, state:int): - self.title = title - self.ids = ids - self.attrs = attrs - self.slots = slots - self.fields = fields + def __init__(self, pkt:dict, state:int): + self.name = pkt["name"] self.state = state + self.attrs = set() + self.ids = {} + self.types = {} + self.hints = {} + self.definitions = {} + for v, defn in pkt["definitions"].items(): + self.ids[v] = defn["id"] + self.definitions[v] = [] + for field in defn["slots"]: + if "name" not in field: + logging.error("Skipping anonymous field %s", str(field)) + continue + field_name = field["name"] if not keyword.iskeyword(field["name"]) else "is_" + field["name"] + self.attrs.add(field_name) + self.definitions[v].append((field_name, mctype(field["type"]))) + if field_name not in self.types: + self.types[field_name] = set() + self.types[field_name].add(mctype(field["type"])) + if field_name not in self.hints: + self.hints[field_name] = set() + self.hints[field_name].add(mchint(field["type"])) def compile(self) -> str: return PREFACE + \ IMPORTS + \ OBJECT.format( - name=self.title, - ids='{\n\t\t' + ',\n\t\t'.join(self.ids) + '\n\t}\n', - definitions='{\n\t\t' + '\n\t\t'.join(self.slots) + '\n\t}\n', - slots=', '.join((f"'is_{x}'" if keyword.iskeyword(x) else f"'{x}'") for x in (list(self.attrs) + ["id"])), # TODO de-jank! - fields='\n\t'.join(self.fields), + name=self.name, + ids=format_dict(self.ids, depth=2), + definitions=format_dict({ k : Ref(format_list(Ref(format_tuple(x)) for x in v)) for k,v in self.definitions.items() }, depth=2), + slots=format_tuple(["id"] + list(self.attrs), depth=0), # TODO jank fix when no slots + fields="\n\t" + "\n\t".join(f"{a} : {pytype(self.hints[a])}" for a in self.attrs), state=self.state, ) +class RegistryClassWriter: + registry : dict + + def __init__(self, registry:dict): + self.registry = registry + + def compile(self) -> str: + return REGISTRY_ENTRY.format( + entries='{\n\t' + ",\n\t".join(( + str(v) + " : { " + ", ".join( + f"{pid}:{clazz}" for (pid, clazz) in self.registry[v].items() + ) + ' }' ) for v in self.registry.keys() + ) + '\n}' + ) + def _make_module(path:Path, contents:dict): os.mkdir(path) imports = "" - for key in contents: - imports += f"from .{key} import {contents[key]}\n" + for key, value in contents.items(): + imports += f"from .{key} import {value}\n" with open(path / "__init__.py", "w") as f: f.write(PREFACE + imports) @@ -156,7 +291,6 @@ def compile(): } } - # TODO load all versions! all_versions = os.listdir(mc_path / f'{folder_name}/data/pc/') all_versions.remove("common") all_proto_numbers = [] @@ -175,6 +309,7 @@ def compile(): with open(mc_path / f'{folder_name}/data/pc/{v}/protocol.json') as f: data = json.load(f) + # Build data structure containing all packets with all their definitions for different versions for state in ("handshaking", "status", "login", "play"): for _direction in ("toClient", "toServer"): direction = DIR_MAP[_direction] @@ -212,45 +347,24 @@ def compile(): _make_module(mc_path / f"proto/{state}", { k:"*" for k in PACKETS[state].keys() }) for direction in PACKETS[state].keys(): registry = {} + _make_module(mc_path / f"proto/{state}/{direction}", { k:snake_to_camel(k) for k in PACKETS[state][direction].keys() }) for packet in PACKETS[state][direction].keys(): pkt = PACKETS[state][direction][packet] - slots = [] - fields = set() - attrs = set() - ids = [] - for v in sorted(PACKETS[state][direction][packet]["definitions"].keys()): - defn = pkt["definitions"][v] + + for v, defn in pkt["definitions"].items(): if v not in registry: registry[v] = {} registry[v][defn['id']] = snake_to_camel(packet) - ids.append(f"{v} : 0x{defn['id']:02X}") - v_slots = [] - v_fields = [] - for slot in defn["slots"]: - v_slots.append(parse_slot(slot)) - fields.add(parse_field(slot)) - if "name" in slot: - attrs.add(slot["name"]) - slots.append(f"{v} : [ {','.join(v_slots)} ],") with open(mc_path / f"proto/{state}/{direction}/{packet}.py", "w") as f: - f.write( - PacketClassWriter( - pkt["name"], ids, attrs, slots, fields, _STATE_MAP[state] - ).compile() - ) + f.write(PacketClassWriter(pkt, _STATE_MAP[state]).compile()) with open(mc_path / f"proto/{state}/{direction}/__init__.py", "a") as f: - f.write( # TODO make this thing actually readable, maybe not using nested joins and generators - REGISTRY_ENTRY.format( - entries='{\n\t' + ",\n\t".join(( - str(v) + " : { " + ", ".join( - f"{pid}:{clazz}" for (pid, clazz) in registry[v].items() - ) + ' }' ) for v in registry.keys() - ) + '\n}' - ) - ) + f.write(RegistryClassWriter(registry).compile()) + +if __name__ == "__main__": + compile()