implemented packet serialization with BytesIO
This commit is contained in:
parent
e4a1a374e4
commit
b09646deb1
3 changed files with 104 additions and 47 deletions
|
@ -5,6 +5,18 @@ from enum import Enum
|
||||||
class InvalidState(Exception):
|
class InvalidState(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def read_varint(stream: asyncio.StreamReader) -> int:
|
||||||
|
"""Utility method to read a VarInt off the socket, because len comes as a VarInt..."""
|
||||||
|
buf = 0
|
||||||
|
off = 0
|
||||||
|
while True:
|
||||||
|
byte = int.from_bytes(await stream.read(1), 'little')
|
||||||
|
buf |= (byte & 0b01111111) >> (7*off)
|
||||||
|
if not byte & 0b10000000:
|
||||||
|
break
|
||||||
|
off += 1
|
||||||
|
return buf
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher:
|
||||||
_down : StreamReader
|
_down : StreamReader
|
||||||
_up : StreamWriter
|
_up : StreamWriter
|
||||||
|
@ -32,7 +44,7 @@ class Dispatcher:
|
||||||
|
|
||||||
async def _down_worker(self):
|
async def _down_worker(self):
|
||||||
while self._dispatching:
|
while self._dispatching:
|
||||||
length = await VarInt.read(self._down)
|
length = await read_varint(self._down)
|
||||||
buffer = await self._down.read(length)
|
buffer = await self._down.read(length)
|
||||||
# TODO encryption
|
# TODO encryption
|
||||||
# TODO compression
|
# TODO compression
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import io
|
||||||
import struct
|
import struct
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
@ -5,127 +6,145 @@ from typing import Any
|
||||||
|
|
||||||
class Type(object):
|
class Type(object):
|
||||||
_pytype : type
|
_pytype : type
|
||||||
|
_size : int
|
||||||
_fmt : str
|
_fmt : str
|
||||||
|
|
||||||
@classmethod
|
# These methods will work only for fixed size data, if _size and _fmt are defined.
|
||||||
def serialize(cls, data:Any) -> bytes:
|
# For anything variabile in size, define custom read() and write() classmethods
|
||||||
return struct.pack(cls._fmt, data)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize(cls, data:bytes) -> Any:
|
def write(cls, data:Any, buffer:io.BytesIO):
|
||||||
return struct.unpack(cls._fmt, data)[0]
|
buffer.write(struct.pack(cls._fmt, data))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read(cls, buffer:io.BytesIO) -> Any:
|
||||||
|
return struct.unpack(cls._fmt, buffer.read(cls._size))[0]
|
||||||
|
|
||||||
class Boolean(Type):
|
class Boolean(Type):
|
||||||
_pytype : type = bool
|
_pytype : type = bool
|
||||||
|
_size : int = 1
|
||||||
_fmt : str = ">?"
|
_fmt : str = ">?"
|
||||||
|
|
||||||
class Byte(Type):
|
class Byte(Type):
|
||||||
_pytype : type = int
|
_pytype : type = int
|
||||||
|
_size : int = 1
|
||||||
_fmt : str = ">b"
|
_fmt : str = ">b"
|
||||||
|
|
||||||
class UnsignedByte(Type):
|
class UnsignedByte(Type):
|
||||||
_pytype : type = int
|
_pytype : type = int
|
||||||
|
_size : int = 1
|
||||||
_fmt : str = ">B"
|
_fmt : str = ">B"
|
||||||
|
|
||||||
class Short(Type):
|
class Short(Type):
|
||||||
_pytype : type = int
|
_pytype : type = int
|
||||||
|
_size : int = 2
|
||||||
_fmt : str = ">h"
|
_fmt : str = ">h"
|
||||||
|
|
||||||
class UnsignedShort(Type):
|
class UnsignedShort(Type):
|
||||||
_pytype : type = int
|
_pytype : type = int
|
||||||
|
_size : int = 2
|
||||||
_fmt : str = ">H"
|
_fmt : str = ">H"
|
||||||
|
|
||||||
class Int(Type):
|
class Int(Type):
|
||||||
_pytype : type = int
|
_pytype : type = int
|
||||||
|
_size : int = 4
|
||||||
_fmt : str = ">i"
|
_fmt : str = ">i"
|
||||||
|
|
||||||
class UnsignedInt(Type):
|
class UnsignedInt(Type):
|
||||||
_pytype : type = int
|
_pytype : type = int
|
||||||
|
_size : int = 4
|
||||||
_fmt : str = ">I"
|
_fmt : str = ">I"
|
||||||
|
|
||||||
class Long(Type):
|
class Long(Type):
|
||||||
_pytype : type = int
|
_pytype : type = int
|
||||||
|
_size : int = 8
|
||||||
_fmt : str = ">q"
|
_fmt : str = ">q"
|
||||||
|
|
||||||
class UnsignedLong(Type):
|
class UnsignedLong(Type):
|
||||||
_pytype : type = int
|
_pytype : type = int
|
||||||
|
_size : int = 8
|
||||||
_fmt : str = ">Q"
|
_fmt : str = ">Q"
|
||||||
|
|
||||||
class Float(Type):
|
class Float(Type):
|
||||||
_pytype : type = float
|
_pytype : type = float
|
||||||
|
_size : int = 4
|
||||||
_fmt : str = ">f"
|
_fmt : str = ">f"
|
||||||
|
|
||||||
class Double(Type):
|
class Double(Type):
|
||||||
_pytype : type = float
|
_pytype : type = float
|
||||||
|
_size : int = 8
|
||||||
_fmt : str = ">d"
|
_fmt : str = ">d"
|
||||||
|
|
||||||
class VarInt(Type):
|
class VarInt(Type):
|
||||||
_pytype : type = int
|
_pytype : type = int
|
||||||
_maxBytes = 5
|
_size = 5
|
||||||
|
|
||||||
|
# @classmethod
|
||||||
|
# async def read(cls, stream: asyncio.StreamReader) -> int:
|
||||||
|
# """Utility method to read a VarInt off the socket, because len comes as a VarInt..."""
|
||||||
|
# buf = 0
|
||||||
|
# off = 0
|
||||||
|
# while True:
|
||||||
|
# byte = await stream.read(1)
|
||||||
|
# buf |= (byte & 0b01111111) >> (7*off)
|
||||||
|
# if not byte & 0b10000000:
|
||||||
|
# break
|
||||||
|
# off += 1
|
||||||
|
# return buf
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def read(cls, stream: asyncio.StreamReader) -> int:
|
def write(cls, data:int, buffer:io.BytesIO):
|
||||||
"""Utility method to read a VarInt off the socket, because len comes as a VarInt..."""
|
|
||||||
buf = 0
|
|
||||||
off = 0
|
|
||||||
while True:
|
|
||||||
byte = await stream.read(1)
|
|
||||||
buf |= (byte & 0b01111111) >> (7*off)
|
|
||||||
if not byte & 0b10000000:
|
|
||||||
break
|
|
||||||
off += 1
|
|
||||||
return buf
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def serialize(cls, data:int) -> bytes:
|
|
||||||
res : bytearray = bytearray()
|
|
||||||
count = 0
|
count = 0
|
||||||
while True:
|
while True:
|
||||||
if count >= cls._maxBytes:
|
|
||||||
break
|
|
||||||
buf = data >> (7*count)
|
buf = data >> (7*count)
|
||||||
val = (buf & 0b01111111)
|
val = (buf & 0b01111111)
|
||||||
if (buf & 0b0000000) != 0:
|
if (buf & 0b0000000) != 0:
|
||||||
val |= 0b1000000
|
val |= 0b1000000
|
||||||
res.extend(val.to_bytes(1, 'little'))
|
buffer.write(val.to_bytes(1, 'little'))
|
||||||
|
count += 1
|
||||||
|
if count >= cls._size:
|
||||||
|
break
|
||||||
if (buf & 0b0000000) == 0:
|
if (buf & 0b0000000) == 0:
|
||||||
break
|
break
|
||||||
count += 1
|
|
||||||
return res
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def unserialize(cls, data:bytes) -> int:
|
def read(cls, buffer:io.BytesIO) -> int:
|
||||||
numRead = 0
|
numRead = 0
|
||||||
result = 0
|
result = 0
|
||||||
pos = 0
|
|
||||||
while True:
|
while True:
|
||||||
buf = data[0]
|
buf = int.from_bytes(buffer.read(1), 'little')
|
||||||
value = buf & 0b01111111
|
result |= (buf & 0b01111111) << (7 * numRead)
|
||||||
result |= value << (7 * numRead)
|
|
||||||
numRead +=1
|
numRead +=1
|
||||||
if numRead > cls._maxBytes:
|
if numRead > cls._size:
|
||||||
raise ValueError("VarInt is too big")
|
raise ValueError("VarInt is too big")
|
||||||
if buf & 0b10000000 == 0:
|
if buf & 0b10000000 == 0:
|
||||||
break
|
break
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def serialize(cls, data:int) -> bytes:
|
||||||
|
buf = io.BytesIO()
|
||||||
|
cls.write(data, buf)
|
||||||
|
buf.seek(0)
|
||||||
|
return buf.read()
|
||||||
|
|
||||||
class VarLong(VarInt):
|
class VarLong(VarInt):
|
||||||
_pytype : type = int
|
_pytype : type = int
|
||||||
_maxBytes = 10
|
_size = 10
|
||||||
|
|
||||||
class String(Type):
|
class String(Type):
|
||||||
_pytype : type = str
|
_pytype : type = str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def serialize(cls, data:str) -> bytes:
|
def write(cls, data:str, buffer:io.BytesIO):
|
||||||
encoded = data.encode('utf-8')
|
encoded = data.encode('utf-8')
|
||||||
return VarInt.serialize(len(encoded)) + struct.pack(f">{len(encoded)}s", encoded)
|
VarInt.write(len(encoded), buffer)
|
||||||
|
buffer.write(struct.pack(f">{len(encoded)}s", encoded))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def unserialize(cls, data:bytes) -> str:
|
def read(cls, buffer:io.BytesIO) -> str:
|
||||||
length = VarInt.unserialize(data)
|
length = VarInt.read(buffer)
|
||||||
start_index = len(data) - length
|
return struct.unpack(f">{length}s", buffer.read(length))[0]
|
||||||
return struct.unpack(f">{length}s", data[start_index:])[0]
|
|
||||||
|
|
||||||
class Chat(String):
|
class Chat(String):
|
||||||
_pytype : type = str
|
_pytype : type = str
|
||||||
|
@ -169,3 +188,14 @@ class UUID(Type):
|
||||||
_pytype : type = bytes
|
_pytype : type = bytes
|
||||||
# TODO
|
# TODO
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class TrailingByteArray(Type):
|
||||||
|
_pytype : type = bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def write(cls, data:bytes, buffer:io.BytesIO):
|
||||||
|
buffer.write(data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def read(cls, buffer:io.BytesIO) -> bytes:
|
||||||
|
return buffer.read()
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
from typing import Tuple, Dict
|
from typing import Tuple, Dict
|
||||||
|
|
||||||
|
@ -18,14 +19,28 @@ class Packet:
|
||||||
setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None)
|
setattr(self, name, t._pytype(kwargs[name]) if name in kwargs else None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def deserialize(cls, proto:int, data:bytes):
|
def deserialize(cls, proto:int, buffer:io.BytesIO):
|
||||||
return cls(proto, **{ name : t.deserialize(data) for (name, t) in cls._slots[proto] })
|
pid = VarInt.read(buffer)
|
||||||
|
return cls(proto, **{ name : t.read(buffer) for (name, t) in cls._slots[proto] })
|
||||||
|
|
||||||
|
def serialize(self) -> io.BytesIO:
|
||||||
|
buf = io.BytesIO()
|
||||||
|
VarInt.write(self.id, buf)
|
||||||
|
for name, t in self.slots:
|
||||||
|
t.write(getattr(self, name, None), buf)
|
||||||
|
buf.seek(0)
|
||||||
|
return buf
|
||||||
|
|
||||||
|
def __eq__(self, other) -> bool:
|
||||||
|
if not isinstance(other, self.__class__):
|
||||||
|
return False
|
||||||
|
if self._protocol != other._protocol:
|
||||||
|
return False
|
||||||
|
for name, t in self.slots:
|
||||||
|
if getattr(self, name) != getattr(other, name):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def serialize(self) -> bytes:
|
|
||||||
return VarInt.serialize(self.id) + b''.join(
|
|
||||||
slot[1].serialize(getattr(self, slot[0], None))
|
|
||||||
for slot in self.slots
|
|
||||||
)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
obj = {} # could be done with dict comp but the _ key gets put last :(
|
obj = {} # could be done with dict comp but the _ key gets put last :(
|
||||||
|
|
Loading…
Reference in a new issue