diff --git a/aiocraft/mc/packet.py b/aiocraft/mc/packet.py index 1b91f03..feb110c 100644 --- a/aiocraft/mc/packet.py +++ b/aiocraft/mc/packet.py @@ -1,32 +1,38 @@ import io import json -from typing import Tuple, Dict +from asyncio import Event +from typing import Tuple, Dict, Any from .mctypes import Type, VarInt class Packet: + __slots__ = 'id', 'definition', 'sent', '_protocol', '_state' + id : int - slots : Tuple[Tuple[str, Type]] + definition : Tuple[Tuple[str, Type]] + sent : Event + _protocol : int + _state : int _ids : Dict[int, int] # definitions are compiled at install time - _slots : Dict[int, Tuple[Tuple[str, Type]]] # definitions are compiled at install time + _definitions : Dict[int, Tuple[Tuple[str, Type]]] # definitions are compiled at install time def __init__(self, proto:int, **kwargs): self._protocol = proto - self.slots = self._slots[proto] + self.definition = self._definitions[proto] + self.sent = Event() self.id = self._ids[proto] - for name, t in self.slots: + for name, t in self.definition: setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None) @classmethod def deserialize(cls, proto:int, buffer:io.BytesIO): - pid = VarInt.read(buffer) - return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._slots[proto] }) + return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._definitions[proto] }) def serialize(self) -> io.BytesIO: buf = io.BytesIO() VarInt.write(self.id, buf) - for name, t in self.slots: + for name, t in self.definition: t.write(getattr(self, name, None), buf) buf.seek(0) return buf @@ -36,22 +42,23 @@ class Packet: return False if self._protocol != other._protocol: return False - for name, t in self.slots: + for name, t in self.definition: if getattr(self, name) != getattr(other, name): return False return True def __str__(self) -> str: - obj = {} # could be done with dict comp but the _ key gets put last :( + obj : Dict[str, Any] = {} # could be done with dict comp but the _ key gets put last :( obj["_"] = self.__class__.__name__ obj["_proto"] = self._protocol - for key, t in self.slots: + obj["_state"] = self._state + for key, t in self.definition: obj[key] = getattr(self, key, None) return json.dumps(obj, indent=2, default=str) def __repr__(self) -> str: - attrs = (f"{key}={repr(getattr(self, key, None))}" for (key, t) in self.slots) + attrs = (f"{key}={repr(getattr(self, key, None))}" for (key, t) in self.definition) return f"{self.__class__.__name__}({self._protocol}, {', '.join(attrs)})" diff --git a/compiler/proto.py b/compiler/proto.py index 3e0cee9..389e1f8 100644 --- a/compiler/proto.py +++ b/compiler/proto.py @@ -21,15 +21,19 @@ OBJECT = """ class {name}(Packet): {fields} + _state : int = {state} + _ids : Dict[int, int] = {ids} - _slots : Dict[int, List[Tuple[str, Type]]] = {slots} + _definitions : Dict[int, List[Tuple[str, Type]]] = {slots} """ TYPE_MAP = { "varint": VarInt, - "u8": UnsignedShort, - "u16": UnsignedInt, - "u32": UnsignedLong, + "u8": Byte, + "i8": Byte, + "u16": UnsignedShort, + "u32": UnsignedInt, + "u64": UnsignedLong, "i16": Short, "i32": Int, "i64": Long, @@ -45,11 +49,14 @@ TYPE_MAP = { "entityMetadata": EntityMetadata, } -def mctype(name:str) -> Type: - if not isinstance(name, str): - return TrailingByteArray - if name in TYPE_MAP: - return TYPE_MAP[name] +def mctype(slot_type:Any) -> Type: + if isinstance(slot_type, str) and slot_type in TYPE_MAP: + return TYPE_MAP[slot_type] + if isinstance(slot_type, list): + name = slot_type[0] + if name == "buffer": + return ByteArray + # TODO composite data types return TrailingByteArray def snake_to_camel(name:str) -> str: @@ -74,22 +81,25 @@ class PacketClassWriter: ids : str slots : str fields: str + state : int - def __init__(self, title:str, ids:str, slots:str, fields:str): + def __init__(self, title:str, ids:str, slots:str, fields:str, state:int): self.title = title self.ids = ids self.slots = slots self.fields = fields + self.state = state def compile(self) -> str: return PREFACE + \ IMPORTS + \ OBJECT.format( name=self.title, - ids=self.ids, + ids='{\n\t\t' + ',\n\t\t'.join(self.ids) + '\n\t}\n', slots=self.slots, fields=self.fields, + state=self.state, ) def _make_module(path:Path, contents:dict): @@ -187,6 +197,8 @@ def compile(): "slots" : packet[1], } + _STATE_MAP = {"handshaking": 0, "status":1, "login":2, "play":3} + _make_module(mc_path / 'proto', { k:"*" for k in PACKETS.keys() }) for state in PACKETS.keys(): _make_module(mc_path / f"proto/{state}", { k:"*" for k in PACKETS[state].keys() }) @@ -215,9 +227,10 @@ def compile(): f.write( PacketClassWriter( pkt["name"], - '{\n\t\t' + ',\n\t\t'.join(ids) + '\n\t}\n', + ids, '{\n\t\t' + '\n\t\t'.join(slots) + '\n\t}\n', '\n\t'.join(fields), + _STATE_MAP[state] ).compile() )