fixed VarInt reading, some definitions improvements

This commit is contained in:
əlemi 2021-11-10 18:55:46 +01:00
parent a485b25150
commit ba3c0e14e9
2 changed files with 44 additions and 24 deletions

View file

@ -1,32 +1,38 @@
import io import io
import json import json
from typing import Tuple, Dict from asyncio import Event
from typing import Tuple, Dict, Any
from .mctypes import Type, VarInt from .mctypes import Type, VarInt
class Packet: class Packet:
__slots__ = 'id', 'definition', 'sent', '_protocol', '_state'
id : int id : int
slots : Tuple[Tuple[str, Type]] definition : Tuple[Tuple[str, Type]]
sent : Event
_protocol : int
_state : int
_ids : Dict[int, int] # definitions are compiled at install time _ids : Dict[int, int] # definitions are compiled at install time
_slots : Dict[int, Tuple[Tuple[str, Type]]] # definitions are compiled at install time _definitions : Dict[int, Tuple[Tuple[str, Type]]] # definitions are compiled at install time
def __init__(self, proto:int, **kwargs): def __init__(self, proto:int, **kwargs):
self._protocol = proto self._protocol = proto
self.slots = self._slots[proto] self.definition = self._definitions[proto]
self.sent = Event()
self.id = self._ids[proto] self.id = self._ids[proto]
for name, t in self.slots: for name, t in self.definition:
setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None) setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None)
@classmethod @classmethod
def deserialize(cls, proto:int, buffer:io.BytesIO): def deserialize(cls, proto:int, buffer:io.BytesIO):
pid = VarInt.read(buffer) return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._definitions[proto] })
return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._slots[proto] })
def serialize(self) -> io.BytesIO: def serialize(self) -> io.BytesIO:
buf = io.BytesIO() buf = io.BytesIO()
VarInt.write(self.id, buf) VarInt.write(self.id, buf)
for name, t in self.slots: for name, t in self.definition:
t.write(getattr(self, name, None), buf) t.write(getattr(self, name, None), buf)
buf.seek(0) buf.seek(0)
return buf return buf
@ -36,22 +42,23 @@ class Packet:
return False return False
if self._protocol != other._protocol: if self._protocol != other._protocol:
return False return False
for name, t in self.slots: for name, t in self.definition:
if getattr(self, name) != getattr(other, name): if getattr(self, name) != getattr(other, name):
return False return False
return True return True
def __str__(self) -> str: def __str__(self) -> str:
obj = {} # could be done with dict comp but the _ key gets put last :( obj : Dict[str, Any] = {} # could be done with dict comp but the _ key gets put last :(
obj["_"] = self.__class__.__name__ obj["_"] = self.__class__.__name__
obj["_proto"] = self._protocol obj["_proto"] = self._protocol
for key, t in self.slots: obj["_state"] = self._state
for key, t in self.definition:
obj[key] = getattr(self, key, None) obj[key] = getattr(self, key, None)
return json.dumps(obj, indent=2, default=str) return json.dumps(obj, indent=2, default=str)
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = (f"{key}={repr(getattr(self, key, None))}" for (key, t) in self.slots) attrs = (f"{key}={repr(getattr(self, key, None))}" for (key, t) in self.definition)
return f"{self.__class__.__name__}({self._protocol}, {', '.join(attrs)})" return f"{self.__class__.__name__}({self._protocol}, {', '.join(attrs)})"

View file

@ -21,15 +21,19 @@ OBJECT = """
class {name}(Packet): class {name}(Packet):
{fields} {fields}
_state : int = {state}
_ids : Dict[int, int] = {ids} _ids : Dict[int, int] = {ids}
_slots : Dict[int, List[Tuple[str, Type]]] = {slots} _definitions : Dict[int, List[Tuple[str, Type]]] = {slots}
""" """
TYPE_MAP = { TYPE_MAP = {
"varint": VarInt, "varint": VarInt,
"u8": UnsignedShort, "u8": Byte,
"u16": UnsignedInt, "i8": Byte,
"u32": UnsignedLong, "u16": UnsignedShort,
"u32": UnsignedInt,
"u64": UnsignedLong,
"i16": Short, "i16": Short,
"i32": Int, "i32": Int,
"i64": Long, "i64": Long,
@ -45,11 +49,14 @@ TYPE_MAP = {
"entityMetadata": EntityMetadata, "entityMetadata": EntityMetadata,
} }
def mctype(name:str) -> Type: def mctype(slot_type:Any) -> Type:
if not isinstance(name, str): if isinstance(slot_type, str) and slot_type in TYPE_MAP:
return TrailingByteArray return TYPE_MAP[slot_type]
if name in TYPE_MAP: if isinstance(slot_type, list):
return TYPE_MAP[name] name = slot_type[0]
if name == "buffer":
return ByteArray
# TODO composite data types
return TrailingByteArray return TrailingByteArray
def snake_to_camel(name:str) -> str: def snake_to_camel(name:str) -> str:
@ -74,22 +81,25 @@ class PacketClassWriter:
ids : str ids : str
slots : str slots : str
fields: str fields: str
state : int
def __init__(self, title:str, ids:str, slots:str, fields:str): def __init__(self, title:str, ids:str, slots:str, fields:str, state:int):
self.title = title self.title = title
self.ids = ids self.ids = ids
self.slots = slots self.slots = slots
self.fields = fields self.fields = fields
self.state = state
def compile(self) -> str: def compile(self) -> str:
return PREFACE + \ return PREFACE + \
IMPORTS + \ IMPORTS + \
OBJECT.format( OBJECT.format(
name=self.title, name=self.title,
ids=self.ids, ids='{\n\t\t' + ',\n\t\t'.join(self.ids) + '\n\t}\n',
slots=self.slots, slots=self.slots,
fields=self.fields, fields=self.fields,
state=self.state,
) )
def _make_module(path:Path, contents:dict): def _make_module(path:Path, contents:dict):
@ -187,6 +197,8 @@ def compile():
"slots" : packet[1], "slots" : packet[1],
} }
_STATE_MAP = {"handshaking": 0, "status":1, "login":2, "play":3}
_make_module(mc_path / 'proto', { k:"*" for k in PACKETS.keys() }) _make_module(mc_path / 'proto', { k:"*" for k in PACKETS.keys() })
for state in PACKETS.keys(): for state in PACKETS.keys():
_make_module(mc_path / f"proto/{state}", { k:"*" for k in PACKETS[state].keys() }) _make_module(mc_path / f"proto/{state}", { k:"*" for k in PACKETS[state].keys() })
@ -215,9 +227,10 @@ def compile():
f.write( f.write(
PacketClassWriter( PacketClassWriter(
pkt["name"], pkt["name"],
'{\n\t\t' + ',\n\t\t'.join(ids) + '\n\t}\n', ids,
'{\n\t\t' + '\n\t\t'.join(slots) + '\n\t}\n', '{\n\t\t' + '\n\t\t'.join(slots) + '\n\t}\n',
'\n\t'.join(fields), '\n\t'.join(fields),
_STATE_MAP[state]
).compile() ).compile()
) )