diff --git a/src/aioappsrv/appservice.py b/src/aioappsrv/appservice.py deleted file mode 100644 index 8dec182..0000000 --- a/src/aioappsrv/appservice.py +++ /dev/null @@ -1,203 +0,0 @@ -import json -import logging -import urllib.parse -import uuid -from typing import Union - -import bottle -import urllib3 - -import matrix -from cache import Cache -from misc import log_except, request - - -class AppService(bottle.Bottle): - def __init__(self, config: dict, http: urllib3.PoolManager) -> None: - super(AppService, self).__init__() - - self.as_token = config["as_token"] - self.hs_token = config["hs_token"] - self.base_url = config["homeserver"] - self.server_name = config["server_name"] - self.user_id = f"@{config['user_id']}:{self.server_name}" - self.http = http - self.logger = logging.getLogger("appservice") - - # Map events to functions. - self.mapping = { - "m.room.member": "on_member", - "m.room.message": "on_message", - "m.room.redaction": "on_redaction", - } - - # Add route for bottle. - self.route( - "/transactions/", - callback=self.receive_event, - method="PUT", - ) - - Cache.cache["m_rooms"] = {} - - def handle_event(self, event: dict) -> None: - event_type = event.get("type") - - if event_type in ( - "m.room.member", - "m.room.message", - "m.room.redaction", - ): - obj = matrix.Event(event) - else: - self.logger.info(f"Unknown event type: {event_type}") - return - - func = getattr(self, self.mapping[event_type], None) - - if not func: - self.logger.warning( - f"Function '{func}' not defined, ignoring event." - ) - return - - # We don't catch exceptions here as the homeserver will re-send us - # the event in case of a failure. - func(obj) - - @log_except - def receive_event(self, transaction: str) -> dict: - """ - Verify the homeserver's token and handle events. - """ - - hs_token = bottle.request.query.getone("access_token") - - if not hs_token: - bottle.response.status = 401 - return {"errcode": "APPSERVICE_UNAUTHORIZED"} - - if hs_token != self.hs_token: - bottle.response.status = 403 - return {"errcode": "APPSERVICE_FORBIDDEN"} - - events = bottle.request.json.get("events") - - for event in events: - self.handle_event(event) - - return {} - - def mxc_url(self, mxc: str) -> str: - try: - homeserver, media_id = mxc.replace("mxc://", "").split("/") - except ValueError: - return "" - - return ( - f"https://{self.server_name}/_matrix/media/r0/download/" - f"{homeserver}/{media_id}" - ) - - def join_room(self, room_id: str, mxid: str = "") -> None: - self.send( - "POST", - f"/join/{room_id}", - params={"user_id": mxid} if mxid else {}, - ) - - def redact(self, event_id: str, room_id: str, mxid: str = "") -> None: - self.send( - "PUT", - f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}", - params={"user_id": mxid} if mxid else {}, - ) - - def get_room_id(self, alias: str) -> str: - with Cache.lock: - room = Cache.cache["m_rooms"].get(alias) - if room: - return room - - resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}") - - room_id = resp["room_id"] - - with Cache.lock: - Cache.cache["m_rooms"][alias] = room_id - - return room_id - - def get_event(self, event_id: str, room_id: str) -> matrix.Event: - resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}") - - return matrix.Event(resp) - - def upload(self, url: str) -> str: - """ - Upload a file to the homeserver and get the MXC url. - """ - - resp = self.http.request("GET", url) - - resp = self.send( - "POST", - content=resp.data, - content_type=resp.headers.get("Content-Type"), - params={"filename": f"{uuid.uuid4()}"}, - endpoint="/_matrix/media/r0/upload", - ) - - return resp["content_uri"] - - def send_message( - self, - room_id: str, - content: dict, - mxid: str = "", - ) -> str: - resp = self.send( - "PUT", - f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}", - content, - {"user_id": mxid} if mxid else {}, - ) - - return resp["event_id"] - - def send_typing( - self, room_id: str, mxid: str = "", timeout: int = 8000 - ) -> None: - self.send( - "PUT", - f"/rooms/{room_id}/typing/{mxid}", - {"typing": True, "timeout": timeout}, - {"user_id": mxid} if mxid else {}, - ) - - def send_invite(self, room_id: str, mxid: str) -> None: - self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid}) - - @request - def send( - self, - method: str, - path: str = "", - content: Union[bytes, dict] = {}, - params: dict = {}, - content_type: str = "application/json", - endpoint: str = "/_matrix/client/r0", - ) -> dict: - headers = { - "Authorization": f"Bearer {self.as_token}", - "Content-Type": content_type, - } - payload = json.dumps(content) if isinstance(content, dict) else content - endpoint = ( - f"{self.base_url}{endpoint}{path}?" - f"{urllib.parse.urlencode(params)}" - ) - - return self.http.request( - method, endpoint, body=payload, headers=headers - ) diff --git a/src/aioappsrv/cache.py b/src/aioappsrv/cache.py deleted file mode 100644 index 536887b..0000000 --- a/src/aioappsrv/cache.py +++ /dev/null @@ -1,6 +0,0 @@ -import threading - - -class Cache: - cache = {} - lock = threading.Lock() diff --git a/src/aioappsrv/db.py b/src/aioappsrv/db.py deleted file mode 100644 index a699f61..0000000 --- a/src/aioappsrv/db.py +++ /dev/null @@ -1,120 +0,0 @@ -import os -import sqlite3 -import threading -from typing import List - - -class DataBase: - def __init__(self, db_file) -> None: - self.create(db_file) - - # The database is accessed via multiple threads. - self.lock = threading.Lock() - - def create(self, db_file) -> None: - """ - Create a database with the relevant tables if it doesn't already exist. - """ - - exists = os.path.exists(db_file) - - self.conn = sqlite3.connect(db_file, check_same_thread=False) - self.conn.row_factory = self.dict_factory - - self.cur = self.conn.cursor() - - if exists: - return - - self.cur.execute( - "CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);" - ) - - self.cur.execute( - "CREATE TABLE users(mxid TEXT PRIMARY KEY, " - "avatar_url TEXT, username TEXT);" - ) - - self.conn.commit() - - def dict_factory(self, cursor, row): - """ - https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory - """ - - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d - - def add_room(self, room_id: str, channel_id: str) -> None: - """ - Add a bridged room to the database. - """ - - with self.lock: - self.cur.execute( - "INSERT INTO bridge (room_id, channel_id) VALUES (?, ?)", - [room_id, channel_id], - ) - self.conn.commit() - - def add_user(self, mxid: str) -> None: - with self.lock: - self.cur.execute("INSERT INTO users (mxid) VALUES (?)", [mxid]) - self.conn.commit() - - def add_avatar(self, avatar_url: str, mxid: str) -> None: - with self.lock: - self.cur.execute( - "UPDATE users SET avatar_url = (?) WHERE mxid = (?)", - [avatar_url, mxid], - ) - self.conn.commit() - - def add_username(self, username: str, mxid: str) -> None: - with self.lock: - self.cur.execute( - "UPDATE users SET username = (?) WHERE mxid = (?)", - [username, mxid], - ) - self.conn.commit() - - def get_channel(self, room_id: str) -> str: - """ - Get the corresponding channel ID for a given room ID. - """ - - with self.lock: - self.cur.execute( - "SELECT channel_id FROM bridge WHERE room_id = ?", [room_id] - ) - - room = self.cur.fetchone() - - # Return an empty string if the channel is not bridged. - return "" if not room else room["channel_id"] - - def list_channels(self) -> List[str]: - """ - Get a list of all the bridged channels. - """ - - with self.lock: - self.cur.execute("SELECT channel_id FROM bridge") - - channels = self.cur.fetchall() - - return [channel["channel_id"] for channel in channels] - - def fetch_user(self, mxid: str) -> dict: - """ - Fetch the profile for a bridged user. - """ - - with self.lock: - self.cur.execute("SELECT * FROM users where mxid = ?", [mxid]) - - user = self.cur.fetchone() - - return {} if not user else user diff --git a/src/aioappsrv/errors.py b/src/aioappsrv/errors.py deleted file mode 100644 index 502b06a..0000000 --- a/src/aioappsrv/errors.py +++ /dev/null @@ -1,5 +0,0 @@ -class RequestError(Exception): - def __init__(self, status: int, *args): - super().__init__(*args) - - self.status = status diff --git a/src/aioappsrv/gateway.py b/src/aioappsrv/gateway.py deleted file mode 100644 index 16686d4..0000000 --- a/src/aioappsrv/gateway.py +++ /dev/null @@ -1,260 +0,0 @@ -import asyncio -import json -import logging -import urllib.parse -from typing import Dict, List - -import urllib3 -import websockets - -import discord -from misc import dict_cls, log_except, request - - -class Gateway: - def __init__(self, http: urllib3.PoolManager, token: str): - self.http = http - self.token = token - self.logger = logging.getLogger("discord") - self.Payloads = discord.Payloads(self.token) - self.websocket = None - - @log_except - async def run(self) -> None: - self.heartbeat_task: asyncio.Future = None - self.resume = False - - gateway_url = self.get_gateway_url() - - while True: - try: - await self.gateway_handler(gateway_url) - except ( - websockets.ConnectionClosedError, - websockets.InvalidMessage, - ): - self.logger.exception("Connection lost, reconnecting.") - - # Stop sending heartbeats until we reconnect. - if self.heartbeat_task and not self.heartbeat_task.cancelled(): - self.heartbeat_task.cancel() - - def get_gateway_url(self) -> str: - resp = self.send("GET", "/gateway") - - return resp["url"] - - async def heartbeat_handler(self, interval_ms: int) -> None: - while True: - await asyncio.sleep(interval_ms / 1000) - await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT())) - - async def handle_resp(self, data: dict) -> None: - data_dict = data["d"] - - opcode = data["op"] - - seq = data["s"] - - if seq: - self.Payloads.seq = seq - - if opcode == discord.GatewayOpCodes.DISPATCH: - otype = data["t"] - - if otype == "READY": - self.Payloads.session = data_dict["session_id"] - - self.logger.info("READY") - else: - self.handle_otype(data_dict, otype) - elif opcode == discord.GatewayOpCodes.HELLO: - heartbeat_interval = data_dict.get("heartbeat_interval") - - self.logger.info(f"Heartbeat Interval: {heartbeat_interval}") - - # Send periodic hearbeats to gateway. - self.heartbeat_task = asyncio.ensure_future( - self.heartbeat_handler(heartbeat_interval) - ) - - await self.websocket.send( - json.dumps( - self.Payloads.RESUME() - if self.resume - else self.Payloads.IDENTIFY() - ) - ) - elif opcode == discord.GatewayOpCodes.RECONNECT: - self.logger.info("Received RECONNECT.") - - self.resume = True - await self.websocket.close() - elif opcode == discord.GatewayOpCodes.INVALID_SESSION: - self.logger.info("Received INVALID_SESSION.") - - self.resume = False - await self.websocket.close() - elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK: - # NOP - pass - else: - self.logger.info( - "Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}" - ) - - def handle_otype(self, data: dict, otype: str) -> None: - if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"): - obj = discord.Message(data) - elif otype == "TYPING_START": - obj = dict_cls(data, discord.Typing) - elif otype == "GUILD_CREATE": - obj = discord.Guild(data) - elif otype == "GUILD_MEMBER_UPDATE": - obj = discord.GuildMemberUpdate(data) - elif otype == "GUILD_EMOJIS_UPDATE": - obj = discord.GuildEmojisUpdate(data) - else: - return - - func = getattr(self, f"on_{otype.lower()}", None) - - if not func: - self.logger.warning( - f"Function '{func}' not defined, ignoring message." - ) - return - - try: - func(obj) - except Exception: - self.logger.exception(f"Ignoring exception in '{func.__name__}':") - - async def gateway_handler(self, gateway_url: str) -> None: - async with websockets.connect( - f"{gateway_url}/?v=8&encoding=json" - ) as websocket: - self.websocket = websocket - - async for message in websocket: - await self.handle_resp(json.loads(message)) - - def get_channel(self, channel_id: str) -> discord.Channel: - """ - Get the channel for a given channel ID. - """ - - resp = self.send("GET", f"/channels/{channel_id}") - - return dict_cls(resp, discord.Channel) - - def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]: - """ - Get all channels for a given guild ID. - """ - - resp = self.send("GET", f"/guilds/{guild_id}/channels") - - return { - channel["id"]: dict_cls(channel, discord.Channel) - for channel in resp - } - - def get_emotes(self, guild_id: str) -> List[discord.Emote]: - """ - Get all the emotes for a given guild. - """ - - resp = self.send("GET", f"/guilds/{guild_id}/emojis") - - return [dict_cls(emote, discord.Emote) for emote in resp] - - def get_members(self, guild_id: str) -> List[discord.User]: - """ - Get all the members for a given guild. - """ - - resp = self.send( - "GET", f"/guilds/{guild_id}/members", params={"limit": 1000} - ) - - return [discord.User(member["user"]) for member in resp] - - def create_webhook(self, channel_id: str, name: str) -> discord.Webhook: - """ - Create a webhook with the specified name in a given channel. - """ - - resp = self.send( - "POST", f"/channels/{channel_id}/webhooks", {"name": name} - ) - - return dict_cls(resp, discord.Webhook) - - def edit_webhook( - self, content: str, message_id: str, webhook: discord.Webhook - ) -> None: - self.send( - "PATCH", - f"/webhooks/{webhook.id}/{webhook.token}/messages/" - f"{message_id}", - {"content": content}, - ) - - def delete_webhook( - self, message_id: str, webhook: discord.Webhook - ) -> None: - self.send( - "DELETE", - f"/webhooks/{webhook.id}/{webhook.token}/messages/" - f"{message_id}", - ) - - def send_webhook( - self, - webhook: discord.Webhook, - avatar_url: str, - content: str, - username: str, - ) -> discord.Message: - payload = { - "avatar_url": avatar_url, - "content": content, - "username": username, - # Disable 'everyone' and 'role' mentions. - "allowed_mentions": {"parse": ["users"]}, - } - - resp = self.send( - "POST", - f"/webhooks/{webhook.id}/{webhook.token}", - payload, - {"wait": True}, - ) - - return discord.Message(resp) - - def send_message(self, message: str, channel_id: str) -> None: - self.send( - "POST", f"/channels/{channel_id}/messages", {"content": message} - ) - - @request - def send( - self, method: str, path: str, content: dict = {}, params: dict = {} - ) -> dict: - endpoint = ( - f"https://discord.com/api/v8{path}?" - f"{urllib.parse.urlencode(params)}" - ) - headers = { - "Authorization": f"Bot {self.token}", - "Content-Type": "application/json", - } - - # 'body' being an empty dict breaks "GET" requests. - payload = json.dumps(content) if content else None - - return self.http.request( - method, endpoint, body=payload, headers=headers - ) diff --git a/src/aioappsrv/main.py b/src/aioappsrv/main.py deleted file mode 100644 index fb9fe75..0000000 --- a/src/aioappsrv/main.py +++ /dev/null @@ -1,800 +0,0 @@ -import asyncio -import json -import logging -import os -import re -import sys -import threading -import urllib.parse -from typing import Dict, List, Tuple - -import markdown -import urllib3 - -import discord -import matrix -from appservice import AppService -from cache import Cache -from db import DataBase -from errors import RequestError -from gateway import Gateway -from misc import dict_cls, except_deleted, hash_str - - -class MatrixClient(AppService): - def __init__(self, config: dict, http: urllib3.PoolManager) -> None: - super().__init__(config, http) - - self.db = DataBase(config["database"]) - self.discord = DiscordClient(self, config, http) - self.format = "_discord_" # "{@,#}_discord_1234:localhost" - self.id_regex = "[0-9]+" # Snowflakes may have variable length - - # TODO Find a cleaner way to use these keys. - for k in ("m_emotes", "m_members", "m_messages"): - Cache.cache[k] = {} - - def handle_bridge(self, message: matrix.Event) -> None: - # Ignore events that aren't for us. - if message.sender.split(":")[ - -1 - ] != self.server_name or not message.body.startswith("!bridge"): - return - - # Get the channel ID. - try: - channel = message.body.split()[1] - except IndexError: - return - - # Check if the given channel is valid. - try: - channel = self.discord.get_channel(channel) - except RequestError as e: - # The channel can be invalid or we may not have permissions. - self.logger.warning(f"Failed to fetch channel {channel}: {e}") - return - - if ( - channel.type != discord.ChannelType.GUILD_TEXT - or channel.id in self.db.list_channels() - ): - return - - self.logger.info(f"Creating bridged room for channel {channel.id}.") - - self.create_room(channel, message.sender) - - def on_member(self, event: matrix.Event) -> None: - with Cache.lock: - # Just lazily clear the whole member cache on - # membership update events. - if event.room_id in Cache.cache["m_members"]: - self.logger.info( - f"Clearing member cache for room '{event.room_id}'." - ) - del Cache.cache["m_members"][event.room_id] - - if ( - event.sender.split(":")[-1] != self.server_name - or event.state_key != self.user_id - or not event.is_direct - ): - return - - # Join the direct message room. - self.logger.info(f"Joining direct message room '{event.room_id}'.") - self.join_room(event.room_id) - - def on_message(self, message: matrix.Event) -> None: - if ( - message.sender.startswith((f"@{self.format}", self.user_id)) - or not message.body - ): - return - - # Handle bridging commands. - self.handle_bridge(message) - - channel_id = self.db.get_channel(message.room_id) - - if not channel_id: - return - - author = self.get_members(message.room_id)[message.sender] - - webhook = self.discord.get_webhook( - channel_id, self.discord.webhook_name - ) - - if message.relates_to and message.reltype == "m.replace": - with Cache.lock: - message_id = Cache.cache["m_messages"].get(message.relates_to) - - # TODO validate if the original author sent the edit. - - if not message_id or not message.new_body: - return - - message.new_body = self.process_message(message) - - except_deleted(self.discord.edit_webhook)( - message.new_body, message_id, webhook - ) - else: - message.body = ( - f"`{message.body}`: {self.mxc_url(message.attachment)}" - if message.attachment - else self.process_message(message) - ) - - message_id = self.discord.send_webhook( - webhook, - self.mxc_url(author.avatar_url) if author.avatar_url else None, - message.body, - author.display_name if author.display_name else message.sender, - ).id - - with Cache.lock: - Cache.cache["m_messages"][message.id] = message_id - - def on_redaction(self, event: matrix.Event) -> None: - with Cache.lock: - message_id = Cache.cache["m_messages"].get(event.redacts) - - if not message_id: - return - - webhook = self.discord.get_webhook( - self.db.get_channel(event.room_id), self.discord.webhook_name - ) - - except_deleted(self.discord.delete_webhook)(message_id, webhook) - - with Cache.lock: - del Cache.cache["m_messages"][event.redacts] - - def get_members(self, room_id: str) -> Dict[str, matrix.User]: - with Cache.lock: - cached = Cache.cache["m_members"].get(room_id) - - if cached: - return cached - - resp = self.send("GET", f"/rooms/{room_id}/joined_members") - - joined = resp["joined"] - - for k, v in joined.items(): - joined[k] = dict_cls(v, matrix.User) - - with Cache.lock: - Cache.cache["m_members"][room_id] = joined - - return joined - - def create_room(self, channel: discord.Channel, sender: str) -> None: - """ - Create a bridged room and invite the person who invoked the command. - """ - - content = { - "room_alias_name": f"{self.format}{channel.id}", - "name": channel.name, - "topic": channel.topic if channel.topic else channel.name, - "visibility": "private", - "invite": [sender], - "creation_content": {"m.federate": True}, - "initial_state": [ - { - "type": "m.room.join_rules", - "content": {"join_rule": "public"}, - }, - { - "type": "m.room.history_visibility", - "content": {"history_visibility": "shared"}, - }, - ], - "power_level_content_override": { - "users": {sender: 100, self.user_id: 100} - }, - } - - resp = self.send("POST", "/createRoom", content) - - self.db.add_room(resp["room_id"], channel.id) - - def create_message_event( - self, - message: str, - emotes: dict, - edit: str = "", - reference: discord.Message = None, - ) -> dict: - content = { - "body": message, - "msgtype": "m.text", - } - - fmt = self.get_fmt(message, emotes) - - if fmt != message: - content = { - **content, - "format": "org.matrix.custom.html", - "formatted_body": fmt, - } - - ref_id = None - - if reference: - # Reply to a Discord message. - with Cache.lock: - ref_id = Cache.cache["d_messages"].get(reference.id) - - # Reply to a Matrix message. (maybe) - if not ref_id: - with Cache.lock: - ref_id = [ - k - for k, v in Cache.cache["m_messages"].items() - if v == reference.id - ] - ref_id = next(iter(ref_id), "") - - if ref_id: - event = except_deleted(self.get_event)( - ref_id, - self.get_room_id(self.discord.matrixify(reference.channel_id)), - ) - if event: - # Content with the reply fallbacks stripped. - tmp = "" - # We don't want to strip lines starting with "> " after - # encountering a regular line, so we use this variable. - got_fallback = True - for line in event.body.split("\n"): - if not line.startswith("> "): - got_fallback = False - if not got_fallback: - tmp += line - - event.body = tmp - event.formatted_body = ( - # re.DOTALL allows the match to span newlines. - re.sub( - "", - "", - event.formatted_body, - flags=re.DOTALL, - ) - if event.formatted_body - else event.body - ) - - content = { - **content, - "body": ( - f"> <{event.sender}> {event.body}\n{content['body']}" - ), - "m.relates_to": {"m.in_reply_to": {"event_id": event.id}}, - "format": "org.matrix.custom.html", - "formatted_body": f"""
\ -In reply to\ -{event.sender}\ -
{event.formatted_body if event.formatted_body else event.body}\ -
\ -{content.get("formatted_body", content['body'])}""", - } - - if edit: - content = { - **content, - "body": f" * {content['body']}", - "formatted_body": f" * {content.get('formatted_body', content['body'])}", - "m.relates_to": {"event_id": edit, "rel_type": "m.replace"}, - "m.new_content": {**content}, - } - - return content - - def get_fmt(self, message: str, emotes: dict) -> str: - message = ( - markdown.markdown(message) - .replace("

", "") - .replace("

", "") - .replace("\n", "
") - ) - - # Upload emotes in multiple threads so that we don't - # block the Discord bot for too long. - upload_threads = [ - threading.Thread( - target=self.upload_emote, args=(emote, emotes[emote]) - ) - for emote in emotes - ] - - # Acquire the lock before starting the threads to avoid resource - # contention by tens of threads at once. - with Cache.lock: - for thread in upload_threads: - thread.start() - for thread in upload_threads: - thread.join() - - with Cache.lock: - for emote in emotes: - emote_ = Cache.cache["m_emotes"].get(emote) - - if emote_: - emote = f":{emote}:" - message = message.replace( - emote, - f"""\"{emote}\"""", - ) - - return message - - def mention_regex(self, encode: bool, id_as_group: bool) -> str: - mention = "@" - colon = ":" - snowflake = self.id_regex - - if encode: - mention = urllib.parse.quote(mention) - colon = urllib.parse.quote(colon) - - if id_as_group: - snowflake = f"({snowflake})" - - hashed = f"(?:-{snowflake})?" - - return f"{mention}{self.format}{snowflake}{hashed}{colon}{re.escape(self.server_name)}" - - def process_message(self, event: matrix.Event) -> str: - message = event.new_body if event.new_body else event.body - - emotes = re.findall(r":(\w*):", message) - - mentions = list( - re.finditer( - self.mention_regex(encode=False, id_as_group=True), - event.formatted_body, - ) - ) - # For clients that properly encode mentions. - # 'https://matrix.to/#/%40_discord_...%3Adomain.tld' - mentions.extend( - re.finditer( - self.mention_regex(encode=True, id_as_group=True), - event.formatted_body, - ) - ) - - with Cache.lock: - for emote in set(emotes): - emote_ = Cache.cache["d_emotes"].get(emote) - if emote_: - message = message.replace(f":{emote}:", emote_) - - for mention in set(mentions): - # Unquote just in-case we matched an encoded username. - username = self.db.fetch_user( - urllib.parse.unquote(mention.group(0)) - ).get("username") - if username: - if mention.group(2): - # Replace mention with plain text for hashed users (webhooks) - message = message.replace(mention.group(0), f"@{username}") - else: - # Replace the 'mention' so that the user is tagged - # in the case of replies aswell. - # '> <@_discord_1234:localhost> Message' - for replace in (mention.group(0), username): - message = message.replace( - replace, f"<@{mention.group(1)}>" - ) - - # We trim the message later as emotes take up extra characters too. - return message[: discord.MESSAGE_LIMIT] - - def upload_emote(self, emote_name: str, emote_id: str) -> None: - # There won't be a race condition here, since only a unique - # set of emotes are uploaded at a time. - if emote_name in Cache.cache["m_emotes"]: - return - - emote_url = f"{discord.CDN_URL}/emojis/{emote_id}" - - # We don't want the message to be dropped entirely if an emote - # fails to upload for some reason. - try: - # TODO This is not thread safe, but we're protected by the GIL. - Cache.cache["m_emotes"][emote_name] = self.upload(emote_url) - except RequestError as e: - self.logger.warning(f"Failed to upload emote {emote_id}: {e}") - - def register(self, mxid: str) -> None: - """ - Register a dummy user on the homeserver. - """ - - content = { - "type": "m.login.application_service", - # "@test:localhost" -> "test" (Can't register with a full mxid.) - "username": mxid[1:].split(":")[0], - } - - resp = self.send("POST", "/register", content) - - self.db.add_user(resp["user_id"]) - - def set_avatar(self, avatar_url: str, mxid: str) -> None: - avatar_uri = self.upload(avatar_url) - - self.send( - "PUT", - f"/profile/{mxid}/avatar_url", - {"avatar_url": avatar_uri}, - params={"user_id": mxid}, - ) - - self.db.add_avatar(avatar_url, mxid) - - def set_nick(self, username: str, mxid: str) -> None: - self.send( - "PUT", - f"/profile/{mxid}/displayname", - {"displayname": username}, - params={"user_id": mxid}, - ) - - self.db.add_username(username, mxid) - - -class DiscordClient(Gateway): - def __init__( - self, appservice: MatrixClient, config: dict, http: urllib3.PoolManager - ) -> None: - super().__init__(http, config["discord_token"]) - - self.app = appservice - self.webhook_name = "matrix_bridge" - - # TODO Find a cleaner way to use these keys. - for k in ("d_emotes", "d_messages", "d_webhooks"): - Cache.cache[k] = {} - - def to_return(self, message: discord.Message) -> bool: - with Cache.lock: - hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()] - - return ( - message.channel_id not in self.app.db.list_channels() - or not message.author # Embeds can be weird sometimes. - or message.webhook_id in hook_ids - ) - - def matrixify(self, id: str, user: bool = False, hashed: str = "") -> str: - return ( - f"{'@' if user else '#'}{self.app.format}" - f"{id}{'-' + hashed if hashed else ''}:" - f"{self.app.server_name}" - ) - - def sync_profile(self, user: discord.User, hashed: str = "") -> None: - """ - Sync the avatar and username for a puppeted user. - """ - - mxid = self.matrixify(user.id, user=True, hashed=hashed) - - profile = self.app.db.fetch_user(mxid) - - # User doesn't exist. - if not profile: - return - - username = f"{user.username}#{user.discriminator}" - - if user.avatar_url != profile["avatar_url"]: - self.logger.info(f"Updating avatar for Discord user '{user.id}'") - self.app.set_avatar(user.avatar_url, mxid) - if username != profile["username"]: - self.logger.info(f"Updating username for Discord user '{user.id}'") - self.app.set_nick(username, mxid) - - def wrap(self, message: discord.Message) -> Tuple[str, str]: - """ - Get the room ID and the puppet's mxid for a given channel ID and a - Discord user. - """ - - hashed = "" - if message.webhook_id and message.webhook_id != message.application_id: - hashed = str(hash_str(message.author.username)) - - mxid = self.matrixify(message.author.id, user=True, hashed=hashed) - room_id = self.app.get_room_id(self.matrixify(message.channel_id)) - - if not self.app.db.fetch_user(mxid): - self.logger.info( - f"Creating dummy user for Discord user {message.author.id}." - ) - self.app.register(mxid) - - self.app.set_nick( - f"{message.author.username}#" - f"{message.author.discriminator}", - mxid, - ) - - if message.author.avatar_url: - self.app.set_avatar(message.author.avatar_url, mxid) - - if mxid not in self.app.get_members(room_id): - self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.") - - self.app.send_invite(room_id, mxid) - self.app.join_room(room_id, mxid) - - if message.webhook_id: - # Sync webhooks here as they can't be accessed like guild members. - self.sync_profile(message.author, hashed=hashed) - - return mxid, room_id - - def cache_emotes(self, emotes: List[discord.Emote]): - # TODO maybe "namespace" emotes by guild in the cache ? - with Cache.lock: - for emote in emotes: - Cache.cache["d_emotes"][emote.name] = ( - f"<{'a' if emote.animated else ''}:" - f"{emote.name}:{emote.id}>" - ) - - def on_guild_create(self, guild: discord.Guild) -> None: - for member in guild.members: - self.sync_profile(member) - - self.cache_emotes(guild.emojis) - - def on_guild_emojis_update( - self, update: discord.GuildEmojisUpdate - ) -> None: - self.cache_emotes(update.emojis) - - def on_guild_member_update( - self, update: discord.GuildMemberUpdate - ) -> None: - self.sync_profile(update.user) - - def on_message_create(self, message: discord.Message) -> None: - if self.to_return(message): - return - - mxid, room_id = self.wrap(message) - - content_, emotes = self.process_message(message) - - content = self.app.create_message_event( - content_, emotes, reference=message.referenced_message - ) - - with Cache.lock: - Cache.cache["d_messages"][message.id] = self.app.send_message( - room_id, content, mxid - ) - - def on_message_delete(self, message: discord.Message) -> None: - with Cache.lock: - event_id = Cache.cache["d_messages"].get(message.id) - - if not event_id: - return - - room_id = self.app.get_room_id(self.matrixify(message.channel_id)) - event = except_deleted(self.app.get_event)(event_id, room_id) - - if event: - self.app.redact(event.id, event.room_id, event.sender) - - with Cache.lock: - del Cache.cache["d_messages"][message.id] - - def on_message_update(self, message: discord.Message) -> None: - if self.to_return(message): - return - - with Cache.lock: - event_id = Cache.cache["d_messages"].get(message.id) - - if not event_id: - return - - room_id = self.app.get_room_id(self.matrixify(message.channel_id)) - mxid = self.matrixify(message.author.id, user=True) - - # It is possible that a webhook edit's it's own old message - # after changing it's name, hence we generate a new mxid from - # the hashed username, but that mxid hasn't been registered before, - # so the request fails with: - # M_FORBIDDEN: Application service has not registered this user - if not self.app.db.fetch_user(mxid): - return - - content_, emotes = self.process_message(message) - - content = self.app.create_message_event( - content_, emotes, edit=event_id - ) - - self.app.send_message(room_id, content, mxid) - - def on_typing_start(self, typing: discord.Typing) -> None: - if typing.channel_id not in self.app.db.list_channels(): - return - - mxid = self.matrixify(typing.user_id, user=True) - room_id = self.app.get_room_id(self.matrixify(typing.channel_id)) - - if mxid not in self.app.get_members(room_id): - return - - self.app.send_typing(room_id, mxid) - - def get_webhook(self, channel_id: str, name: str) -> discord.Webhook: - """ - Get the webhook object for the first webhook that matches the specified - name in a given channel, create the webhook if it doesn't exist. - """ - - # Check the cache first. - with Cache.lock: - webhook = Cache.cache["d_webhooks"].get(channel_id) - - if webhook: - return webhook - - webhooks = self.send("GET", f"/channels/{channel_id}/webhooks") - webhook = next( - ( - dict_cls(webhook, discord.Webhook) - for webhook in webhooks - if webhook["name"] == name - ), - None, - ) - - if not webhook: - webhook = self.create_webhook(channel_id, name) - - with Cache.lock: - Cache.cache["d_webhooks"][channel_id] = webhook - - return webhook - - def process_message(self, message: discord.Message) -> Tuple[str, Dict]: - content = message.content - emotes = {} - regex = r"" - - # Mentions can either be in the form of `<@1234>` or `<@!1234>`. - for member in message.mentions: - for char in ("", "!"): - content = content.replace( - f"<@{char}{member.id}>", f"@{member.username}" - ) - - # Replace channel IDs with names. - channels = re.findall("<#([0-9]+)>", content) - if channels: - if not message.guild_id: - self.logger.warning( - f"Message '{message.id}' in channel '{message.channel_id}' does not have a guild_id!" - ) - else: - discord_channels = self.get_channels(message.guild_id) - for channel in channels: - discord_channel = discord_channels.get(channel) - name = ( - discord_channel.name - if discord_channel - else "deleted-channel" - ) - content = content.replace(f"<#{channel}>", f"#{name}") - - # { "emote_name": "emote_id" } - for emote in re.findall(regex, content): - emotes[emote[0]] = emote[1] - - # Replace emote IDs with names. - content = re.sub(regex, r":\g<1>:", content) - - # Append attachments to message. - for attachment in message.attachments: - content += f"\n{attachment['url']}" - - # Append stickers to message. - for sticker in message.stickers: - if sticker.format_type != 3: # 3 == Lottie format. - content += f"\n{discord.CDN_URL}/stickers/{sticker.id}.png" - - return content, emotes - - -def config_gen(basedir: str, config_file: str) -> dict: - config_file = f"{basedir}/{config_file}" - - config_dict = { - "as_token": "my-secret-as-token", - "hs_token": "my-secret-hs-token", - "user_id": "appservice-discord", - "homeserver": "http://127.0.0.1:8008", - "server_name": "localhost", - "discord_token": "my-secret-discord-token", - "port": 5000, - "database": f"{basedir}/bridge.db", - } - - if not os.path.exists(config_file): - with open(config_file, "w") as f: - json.dump(config_dict, f, indent=4) - print(f"Configuration dumped to '{config_file}'") - sys.exit() - - with open(config_file, "r") as f: - return json.loads(f.read()) - - -def excepthook(exc_type, exc_value, exc_traceback): - if issubclass(exc_type, KeyboardInterrupt): - sys.__excepthook__(exc_type, exc_value, exc_traceback) - return - - logging.critical( - "Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback) - ) - - -def main() -> None: - try: - basedir = sys.argv[1] - if not os.path.exists(basedir): - print(f"Path '{basedir}' does not exist!") - sys.exit(1) - basedir = os.path.abspath(basedir) - except IndexError: - basedir = os.getcwd() - - config = config_gen(basedir, "appservice.json") - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(name)s:%(levelname)s:%(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[ - logging.FileHandler(f"{basedir}/appservice.log"), - ], - ) - - sys.excepthook = excepthook - - app = MatrixClient(config, urllib3.PoolManager(maxsize=10)) - - # Start the bottle app in a separate thread. - app_thread = threading.Thread( - target=app.run, kwargs={"port": int(config["port"])}, daemon=True - ) - app_thread.start() - - try: - asyncio.run(app.discord.run()) - except KeyboardInterrupt: - sys.exit() - - -if __name__ == "__main__": - main() diff --git a/src/aioappsrv/misc.py b/src/aioappsrv/misc.py deleted file mode 100644 index be77c9d..0000000 --- a/src/aioappsrv/misc.py +++ /dev/null @@ -1,84 +0,0 @@ -import json -from dataclasses import fields -from typing import Any - -import urllib3 - -from errors import RequestError - - -def dict_cls(d: dict, cls: Any) -> Any: - """ - Create a dataclass from a dictionary. - """ - - field_names = set(f.name for f in fields(cls)) - filtered_dict = {k: v for k, v in d.items() if k in field_names} - - return cls(**filtered_dict) - - -def log_except(fn): - """ - Log unhandled exceptions to a logger instead of `stderr`. - """ - - def wrapper(self, *args, **kwargs): - try: - return fn(self, *args, **kwargs) - except Exception: - self.logger.exception(f"Exception in '{fn.__name__}':") - raise - - return wrapper - - -def request(fn): - """ - Either return json data or raise a `RequestError` if the request was - unsuccessful. - """ - - def wrapper(*args, **kwargs): - try: - resp = fn(*args, **kwargs) - except urllib3.exceptions.HTTPError as e: - raise RequestError(None, f"Failed to connect: {e}") from None - - if resp.status < 200 or resp.status >= 300: - raise RequestError( - resp.status, - f"Failed to get response from '{resp.geturl()}':\n{resp.data}", - ) - - return {} if resp.status == 204 else json.loads(resp.data) - - return wrapper - - -def except_deleted(fn): - """ - Ignore the `RequestError` on 404s, the content might have been removed. - """ - - def wrapper(*args, **kwargs): - try: - return fn(*args, **kwargs) - except RequestError as e: - if e.status != 404: - raise - - return wrapper - - -def hash_str(string: str) -> int: - """ - Create the hash for a string - """ - - hash = 5381 - - for ch in string: - hash = ((hash << 5) + hash) + ord(ch) - - return hash & 0xFFFFFFFF