implemented packet encryption
This commit is contained in:
parent
8e68f70e1d
commit
0a688c251e
2 changed files with 174 additions and 84 deletions
|
@ -1,149 +1,200 @@
|
||||||
|
import io
|
||||||
import asyncio
|
import asyncio
|
||||||
import zlib
|
import zlib
|
||||||
|
import logging
|
||||||
from asyncio import StreamReader, StreamWriter, Queue, Task
|
from asyncio import StreamReader, StreamWriter, Queue, Task
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from cryptography.hazmat.primitives.ciphers import CipherContext
|
||||||
|
|
||||||
from .mc import proto
|
from .mc import proto
|
||||||
from .mc.mctypes import VarInt
|
from .mc.mctypes import VarInt
|
||||||
|
from .mc.packet import Packet
|
||||||
|
from .mc import encryption
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class ConnectionState(Enum):
|
class ConnectionState(Enum):
|
||||||
|
NONE = -1
|
||||||
HANDSHAKING = 0
|
HANDSHAKING = 0
|
||||||
STATUS = 1
|
STATUS = 1
|
||||||
LOGIN = 2
|
LOGIN = 2
|
||||||
PLAY = 3
|
PLAY = 3
|
||||||
|
|
||||||
def _registry_from_state(state:ConnectionState) -> Dict[int, Dict[int, Packet]]:
|
|
||||||
if state == ConnectionState.HANDSHAKING:
|
|
||||||
return proto.handshaking.clientbound.REGISTRY
|
|
||||||
if state == ConnectionState.STATUS:
|
|
||||||
return proto.status.clientbound.REGISTRY
|
|
||||||
if state == ConnectionState.LOGIN:
|
|
||||||
return proto.login.clientbound.REGISTRY
|
|
||||||
if state == ConnectionState.PLAY:
|
|
||||||
return proto.play.clientbound.REGISTRY
|
|
||||||
|
|
||||||
class InvalidState(Exception):
|
class InvalidState(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def read_varint(stream: asyncio.StreamReader) -> int:
|
class ConnectionError(Exception):
|
||||||
"""Utility method to read a VarInt off the socket, because len comes as a VarInt..."""
|
pass
|
||||||
buf = 0
|
|
||||||
off = 0
|
|
||||||
while True:
|
|
||||||
byte = int.from_bytes(await stream.read(1), 'little')
|
|
||||||
buf |= (byte & 0b01111111) >> (7*off)
|
|
||||||
if not byte & 0b10000000:
|
|
||||||
break
|
|
||||||
off += 1
|
|
||||||
return buf
|
|
||||||
|
|
||||||
_STATE_REGS = {
|
_STATE_REGS = {
|
||||||
ConnectionState.HANDSHAKING : proto.handshaking,
|
ConnectionState.HANDSHAKING : proto.handshaking.clientbound.REGISTRY,
|
||||||
ConnectionState.STATUS : proto.status,
|
ConnectionState.STATUS : proto.status.clientbound.REGISTRY,
|
||||||
ConnectionState.LOGIN : proto.login,
|
ConnectionState.LOGIN : proto.login.clientbound.REGISTRY,
|
||||||
ConnectionState.PLAY : proto.play,
|
ConnectionState.PLAY : proto.play.clientbound.REGISTRY,
|
||||||
}
|
}
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher:
|
||||||
_down : StreamReader
|
_down : StreamReader
|
||||||
_up : StreamWriter
|
|
||||||
_reader : Task
|
_reader : Task
|
||||||
|
_decryptor : CipherContext
|
||||||
|
|
||||||
|
_up : StreamWriter
|
||||||
_writer : Task
|
_writer : Task
|
||||||
|
_encryptor : CipherContext
|
||||||
|
|
||||||
_dispatching : bool
|
_dispatching : bool
|
||||||
|
|
||||||
incoming : Queue
|
incoming : Queue[Packet]
|
||||||
outgoing : Queue
|
outgoing : Queue[Packet]
|
||||||
|
|
||||||
|
host : str
|
||||||
|
port : int
|
||||||
|
|
||||||
|
proto : int
|
||||||
connected : bool
|
connected : bool
|
||||||
state : ConnectionState
|
state : ConnectionState
|
||||||
encryption : bool
|
encryption : bool
|
||||||
compression : Optional[int]
|
compression : Optional[int]
|
||||||
|
|
||||||
host : str
|
|
||||||
port : int
|
|
||||||
|
|
||||||
def __init__(self, host:str, port:int):
|
def __init__(self, host:str, port:int):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
|
||||||
|
self.proto = 340
|
||||||
self.connected = False
|
self.connected = False
|
||||||
self._dispatching = False
|
self._dispatching = False
|
||||||
|
self.compression = None
|
||||||
|
self.encryption = False
|
||||||
|
|
||||||
|
self.incoming = Queue()
|
||||||
|
self.outgoing = Queue()
|
||||||
|
|
||||||
|
async def write(self, packet:Packet, wait:bool=False) -> int:
|
||||||
|
await self.outgoing.put(packet)
|
||||||
|
if wait:
|
||||||
|
await packet.sent.wait()
|
||||||
|
return self.outgoing.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self.connected:
|
||||||
|
raise InvalidState("Dispatcher already connected")
|
||||||
|
await self.connect()
|
||||||
|
|
||||||
|
async def stop(self, block:bool=True):
|
||||||
|
self._dispatching = False
|
||||||
|
if block:
|
||||||
|
await asyncio.gather(self._writer, self._reader)
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
self._down, self._up = await asyncio.open_connection(
|
self._down, self._up = await asyncio.open_connection(
|
||||||
host=self.host,
|
host=self.host,
|
||||||
port=self.port,
|
port=self.port,
|
||||||
)
|
)
|
||||||
self.state = ConnectionState.HANDSHAKING
|
|
||||||
self.encryption = False
|
self.encryption = False
|
||||||
self.compression = None
|
self.compression = None
|
||||||
self.connected = True
|
self.connected = True
|
||||||
|
self.state = ConnectionState.HANDSHAKING
|
||||||
packet_handshake = proto.handshaking.serverbound.PacketSetProtocol(
|
|
||||||
self.proto,
|
|
||||||
protocolVersion=self.proto,
|
|
||||||
serverHost=self.host,
|
|
||||||
serverPost=self.port,
|
|
||||||
nextState=3, # play
|
|
||||||
)
|
|
||||||
packet_login = proto.login.serverbound.PacketLoginStart(340, username=self.username)
|
|
||||||
|
|
||||||
await self.outgoing.put(packet_handshake)
|
|
||||||
await self.outgoing.put(packet_login)
|
|
||||||
|
|
||||||
self._dispatching = True
|
self._dispatching = True
|
||||||
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
|
self._reader = asyncio.get_event_loop().create_task(self._down_worker())
|
||||||
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
|
self._writer = asyncio.get_event_loop().create_task(self._up_worker())
|
||||||
|
|
||||||
|
async def encrypt(self, secret:bytes):
|
||||||
|
cipher = encryption.create_AES_cipher(secret)
|
||||||
|
self._encryptor = cipher.encryptor()
|
||||||
|
self._decryptor = cipher.decryptor()
|
||||||
|
self.encryption = True
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_varint(self) -> int:
|
||||||
|
numRead = 0
|
||||||
|
result = 0
|
||||||
|
while True:
|
||||||
|
data = await self._down.readexactly(1)
|
||||||
|
if len(data) < 1:
|
||||||
|
raise ConnectionError("Could not read data off socket")
|
||||||
|
if self.encryption:
|
||||||
|
data = self._decryptor.update(data)
|
||||||
|
buf = int.from_bytes(data, 'little')
|
||||||
|
result |= (buf & 0b01111111) << (7 * numRead)
|
||||||
|
numRead +=1
|
||||||
|
if numRead > 5:
|
||||||
|
raise ValueError("VarInt is too big")
|
||||||
|
if buf & 0b10000000 == 0:
|
||||||
|
break
|
||||||
|
return result
|
||||||
|
|
||||||
async def _down_worker(self):
|
async def _down_worker(self):
|
||||||
while self._dispatching:
|
while self._dispatching:
|
||||||
length = await read_varint(self._down)
|
if self.state != ConnectionState.PLAY:
|
||||||
buffer = io.BytesIO(await self._down.read(length))
|
await self.incoming.join() # During login we cannot pre-process any packet, first need to maybe get encryption/compression
|
||||||
|
try: # these 2 will timeout or raise EOFError if client gets disconnected
|
||||||
|
length = await self._read_varint()
|
||||||
|
data = await self._down.readexactly(length)
|
||||||
|
|
||||||
# TODO encryption
|
if self.encryption:
|
||||||
|
data = self._decryptor.update(data)
|
||||||
|
|
||||||
if self.compression is not None:
|
buffer = io.BytesIO(data)
|
||||||
decompressed_size = VarInt.read(buffer)
|
|
||||||
if decompressed_size > 0:
|
|
||||||
decompressor = zlib.decompressobj()
|
|
||||||
decompressed_data = decompressor.decompress(buffer.read())
|
|
||||||
if len(decompressed_data) != decompressed_size:
|
|
||||||
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
|
|
||||||
buffer = io.BytesIO(decompressed_data)
|
|
||||||
|
|
||||||
packet_id = VarInt.read(buffer)
|
if self.compression is not None:
|
||||||
cls = _registry_from_state(self.state)[self.proto][packet_id]
|
decompressed_size = VarInt.read(buffer)
|
||||||
await self.incoming.put(cls.deserialize(self.proto, buffer))
|
# logger.info("Decompressing packet to %d | %s", decompressed_size, buffer.getvalue())
|
||||||
|
if decompressed_size > 0:
|
||||||
|
decompressor = zlib.decompressobj()
|
||||||
|
decompressed_data = decompressor.decompress(buffer.read())
|
||||||
|
# logger.info("Obtained %s", decompressed_data)
|
||||||
|
if len(decompressed_data) != decompressed_size:
|
||||||
|
raise ValueError(f"Failed decompressing packet: expected size is {decompressed_size}, but actual size is {len(decompressed_data)}")
|
||||||
|
buffer = io.BytesIO(decompressed_data)
|
||||||
|
|
||||||
|
packet_id = VarInt.read(buffer)
|
||||||
|
# logger.info("Parsing packet '%d' [%s] | %s", packet_id, str(self.state), buffer.getvalue())
|
||||||
|
cls = _STATE_REGS[self.state][self.proto][packet_id]
|
||||||
|
packet = cls.deserialize(self.proto, buffer)
|
||||||
|
await self.incoming.put(packet)
|
||||||
|
except AttributeError:
|
||||||
|
logger.debug("Received unimplemented packet %d", packet_id)
|
||||||
|
except (ConnectionError, asyncio.IncompleteReadError):
|
||||||
|
logger.exception("Connection error")
|
||||||
|
await self.stop(block=False)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error while processing packet %d | %s", packet_id, buffer.getvalue())
|
||||||
|
|
||||||
async def _up_worker(self):
|
async def _up_worker(self):
|
||||||
while self._dispatching:
|
while self._dispatching:
|
||||||
packet = await self.outgoing.get()
|
# logger.info("Up packet")
|
||||||
buffer = packet.serialize()
|
try:
|
||||||
length = len(buffer)
|
packet = await asyncio.wait_for(self.outgoing.get(), timeout=5)
|
||||||
|
buffer = packet.serialize()
|
||||||
|
# logger.info("Sending packet %s [%s]", str(packet), buffer.getvalue())
|
||||||
|
length = len(buffer.getvalue()) # ewww TODO
|
||||||
|
|
||||||
if self.compression is not None:
|
if self.compression is not None:
|
||||||
if length > self.compression:
|
if length > self.compression:
|
||||||
new_buffer = io.BytesIO()
|
new_buffer = io.BytesIO()
|
||||||
VarInt.write(length, new_buffer)
|
VarInt.write(length, new_buffer)
|
||||||
new_buffer.write(zlib.compress(buffer.read()))
|
new_buffer.write(zlib.compress(buffer.read()))
|
||||||
buffer = new_buffer
|
buffer = new_buffer
|
||||||
else:
|
else:
|
||||||
new_buffer = io.BytesIO()
|
new_buffer = io.BytesIO()
|
||||||
VarInt.write(0, new_buffer)
|
VarInt.write(0, new_buffer)
|
||||||
new_buffer.write(buffer.read)
|
new_buffer.write(buffer.read())
|
||||||
buffer = new_buffer
|
buffer = new_buffer
|
||||||
length = len(buffer)
|
length = len(buffer.getvalue())
|
||||||
|
|
||||||
# TODO encryption
|
data = VarInt.serialize(length) + buffer.getvalue()
|
||||||
|
if self.encryption:
|
||||||
|
data = self._encryptor.update(data)
|
||||||
|
|
||||||
await self._up.write(VarInt.serialize(length) + buffer)
|
self._up.write(data)
|
||||||
|
await self._up.drain()
|
||||||
async def run(self):
|
|
||||||
if self.connected:
|
|
||||||
raise InvalidState("Dispatcher already connected")
|
|
||||||
await self.connect()
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
self._dispatching = False
|
|
||||||
await asyncio.gather(self._writer, self._reader)
|
|
||||||
|
|
||||||
|
packet.sent.set() # Notify
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass # need this to recheck self._dispatching periodically
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error while sending packet %s", str(packet))
|
||||||
|
|
||||||
|
|
39
aiocraft/mc/encryption.py
Normal file
39
aiocraft/mc/encryption.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
"""Minecraft encryption utilities. These are mostly pasted and edited from https://github.com/ammaraskar/pyCraft"""
|
||||||
|
# TODO read more about this, improve implementation if possible
|
||||||
|
import os
|
||||||
|
from hashlib import sha1
|
||||||
|
from asyncio import StreamWriter, StreamReader
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
||||||
|
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||||
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
|
||||||
|
|
||||||
|
def generate_shared_secret() -> bytes:
|
||||||
|
return os.urandom(16)
|
||||||
|
|
||||||
|
|
||||||
|
def create_AES_cipher(shared_secret:bytes) -> Cipher:
|
||||||
|
return Cipher(algorithms.AES(shared_secret), modes.CFB8(shared_secret),
|
||||||
|
backend=default_backend())
|
||||||
|
|
||||||
|
def encrypt_token_and_secret(pubkey, verification_token, shared_secret:bytes):
|
||||||
|
pubkey = load_der_public_key(pubkey, default_backend())
|
||||||
|
|
||||||
|
encrypted_token = pubkey.encrypt(verification_token, PKCS1v15())
|
||||||
|
encrypted_secret = pubkey.encrypt(shared_secret, PKCS1v15())
|
||||||
|
return encrypted_token, encrypted_secret
|
||||||
|
|
||||||
|
def generate_verification_hash(server_id, shared_secret, public_key) -> str:
|
||||||
|
verification_hash = sha1()
|
||||||
|
|
||||||
|
verification_hash.update(server_id.encode('utf-8'))
|
||||||
|
verification_hash.update(shared_secret)
|
||||||
|
verification_hash.update(public_key)
|
||||||
|
|
||||||
|
return minecraft_sha1_hash_digest(verification_hash)
|
||||||
|
|
||||||
|
def minecraft_sha1_hash_digest(sha1_hash) -> str:
|
||||||
|
return format(int.from_bytes(sha1_hash.digest(), byteorder='big', signed=True), 'x')
|
Loading…
Reference in a new issue