diff --git a/aiocraft/mc/types.py b/aiocraft/mc/types.py index b1c0f0d..f6ffe71 100644 --- a/aiocraft/mc/types.py +++ b/aiocraft/mc/types.py @@ -3,8 +3,12 @@ import struct import asyncio import uuid +import logging +import pynbt + from typing import List, Tuple, Dict, Any, Union, Optional, Type as Class + class Type(object): pytype : type @@ -44,8 +48,6 @@ class UnimplementedDataType(Type): TrailingData = UnimplementedDataType() EntityMetadata = UnimplementedDataType() EntityMetadataItem = UnimplementedDataType() -NBTTag = UnimplementedDataType() -Slot = UnimplementedDataType() class PrimitiveType(Type): size : int @@ -75,6 +77,21 @@ Float = PrimitiveType(float, ">f", 4) Double = PrimitiveType(float, ">d", 8) Angle = PrimitiveType(int, ">b", 1) +class NBTType(Type): + + def write(self, data:pynbt.NBTFile, buffer:io.BytesIO, ctx:object=None): + data.save(buffer) + + def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[pynbt.NBTFile]: + head = Byte.read(buffer) + if head == 0x0: + return None + buffer.seek(-1,1) # go back 1 byte + return pynbt.NBTFile(io=buffer) + +NBTTag = NBTType() +# NBTTag = TrailingData + class VarLenPrimitive(Type): pytype : type = int max_bytes : int @@ -159,6 +176,13 @@ class BufferType(Type): ByteArray = BufferType() IntegerByteArray = BufferType(Int) +def twos_comp(val, bits): + """compute the 2's complement of int value val""" + if (val & (1 << (bits - 1))) != 0: # if sign bit is set e.g., 8bit: 128-255 + val = val - (1 << bits) # compute negative value + return val # return positive value as is + + class PositionType(Type): pytype : type = tuple MAX_SIZE : int = 8 @@ -166,14 +190,16 @@ class PositionType(Type): # TODO THIS IS FOR 1.12.2!!! Make a generic version-less? def write(self, data:tuple, buffer:io.BytesIO, ctx:object=None): - packed = ((0x3FFFFFF & data[0]) << 38) | ((0xFFF & data[1]) << 26) | (0x3FFFFFF & data[2]) + packed = ((0x3FFFFFF & data[0]) << 38) \ + | ((0xFFF & data[1]) << 12) \ + | (0x3FFFFFF & data[2]) UnsignedLong.write(packed, buffer, ctx=ctx) def read(self, buffer:io.BytesIO, ctx:object=None) -> tuple: packed = UnsignedLong.read(buffer) - x = packed >> 38 - y = (packed >> 24) & 0xFFF - z = packed & 0x3FFFFFF + x = twos_comp(packed >> 38, 26) + y = (packed >> 26) & 0xFFF + z = twos_comp(packed & 0x3FFFFFF, 26) return (x, y, z) Position = PositionType() @@ -209,7 +235,10 @@ class ArrayType(Type): def read(self, buffer:io.BytesIO, ctx:object=None) -> List[Any]: length = self.counter if isinstance(self.counter, int) else self.counter.read(buffer, ctx=ctx) - return [ self.content.read(buffer, ctx=ctx) for _ in range(length) ] + out = [] + for _ in range(length): + out.append(self.content.read(buffer, ctx=ctx)) + return out class OptionalType(Type): t : Type @@ -238,15 +267,15 @@ class SwitchType(Type): self.default = default def write(self, data:Any, buffer:io.BytesIO, ctx:object=None): - watched = getattr(ctx, self.field) - if watched in self.mappings: + watched = getattr(ctx, self.field, None) + if watched is not None and watched in self.mappings: return self.mappings[watched].write(data, buffer, ctx=ctx) elif self.default: return self.default.write(data, buffer, ctx=ctx) def read(self, buffer:io.BytesIO, ctx:object=None) -> Optional[Any]: - watched = getattr(ctx, self.field) - if watched in self.mappings: + watched = getattr(ctx, self.field, None) + if watched is not None and watched in self.mappings: return self.mappings[watched].read(buffer, ctx=ctx) elif self.default: return self.default.read(buffer, ctx=ctx) @@ -265,5 +294,41 @@ class StructType(Type): def read(self, buffer:io.BytesIO, ctx:object=None) -> Dict[str, Any]: return { k : t.read(buffer, ctx=ctx) for k, t in self.fields } + +class SlotType(Type): + pytype : type = dict + + def write(self, data:Any, buffer:io.BytesIO, ctx:object=None): + new_way = ctx._proto > 340 + check_type = Boolean if new_way else Short + if data: + check_type.write(True if new_way else data["id"], buffer) + if new_way: + VarInt.write(data["id"], buffer) + Byte.write(data["count"], buffer) + if not new_way: + Short.write(data["damage"], buffer) + NBTTag.write(data["nbt"], buffer) + else: + check_type.write(False if new_way else -1, buffer) + + def read(self, buffer:io.BytesIO, ctx:object=None) -> Any: + slot = {} + new_way = ctx._proto > 340 + check_type = Boolean if new_way else Short + val = check_type.read(buffer) + if (new_way and val) or val != -1: + if new_way: + slot["id"] = VarInt.read(buffer) + else: + slot["id"] = val + slot["count"] = Byte.read(buffer) + if not new_way: + slot["damage"] = Short.read(buffer) + slot["nbt"] = NBTTag.read(buffer) + return slot + +Slot = SlotType() +# Slot = TrailingData diff --git a/requirements.txt b/requirements.txt index 6787c69..9dc0c11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +pynbt cryptography aiohttp termcolor