implemented packet serialization with BytesIO

This commit is contained in:
əlemi 2021-10-13 01:30:07 +02:00 committed by alemidev
parent e4a1a374e4
commit b09646deb1
3 changed files with 104 additions and 47 deletions

View file

@ -5,6 +5,18 @@ from enum import Enum
class InvalidState(Exception):
pass
async def read_varint(stream: asyncio.StreamReader) -> int:
"""Utility method to read a VarInt off the socket, because len comes as a VarInt..."""
buf = 0
off = 0
while True:
byte = int.from_bytes(await stream.read(1), 'little')
buf |= (byte & 0b01111111) >> (7*off)
if not byte & 0b10000000:
break
off += 1
return buf
class Dispatcher:
_down : StreamReader
_up : StreamWriter
@ -32,7 +44,7 @@ class Dispatcher:
async def _down_worker(self):
while self._dispatching:
length = await VarInt.read(self._down)
length = await read_varint(self._down)
buffer = await self._down.read(length)
# TODO encryption
# TODO compression

View file

@ -1,3 +1,4 @@
import io
import struct
import asyncio
@ -5,127 +6,145 @@ from typing import Any
class Type(object):
_pytype : type
_size : int
_fmt : str
@classmethod
def serialize(cls, data:Any) -> bytes:
return struct.pack(cls._fmt, data)
# These methods will work only for fixed size data, if _size and _fmt are defined.
# For anything variabile in size, define custom read() and write() classmethods
@classmethod
def deserialize(cls, data:bytes) -> Any:
return struct.unpack(cls._fmt, data)[0]
def write(cls, data:Any, buffer:io.BytesIO):
buffer.write(struct.pack(cls._fmt, data))
@classmethod
def read(cls, buffer:io.BytesIO) -> Any:
return struct.unpack(cls._fmt, buffer.read(cls._size))[0]
class Boolean(Type):
_pytype : type = bool
_size : int = 1
_fmt : str = ">?"
class Byte(Type):
_pytype : type = int
_size : int = 1
_fmt : str = ">b"
class UnsignedByte(Type):
_pytype : type = int
_size : int = 1
_fmt : str = ">B"
class Short(Type):
_pytype : type = int
_size : int = 2
_fmt : str = ">h"
class UnsignedShort(Type):
_pytype : type = int
_size : int = 2
_fmt : str = ">H"
class Int(Type):
_pytype : type = int
_size : int = 4
_fmt : str = ">i"
class UnsignedInt(Type):
_pytype : type = int
_size : int = 4
_fmt : str = ">I"
class Long(Type):
_pytype : type = int
_size : int = 8
_fmt : str = ">q"
class UnsignedLong(Type):
_pytype : type = int
_size : int = 8
_fmt : str = ">Q"
class Float(Type):
_pytype : type = float
_size : int = 4
_fmt : str = ">f"
class Double(Type):
_pytype : type = float
_size : int = 8
_fmt : str = ">d"
class VarInt(Type):
_pytype : type = int
_maxBytes = 5
_size = 5
# @classmethod
# async def read(cls, stream: asyncio.StreamReader) -> int:
# """Utility method to read a VarInt off the socket, because len comes as a VarInt..."""
# buf = 0
# off = 0
# while True:
# byte = await stream.read(1)
# buf |= (byte & 0b01111111) >> (7*off)
# if not byte & 0b10000000:
# break
# off += 1
# return buf
@classmethod
async def read(cls, stream: asyncio.StreamReader) -> int:
"""Utility method to read a VarInt off the socket, because len comes as a VarInt..."""
buf = 0
off = 0
while True:
byte = await stream.read(1)
buf |= (byte & 0b01111111) >> (7*off)
if not byte & 0b10000000:
break
off += 1
return buf
@classmethod
def serialize(cls, data:int) -> bytes:
res : bytearray = bytearray()
def write(cls, data:int, buffer:io.BytesIO):
count = 0
while True:
if count >= cls._maxBytes:
break
buf = data >> (7*count)
val = (buf & 0b01111111)
if (buf & 0b0000000) != 0:
val |= 0b1000000
res.extend(val.to_bytes(1, 'little'))
buffer.write(val.to_bytes(1, 'little'))
count += 1
if count >= cls._size:
break
if (buf & 0b0000000) == 0:
break
count += 1
return res
@classmethod
def unserialize(cls, data:bytes) -> int:
def read(cls, buffer:io.BytesIO) -> int:
numRead = 0
result = 0
pos = 0
while True:
buf = data[0]
value = buf & 0b01111111
result |= value << (7 * numRead)
buf = int.from_bytes(buffer.read(1), 'little')
result |= (buf & 0b01111111) << (7 * numRead)
numRead +=1
if numRead > cls._maxBytes:
if numRead > cls._size:
raise ValueError("VarInt is too big")
if buf & 0b10000000 == 0:
break
return result
@classmethod
def serialize(cls, data:int) -> bytes:
buf = io.BytesIO()
cls.write(data, buf)
buf.seek(0)
return buf.read()
class VarLong(VarInt):
_pytype : type = int
_maxBytes = 10
_size = 10
class String(Type):
_pytype : type = str
@classmethod
def serialize(cls, data:str) -> bytes:
def write(cls, data:str, buffer:io.BytesIO):
encoded = data.encode('utf-8')
return VarInt.serialize(len(encoded)) + struct.pack(f">{len(encoded)}s", encoded)
VarInt.write(len(encoded), buffer)
buffer.write(struct.pack(f">{len(encoded)}s", encoded))
@classmethod
def unserialize(cls, data:bytes) -> str:
length = VarInt.unserialize(data)
start_index = len(data) - length
return struct.unpack(f">{length}s", data[start_index:])[0]
def read(cls, buffer:io.BytesIO) -> str:
length = VarInt.read(buffer)
return struct.unpack(f">{length}s", buffer.read(length))[0]
class Chat(String):
_pytype : type = str
@ -169,3 +188,14 @@ class UUID(Type):
_pytype : type = bytes
# TODO
pass
class TrailingByteArray(Type):
_pytype : type = bytes
@classmethod
def write(cls, data:bytes, buffer:io.BytesIO):
buffer.write(data)
@classmethod
def read(cls, buffer:io.BytesIO) -> bytes:
return buffer.read()

View file

@ -1,3 +1,4 @@
import io
import json
from typing import Tuple, Dict
@ -18,14 +19,28 @@ class Packet:
setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None)
@classmethod
def deserialize(cls, proto:int, data:bytes):
return cls(proto, **{ name : t.deserialize(data) for (name, t) in cls._slots[proto] })
def deserialize(cls, proto:int, buffer:io.BytesIO):
pid = VarInt.read(buffer)
return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._slots[proto] })
def serialize(self) -> io.BytesIO:
buf = io.BytesIO()
VarInt.write(self.id, buf)
for name, t in self.slots:
t.write(getattr(self, name, None), buf)
buf.seek(0)
return buf
def __eq__(self, other) -> bool:
if not isinstance(other, self.__class__):
return False
if self._protocol != other._protocol:
return False
for name, t in self.slots:
if getattr(self, name) != getattr(other, name):
return False
return True
def serialize(self) -> bytes:
return VarInt.serialize(self.id) + b''.join(
slot[1].serialize(getattr(self, slot[0], None))
for slot in self.slots
)
def __str__(self) -> str:
obj = {} # could be done with dict comp but the _ key gets put last :(