diff --git a/appservice/README.md b/appservice/README.md index d0b0393..9069967 100644 --- a/appservice/README.md +++ b/appservice/README.md @@ -49,20 +49,23 @@ A path can optionally be passed as the first argument to `main.py`. This path wi Eg. Running `python3 main.py /path/to/my/dir` will store the database and logs in `/path/to/my/dir`. `$PWD` is used by default if no path is specified. +After setting up the bridge, send a direct message to `@appservice-discord:domain.tld` containing the channel ID to be bridged (`!bridge 123456`). + This bridge is written with: * `bottle`: Receiving events from the homeserver. * `urllib3`: Sending requests, thread safety. * `websockets`: Connecting to Discord. (Big thanks to an anonymous person "nesslersreagent" for figuring out the initial connection mess.) ## NOTES + * A basic sqlite database is used for keeping track of bridged rooms. * Logs are saved to the `appservice.log` file in `$PWD` or the specified directory. * For avatars to show up on Discord, you must have a [reverse proxy](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) set up on your homeserver as the bridge does not specify the homeserver port when passing the avatar url. -* It is not possible to add normal Discord bot functionality like commands as this bridge does not use `discord.py`. +* It is not possible to add "normal" Discord bot functionality like commands as this bridge does not use `discord.py`. -* [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) must be enabled for your Discord bot. +* [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) for members and presence must be enabled for your Discord bot. * This Appservice might not work well for bridging a large number of rooms since it is mostly synchronous. However, it wouldn't take much effort to port it to `asyncio` and `aiohttp` if desired. diff --git a/appservice/appservice.py b/appservice/appservice.py index e479104..c0f464e 100644 --- a/appservice/appservice.py +++ b/appservice/appservice.py @@ -2,13 +2,14 @@ import json import logging import urllib.parse import uuid -from typing import List, Union +from typing import Union import bottle import urllib3 import matrix -from misc import dict_cls, except_deleted, log_except, request +from cache import Cache +from misc import log_except, request class AppService(bottle.Bottle): @@ -37,13 +38,17 @@ class AppService(bottle.Bottle): method="PUT", ) + Cache.cache["m_rooms"] = {} + def handle_event(self, event: dict) -> None: event_type = event.get("type") - if event_type == "m.room.member" or event_type == "m.room.message": - obj = self.get_event_object(event) - elif event_type == "m.room.redaction": - obj = event + if event_type in ( + "m.room.member", + "m.room.message", + "m.room.redaction", + ): + obj = matrix.Event(event) else: self.logger.info(f"Unknown event type: {event_type}") return @@ -86,23 +91,14 @@ class AppService(bottle.Bottle): def mxc_url(self, mxc: str) -> str: try: homeserver, media_id = mxc.replace("mxc://", "").split("/") - converted = ( - f"https://{self.server_name}/_matrix/media/r0/download/" - f"{homeserver}/{media_id}" - ) except ValueError: - converted = "" + return "" - return converted - - def get_event_object(self, event: dict) -> matrix.Event: - # TODO use caching and invalidate old cache on member events. - event["author"] = dict_cls( - self.get_profile(event["sender"]), matrix.User + return ( + f"https://{self.server_name}/_matrix/media/r0/download/" + f"{homeserver}/{media_id}" ) - return matrix.Event(event) - def join_room(self, room_id: str, mxid: str = "") -> None: self.send( "POST", @@ -117,42 +113,25 @@ class AppService(bottle.Bottle): params={"user_id": mxid} if mxid else {}, ) - def get_profile(self, mxid: str) -> dict: - resp = except_deleted(self.send)("GET", f"/profile/{mxid}") - - # No profile exists for the user. - if not resp: - return {} - - avatar_url = resp.get("avatar_url") - - if avatar_url: - avatar_url = self.mxc_url(avatar_url) - - return { - "avatar_url": avatar_url, - "displayname": resp.get("displayname"), - } - - def get_members(self, room_id: str) -> List[str]: - resp = self.send( - "GET", - f"/rooms/{room_id}/members", - params={"membership": "join", "not_membership": "leave"}, - ) - - return [ - content["sender"] - for content in resp["chunk"] - if content["content"]["membership"] == "join" - ] - def get_room_id(self, alias: str) -> str: + with Cache.lock: + room = Cache.cache["m_rooms"].get(alias) + if room: + return room + resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}") - # TODO cache ? + room_id = resp["room_id"] - return resp["room_id"] + with Cache.lock: + Cache.cache["m_rooms"][alias] = room_id + + return room_id + + def get_event(self, event_id: str, room_id: str) -> matrix.Event: + resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}") + + return matrix.Event(resp) def upload(self, url: str) -> str: """ @@ -211,12 +190,12 @@ class AppService(bottle.Bottle): ) -> dict: params["access_token"] = self.as_token headers = {"Content-Type": content_type} - content = json.dumps(content) if isinstance(content, dict) else content + payload = json.dumps(content) if isinstance(content, dict) else content endpoint = ( f"{self.base_url}{endpoint}{path}?" f"{urllib.parse.urlencode(params)}" ) return self.http.request( - method, endpoint, body=content, headers=headers + method, endpoint, body=payload, headers=headers ) diff --git a/appservice/cache.py b/appservice/cache.py new file mode 100644 index 0000000..8d14443 --- /dev/null +++ b/appservice/cache.py @@ -0,0 +1,6 @@ +import threading + + +class Cache: + cache = {} + lock = threading.Lock() diff --git a/appservice/db.py b/appservice/db.py index 295fa81..b2603e7 100644 --- a/appservice/db.py +++ b/appservice/db.py @@ -4,11 +4,11 @@ import threading from typing import List -class DataBase(object): +class DataBase: def __init__(self, db_file) -> None: self.create(db_file) - # The database is accessed via both the threads. + # The database is accessed via multiple threads. self.lock = threading.Lock() def create(self, db_file) -> None: @@ -92,7 +92,7 @@ class DataBase(object): room = self.cur.fetchone() - # Return an empty string if nothing is bridged. + # Return an empty string if the channel is not bridged. return "" if not room else room["channel_id"] def list_channels(self) -> List[str]: @@ -116,6 +116,8 @@ class DataBase(object): self.cur.execute("SELECT * FROM users") users = self.cur.fetchall() - user = [user for user in users if user["mxid"] == mxid] + user: dict = next( + iter([user for user in users if user["mxid"] == mxid]), {} + ) - return user[0] if user else {} + return user diff --git a/appservice/discord.py b/appservice/discord.py index 3567416..6c5239b 100644 --- a/appservice/discord.py +++ b/appservice/discord.py @@ -1,10 +1,16 @@ from dataclasses import dataclass +from misc import dict_cls CDN_URL = "https://cdn.discordapp.com" +ID_LEN = 18 + + +def bitmask(bit: int) -> int: + return 1 << bit @dataclass -class Channel(object): +class Channel: id: str type: str guild_id: str = "" @@ -13,13 +19,32 @@ class Channel(object): @dataclass -class Emote(object): +class Emote: animated: bool id: str name: str -class User(object): +@dataclass +class MessageReference: + message_id: str + channel_id: str + guild_id: str + + +@dataclass +class Typing: + user_id: str + channel_id: str + + +@dataclass +class Webhook: + id: str + token: str + + +class User: def __init__(self, user: dict) -> None: self.discriminator = user["discriminator"] self.id = user["id"] @@ -38,45 +63,57 @@ class User(object): self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}" -class Message(object): +class Guild: + def __init__(self, guild: dict) -> None: + self.guild_id = guild["id"] + self.channels = [dict_cls(c, Channel) for c in guild["channels"]] + self.emojis = [dict_cls(e, Emote) for e in guild["emojis"]] + members = [member["user"] for member in guild["members"]] + self.members = [User(m) for m in members] + + +class GuildEmojisUpdate: + def __init__(self, update: dict) -> None: + self.guild_id = update["guild_id"] + self.emojis = [dict_cls(e, Emote) for e in update["emojis"]] + + +class GuildMembersChunk: + def __init__(self, chunk: dict) -> None: + self.chunk_index = chunk["chunk_index"] + self.chunk_count = chunk["chunk_count"] + self.guild_id = chunk["guild_id"] + self.members = [User(m) for m in chunk["members"]] + + +class GuildMemberUpdate: + def __init__(self, update: dict) -> None: + self.guild_id = update["guild_id"] + self.user = User(update["user"]) + + +class Message: def __init__(self, message: dict) -> None: self.attachments = message.get("attachments", []) self.channel_id = message["channel_id"] self.content = message.get("content", "") self.id = message["id"] - self.reference = message.get("message_reference", {}).get( - "message_id", "" - ) self.webhook_id = message.get("webhook_id", "") self.mentions = [ User(mention) for mention in message.get("mentions", []) ] + ref = message.get("message_reference") + + self.reference = dict_cls(ref, MessageReference) if ref else None + author = message.get("author") self.author = User(author) if author else None -@dataclass -class DeletedMessage(object): - channel_id: str - id: str - - -@dataclass -class Typing(object): - user_id: str - channel_id: str - - -@dataclass -class Webhook(object): - id: str - token: str - - -class ChannelType(object): +class ChannelType: GUILD_TEXT = 0 DM = 1 GUILD_VOICE = 2 @@ -86,7 +123,7 @@ class ChannelType(object): GUILD_STORE = 6 -class InteractionResponseType(object): +class InteractionResponseType: PONG = 0 ACKNOWLEDGE = 1 CHANNEL_MESSAGE = 2 @@ -94,10 +131,7 @@ class InteractionResponseType(object): ACKNOWLEDGE_WITH_SOURCE = 5 -class GatewayIntents(object): - def bitmask(bit: int) -> int: - return 1 << bit - +class GatewayIntents: GUILDS = bitmask(0) GUILD_MEMBERS = bitmask(1) GUILD_BANS = bitmask(2) @@ -115,7 +149,7 @@ class GatewayIntents(object): DIRECT_MESSAGE_TYPING = bitmask(14) -class GatewayOpCodes(object): +class GatewayOpCodes: DISPATCH = 0 HEARTBEAT = 1 IDENTIFY = 2 @@ -129,7 +163,7 @@ class GatewayOpCodes(object): HEARTBEAT_ACK = 11 -class Payloads(object): +class Payloads: def __init__(self, token: str) -> None: self.seq = self.session = None self.token = token @@ -143,27 +177,19 @@ class Payloads(object): "d": { "token": self.token, "intents": GatewayIntents.GUILDS + | GatewayIntents.GUILD_EMOJIS + | GatewayIntents.GUILD_MEMBERS | GatewayIntents.GUILD_MESSAGES - | GatewayIntents.GUILD_MESSAGE_TYPING, + | GatewayIntents.GUILD_MESSAGE_TYPING + | GatewayIntents.GUILD_PRESENCES, "properties": { "$os": "discord", - "$browser": "discord", + "$browser": "Discord Client", "$device": "discord", }, }, } - def QUERY(self, guild_id: str, query: str, limit: int = 1) -> dict: - """ - Return the Payload to query a member from a guild ID. - Return only a single match if `limit` isn't specified. - """ - - return { - "op": GatewayOpCodes.REQUEST_GUILD_MEMBERS, - "d": {"guild_id": guild_id, "query": query, "limit": limit}, - } - def RESUME(self) -> dict: return { "op": GatewayOpCodes.RESUME, diff --git a/appservice/gateway.py b/appservice/gateway.py index e1f70c3..7fd2c08 100644 --- a/appservice/gateway.py +++ b/appservice/gateway.py @@ -8,31 +8,27 @@ import urllib3 import websockets import discord -from misc import dict_cls, log_except, request, wrap_async +from misc import dict_cls, log_except, request -class Gateway(object): +class Gateway: def __init__(self, http: urllib3.PoolManager, token: str): self.http = http self.token = token self.logger = logging.getLogger("discord") - self.cdn_url = "https://cdn.discordapp.com" self.Payloads = discord.Payloads(self.token) - self.loop = self.websocket = None - - self.query_cache = {} + self.websocket = None @log_except async def run(self) -> None: - self.loop = asyncio.get_running_loop() - self.query_ev = asyncio.Event() - - self.heartbeat_task = None + self.heartbeat_task: asyncio.Future = None self.resume = False + gateway_url = self.get_gateway_url() + while True: try: - await self.gateway_handler(self.get_gateway_url()) + await self.gateway_handler(gateway_url) except ( websockets.ConnectionClosedError, websockets.InvalidMessage, @@ -55,26 +51,71 @@ class Gateway(object): await asyncio.sleep(interval_ms / 1000) await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT())) - def query_handler(self, data: dict) -> None: - members = data["members"] - guild_id = data["guild_id"] + async def handle_resp(self, data: dict) -> None: + data_dict = data["d"] - for member in members: - user = member["user"] - self.query_cache[guild_id].append(user) + opcode = data["op"] - self.query_ev.set() + seq = data["s"] + + if seq: + self.Payloads.seq = seq + + if opcode == discord.GatewayOpCodes.DISPATCH: + otype = data["t"] + + if otype == "READY": + self.Payloads.session = data_dict["session_id"] + + self.logger.info("READY") + else: + self.handle_otype(data_dict, otype) + elif opcode == discord.GatewayOpCodes.HELLO: + heartbeat_interval = data_dict.get("heartbeat_interval") + + self.logger.info(f"Heartbeat Interval: {heartbeat_interval}") + + # Send periodic hearbeats to gateway. + self.heartbeat_task = asyncio.ensure_future( + self.heartbeat_handler(heartbeat_interval) + ) + + await self.websocket.send( + json.dumps( + self.Payloads.RESUME() + if self.resume + else self.Payloads.IDENTIFY() + ) + ) + elif opcode == discord.GatewayOpCodes.RECONNECT: + self.logger.info("Received RECONNECT.") + + self.resume = True + await self.websocket.close() + elif opcode == discord.GatewayOpCodes.INVALID_SESSION: + self.logger.info("Received INVALID_SESSION.") + + self.resume = False + await self.websocket.close() + elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK: + # NOP + pass + else: + self.logger.info( + "Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}" + ) def handle_otype(self, data: dict, otype: str) -> None: - if otype == "MESSAGE_CREATE" or otype == "MESSAGE_UPDATE": + if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"): obj = discord.Message(data) - elif otype == "MESSAGE_DELETE": - obj = dict_cls(data, discord.DeletedMessage) elif otype == "TYPING_START": obj = dict_cls(data, discord.Typing) - elif otype == "GUILD_MEMBERS_CHUNK": - self.query_handler(data) - return + elif otype == "GUILD_CREATE": + obj = discord.Guild(data) + elif otype == "GUILD_MEMBER_UPDATE": + obj = discord.GuildMemberUpdate(data) + elif otype == "GUILD_EMOJIS_UPDATE": + obj = discord.GuildEmojisUpdate(data) else: self.logger.info(f"Unknown OTYPE: {otype}") return @@ -90,119 +131,20 @@ class Gateway(object): try: func(obj) except Exception: - self.logger.exception(f"Ignoring exception in {func}:") + self.logger.exception(f"Ignoring exception in '{func.__name__}':") async def gateway_handler(self, gateway_url: str) -> None: async with websockets.connect( f"{gateway_url}/?v=8&encoding=json" ) as websocket: self.websocket = websocket + async for message in websocket: - data = json.loads(message) - data_dict = data.get("d") - - opcode = data.get("op") - - seq = data.get("s") - if seq: - self.Payloads.seq = seq - - if opcode == discord.GatewayOpCodes.DISPATCH: - otype = data.get("t") - - if otype == "READY": - self.Payloads.session = data_dict["session_id"] - - self.logger.info("READY") - - else: - self.handle_otype(data_dict, otype) - - elif opcode == discord.GatewayOpCodes.HELLO: - heartbeat_interval = data_dict.get("heartbeat_interval") - - self.logger.info( - f"Heartbeat Interval: {heartbeat_interval}" - ) - - # Send periodic hearbeats to gateway. - self.heartbeat_task = asyncio.ensure_future( - self.heartbeat_handler(heartbeat_interval) - ) - - await websocket.send( - json.dumps( - self.Payloads.RESUME() - if self.resume - else self.Payloads.IDENTIFY() - ) - ) - - elif opcode == discord.GatewayOpCodes.RECONNECT: - self.logger.info("Received RECONNECT.") - - self.resume = True - await websocket.close() - - elif opcode == discord.GatewayOpCodes.INVALID_SESSION: - self.logger.info("Received INVALID_SESSION.") - - self.resume = False - await websocket.close() - - elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK: - # NOP - pass - - else: - self.logger.info( - f"Unknown OP code {opcode}:\n" - f"{json.dumps(data, indent=4)}" - ) - - @wrap_async - async def query_member(self, guild_id: str, name: str) -> discord.User: - """ - Query the members for a given guild and return the first match. - """ - - self.query_ev.clear() - - def query(): - if not self.query_cache.get(guild_id): - self.query_cache[guild_id] = [] - - user = [ - user - for user in self.query_cache[guild_id] - if name.lower() in user["username"].lower() - ] - - return None if not user else discord.User(user[0]) - - user = query() - - if user: - return user - - if not self.websocket or self.websocket.closed: - self.logger.warning("Not fetching members, websocket closed.") - return - - await self.websocket.send( - json.dumps(self.Payloads.QUERY(guild_id, name)) - ) - - # TODO clean this mess. - - # Wait for our websocket to receive the chunk. - await asyncio.wait_for(self.query_ev.wait(), timeout=5) - - return query() + await self.handle_resp(json.loads(message)) def get_channel(self, channel_id: str) -> discord.Channel: """ - Get the channel object for a given channel ID. + Get the channel for a given channel ID. """ resp = self.send("GET", f"/channels/{channel_id}") @@ -259,9 +201,17 @@ class Gateway(object): f"{message_id}", ) - def send_webhook(self, webhook: discord.Webhook, **kwargs) -> str: - content = { - **kwargs, + def send_webhook( + self, + webhook: discord.Webhook, + avatar_url: str, + content: str, + username: str, + ) -> discord.Message: + payload = { + "avatar_url": avatar_url, + "content": content, + "username": username, # Disable 'everyone' and 'role' mentions. "allowed_mentions": {"parse": ["users"]}, } @@ -269,11 +219,11 @@ class Gateway(object): resp = self.send( "POST", f"/webhooks/{webhook.id}/{webhook.token}", - content, + payload, {"wait": True}, ) - return resp["id"] + return discord.Message(resp) def send_message(self, message: str, channel_id: str) -> None: self.send( @@ -294,8 +244,8 @@ class Gateway(object): } # 'body' being an empty dict breaks "GET" requests. - content = json.dumps(content) if content else None + payload = json.dumps(content) if content else None return self.http.request( - method, endpoint, body=content, headers=headers + method, endpoint, body=payload, headers=headers ) diff --git a/appservice/main.py b/appservice/main.py index bbdec7a..b2b6c75 100644 --- a/appservice/main.py +++ b/appservice/main.py @@ -5,21 +5,19 @@ import os import re import sys import threading -from typing import Dict, Tuple, Union +from typing import Dict, List, Tuple import urllib3 import discord import matrix from appservice import AppService +from cache import Cache from db import DataBase from errors import RequestError from gateway import Gateway from misc import dict_cls, except_deleted, hash_str -# TODO should this be cleared periodically ? -message_cache: Dict[str, Union[discord.Webhook, str]] = {} - class MatrixClient(AppService): def __init__(self, config: dict, http: urllib3.PoolManager) -> None: @@ -27,9 +25,11 @@ class MatrixClient(AppService): self.db = DataBase(config["database"]) self.discord = DiscordClient(self, config, http) - self.emote_cache: Dict[str, str] = {} self.format = "_discord_" # "{@,#}_discord_1234:localhost" + for k in ("m_emotes", "m_members", "m_messages"): + Cache.cache[k] = {} + def handle_bridge(self, message: matrix.Event) -> None: # Ignore events that aren't for us. if message.sender.split(":")[ @@ -37,6 +37,7 @@ class MatrixClient(AppService): ] != self.server_name or not message.body.startswith("!bridge"): return + # Get the channel ID. try: channel = message.body.split()[1] except IndexError: @@ -46,7 +47,7 @@ class MatrixClient(AppService): try: channel = self.discord.get_channel(channel) except RequestError as e: - # The channel can be invalid or we may not have permission. + # The channel can be invalid or we may not have permissions. self.logger.warning(f"Failed to fetch channel {channel}: {e}") return @@ -61,7 +62,15 @@ class MatrixClient(AppService): self.create_room(channel, message.sender) def on_member(self, event: matrix.Event) -> None: - # Ignore events that aren't for us. + with Cache.lock: + # Just lazily clear the whole member cache on + # membership update events. + if event.room_id in Cache.cache["m_members"]: + self.logger.info( + f"Clearing member cache for room '{event.room_id}'." + ) + del Cache.cache["m_members"][event.room_id] + if ( event.sender.split(":")[-1] != self.server_name or event.state_key != self.user_id @@ -70,7 +79,7 @@ class MatrixClient(AppService): return # Join the direct message room. - self.logger.info(f"Joining direct message room {event.room_id}.") + self.logger.info(f"Joining direct message room '{event.room_id}'.") self.join_room(event.room_id) def on_message(self, message: matrix.Event) -> None: @@ -88,51 +97,78 @@ class MatrixClient(AppService): if not channel_id: return - webhook = self.discord.get_webhook(channel_id, "matrix_bridge") + author = self.get_members(message.room_id)[message.sender] + + if not author.display_name: + author.display_name = message.sender + + webhook = self.discord.get_webhook( + channel_id, self.discord.webhook_name + ) if message.relates_to and message.reltype == "m.replace": - relation = message_cache.get(message.relates_to) + with Cache.lock: + message_id = Cache.cache["m_messages"].get(message.relates_to) - if not message.new_body or not relation: + if not message_id or not message.new_body: return - message.new_body = self.process_message( - channel_id, message.new_body - ) + message.new_body = self.process_message(message) except_deleted(self.discord.edit_webhook)( - message.new_body, relation["message_id"], webhook + message.new_body, message_id, webhook ) - else: message.body = ( f"`{message.body}`: {self.mxc_url(message.attachment)}" if message.attachment - else self.process_message(channel_id, message.body) + else self.process_message(message) ) - message_cache[message.event_id] = { - "message_id": self.discord.send_webhook( - webhook, - avatar_url=message.author.avatar_url, - content=message.body, - username=message.author.displayname, - ), - "webhook": webhook, - } + message_id = self.discord.send_webhook( + webhook, + self.mxc_url(author.avatar_url), + message.body, + author.display_name, + ).id - @except_deleted - def on_redaction(self, event: dict) -> None: - redacts = event["redacts"] + with Cache.lock: + Cache.cache["m_messages"][message.id] = message_id - event = message_cache.get(redacts) + def on_redaction(self, event: matrix.Event) -> None: + with Cache.lock: + message_id = Cache.cache["m_messages"].get(event.redacts) - if not event: + if not message_id: return - self.discord.delete_webhook(event["message_id"], event["webhook"]) + webhook = self.discord.get_webhook( + self.db.get_channel(event.room_id), self.discord.webhook_name + ) - message_cache.pop(redacts) + except_deleted(self.discord.delete_webhook)(message_id, webhook) + + with Cache.lock: + del Cache.cache["m_messages"][event.redacts] + + def get_members(self, room_id: str) -> Dict[str, matrix.User]: + with Cache.lock: + cached = Cache.cache["m_members"].get(room_id) + + if cached: + return cached + + resp = self.send("GET", f"/rooms/{room_id}/joined_members") + + joined = resp["joined"] + + for k, v in joined.items(): + joined[k] = dict_cls(v, matrix.User) + + with Cache.lock: + Cache.cache["m_members"][room_id] = joined + + return joined def create_room(self, channel: discord.Channel, sender: str) -> None: """ @@ -164,7 +200,11 @@ class MatrixClient(AppService): self.db.add_room(resp["room_id"], channel.id) def create_message_event( - self, message: str, emotes: dict, edit: str = "", reply: str = "" + self, + message: str, + emotes: dict, + edit: str = "", + reference: discord.MessageReference = None, ) -> dict: content = { "body": message, @@ -173,20 +213,39 @@ class MatrixClient(AppService): "formatted_body": self.get_fmt(message, emotes), } - event = message_cache.get(reply) + if reference: + # Reply to a Discord message. + with Cache.lock: + event_id = Cache.cache["d_messages"].get(reference.message_id) - if event: - content = { - **content, - "m.relates_to": { - "m.in_reply_to": {"event_id": event["event_id"]} - }, - "formatted_body": f"""
\ -\ -In reply to\ -{event["mxid"]}
{event["body"]}
\ + # Reply to a Matrix message. (maybe) + if not event_id: + with Cache.lock: + event_id = [ + k + for k, v in Cache.cache["m_messages"].items() + if v == reference.message_id + ] + event_id = next(iter(event_id), "") + + if reference and event_id: + event = except_deleted(self.get_event)( + event_id, + self.get_room_id(self.discord.matrixify(reference.channel_id)), + ) + if event: + content = { + **content, + "body": ( + f"> <{event.sender}> {event.body}\n{content['body']}" + ), + "m.relates_to": {"m.in_reply_to": {"event_id": event.id}}, + "formatted_body": f"""
\ +\ +In reply to\ +{event.sender}
{event.formatted_body}
\ {content["formatted_body"]}""", - } + } if edit: content = { @@ -227,63 +286,67 @@ In reply to\ for emote in emotes ] - [thread.start() for thread in upload_threads] - [thread.join() for thread in upload_threads] + # Acquire the lock before starting the threads to avoid resource + # contention by tens of threads at once. + with Cache.lock: + for thread in upload_threads: + thread.start() + for thread in upload_threads: + thread.join() - for emote in emotes: - emote_ = self.emote_cache.get(emote) + with Cache.lock: + for emote in emotes: + emote_ = Cache.cache["m_emotes"].get(emote) - if emote_: - emote = f":{emote}:" - message = message.replace( - emote, - f"""\"{emote}\"""", - ) + ) return message - def process_message(self, channel_id: str, message: str) -> str: + def process_message(self, event: matrix.Event) -> str: + message = event.new_body if event.new_body else event.body + message = message[:2000] # Discord limit. + id_regex = f"[0-9]{{{discord.ID_LEN}}}" + emotes = re.findall(r":(\w*):", message) - mentions = re.findall(r"(@(\w*))", message) + mentions = re.findall( + f"@{self.format}{id_regex}:{re.escape(self.server_name)}", + event.formatted_body, + ) - # Remove the puppet user's username from replies. - message = re.sub(f"<@{self.format}.+?>", "", message) - - added_emotes = [] - for emote in emotes: - # Don't replace emote names with IDs multiple times. - if emote not in added_emotes: - added_emotes.append(emote) - emote_ = self.discord.emote_cache.get(emote) + with Cache.lock: + for emote in set(emotes): + emote_ = Cache.cache["d_emotes"].get(emote) if emote_: message = message.replace(f":{emote}:", emote_) - # Don't unnecessarily fetch the channel. - if mentions: - guild_id = self.discord.get_channel(channel_id).guild_id + for mention in set(mentions): + username = self.db.fetch_user(mention).get("username") + if username: + match = re.search(id_regex, mention) - # TODO this can block for too long if a long list is to be fetched. - for mention in mentions: - if not mention[1]: - continue - - try: - member = self.discord.query_member(guild_id, mention[1]) - except (asyncio.TimeoutError, RuntimeError): - continue - - if member: - message = message.replace(mention[0], member.mention) + if match: + # Replace the 'mention' so that the user is tagged + # in the case of replies aswell. + # '> <@_discord_1234:localhost> Message' + for replace in (mention, username): + message = message.replace( + replace, f"<@{match.group()}>" + ) return message def upload_emote(self, emote_name: str, emote_id: str) -> None: # There won't be a race condition here, since only a unique # set of emotes are uploaded at a time. - if emote_name in self.emote_cache: + if emote_name in Cache.cache["m_emotes"]: return emote_url = f"{discord.CDN_URL}/emojis/{emote_id}" @@ -291,7 +354,8 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""", # We don't want the message to be dropped entirely if an emote # fails to upload for some reason. try: - self.emote_cache[emote_name] = self.upload(emote_url) + # TODO This is not thread safe, but we're protected by the GIL. + Cache.cache["m_emotes"][emote_name] = self.upload(emote_url) except RequestError as e: self.logger.warning(f"Failed to upload emote {emote_id}: {e}") @@ -340,66 +404,19 @@ class DiscordClient(Gateway): super().__init__(http, config["discord_token"]) self.app = appservice - self.emote_cache: Dict[str, str] = {} - self.webhook_cache: Dict[str, discord.Webhook] = {} + self.webhook_name = "matrix_bridge" - async def sync(self) -> None: - """ - Periodically compare the usernames and avatar URLs with Discord - and update if they differ. Also synchronise emotes. - """ - - # TODO use websocket events and requests. - - def sync_emotes(guilds: set): - emotes = [] - - for guild in guilds: - [emotes.append(emote) for emote in (self.get_emotes(guild))] - - self.emote_cache.clear() # Clear deleted/renamed emotes. - - for emote in emotes: - self.emote_cache[f"{emote.name}"] = ( - f"<{'a' if emote.animated else ''}:" - f"{emote.name}:{emote.id}>" - ) - - def sync_users(guilds: set): - for guild in guilds: - [ - self.sync_profile(user, self.matrixify(user.id, user=True)) - for user in self.get_members(guild) - ] - - while True: - guilds = set() # Avoid duplicates. - - try: - for channel in self.app.db.list_channels(): - guilds.add(self.get_channel(channel).guild_id) - - sync_emotes(guilds) - sync_users(guilds) - # Don't let the background task die. - except RequestError: - self.logger.exception( - "Ignoring exception during background sync:" - ) - - await asyncio.sleep(120) # Check every 2 minutes. - - async def start(self) -> None: - asyncio.ensure_future(self.sync()) - - await self.run() + for k in ("d_emotes", "d_messages", "d_webhooks"): + Cache.cache[k] = {} def to_return(self, message: discord.Message) -> bool: + with Cache.lock: + hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()] + return ( message.channel_id not in self.app.db.list_channels() or not message.author # Embeds can be weird sometimes. - or message.webhook_id - in [hook.id for hook in self.webhook_cache.values()] + or message.webhook_id in hook_ids ) def matrixify(self, id: str, user: bool = False) -> str: @@ -408,11 +425,13 @@ class DiscordClient(Gateway): f"{self.app.server_name}" ) - def sync_profile(self, user: discord.User, mxid: str) -> None: + def sync_profile(self, user: discord.User) -> None: """ Sync the avatar and username for a puppeted user. """ + mxid = self.matrixify(user.id, user=True) + profile = self.app.db.fetch_user(mxid) # User doesn't exist. @@ -422,10 +441,10 @@ class DiscordClient(Gateway): username = f"{user.username}#{user.discriminator}" if user.avatar_url != profile["avatar_url"]: - self.logger.info(f"Updating avatar for Discord user {user.id}") + self.logger.info(f"Updating avatar for Discord user '{user.id}'") self.app.set_avatar(user.avatar_url, mxid) if username != profile["username"]: - self.logger.info(f"Updating username for Discord user {user.id}") + self.logger.info(f"Updating username for Discord user '{user.id}'") self.app.set_nick(username, mxid) def wrap(self, message: discord.Message) -> Tuple[str, str]: @@ -457,62 +476,95 @@ class DiscordClient(Gateway): self.app.set_avatar(message.author.avatar_url, mxid) if mxid not in self.app.get_members(room_id): - self.logger.info(f"Inviting user {mxid} to room {room_id}.") + self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.") self.app.send_invite(room_id, mxid) self.app.join_room(room_id, mxid) if message.webhook_id: # Sync webhooks here as they can't be accessed like guild members. - self.sync_profile(message.author, mxid) + self.sync_profile(message.author) return mxid, room_id + def cache_emotes(self, emotes: List[discord.Emote]): + # TODO maybe "namespace" emotes by guild in the cache ? + with Cache.lock: + for emote in emotes: + Cache.cache["d_emotes"][emote.name] = ( + f"<{'a' if emote.animated else ''}:" + f"{emote.name}:{emote.id}>" + ) + + def on_guild_create(self, guild: discord.Guild) -> None: + for member in guild.members: + self.sync_profile(member) + + self.cache_emotes(guild.emojis) + + def on_guild_emojis_update( + self, update: discord.GuildEmojisUpdate + ) -> None: + self.cache_emotes(update.emojis) + + def on_guild_member_update( + self, update: discord.GuildMemberUpdate + ) -> None: + self.sync_profile(update.user) + def on_message_create(self, message: discord.Message) -> None: if self.to_return(message): return mxid, room_id = self.wrap(message) - content, emotes = self.process_message(message) + content_, emotes = self.process_message(message) content = self.app.create_message_event( - content, emotes, reply=message.reference + content_, emotes, reference=message.reference ) - message_cache[message.id] = { - "body": content["body"], - "event_id": self.app.send_message(room_id, content, mxid), - "mxid": mxid, - "room_id": room_id, - } + with Cache.lock: + Cache.cache["d_messages"][message.id] = self.app.send_message( + room_id, content, mxid + ) - def on_message_delete(self, message: discord.DeletedMessage) -> None: - event = message_cache.get(message.id) + def on_message_delete(self, message: discord.Message) -> None: + with Cache.lock: + event_id = Cache.cache["d_messages"].get(message.id) - if not event: + if not event_id: return - self.app.redact(event["event_id"], event["room_id"], event["mxid"]) + room_id = self.app.get_room_id(self.matrixify(message.channel_id)) + event = except_deleted(self.app.get_event)(event_id, room_id) - message_cache.pop(message.id) + if event: + self.app.redact(event.id, event.room_id, event.sender) + + with Cache.lock: + del Cache.cache["d_messages"][message.id] def on_message_update(self, message: discord.Message) -> None: if self.to_return(message): return - event = message_cache.get(message.id) + with Cache.lock: + event_id = Cache.cache["d_messages"].get(message.id) - if not event: + if not event_id: return - content, emotes = self.process_message(message) + room_id = self.app.get_room_id(self.matrixify(message.channel_id)) + mxid = self.matrixify(message.author.id, user=True) + + content_, emotes = self.process_message(message) content = self.app.create_message_event( - content, emotes, edit=event["event_id"] + content_, emotes, edit=event_id ) - self.app.send_message(event["room_id"], content, event["mxid"]) + self.app.send_message(room_id, content, mxid) def on_typing_start(self, typing: discord.Typing) -> None: if typing.channel_id not in self.app.db.list_channels(): @@ -533,7 +585,8 @@ class DiscordClient(Gateway): """ # Check the cache first. - webhook = self.webhook_cache.get(channel_id) + with Cache.lock: + webhook = Cache.cache["d_webhooks"].get(channel_id) if webhook: return webhook @@ -551,24 +604,26 @@ class DiscordClient(Gateway): if not webhook: webhook = self.create_webhook(channel_id, name) - self.webhook_cache[channel_id] = webhook + with Cache.lock: + Cache.cache["d_webhooks"][channel_id] = webhook return webhook - def process_message(self, message: discord.Message) -> Tuple[str, str]: + def process_message(self, message: discord.Message) -> Tuple[str, Dict]: content = message.content emotes = {} regex = r"" # Mentions can either be in the form of `<@1234>` or `<@!1234>`. - for char in ("", "!"): - for member in message.mentions: + for member in message.mentions: + for char in ("", "!"): content = content.replace( f"<@{char}{member.id}>", f"@{member.username}" ) # `except_deleted` for invalid channels. - for channel in re.findall(r"<#([0-9]+)>", content): + # TODO can this block for too long ? + for channel in re.findall(r"<#([0-9]{{{discord.ID_LEN}}})>", content): channel_ = except_deleted(self.get_channel)(channel) content = content.replace( f"<#{channel}>", @@ -613,6 +668,16 @@ def config_gen(basedir: str, config_file: str) -> dict: return json.loads(f.read()) +def excepthook(exc_type, exc_value, exc_traceback): + if issubclass(exc_type, KeyboardInterrupt): + sys.__excepthook__(exc_type, exc_value, exc_traceback) + return + + logging.critical( + "Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback) + ) + + def main() -> None: try: basedir = sys.argv[1] @@ -634,9 +699,9 @@ def main() -> None: ], ) - http = urllib3.PoolManager(maxsize=10) + sys.excepthook = excepthook - app = MatrixClient(config, http) + app = MatrixClient(config, urllib3.PoolManager(maxsize=10)) # Start the bottle app in a separate thread. app_thread = threading.Thread( @@ -645,7 +710,7 @@ def main() -> None: app_thread.start() try: - asyncio.run(app.discord.start()) + asyncio.run(app.discord.run()) except KeyboardInterrupt: sys.exit() diff --git a/appservice/matrix.py b/appservice/matrix.py index 5b84060..ab10124 100644 --- a/appservice/matrix.py +++ b/appservice/matrix.py @@ -2,20 +2,21 @@ from dataclasses import dataclass @dataclass -class User(object): +class User: avatar_url: str = "" - displayname: str = "" + display_name: str = "" -class Event(object): +class Event: def __init__(self, event: dict): - content = event["content"] + content = event.get("content", {}) self.attachment = content.get("url") - self.author = event["author"] self.body = content.get("body", "").strip() - self.event_id = event["event_id"] + self.formatted_body = content.get("formatted_body", "") + self.id = event["event_id"] self.is_direct = content.get("is_direct", False) + self.redacts = event.get("redacts", "") self.room_id = event["room_id"] self.sender = event["sender"] self.state_key = event.get("state_key", "") diff --git a/appservice/misc.py b/appservice/misc.py index 3befa4a..7c69459 100644 --- a/appservice/misc.py +++ b/appservice/misc.py @@ -1,4 +1,3 @@ -import asyncio import json from dataclasses import fields from typing import Any @@ -8,13 +7,13 @@ import urllib3 from errors import RequestError -def dict_cls(dict_var: dict, cls: Any) -> Any: +def dict_cls(d: dict, cls: Any) -> Any: """ Create a dataclass from a dictionary. """ field_names = set(f.name for f in fields(cls)) - filtered_dict = {k: v for k, v in dict_var.items() if k in field_names} + filtered_dict = {k: v for k, v in d.items() if k in field_names} return cls(**filtered_dict) @@ -34,22 +33,6 @@ def log_except(fn): return wrapper -def wrap_async(fn): - """ - Call an asynchronous function from a synchronous one. - """ - - def wrapper(self, *args, **kwargs): - if not self.loop: - raise RuntimeError("loop is None.") - - return asyncio.run_coroutine_threadsafe( - fn(self, *args, **kwargs), loop=self.loop - ).result() - - return wrapper - - def request(fn): """ Either return json data or raise a `RequestError` if the request was @@ -75,8 +58,7 @@ def request(fn): def except_deleted(fn): """ - Ignore the `RequestError` on 404s, the message might have been - deleted by someone else already. + Ignore the `RequestError` on 404s, the content might have been removed. """ def wrapper(*args, **kwargs): diff --git a/appservice/requirements.txt b/appservice/requirements.txt index 10a331b..7db4be2 100644 --- a/appservice/requirements.txt +++ b/appservice/requirements.txt @@ -1,3 +1,3 @@ -bottle==0.12.19 -urllib3==1.26.3 -websockets==8.1 +bottle +urllib3 +websockets