implemented packet encryption

This commit is contained in:
əlemi 2021-11-10 18:56:54 +01:00
parent 8e68f70e1d
commit 0a688c251e
2 changed files with 174 additions and 84 deletions

View file

@ -1,124 +1,176 @@
import io
import asyncio
import zlib
import logging
from asyncio import StreamReader, StreamWriter, Queue, Task
from enum import Enum
from typing import Dict, Optional
from cryptography.hazmat.primitives.ciphers import CipherContext
from .mc import proto
from .mc.mctypes import VarInt
from .mc.packet import Packet
from .mc import encryption
logger = logging.getLogger(__name__)
class ConnectionState(Enum):
NONE = -1
HANDSHAKING = 0
STATUS = 1
LOGIN = 2
PLAY = 3
def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Packet]]:
if state == ConnectionState.HANDSHAKING:
return proto.handshaking.clientbound.REGISTRY
if state == ConnectionState.STATUS:
return proto.status.clientbound.REGISTRY
if state == ConnectionState.LOGIN:
return proto.login.clientbound.REGISTRY
if state == ConnectionState.PLAY:
return proto.play.clientbound.REGISTRY
class InvalidState(Exception):
pass
async def read_varint(stream: asyncio.StreamReader) -> int:
"""Utility method to read a VarInt off the socket, because len comes as a VarInt..."""
buf = 0
off = 0
while True:
byte = int.from_bytes(await stream.read(1), 'little')
buf |= (byte & 0b01111111) >> (7*off)
if not byte & 0b10000000:
break
off += 1
return buf
class ConnectionError(Exception):
pass
_STATE_REGS = {
ConnectionState.HANDSHAKING : proto.handshaking,
ConnectionState.STATUS : proto.status,
ConnectionState.LOGIN : proto.login,
ConnectionState.PLAY : proto.play,
ConnectionState.HANDSHAKING : proto.handshaking.clientbound.REGISTRY,
ConnectionState.STATUS : proto.status.clientbound.REGISTRY,
ConnectionState.LOGIN : proto.login.clientbound.REGISTRY,
ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
}
class Dispatcher:
_down : StreamReader
_up : StreamWriter
_reader : Task
_decryptor : CipherContext
_up : StreamWriter
_writer : Task
_encryptor : CipherContext
_dispatching : bool
incoming : Queue
outgoing : Queue
incoming : Queue[Packet]
outgoing : Queue[Packet]
host : str
port : int
proto : int
connected : bool
state : ConnectionState
encryption : bool
compression : Optional[int]
host : str
port : int
def __init__(self, host:str, port:int):
self.host = host
self.port = port
self.proto = 340
self.connected = False
self._dispatching = False
self.compression = None
self.encryption = False
self.incoming = Queue()
self.outgoing = Queue()
async def write(self, packet:Packet, wait:bool=False) -> int:
await self.outgoing.put(packet)
if wait:
await packet.sent.wait()
return self.outgoing.qsize()
async def start(self):
if self.connected:
raise InvalidState("Dispatcher already connected")
await self.connect()
async def stop(self, block:bool=True):
self._dispatching = False
if block:
await asyncio.gather(self._writer, self._reader)
async def connect(self):
self._down, self._up = await asyncio.open_connection(
host=self.host,
port=self.port,
)
self.state = ConnectionState.HANDSHAKING
self.encryption = False
self.compression = None
self.connected = True
packet_handshake = proto.handshaking.serverbound.PacketSetProtocol(
self.proto,
protocolVersion=self.proto,
serverHost=self.host,
serverPost=self.port,
nextState=3, # play
)
packet_login = proto.login.serverbound.PacketLoginStart(340, username=self.username)
await self.outgoing.put(packet_handshake)
await self.outgoing.put(packet_login)
self.state = ConnectionState.HANDSHAKING
self._dispatching = True
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
async def encrypt(self, secret:bytes):
cipher = encryption.create_AES_cipher(secret)
self._encryptor = cipher.encryptor()
self._decryptor = cipher.decryptor()
self.encryption = True
async def _read_varint(self) -> int:
numRead = 0
result = 0
while True:
data = await self._down.readexactly(1)
if len(data) < 1:
raise ConnectionError("Could not read data off socket")
if self.encryption:
data = self._decryptor.update(data)
buf = int.from_bytes(data, 'little')
result |= (buf & 0b01111111) << (7 * numRead)
numRead +=1
if numRead > 5:
raise ValueError("VarInt is too big")
if buf & 0b10000000 == 0:
break
return result
async def _down_worker(self):
while self._dispatching:
length = await read_varint(self._down)
buffer = io.BytesIO(await self._down.read(length))
if self.state != ConnectionState.PLAY:
await self.incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression
try: # these 2 will timeout or raise EOFError if client gets disconnected
length = await self._read_varint()
data = await self._down.readexactly(length)
# TODO encryption
if self.encryption:
data = self._decryptor.update(data)
buffer = io.BytesIO(data)
if self.compression is not None:
decompressed_size = VarInt.read(buffer)
# logger.info("Decompressing packet to %d | %s", decompressed_size, buffer.getvalue())
if decompressed_size > 0:
decompressor = zlib.decompressobj()
decompressed_data = decompressor.decompress(buffer.read())
# logger.info("Obtained %s", decompressed_data)
if len(decompressed_data) != decompressed_size:
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
buffer = io.BytesIO(decompressed_data)
packet_id = VarInt.read(buffer)
cls = _registry_from_state(self.state)[self.proto][packet_id]
await self.incoming.put(cls.deserialize(self.proto, buffer))
# logger.info("Parsing packet '%d' [%s] | %s", packet_id, str(self.state), buffer.getvalue())
cls = _STATE_REGS[self.state][self.proto][packet_id]
packet = cls.deserialize(self.proto, buffer)
await self.incoming.put(packet)
except AttributeError:
logger.debug("Received unimplemented packet %d", packet_id)
except (ConnectionError, asyncio.IncompleteReadError):
logger.exception("Connection error")
await self.stop(block=False)
except Exception:
logger.exception("Error while processing packet %d | %s", packet_id, buffer.getvalue())
async def _up_worker(self):
while self._dispatching:
packet = await self.outgoing.get()
# logger.info("Up packet")
try:
packet = await asyncio.wait_for(self.outgoing.get(), timeout=5)
buffer = packet.serialize()
length = len(buffer)
# logger.info("Sending packet %s [%s]", str(packet), buffer.getvalue())
length = len(buffer.getvalue()) # ewww TODO
if self.compression is not None:
if length > self.compression:
@ -129,21 +181,20 @@ class Dispatcher:
else:
new_buffer = io.BytesIO()
VarInt.write(0, new_buffer)
new_buffer.write(buffer.read)
new_buffer.write(buffer.read())
buffer = new_buffer
length = len(buffer)
length = len(buffer.getvalue())
# TODO encryption
data = VarInt.serialize(length) + buffer.getvalue()
if self.encryption:
data = self._encryptor.update(data)
await self._up.write(VarInt.serialize(length) + buffer)
async def run(self):
if self.connected:
raise InvalidState("Dispatcher already connected")
await self.connect()
async def stop(self):
self._dispatching = False
await asyncio.gather(self._writer, self._reader)
self._up.write(data)
await self._up.drain()
packet.sent.set() # Notify
except asyncio.TimeoutError:
pass # need this to recheck self._dispatching periodically
except Exception:
logger.exception("Error while sending packet %s", str(packet))

39
aiocraft/mc/encryption.py Normal file
View file

@ -0,0 +1,39 @@
"""Minecraft encryption utilities. These are mostly pasted and edited from https://github.com/ammaraskar/pyCraft"""
# TODO read more about this, improve implementation if possible
import os
from hashlib import sha1
from asyncio import StreamWriter, StreamReader
from dataclasses import dataclass
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.serialization import load_der_public_key
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
def generate_shared_secret() -> bytes:
return os.urandom(16)
def create_AES_cipher(shared_secret:bytes) -> Cipher:
return Cipher(algorithms.AES(shared_secret), modes.CFB8(shared_secret),
backend=default_backend())
def encrypt_token_and_secret(pubkey, verification_token, shared_secret:bytes):
pubkey = load_der_public_key(pubkey, default_backend())
encrypted_token = pubkey.encrypt(verification_token, PKCS1v15())
encrypted_secret = pubkey.encrypt(shared_secret, PKCS1v15())
return encrypted_token, encrypted_secret
def generate_verification_hash(server_id, shared_secret, public_key) -> str:
verification_hash = sha1()
verification_hash.update(server_id.encode('utf-8'))
verification_hash.update(shared_secret)
verification_hash.update(public_key)
return minecraft_sha1_hash_digest(verification_hash)
def minecraft_sha1_hash_digest(sha1_hash) -> str:
return format(int.from_bytes(sha1_hash.digest(), byteorder='big', signed=True), 'x')