implemented packet encryption

This commit is contained in:
əlemi 2021-11-10 18:56:54 +01:00
parent 8e68f70e1d
commit 0a688c251e
2 changed files with 174 additions and 84 deletions

View file

@ -1,149 +1,200 @@
import io
import asyncio import asyncio
import zlib import zlib
import logging
from asyncio import StreamReader, StreamWriter, Queue, Task from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum from enum import Enum
from typing import Dict, Optional
from cryptography.hazmat.primitives.ciphers import CipherContext
from .mc import proto from .mc import proto
from .mc.mctypes import VarInt from .mc.mctypes import VarInt
from .mc.packet import Packet
from .mc import encryption
logger = logging.getLogger(__name__)
class ConnectionState(Enum): class ConnectionState(Enum):
NONE = -1
HANDSHAKING = 0 HANDSHAKING = 0
STATUS = 1 STATUS = 1
LOGIN = 2 LOGIN = 2
PLAY = 3 PLAY = 3
def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Packet]]:
if state == ConnectionState.HANDSHAKING:
return proto.handshaking.clientbound.REGISTRY
if state == ConnectionState.STATUS:
return proto.status.clientbound.REGISTRY
if state == ConnectionState.LOGIN:
return proto.login.clientbound.REGISTRY
if state == ConnectionState.PLAY:
return proto.play.clientbound.REGISTRY
class InvalidState(Exception): class InvalidState(Exception):
pass pass
async def read_varint(stream: asyncio.StreamReader) -> int: class ConnectionError(Exception):
"""Utility method to read a VarInt off the socket, because len comes as a VarInt...""" pass
buf = 0
off = 0
while True:
byte = int.from_bytes(await stream.read(1), 'little')
buf |= (byte & 0b01111111) >> (7*off)
if not byte & 0b10000000:
break
off += 1
return buf
_STATE_REGS = { _STATE_REGS = {
ConnectionState.HANDSHAKING : proto.handshaking, ConnectionState.HANDSHAKING : proto.handshaking.clientbound.REGISTRY,
ConnectionState.STATUS : proto.status, ConnectionState.STATUS : proto.status.clientbound.REGISTRY,
ConnectionState.LOGIN : proto.login, ConnectionState.LOGIN : proto.login.clientbound.REGISTRY,
ConnectionState.PLAY : proto.play, ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
} }
class Dispatcher: class Dispatcher:
_down : StreamReader _down : StreamReader
_up : StreamWriter
_reader : Task _reader : Task
_decryptor : CipherContext
_up : StreamWriter
_writer : Task _writer : Task
_encryptor : CipherContext
_dispatching : bool _dispatching : bool
incoming : Queue incoming : Queue[Packet]
outgoing : Queue outgoing : Queue[Packet]
host : str
port : int
proto : int
connected : bool connected : bool
state : ConnectionState state : ConnectionState
encryption : bool encryption : bool
compression : Optional[int] compression : Optional[int]
host : str
port : int
def __init__(self, host:str, port:int): def __init__(self, host:str, port:int):
self.host = host self.host = host
self.port = port self.port = port
self.proto = 340
self.connected = False self.connected = False
self._dispatching = False self._dispatching = False
self.compression = None
self.encryption = False
self.incoming = Queue()
self.outgoing = Queue()
async def write(self, packet:Packet, wait:bool=False) -> int:
await self.outgoing.put(packet)
if wait:
await packet.sent.wait()
return self.outgoing.qsize()
async def start(self):
if self.connected:
raise InvalidState("Dispatcher already connected")
await self.connect()
async def stop(self, block:bool=True):
self._dispatching = False
if block:
await asyncio.gather(self._writer, self._reader)
async def connect(self): async def connect(self):
self._down, self._up = await asyncio.open_connection( self._down, self._up = await asyncio.open_connection(
host=self.host, host=self.host,
port=self.port, port=self.port,
) )
self.state = ConnectionState.HANDSHAKING
self.encryption = False self.encryption = False
self.compression = None self.compression = None
self.connected = True self.connected = True
self.state = ConnectionState.HANDSHAKING
packet_handshake = proto.handshaking.serverbound.PacketSetProtocol(
self.proto,
protocolVersion=self.proto,
serverHost=self.host,
serverPost=self.port,
nextState=3, # play
)
packet_login = proto.login.serverbound.PacketLoginStart(340, username=self.username)
await self.outgoing.put(packet_handshake)
await self.outgoing.put(packet_login)
self._dispatching = True self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker()) self._reader = asyncio.get_event_loop().create_task(self._down_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker()) self._writer = asyncio.get_event_loop().create_task(self._up_worker())
async def encrypt(self, secret:bytes):
cipher = encryption.create_AES_cipher(secret)
self._encryptor = cipher.encryptor()
self._decryptor = cipher.decryptor()
self.encryption = True
async def _read_varint(self) -> int:
numRead = 0
result = 0
while True:
data = await self._down.readexactly(1)
if len(data) < 1:
raise ConnectionError("Could not read data off socket")
if self.encryption:
data = self._decryptor.update(data)
buf = int.from_bytes(data, 'little')
result |= (buf & 0b01111111) << (7 * numRead)
numRead +=1
if numRead > 5:
raise ValueError("VarInt is too big")
if buf & 0b10000000 == 0:
break
return result
async def _down_worker(self): async def _down_worker(self):
while self._dispatching: while self._dispatching:
length = await read_varint(self._down) if self.state != ConnectionState.PLAY:
buffer = io.BytesIO(await self._down.read(length)) await self.incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression
try: # these 2 will timeout or raise EOFError if client gets disconnected
length = await self._read_varint()
data = await self._down.readexactly(length)
# TODO encryption if self.encryption:
data = self._decryptor.update(data)
if self.compression is not None: buffer = io.BytesIO(data)
decompressed_size = VarInt.read(buffer)
if decompressed_size > 0:
decompressor = zlib.decompressobj()
decompressed_data = decompressor.decompress(buffer.read())
if len(decompressed_data) != decompressed_size:
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer) if self.compression is not None:
cls = _registry_from_state(self.state)[self.proto][packet_id] decompressed_size = VarInt.read(buffer)
await self.incoming.put(cls.deserialize(self.proto, buffer)) # logger.info("Decompressing packet to %d | %s", decompressed_size, buffer.getvalue())
if decompressed_size > 0:
decompressor = zlib.decompressobj()
decompressed_data = decompressor.decompress(buffer.read())
# logger.info("Obtained %s", decompressed_data)
if len(decompressed_data) != decompressed_size:
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer)
# logger.info("Parsing packet '%d' [%s] | %s", packet_id, str(self.state), buffer.getvalue())
cls = _STATE_REGS[self.state][self.proto][packet_id]
packet = cls.deserialize(self.proto, buffer)
await self.incoming.put(packet)
except AttributeError:
logger.debug("Received unimplemented packet %d", packet_id)
except (ConnectionError, asyncio.IncompleteReadError):
logger.exception("Connection error")
await self.stop(block=False)
except Exception:
logger.exception("Error while processing packet %d | %s", packet_id, buffer.getvalue())
async def _up_worker(self): async def _up_worker(self):
while self._dispatching: while self._dispatching:
packet = await self.outgoing.get() # logger.info("Up packet")
buffer = packet.serialize() try:
length = len(buffer) packet = await asyncio.wait_for(self.outgoing.get(), timeout=5)
buffer = packet.serialize()
# logger.info("Sending packet %s [%s]", str(packet), buffer.getvalue())
length = len(buffer.getvalue()) # ewww TODO
if self.compression is not None: if self.compression is not None:
if length > self.compression: if length > self.compression:
new_buffer = io.BytesIO() new_buffer = io.BytesIO()
VarInt.write(length, new_buffer) VarInt.write(length, new_buffer)
new_buffer.write(zlib.compress(buffer.read())) new_buffer.write(zlib.compress(buffer.read()))
buffer = new_buffer buffer = new_buffer
else: else:
new_buffer = io.BytesIO() new_buffer = io.BytesIO()
VarInt.write(0, new_buffer) VarInt.write(0, new_buffer)
new_buffer.write(buffer.read) new_buffer.write(buffer.read())
buffer = new_buffer buffer = new_buffer
length = len(buffer) length = len(buffer.getvalue())
# TODO encryption data = VarInt.serialize(length) + buffer.getvalue()
if self.encryption:
data = self._encryptor.update(data)
await self._up.write(VarInt.serialize(length) + buffer) self._up.write(data)
await self._up.drain()
async def run(self):
if self.connected:
raise InvalidState("Dispatcher already connected")
await self.connect()
async def stop(self):
self._dispatching = False
await asyncio.gather(self._writer, self._reader)
packet.sent.set() # Notify
except asyncio.TimeoutError:
pass # need this to recheck self._dispatching periodically
except Exception:
logger.exception("Error while sending packet %s", str(packet))

39
aiocraft/mc/encryption.py Normal file
View file

@ -0,0 +1,39 @@
"""Minecraft encryption utilities. These are mostly pasted and edited from https://github.com/ammaraskar/pyCraft"""
# TODO read more about this, improve implementation if possible
import os
from hashlib import sha1
from asyncio import StreamWriter, StreamReader
from dataclasses import dataclass
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.serialization import load_der_public_key
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
def generate_shared_secret() -> bytes:
return os.urandom(16)
def create_AES_cipher(shared_secret:bytes) -> Cipher:
return Cipher(algorithms.AES(shared_secret), modes.CFB8(shared_secret),
backend=default_backend())
def encrypt_token_and_secret(pubkey, verification_token, shared_secret:bytes):
pubkey = load_der_public_key(pubkey, default_backend())
encrypted_token = pubkey.encrypt(verification_token, PKCS1v15())
encrypted_secret = pubkey.encrypt(shared_secret, PKCS1v15())
return encrypted_token, encrypted_secret
def generate_verification_hash(server_id, shared_secret, public_key) -> str:
verification_hash = sha1()
verification_hash.update(server_id.encode('utf-8'))
verification_hash.update(shared_secret)
verification_hash.update(public_key)
return minecraft_sha1_hash_digest(verification_hash)
def minecraft_sha1_hash_digest(sha1_hash) -> str:
return format(int.from_bytes(sha1_hash.digest(), byteorder='big', signed=True), 'x')