From d6740b4bd32057fd99071f77591c9cd387d749f7 Mon Sep 17 00:00:00 2001 From: git-bruh Date: Sat, 17 Apr 2021 10:15:51 +0530 Subject: [PATCH] Appservice (#4) --- appservice/README.md | 66 ++++ appservice/appservice.py | 212 ++++++++++++ appservice/db.py | 132 ++++++++ appservice/discord.py | 175 ++++++++++ appservice/errors.py | 5 + appservice/gateway.py | 299 ++++++++++++++++ appservice/main.py | 656 ++++++++++++++++++++++++++++++++++++ appservice/matrix.py | 26 ++ appservice/misc.py | 89 +++++ appservice/requirements.txt | 3 + 10 files changed, 1663 insertions(+) create mode 100644 appservice/README.md create mode 100644 appservice/appservice.py create mode 100644 appservice/db.py create mode 100644 appservice/discord.py create mode 100644 appservice/errors.py create mode 100644 appservice/gateway.py create mode 100644 appservice/main.py create mode 100644 appservice/matrix.py create mode 100644 appservice/misc.py create mode 100644 appservice/requirements.txt diff --git a/appservice/README.md b/appservice/README.md new file mode 100644 index 0000000..1b3616b --- /dev/null +++ b/appservice/README.md @@ -0,0 +1,66 @@ +## Installation + +`pip install -r requirements.txt` + +## Usage + +* Run `main.py` to generate `appservice.json` + +* Edit `appservice.json`: + +``` +{ + "as_token": "my-secret-as-token", + "hs_token": "my-secret-hs-token", + "user_id": "appservice-discord", + # Homeserver running on the same machine, listening on port 8008. + "homeserver": "http://127.0.0.1:8008", + # Change "localhost" to your server_name. + # Eg. "kde.org" is the server_name in "@testuser:kde.org". + "server_name": "localhost", + "discord_token": "my-secret-discord-token", + "port": 5000, # Port to run the bottle app on. + "database": "/path/to/bridge.db" +} +``` + +* Create `appservice.yaml` and add it to your homeserver configuration: + +``` +id: "discord" +url: "http://127.0.0.1:5000" +as_token: "my-secret-as-token" +hs_token: "my-secret-hs-token" +sender_localpart: "appservice-discord" +namespaces: + users: + - exclusive: true + regex: "@_discord.*" + # Work around for temporary bug in dendrite. + - regex: "@appservice-discord" + aliases: + - exclusive: false + regex: "#_discord.*" + rooms: [] +``` + +A path can optionally be passed as the first argument to `main.py`. This path will be used as the base directory for the database and log file. + +Eg. Running `python3 main.py /path/to/my/dir` will store the database and logs in `/path/to/my/dir`. +`$PWD` is used by default if no path is specified. + +This bridge is written with: +* `bottle`: Receiving events from the homeserver. +* `urllib3`: Sending requests, thread safety. +* `websockets`: Connecting to Discord. (Big thanks to an anonymous person "nesslersreagent" for figuring out the initial connection mess.) + +## NOTES +* A basic sqlite database is used for keeping track of bridged rooms. + +* Logs are saved to the `appservice.log` file in `$PWD` or the specified directory. + +* For avatars to show up on Discord, you must have a [reverse proxy](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) set up on your homeserver as the bridge does not specify the homeserver port when passing the avatar url. + +* It is not possible to add normal Discord bot functionality like commands as this bridge does not use `discord.py`. + +NOTE: [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) must be enabled for your Discord bot. diff --git a/appservice/appservice.py b/appservice/appservice.py new file mode 100644 index 0000000..654c365 --- /dev/null +++ b/appservice/appservice.py @@ -0,0 +1,212 @@ +import json +import logging +import urllib.parse +import uuid +from typing import List, Union + +import bottle +import urllib3 + +import matrix +from misc import dict_cls, log_except, request + + +class AppService(bottle.Bottle): + def __init__(self, config: dict, http: urllib3.PoolManager) -> None: + super(AppService, self).__init__() + + self.as_token = config["as_token"] + self.hs_token = config["hs_token"] + self.base_url = config["homeserver"] + self.server_name = config["server_name"] + self.user_id = f"@{config['user_id']}:{self.server_name}" + self.http = http + self.logger = logging.getLogger("appservice") + + # TODO better method. + # Map events to functions. + self.mapping = { + "m.room.member": "on_member", + "m.room.message": "on_message", + "m.room.redaction": "on_redaction", + } + + # Add route for bottle. + self.route( + "/transactions/", + callback=self.receive_event, + method="PUT", + ) + + def handle_event(self, event: dict) -> None: + event_type = event.get("type") + + if event_type == "m.room.member" or event_type == "m.room.message": + obj = self.get_event_object(event) + elif event_type == "m.room.redaction": + obj = event + else: + self.logger.info(f"Unknown event type: {event_type}") + return + + func = getattr(self, self.mapping[event_type], None) + + if not func: + self.logger.warning( + f"Function '{func}' not defined, ignoring event." + ) + return + + # We don't catch exceptions here as the homeserver will re-send us + # the event in case of a failure. + func(obj) + + @log_except + def receive_event(self, transaction: str) -> dict: + """ + Verify the homeserver's token and handle events. + """ + + hs_token = bottle.request.query.getone("access_token") + + if not hs_token: + bottle.response.status = 401 + return {"errcode": "APPSERVICE_UNAUTHORIZED"} + + if hs_token != self.hs_token: + bottle.response.status = 403 + return {"errcode": "APPSERVICE_FORBIDDEN"} + + events = bottle.request.json.get("events") + + for event in events: + self.handle_event(event) + + return {} + + def get_event_object(self, event: dict) -> matrix.Event: + event["author"] = dict_cls( + self.get_profile(event["sender"]), matrix.User + ) + + return matrix.Event(event) + + def join_room(self, room_id: str, mxid: str = "") -> None: + self.send( + "POST", + f"/join/{room_id}", + params={"user_id": mxid} if mxid else {}, + ) + + def redact(self, event_id: str, room_id: str, mxid: str = "") -> None: + self.send( + "PUT", + f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}", + params={"user_id": mxid} if mxid else {}, + ) + + def get_profile(self, mxid: str) -> dict: + # TODO handle failure, avoid querying this endpoint repeatedly. + resp = self.send("GET", f"/profile/{mxid}") + + avatar_url = resp.get("avatar_url", "")[6:].split("/") + avatar_url = ( + ( + f"https://{self.server_name}/_matrix/media/r0/download/" + f"{avatar_url[0]}/{avatar_url[1]}" + ) + if len(avatar_url) > 1 + else None + ) + + return { + "avatar_url": avatar_url, + "displayname": resp.get("displayname"), + } + + def get_members(self, room_id: str) -> List[str]: + resp = self.send( + "GET", + f"/rooms/{room_id}/members", + params={"membership": "join", "not_membership": "leave"}, + ) + + return [ + content["sender"] + for content in resp["chunk"] + if content["content"]["membership"] == "join" + ] + + def get_room_id(self, alias: str) -> str: + resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}") + + # TODO cache ? + + return resp["room_id"] + + def upload(self, url: str) -> str: + """ + Upload a file to the homeserver and get the MXC url. + """ + + resp = self.http.request("GET", url) + + resp = self.send( + "POST", + content=resp.data, + content_type=resp.headers.get("Content-Type"), + params={"filename": f"{uuid.uuid4()}"}, + endpoint="/_matrix/media/r0/upload", + ) + + return resp["content_uri"] + + def send_message( + self, + room_id: str, + content: dict, + mxid: str = "", + ) -> str: + resp = self.send( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}", + content, + {"user_id": mxid} if mxid else {}, + ) + + return resp["event_id"] + + def send_typing( + self, room_id: str, mxid: str = "", timeout: int = 8000 + ) -> None: + self.send( + "PUT", + f"/rooms/{room_id}/typing/{mxid}", + {"typing": True, "timeout": timeout}, + {"user_id": mxid} if mxid else {}, + ) + + def send_invite(self, room_id: str, mxid: str) -> None: + self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid}) + + @request + def send( + self, + method: str, + path: str = "", + content: Union[bytes, dict] = {}, + params: dict = {}, + content_type: str = "application/json", + endpoint: str = "/_matrix/client/r0", + ) -> dict: + params["access_token"] = self.as_token + headers = {"Content-Type": content_type} + content = json.dumps(content) if isinstance(content, dict) else content + endpoint = ( + f"{self.base_url}{endpoint}{path}?" + f"{urllib.parse.urlencode(params)}" + ) + + return self.http.request( + method, endpoint, body=content, headers=headers + ) diff --git a/appservice/db.py b/appservice/db.py new file mode 100644 index 0000000..1598260 --- /dev/null +++ b/appservice/db.py @@ -0,0 +1,132 @@ +import os +import sqlite3 +import threading +from typing import List + + +class DataBase(object): + def __init__(self, db_file) -> None: + self.create(db_file) + + # The database is accessed via both the threads. + self.lock = threading.Lock() + + def create(self, db_file) -> None: + """ + Create a database with the relevant tables if it doesn't already exist. + """ + + exists = os.path.exists(db_file) + + self.conn = sqlite3.connect(db_file, check_same_thread=False) + self.conn.row_factory = self.dict_factory + + self.cur = self.conn.cursor() + + if exists: + return + + self.cur.execute( + "CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);" + ) + + self.cur.execute( + "CREATE TABLE users(mxid TEXT PRIMARY KEY, " + "avatar_url TEXT, username TEXT);" + ) + + self.conn.commit() + + def dict_factory(self, cursor, row): + """ + https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory + """ + + d = {} + for idx, col in enumerate(cursor.description): + d[col[0]] = row[idx] + return d + + def add_room(self, room_id: str, channel_id: str) -> None: + """ + Add a bridged room to the database. + """ + + with self.lock: + self.cur.execute( + "INSERT INTO bridge (room_id, channel_id) " + f"VALUES ('{room_id}', '{channel_id}')" + ) + self.conn.commit() + + def add_user(self, mxid: str) -> None: + with self.lock: + self.cur.execute(f"INSERT INTO users (mxid) VALUES ('{mxid}')") + self.conn.commit() + + def add_avatar(self, avatar_url: str, mxid: str) -> None: + with self.lock: + self.cur.execute( + f"UPDATE users SET avatar_url = '{avatar_url}'" + f"WHERE mxid = '{mxid}'" + ) + self.conn.commit() + + def add_username(self, username: str, mxid: str) -> None: + with self.lock: + self.cur.execute( + f"UPDATE users SET username = '{username}'" + f"WHERE mxid = '{mxid}'" + ) + self.conn.commit() + + def get_channel(self, room_id: str) -> str: + """ + Get the corresponding channel ID for a given room ID. + """ + + with self.lock: + self.cur.execute( + "SELECT channel_id FROM bridge WHERE room_id = ?", [room_id] + ) + + room = self.cur.fetchone() + + # Return an empty string if nothing is bridged. + return "" if not room else room["channel_id"] + + def list_channels(self) -> List[str]: + """ + Get a list of all the bridged channels. + """ + + with self.lock: + self.cur.execute("SELECT channel_id FROM bridge") + + channels = self.cur.fetchall() + + return [channel["channel_id"] for channel in channels] + + def list_users(self) -> List[dict]: + """ + Get a dictionary of all the puppeted users. + """ + + with self.lock: + self.cur.execute("SELECT * FROM users") + + users = self.cur.fetchall() + + return users + + def query_user(self, mxid: str) -> bool: + """ + Check whether a puppet user has already been created for a given mxid. + """ + + with self.lock: + self.cur.execute("SELECT mxid FROM users") + + users = self.cur.fetchall() + + return next((True for user in users if user["mxid"] == mxid), False) diff --git a/appservice/discord.py b/appservice/discord.py new file mode 100644 index 0000000..3567416 --- /dev/null +++ b/appservice/discord.py @@ -0,0 +1,175 @@ +from dataclasses import dataclass + +CDN_URL = "https://cdn.discordapp.com" + + +@dataclass +class Channel(object): + id: str + type: str + guild_id: str = "" + name: str = "" + topic: str = "" + + +@dataclass +class Emote(object): + animated: bool + id: str + name: str + + +class User(object): + def __init__(self, user: dict) -> None: + self.discriminator = user["discriminator"] + self.id = user["id"] + self.mention = f"<@{self.id}>" + self.username = user["username"] + + avatar = user["avatar"] + + if not avatar: + # https://discord.com/developers/docs/reference#image-formatting + self.avatar_url = ( + f"{CDN_URL}/embed/avatars/{int(self.discriminator) % 5}.png" + ) + else: + ext = "gif" if avatar.startswith("a_") else "png" + self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}" + + +class Message(object): + def __init__(self, message: dict) -> None: + self.attachments = message.get("attachments", []) + self.channel_id = message["channel_id"] + self.content = message.get("content", "") + self.id = message["id"] + self.reference = message.get("message_reference", {}).get( + "message_id", "" + ) + self.webhook_id = message.get("webhook_id", "") + + self.mentions = [ + User(mention) for mention in message.get("mentions", []) + ] + + author = message.get("author") + + self.author = User(author) if author else None + + +@dataclass +class DeletedMessage(object): + channel_id: str + id: str + + +@dataclass +class Typing(object): + user_id: str + channel_id: str + + +@dataclass +class Webhook(object): + id: str + token: str + + +class ChannelType(object): + GUILD_TEXT = 0 + DM = 1 + GUILD_VOICE = 2 + GROUP_DM = 3 + GUILD_CATEGORY = 4 + GUILD_NEWS = 5 + GUILD_STORE = 6 + + +class InteractionResponseType(object): + PONG = 0 + ACKNOWLEDGE = 1 + CHANNEL_MESSAGE = 2 + CHANNEL_MESSAGE_WITH_SOURCE = 4 + ACKNOWLEDGE_WITH_SOURCE = 5 + + +class GatewayIntents(object): + def bitmask(bit: int) -> int: + return 1 << bit + + GUILDS = bitmask(0) + GUILD_MEMBERS = bitmask(1) + GUILD_BANS = bitmask(2) + GUILD_EMOJIS = bitmask(3) + GUILD_INTEGRATIONS = bitmask(4) + GUILD_WEBHOOKS = bitmask(5) + GUILD_INVITES = bitmask(6) + GUILD_VOICE_STATES = bitmask(7) + GUILD_PRESENCES = bitmask(8) + GUILD_MESSAGES = bitmask(9) + GUILD_MESSAGE_REACTIONS = bitmask(10) + GUILD_MESSAGE_TYPING = bitmask(11) + DIRECT_MESSAGES = bitmask(12) + DIRECT_MESSAGE_REACTIONS = bitmask(13) + DIRECT_MESSAGE_TYPING = bitmask(14) + + +class GatewayOpCodes(object): + DISPATCH = 0 + HEARTBEAT = 1 + IDENTIFY = 2 + PRESENCE_UPDATE = 3 + VOICE_STATE_UPDATE = 4 + RESUME = 6 + RECONNECT = 7 + REQUEST_GUILD_MEMBERS = 8 + INVALID_SESSION = 9 + HELLO = 10 + HEARTBEAT_ACK = 11 + + +class Payloads(object): + def __init__(self, token: str) -> None: + self.seq = self.session = None + self.token = token + + def HEARTBEAT(self) -> dict: + return {"op": GatewayOpCodes.HEARTBEAT, "d": self.seq} + + def IDENTIFY(self) -> dict: + return { + "op": GatewayOpCodes.IDENTIFY, + "d": { + "token": self.token, + "intents": GatewayIntents.GUILDS + | GatewayIntents.GUILD_MESSAGES + | GatewayIntents.GUILD_MESSAGE_TYPING, + "properties": { + "$os": "discord", + "$browser": "discord", + "$device": "discord", + }, + }, + } + + def QUERY(self, guild_id: str, query: str, limit: int = 1) -> dict: + """ + Return the Payload to query a member from a guild ID. + Return only a single match if `limit` isn't specified. + """ + + return { + "op": GatewayOpCodes.REQUEST_GUILD_MEMBERS, + "d": {"guild_id": guild_id, "query": query, "limit": limit}, + } + + def RESUME(self) -> dict: + return { + "op": GatewayOpCodes.RESUME, + "d": { + "token": self.token, + "session_id": self.session, + "seq": self.seq, + }, + } diff --git a/appservice/errors.py b/appservice/errors.py new file mode 100644 index 0000000..4200f50 --- /dev/null +++ b/appservice/errors.py @@ -0,0 +1,5 @@ +class RequestError(Exception): + def __init__(self, status: int, *args): + super().__init__(*args) + + self.status = status diff --git a/appservice/gateway.py b/appservice/gateway.py new file mode 100644 index 0000000..36421ee --- /dev/null +++ b/appservice/gateway.py @@ -0,0 +1,299 @@ +import asyncio +import json +import logging +import urllib.parse +from typing import List + +import urllib3 +import websockets + +import discord +from misc import dict_cls, log_except, request, wrap_async + + +class Gateway(object): + def __init__(self, http: urllib3.PoolManager, token: str): + self.http = http + self.token = token + self.logger = logging.getLogger("discord") + self.cdn_url = "https://cdn.discordapp.com" + self.Payloads = discord.Payloads(self.token) + self.loop = self.websocket = None + + self.query_cache = {} + + @log_except + async def run(self) -> None: + self.loop = asyncio.get_running_loop() + self.query_ev = asyncio.Event() + + self.heartbeat_task = None + self.resume = False + + while True: + try: + await self.gateway_handler(self.get_gateway_url()) + except websockets.ConnectionClosedError: + # TODO reconnect ? + self.logger.exception("Quitting, connection lost.") + break + + # Stop sending heartbeats until we reconnect. + if self.heartbeat_task and not self.heartbeat_task.cancelled(): + self.heartbeat_task.cancel() + + def get_gateway_url(self) -> str: + resp = self.send("GET", "/gateway") + + return resp["url"] + + async def heartbeat_handler(self, interval_ms: int) -> None: + while True: + await asyncio.sleep(interval_ms / 1000) + await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT())) + + def query_handler(self, data: dict) -> None: + members = data["members"] + guild_id = data["guild_id"] + + for member in members: + user = member["user"] + self.query_cache[guild_id].append(user) + + self.query_ev.set() + + def handle_otype(self, data: dict, otype: str) -> None: + if data.get("embeds"): + return # TODO embeds + + if otype == "MESSAGE_CREATE" or otype == "MESSAGE_UPDATE": + obj = discord.Message(data) + elif otype == "MESSAGE_DELETE": + obj = dict_cls(data, discord.DeletedMessage) + elif otype == "TYPING_START": + obj = dict_cls(data, discord.Typing) + elif otype == "GUILD_MEMBERS_CHUNK": + self.query_handler(data) + return + else: + self.logger.info(f"Unknown OTYPE: {otype}") + return + + func = getattr(self, f"on_{otype.lower()}", None) + + if not func: + self.logger.warning( + f"Function '{func}' not defined, ignoring message." + ) + return + + try: + func(obj) + except Exception: + self.logger.exception(f"Ignoring exception in {func}:") + + async def gateway_handler(self, gateway_url: str) -> None: + async with websockets.connect( + f"{gateway_url}/?v=8&encoding=json" + ) as websocket: + self.websocket = websocket + async for message in websocket: + data = json.loads(message) + data_dict = data.get("d") + + opcode = data.get("op") + + seq = data.get("s") + if seq: + self.Payloads.seq = seq + + if opcode == discord.GatewayOpCodes.DISPATCH: + otype = data.get("t") + + if otype == "READY": + self.Payloads.session = data_dict["session_id"] + + self.logger.info("READY") + + else: + self.handle_otype(data_dict, otype) + + elif opcode == discord.GatewayOpCodes.HELLO: + heartbeat_interval = data_dict.get("heartbeat_interval") + + self.logger.info( + f"Heartbeat Interval: {heartbeat_interval}" + ) + + # Send periodic hearbeats to gateway. + self.heartbeat_task = asyncio.ensure_future( + self.heartbeat_handler(heartbeat_interval) + ) + + await websocket.send( + json.dumps( + self.Payloads.RESUME() + if self.resume + else self.Payloads.IDENTIFY() + ) + ) + + elif opcode == discord.GatewayOpCodes.RECONNECT: + self.logger.info("Received RECONNECT.") + + self.resume = True + await websocket.close() + + elif opcode == discord.GatewayOpCodes.INVALID_SESSION: + self.logger.info("Received INVALID_SESSION.") + + self.resume = False + await websocket.close() + + elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK: + # NOP + pass + + else: + self.logger.info( + f"Unknown OP code {opcode}:\n" + f"{json.dumps(data, indent=4)}" + ) + + @wrap_async + async def query_member(self, guild_id: str, name: str) -> discord.User: + """ + Query the members for a given guild and return the first match. + """ + + self.query_ev.clear() + + def query(): + if not self.query_cache.get(guild_id): + self.query_cache[guild_id] = [] + + user = [ + user + for user in self.query_cache[guild_id] + if name.lower() in user["username"].lower() + ] + + return None if not user else discord.User(user[0]) + + user = query() + + if user: + return user + + if not self.websocket or self.websocket.closed: + self.logger.warning("Not fetching members, websocket closed.") + return + + await self.websocket.send( + json.dumps(self.Payloads.QUERY(guild_id, name)) + ) + + # Wait for our websocket to receive the chunk. + await asyncio.wait_for(self.query_ev.wait(), timeout=5) + + return query() + + def get_channel(self, channel_id: str) -> discord.Channel: + """ + Get the channel object for a given channel ID. + """ + + resp = self.send("GET", f"/channels/{channel_id}") + + return dict_cls(resp, discord.Channel) + + def get_emotes(self, guild_id: str) -> List[discord.Emote]: + """ + Get all the emotes for a given guild. + """ + + resp = self.send("GET", f"/guilds/{guild_id}/emojis") + + return [dict_cls(emote, discord.Emote) for emote in resp] + + def get_members(self, guild_id: str) -> List[discord.User]: + """ + Get all the members for a given guild. + """ + + resp = self.send( + "GET", f"/guilds/{guild_id}/members", params={"limit": 1000} + ) + + return [discord.User(member["user"]) for member in resp] + + def create_webhook(self, channel_id: str, name: str) -> discord.Webhook: + """ + Create a webhook with the specified name in a given channel. + """ + + resp = self.send( + "POST", f"/channels/{channel_id}/webhooks", {"name": name} + ) + + return dict_cls(resp, discord.Webhook) + + def edit_webhook( + self, content: str, message_id: str, webhook: discord.Webhook + ) -> None: + self.send( + "PATCH", + f"/webhooks/{webhook.id}/{webhook.token}/messages/" + f"{message_id}", + {"content": content}, + ) + + def delete_webhook( + self, message_id: str, webhook: discord.Webhook + ) -> None: + self.send( + "DELETE", + f"/webhooks/{webhook.id}/{webhook.token}/messages/" + f"{message_id}", + ) + + def send_webhook(self, webhook: discord.Webhook, **kwargs) -> str: + content = { + **kwargs, + # Disable 'everyone' and 'role' mentions. + "allowed_mentions": {"parse": ["users"]}, + } + + resp = self.send( + "POST", + f"/webhooks/{webhook.id}/{webhook.token}", + content, + {"wait": True}, + ) + + return resp["id"] + + def send_message(self, message: str, channel_id: str) -> None: + self.send( + "POST", f"/channels/{channel_id}/messages", {"content": message} + ) + + @request + def send( + self, method: str, path: str, content: dict = {}, params: dict = {} + ) -> dict: + endpoint = ( + f"https://discord.com/api/v8{path}?" + f"{urllib.parse.urlencode(params)}" + ) + headers = { + "Authorization": f"Bot {self.token}", + "Content-Type": "application/json", + } + + # 'body' being an empty dict breaks "GET" requests. + content = json.dumps(content) if content else None + + return self.http.request( + method, endpoint, body=content, headers=headers + ) diff --git a/appservice/main.py b/appservice/main.py new file mode 100644 index 0000000..aab6678 --- /dev/null +++ b/appservice/main.py @@ -0,0 +1,656 @@ +import asyncio +import json +import logging +import os +import re +import sys +import threading +from typing import Dict, Tuple, Union + +import urllib3 + +import discord +import matrix +from appservice import AppService +from db import DataBase +from errors import RequestError +from gateway import Gateway +from misc import dict_cls, except_deleted + +# TODO should this be cleared periodically ? +message_cache: Dict[str, Union[discord.Webhook, str]] = {} + + +class MatrixClient(AppService): + def __init__(self, config: dict, http: urllib3.PoolManager) -> None: + super().__init__(config, http) + + self.db = DataBase(config["database"]) + self.discord = DiscordClient(self, config, http) + self.emote_cache: Dict[str, str] = {} + self.format = "_discord_" # "{@,#}_discord_1234:localhost" + + def to_return(self, event: matrix.Event) -> bool: + return event.sender.startswith(("@_discord", self.user_id)) + + def handle_bridge(self, message: matrix.Event) -> None: + # Ignore events that aren't for us. + if message.sender.split(":")[ + -1 + ] != self.server_name or not message.body.startswith("!bridge"): + return + + try: + channel = message.body.split()[1] + except IndexError: + return + + # Check if the given channel is valid. + try: + channel = self.discord.get_channel(channel) + except RequestError as e: + # The channel can be invalid or we may not have permission. + self.logger.warning(f"Failed to fetch channel {channel}: {e}") + return + + if ( + channel.type != discord.ChannelType.GUILD_TEXT + or channel.id in self.db.list_channels() + ): + return + + self.logger.info(f"Creating bridged room for channel {channel.id}.") + + self.create_room(channel, message.sender) + + def on_member(self, event: matrix.Event) -> None: + # Ignore events that aren't for us. + if ( + event.sender.split(":")[-1] != self.server_name + or event.state_key != self.user_id + or not event.is_direct + ): + return + + # Join the direct message room. + self.logger.info(f"Joining direct message room {event.room_id}.") + self.join_room(event.room_id) + + def on_message(self, message: matrix.Event) -> None: + if self.to_return(message): + return + + # Handle bridging commands. + self.handle_bridge(message) + + channel_id = self.db.get_channel(message.room_id) + + if not channel_id: + return + + webhook = self.discord.get_webhook(channel_id, "matrix_bridge") + + if message.relates_to and message.reltype == "m.replace": + + relation = message_cache.get(message.relates_to) + + if not message.new_body or not relation: + return + + message.new_body = self.process_message( + channel_id, message.new_body + ) + + except_deleted(self.discord.edit_webhook)( + message.new_body, relation["message_id"], webhook + ) + + else: + if not message.body: + return + + message.body = self.process_message(channel_id, message.body) + message_cache[message.event_id] = { + "message_id": self.discord.send_webhook( + webhook, + avatar_url=message.author.avatar_url, + content=message.body, + username=message.author.displayname, + ), + "webhook": webhook, + } + + @except_deleted + def on_redaction(self, event: dict) -> None: + redacts = event["redacts"] + + event = message_cache.get(redacts) + + if not event: + return + + self.discord.delete_webhook(event["message_id"], event["webhook"]) + + message_cache.pop(redacts) + + def create_room(self, channel: discord.Channel, sender: str) -> None: + """ + Create a bridged room and invite the person who invoked the command. + """ + + content = { + "room_alias_name": f"{self.format}{channel.id}", + "name": channel.name, + "topic": channel.topic, + "visibility": "private", + "invite": [sender], + "creation_content": {"m.federate": True}, + "initial_state": [ + { + "type": "m.room.join_rules", + "content": {"join_rule": "public"}, + }, + { + "type": "m.room.history_visibility", + "content": {"history_visibility": "shared"}, + }, + ], + "power_level_content_override": {"users": {sender: 100}}, + } + + resp = self.send("POST", "/createRoom", content) + + self.db.add_room(resp["room_id"], channel.id) + + def create_message_event( + self, message: str, emotes: dict, edit: str = "", reply: str = "" + ) -> dict: + content = { + "body": message, + "format": "org.matrix.custom.html", + "msgtype": "m.text", + "formatted_body": self.get_fmt(message, emotes), + } + + event = message_cache.get(reply) + + if event: + content = { + **content, + "m.relates_to": { + "m.in_reply_to": {"event_id": event["event_id"]} + }, + "formatted_body": f"""
\ +\ +In reply to\ +{event["mxid"]}
{event["body"]}
\ +{content["formatted_body"]}""", + } + + if edit: + content = { + **content, + "body": f" * {content['body']}", + "formatted_body": f" * {content['formatted_body']}", + "m.relates_to": {"event_id": edit, "rel_type": "m.replace"}, + "m.new_content": {**content}, + } + + return content + + def get_fmt(self, message: str, emotes: dict) -> str: + replace = [ + # Bold. + ("**", "", ""), + # Code blocks. + ("```", "
", "
"), + # Spoilers. + ("||", "", ""), + # Strikethrough. + ("~~", "", ""), + ] + + for replace_ in replace: + for i in range(1, message.count(replace_[0]) + 1): + if i % 2: + message = message.replace(replace_[0], replace_[1], 1) + else: + message = message.replace(replace_[0], replace_[2], 1) + + # Upload emotes in multiple threads so that we don't + # block the Discord bot for too long. + upload_threads = [ + threading.Thread( + target=self.upload_emote, args=(emote, emotes[emote]) + ) + for emote in emotes + ] + + [thread.start() for thread in upload_threads] + [thread.join() for thread in upload_threads] + + for emote in emotes: + emote_ = self.emote_cache.get(emote) + + if emote_: + emote = f":{emote}:" + message = message.replace( + emote, + f"""\"{emote}\"""", + ) + + return message + + def process_message(self, channel_id: str, message: str) -> str: + message = message[:2000] # Discord limit. + + emotes = re.findall(r":(\w*):", message) + mentions = re.findall(r"(@(\w*))", message) + + # Remove the puppet user's username from replies. + message = re.sub(f"<@{self.format}.+?>", "", message) + + added_emotes = [] + for emote in emotes: + # Don't replace emote names with IDs multiple times. + if emote not in added_emotes: + added_emotes.append(emote) + emote_ = self.discord.emote_cache.get(emote) + if emote_: + message = message.replace(f":{emote}:", emote_) + + # Don't unnecessarily fetch the channel. + if mentions: + guild_id = self.discord.get_channel(channel_id).guild_id + + for mention in mentions: + if not mention[1]: + continue + + try: + member = self.discord.query_member(guild_id, mention[1]) + except (asyncio.TimeoutError, RuntimeError): + continue + + if member: + message = message.replace(mention[0], member.mention) + + return message + + def upload_emote(self, emote_name: str, emote_id: str) -> None: + # There won't be a race condition here, since only a unique + # set of emotes are uploaded at a time. + if emote_name in self.emote_cache: + return + + emote_url = f"{discord.CDN_URL}/emojis/{emote_id}" + + # We don't want the message to be dropped entirely if an emote + # fails to upload for some reason. + try: + self.emote_cache[emote_name] = self.upload(emote_url) + except RequestError as e: + self.logger.warning(f"Failed to upload emote {emote_id}: {e}") + + def register(self, mxid: str) -> None: + """ + Register a dummy user on the homeserver. + """ + + content = { + "type": "m.login.application_service", + # "@test:localhost" -> "test" (Can't register with a full mxid.) + "username": mxid[1:].split(":")[0], + } + + resp = self.send("POST", "/register", content) + + self.db.add_user(resp["user_id"]) + + def set_avatar(self, avatar_url: str, mxid: str) -> None: + avatar_uri = self.upload(avatar_url) + + self.send( + "PUT", + f"/profile/{mxid}/avatar_url", + {"avatar_url": avatar_uri}, + params={"user_id": mxid}, + ) + + self.db.add_avatar(avatar_url, mxid) + + def set_nick(self, username: str, mxid: str) -> None: + self.send( + "PUT", + f"/profile/{mxid}/displayname", + {"displayname": username}, + params={"user_id": mxid}, + ) + + self.db.add_username(username, mxid) + + +class DiscordClient(Gateway): + def __init__( + self, appservice: MatrixClient, config: dict, http: urllib3.PoolManager + ) -> None: + super().__init__(http, config["discord_token"]) + + self.app = appservice + self.emote_cache: Dict[str, str] = {} + self.webhook_cache: Dict[str, discord.Webhook] = {} + + async def sync(self) -> None: + """ + Periodically compare the usernames and avatar URLs with Discord + and update if they differ. Also synchronise emotes. + """ + + def sync_emotes(guilds: set): + # We could store the emotes once and update according + # to gateway events but we're too lazy for that. + emotes = [] + + for guild in guilds: + [emotes.append(emote) for emote in (self.get_emotes(guild))] + + self.emote_cache.clear() # Clears deleted/renamed emotes. + + for emote in emotes: + self.emote_cache[f"{emote.name}"] = ( + f"<{'a' if emote.animated else ''}:" + f"{emote.name}:{emote.id}>" + ) + + def sync_users(guilds: set): + # TODO use websockets for this, using IDs from database. + + users = [] + + for guild in guilds: + [users.append(member) for member in self.get_members(guild)] + + db_users = self.app.db.list_users() + + # Convert a list of dicts: + # [ { "avatar_url": ... } ] + # to a dict that is indexable by Discord IDs: + # { "discord_id": { "avatar_url": ... } } + users_ = {} + + for user in db_users: + users_[user["mxid"].split("_")[-1].split(":")[0]] = {**user} + + for user in users: + user_ = users_.get(user.id) + + if not user_: + continue + + mxid = user_["mxid"] + username = f"{user.username}#{user.discriminator}" + + if user.avatar_url != user_["avatar_url"]: + self.logger.info( + f"Updating avatar for Discord user {user.id}." + ) + self.app.set_avatar(user.avatar_url, mxid) + + if username != user_["username"]: + self.logger.info( + f"Updating username for Discord user {user.id}." + ) + self.app.set_nick(username, mxid) + + while True: + guilds = set() # Avoid duplicates. + + try: + for channel in self.app.db.list_channels(): + guilds.add(self.get_channel(channel).guild_id) + + sync_emotes(guilds) + sync_users(guilds) + # Don't let the background task die. + except RequestError: + self.logger.exception( + "Ignoring exception during background sync:" + ) + + await asyncio.sleep(120) # Check every 2 minutes. + + async def start(self) -> None: + asyncio.ensure_future(self.sync()) + + await self.run() + + def to_return(self, message: discord.Message) -> bool: + return ( + message.channel_id not in self.app.db.list_channels() + or not message.author + or message.author.discriminator == "0000" + ) + + def matrixify(self, id: str, user: bool = False) -> str: + return ( + f"{'@' if user else '#'}{self.app.format}{id}:" + f"{self.app.server_name}" + ) + + def wrap(self, message: discord.Message) -> Tuple[str, str]: + """ + Get the room ID and the puppet's mxid for a given channel ID and a + Discord user. + """ + + mxid = self.matrixify(message.author.id, user=True) + room_id = self.app.get_room_id(self.matrixify(message.channel_id)) + + if not self.app.db.query_user(mxid): + self.logger.info( + f"Creating dummy user for Discord user {message.author.id}." + ) + self.app.register(mxid) + + self.app.set_nick( + f"{message.author.username}#" + f"{message.author.discriminator}", + mxid, + ) + + if message.author.avatar_url: + self.app.set_avatar(message.author.avatar_url, mxid) + + if mxid not in self.app.get_members(room_id): + self.logger.info(f"Inviting user {mxid} to room {room_id}.") + + self.app.send_invite(room_id, mxid) + self.app.join_room(room_id, mxid) + + return mxid, room_id + + def on_message_create(self, message: discord.Message) -> None: + if self.to_return(message): + return + + mxid, room_id = self.wrap(message) + + content, emotes = self.process_message(message) + + content = self.app.create_message_event( + content, emotes, reply=message.reference + ) + + message_cache[message.id] = { + "body": content["body"], + "event_id": self.app.send_message(room_id, content, mxid), + "mxid": mxid, + "room_id": room_id, + } + + def on_message_delete(self, message: discord.DeletedMessage) -> None: + event = message_cache.get(message.id) + + if not event: + return + + self.app.redact(event["event_id"], event["room_id"], event["mxid"]) + + message_cache.pop(message.id) + + def on_message_update(self, message: dict) -> None: + if self.to_return(message): + return + + event = message_cache.get(message.id) + + if not event: + return + + content, emotes = self.process_message(message) + + content = self.app.create_message_event( + content, emotes, edit=event["event_id"] + ) + + self.app.send_message(event["room_id"], content, event["mxid"]) + + def on_typing_start(self, typing: discord.Typing) -> None: + if typing.channel_id not in self.app.db.list_channels(): + return + + mxid = self.matrixify(typing.user_id, user=True) + room_id = self.app.get_room_id(self.matrixify(typing.channel_id)) + + if mxid not in self.app.get_members(room_id): + return + + self.app.send_typing(room_id, mxid) + + def get_webhook(self, channel_id: str, name: str) -> discord.Webhook: + """ + Get the webhook object for the first webhook that matches the specified + name in a given channel, create the webhook if it doesn't exist. + """ + + # Check the cache first. + webhook = self.webhook_cache.get(channel_id) + + if webhook: + return webhook + + webhooks = self.send("GET", f"/channels/{channel_id}/webhooks") + webhook = next( + ( + dict_cls(webhook, discord.Webhook) + for webhook in webhooks + if webhook["name"] == name + ), + None, + ) + + if not webhook: + webhook = self.create_webhook(channel_id, name) + + self.webhook_cache[channel_id] = webhook + + return webhook + + def process_message(self, message: discord.Message) -> Tuple[str, str]: + content = message.content + emotes = {} + regex = r"" + + # Mentions can either be in the form of `<@1234>` or `<@!1234>`. + for char in ("", "!"): + for member in message.mentions: + content = content.replace( + f"<@{char}{member.id}>", f"@{member.username}" + ) + + # `except_deleted` for invalid channels. + for channel in re.findall(r"<#([0-9]+)>", content): + channel_ = except_deleted(self.get_channel)(channel) + content = content.replace( + f"<#{channel}>", + f"#{channel_.name}" if channel_ else "deleted-channel", + ) + + # { "emote_name": "emote_id" } + for emote in re.findall(regex, content): + emotes[emote[0]] = emote[1] + + # Replace emote IDs with names. + content = re.sub(regex, r":\g<1>:", content) + + # Append attachments to message. + for attachment in message.attachments: + content += f"\n{attachment['url']}" + + return content, emotes + + +def config_gen(basedir: str, config_file: str) -> dict: + config_file = f"{basedir}/{config_file}" + + config_dict = { + "as_token": "my-secret-as-token", + "hs_token": "my-secret-hs-token", + "user_id": "appservice-discord", + "homeserver": "http://127.0.0.1:8008", + "server_name": "localhost", + "discord_token": "my-secret-discord-token", + "port": 5000, + "database": f"{basedir}/bridge.db", + } + + if not os.path.exists(config_file): + with open(config_file, "w") as f: + json.dump(config_dict, f, indent=4) + print(f"Configuration dumped to '{config_file}'") + sys.exit() + + with open(config_file, "r") as f: + return json.loads(f.read()) + + +def main() -> None: + try: + basedir = sys.argv[1] + if not os.path.exists(basedir): + print(f"Path '{basedir}' does not exist!") + sys.exit(1) + basedir = os.path.abspath(basedir) + except IndexError: + basedir = os.getcwd() + + config = config_gen(basedir, "appservice.json") + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[ + logging.FileHandler(f"{basedir}/appservice.log"), + ], + ) + + http = urllib3.PoolManager(maxsize=10) + + app = MatrixClient(config, http) + + # Start the bottle app in a separate thread. + app_thread = threading.Thread( + target=app.run, kwargs={"port": int(config["port"])}, daemon=True + ) + app_thread.start() + + try: + asyncio.run(app.discord.start()) + except KeyboardInterrupt: + sys.exit() + + +if __name__ == "__main__": + main() diff --git a/appservice/matrix.py b/appservice/matrix.py new file mode 100644 index 0000000..985052f --- /dev/null +++ b/appservice/matrix.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass + + +@dataclass +class User(object): + avatar_url: str + displayname: str + + +class Event(object): + def __init__(self, event: dict): + content = event["content"] + + self.author = event["author"] + self.body = content.get("body", "") + self.event_id = event["event_id"] + self.is_direct = content.get("is_direct", False) + self.room_id = event["room_id"] + self.sender = event["sender"] + self.state_key = event.get("state_key", "") + + rel = content.get("m.relates_to", {}) + + self.relates_to = rel.get("event_id") + self.reltype = rel.get("rel_type") + self.new_body = content.get("m.new_content", {}).get("body", "") diff --git a/appservice/misc.py b/appservice/misc.py new file mode 100644 index 0000000..45f7c67 --- /dev/null +++ b/appservice/misc.py @@ -0,0 +1,89 @@ +import asyncio +import json +from dataclasses import fields +from typing import Any + +import urllib3 + +from errors import RequestError + + +def dict_cls(dict_var: dict, cls: Any) -> Any: + """ + Create a dataclass from a dictionary. + """ + + field_names = set(f.name for f in fields(cls)) + filtered_dict = {k: v for k, v in dict_var.items() if k in field_names} + + return cls(**filtered_dict) + + +def log_except(fn): + """ + Log unhandled exceptions to a logger instead of `stderr`. + """ + + def wrapper(self, *args, **kwargs): + try: + return fn(self, *args, **kwargs) + except Exception: + self.logger.exception(f"Exception in '{fn.__name__}':") + raise + + return wrapper + + +def wrap_async(fn): + """ + Call an asynchronous function from a synchronous one. + """ + + def wrapper(self, *args, **kwargs): + if not self.loop: + raise RuntimeError("loop is None.") + + return asyncio.run_coroutine_threadsafe( + fn(self, *args, **kwargs), loop=self.loop + ).result() + + return wrapper + + +def request(fn): + """ + Either return json data or raise a `RequestError` if the request was + unsuccessful. + """ + + def wrapper(*args, **kwargs): + try: + resp = fn(*args, **kwargs) + except urllib3.exceptions.HTTPError as e: + raise RequestError(None, f"Failed to connect: {e}") from None + + if resp.status < 200 or resp.status >= 300: + raise RequestError( + resp.status, + f"Failed to get response from '{resp.geturl()}':\n{resp.data}", + ) + + return {} if resp.status == 204 else json.loads(resp.data) + + return wrapper + + +def except_deleted(fn): + """ + Ignore the `RequestError` on 404s, the message might have been + deleted by someone else already. + """ + + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except RequestError as e: + if e.status != 404: + raise + + return wrapper diff --git a/appservice/requirements.txt b/appservice/requirements.txt new file mode 100644 index 0000000..10a331b --- /dev/null +++ b/appservice/requirements.txt @@ -0,0 +1,3 @@ +bottle==0.12.19 +urllib3==1.26.3 +websockets==8.1