diff --git a/appservice/README.md b/appservice/README.md
index d0b0393..9069967 100644
--- a/appservice/README.md
+++ b/appservice/README.md
@@ -49,20 +49,23 @@ A path can optionally be passed as the first argument to `main.py`. This path wi
Eg. Running `python3 main.py /path/to/my/dir` will store the database and logs in `/path/to/my/dir`.
`$PWD` is used by default if no path is specified.
+After setting up the bridge, send a direct message to `@appservice-discord:domain.tld` containing the channel ID to be bridged (`!bridge 123456`).
+
This bridge is written with:
* `bottle`: Receiving events from the homeserver.
* `urllib3`: Sending requests, thread safety.
* `websockets`: Connecting to Discord. (Big thanks to an anonymous person "nesslersreagent" for figuring out the initial connection mess.)
## NOTES
+
* A basic sqlite database is used for keeping track of bridged rooms.
* Logs are saved to the `appservice.log` file in `$PWD` or the specified directory.
* For avatars to show up on Discord, you must have a [reverse proxy](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) set up on your homeserver as the bridge does not specify the homeserver port when passing the avatar url.
-* It is not possible to add normal Discord bot functionality like commands as this bridge does not use `discord.py`.
+* It is not possible to add "normal" Discord bot functionality like commands as this bridge does not use `discord.py`.
-* [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) must be enabled for your Discord bot.
+* [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) for members and presence must be enabled for your Discord bot.
* This Appservice might not work well for bridging a large number of rooms since it is mostly synchronous. However, it wouldn't take much effort to port it to `asyncio` and `aiohttp` if desired.
diff --git a/appservice/appservice.py b/appservice/appservice.py
index e479104..c0f464e 100644
--- a/appservice/appservice.py
+++ b/appservice/appservice.py
@@ -2,13 +2,14 @@ import json
import logging
import urllib.parse
import uuid
-from typing import List, Union
+from typing import Union
import bottle
import urllib3
import matrix
-from misc import dict_cls, except_deleted, log_except, request
+from cache import Cache
+from misc import log_except, request
class AppService(bottle.Bottle):
@@ -37,13 +38,17 @@ class AppService(bottle.Bottle):
method="PUT",
)
+ Cache.cache["m_rooms"] = {}
+
def handle_event(self, event: dict) -> None:
event_type = event.get("type")
- if event_type == "m.room.member" or event_type == "m.room.message":
- obj = self.get_event_object(event)
- elif event_type == "m.room.redaction":
- obj = event
+ if event_type in (
+ "m.room.member",
+ "m.room.message",
+ "m.room.redaction",
+ ):
+ obj = matrix.Event(event)
else:
self.logger.info(f"Unknown event type: {event_type}")
return
@@ -86,23 +91,14 @@ class AppService(bottle.Bottle):
def mxc_url(self, mxc: str) -> str:
try:
homeserver, media_id = mxc.replace("mxc://", "").split("/")
- converted = (
- f"https://{self.server_name}/_matrix/media/r0/download/"
- f"{homeserver}/{media_id}"
- )
except ValueError:
- converted = ""
+ return ""
- return converted
-
- def get_event_object(self, event: dict) -> matrix.Event:
- # TODO use caching and invalidate old cache on member events.
- event["author"] = dict_cls(
- self.get_profile(event["sender"]), matrix.User
+ return (
+ f"https://{self.server_name}/_matrix/media/r0/download/"
+ f"{homeserver}/{media_id}"
)
- return matrix.Event(event)
-
def join_room(self, room_id: str, mxid: str = "") -> None:
self.send(
"POST",
@@ -117,42 +113,25 @@ class AppService(bottle.Bottle):
params={"user_id": mxid} if mxid else {},
)
- def get_profile(self, mxid: str) -> dict:
- resp = except_deleted(self.send)("GET", f"/profile/{mxid}")
-
- # No profile exists for the user.
- if not resp:
- return {}
-
- avatar_url = resp.get("avatar_url")
-
- if avatar_url:
- avatar_url = self.mxc_url(avatar_url)
-
- return {
- "avatar_url": avatar_url,
- "displayname": resp.get("displayname"),
- }
-
- def get_members(self, room_id: str) -> List[str]:
- resp = self.send(
- "GET",
- f"/rooms/{room_id}/members",
- params={"membership": "join", "not_membership": "leave"},
- )
-
- return [
- content["sender"]
- for content in resp["chunk"]
- if content["content"]["membership"] == "join"
- ]
-
def get_room_id(self, alias: str) -> str:
+ with Cache.lock:
+ room = Cache.cache["m_rooms"].get(alias)
+ if room:
+ return room
+
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
- # TODO cache ?
+ room_id = resp["room_id"]
- return resp["room_id"]
+ with Cache.lock:
+ Cache.cache["m_rooms"][alias] = room_id
+
+ return room_id
+
+ def get_event(self, event_id: str, room_id: str) -> matrix.Event:
+ resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}")
+
+ return matrix.Event(resp)
def upload(self, url: str) -> str:
"""
@@ -211,12 +190,12 @@ class AppService(bottle.Bottle):
) -> dict:
params["access_token"] = self.as_token
headers = {"Content-Type": content_type}
- content = json.dumps(content) if isinstance(content, dict) else content
+ payload = json.dumps(content) if isinstance(content, dict) else content
endpoint = (
f"{self.base_url}{endpoint}{path}?"
f"{urllib.parse.urlencode(params)}"
)
return self.http.request(
- method, endpoint, body=content, headers=headers
+ method, endpoint, body=payload, headers=headers
)
diff --git a/appservice/cache.py b/appservice/cache.py
new file mode 100644
index 0000000..8d14443
--- /dev/null
+++ b/appservice/cache.py
@@ -0,0 +1,6 @@
+import threading
+
+
+class Cache:
+ cache = {}
+ lock = threading.Lock()
diff --git a/appservice/db.py b/appservice/db.py
index 295fa81..b2603e7 100644
--- a/appservice/db.py
+++ b/appservice/db.py
@@ -4,11 +4,11 @@ import threading
from typing import List
-class DataBase(object):
+class DataBase:
def __init__(self, db_file) -> None:
self.create(db_file)
- # The database is accessed via both the threads.
+ # The database is accessed via multiple threads.
self.lock = threading.Lock()
def create(self, db_file) -> None:
@@ -92,7 +92,7 @@ class DataBase(object):
room = self.cur.fetchone()
- # Return an empty string if nothing is bridged.
+ # Return an empty string if the channel is not bridged.
return "" if not room else room["channel_id"]
def list_channels(self) -> List[str]:
@@ -116,6 +116,8 @@ class DataBase(object):
self.cur.execute("SELECT * FROM users")
users = self.cur.fetchall()
- user = [user for user in users if user["mxid"] == mxid]
+ user: dict = next(
+ iter([user for user in users if user["mxid"] == mxid]), {}
+ )
- return user[0] if user else {}
+ return user
diff --git a/appservice/discord.py b/appservice/discord.py
index 3567416..6c5239b 100644
--- a/appservice/discord.py
+++ b/appservice/discord.py
@@ -1,10 +1,16 @@
from dataclasses import dataclass
+from misc import dict_cls
CDN_URL = "https://cdn.discordapp.com"
+ID_LEN = 18
+
+
+def bitmask(bit: int) -> int:
+ return 1 << bit
@dataclass
-class Channel(object):
+class Channel:
id: str
type: str
guild_id: str = ""
@@ -13,13 +19,32 @@ class Channel(object):
@dataclass
-class Emote(object):
+class Emote:
animated: bool
id: str
name: str
-class User(object):
+@dataclass
+class MessageReference:
+ message_id: str
+ channel_id: str
+ guild_id: str
+
+
+@dataclass
+class Typing:
+ user_id: str
+ channel_id: str
+
+
+@dataclass
+class Webhook:
+ id: str
+ token: str
+
+
+class User:
def __init__(self, user: dict) -> None:
self.discriminator = user["discriminator"]
self.id = user["id"]
@@ -38,45 +63,57 @@ class User(object):
self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}"
-class Message(object):
+class Guild:
+ def __init__(self, guild: dict) -> None:
+ self.guild_id = guild["id"]
+ self.channels = [dict_cls(c, Channel) for c in guild["channels"]]
+ self.emojis = [dict_cls(e, Emote) for e in guild["emojis"]]
+ members = [member["user"] for member in guild["members"]]
+ self.members = [User(m) for m in members]
+
+
+class GuildEmojisUpdate:
+ def __init__(self, update: dict) -> None:
+ self.guild_id = update["guild_id"]
+ self.emojis = [dict_cls(e, Emote) for e in update["emojis"]]
+
+
+class GuildMembersChunk:
+ def __init__(self, chunk: dict) -> None:
+ self.chunk_index = chunk["chunk_index"]
+ self.chunk_count = chunk["chunk_count"]
+ self.guild_id = chunk["guild_id"]
+ self.members = [User(m) for m in chunk["members"]]
+
+
+class GuildMemberUpdate:
+ def __init__(self, update: dict) -> None:
+ self.guild_id = update["guild_id"]
+ self.user = User(update["user"])
+
+
+class Message:
def __init__(self, message: dict) -> None:
self.attachments = message.get("attachments", [])
self.channel_id = message["channel_id"]
self.content = message.get("content", "")
self.id = message["id"]
- self.reference = message.get("message_reference", {}).get(
- "message_id", ""
- )
self.webhook_id = message.get("webhook_id", "")
self.mentions = [
User(mention) for mention in message.get("mentions", [])
]
+ ref = message.get("message_reference")
+
+ self.reference = dict_cls(ref, MessageReference) if ref else None
+
author = message.get("author")
self.author = User(author) if author else None
-@dataclass
-class DeletedMessage(object):
- channel_id: str
- id: str
-
-
-@dataclass
-class Typing(object):
- user_id: str
- channel_id: str
-
-
-@dataclass
-class Webhook(object):
- id: str
- token: str
-
-
-class ChannelType(object):
+class ChannelType:
GUILD_TEXT = 0
DM = 1
GUILD_VOICE = 2
@@ -86,7 +123,7 @@ class ChannelType(object):
GUILD_STORE = 6
-class InteractionResponseType(object):
+class InteractionResponseType:
PONG = 0
ACKNOWLEDGE = 1
CHANNEL_MESSAGE = 2
@@ -94,10 +131,7 @@ class InteractionResponseType(object):
ACKNOWLEDGE_WITH_SOURCE = 5
-class GatewayIntents(object):
- def bitmask(bit: int) -> int:
- return 1 << bit
-
+class GatewayIntents:
GUILDS = bitmask(0)
GUILD_MEMBERS = bitmask(1)
GUILD_BANS = bitmask(2)
@@ -115,7 +149,7 @@ class GatewayIntents(object):
DIRECT_MESSAGE_TYPING = bitmask(14)
-class GatewayOpCodes(object):
+class GatewayOpCodes:
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
@@ -129,7 +163,7 @@ class GatewayOpCodes(object):
HEARTBEAT_ACK = 11
-class Payloads(object):
+class Payloads:
def __init__(self, token: str) -> None:
self.seq = self.session = None
self.token = token
@@ -143,27 +177,19 @@ class Payloads(object):
"d": {
"token": self.token,
"intents": GatewayIntents.GUILDS
+ | GatewayIntents.GUILD_EMOJIS
+ | GatewayIntents.GUILD_MEMBERS
| GatewayIntents.GUILD_MESSAGES
- | GatewayIntents.GUILD_MESSAGE_TYPING,
+ | GatewayIntents.GUILD_MESSAGE_TYPING
+ | GatewayIntents.GUILD_PRESENCES,
"properties": {
"$os": "discord",
- "$browser": "discord",
+ "$browser": "Discord Client",
"$device": "discord",
},
},
}
- def QUERY(self, guild_id: str, query: str, limit: int = 1) -> dict:
- """
- Return the Payload to query a member from a guild ID.
- Return only a single match if `limit` isn't specified.
- """
-
- return {
- "op": GatewayOpCodes.REQUEST_GUILD_MEMBERS,
- "d": {"guild_id": guild_id, "query": query, "limit": limit},
- }
-
def RESUME(self) -> dict:
return {
"op": GatewayOpCodes.RESUME,
diff --git a/appservice/gateway.py b/appservice/gateway.py
index e1f70c3..7fd2c08 100644
--- a/appservice/gateway.py
+++ b/appservice/gateway.py
@@ -8,31 +8,27 @@ import urllib3
import websockets
import discord
-from misc import dict_cls, log_except, request, wrap_async
+from misc import dict_cls, log_except, request
-class Gateway(object):
+class Gateway:
def __init__(self, http: urllib3.PoolManager, token: str):
self.http = http
self.token = token
self.logger = logging.getLogger("discord")
- self.cdn_url = "https://cdn.discordapp.com"
self.Payloads = discord.Payloads(self.token)
- self.loop = self.websocket = None
-
- self.query_cache = {}
+ self.websocket = None
@log_except
async def run(self) -> None:
- self.loop = asyncio.get_running_loop()
- self.query_ev = asyncio.Event()
-
- self.heartbeat_task = None
+ self.heartbeat_task: asyncio.Future = None
self.resume = False
+ gateway_url = self.get_gateway_url()
+
while True:
try:
- await self.gateway_handler(self.get_gateway_url())
+ await self.gateway_handler(gateway_url)
except (
websockets.ConnectionClosedError,
websockets.InvalidMessage,
@@ -55,26 +51,71 @@ class Gateway(object):
await asyncio.sleep(interval_ms / 1000)
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
- def query_handler(self, data: dict) -> None:
- members = data["members"]
- guild_id = data["guild_id"]
+ async def handle_resp(self, data: dict) -> None:
+ data_dict = data["d"]
- for member in members:
- user = member["user"]
- self.query_cache[guild_id].append(user)
+ opcode = data["op"]
- self.query_ev.set()
+ seq = data["s"]
+
+ if seq:
+ self.Payloads.seq = seq
+
+ if opcode == discord.GatewayOpCodes.DISPATCH:
+ otype = data["t"]
+
+ if otype == "READY":
+ self.Payloads.session = data_dict["session_id"]
+
+ self.logger.info("READY")
+ else:
+ self.handle_otype(data_dict, otype)
+ elif opcode == discord.GatewayOpCodes.HELLO:
+ heartbeat_interval = data_dict.get("heartbeat_interval")
+
+ self.logger.info(f"Heartbeat Interval: {heartbeat_interval}")
+
+ # Send periodic hearbeats to gateway.
+ self.heartbeat_task = asyncio.ensure_future(
+ self.heartbeat_handler(heartbeat_interval)
+ )
+
+ await self.websocket.send(
+ json.dumps(
+ self.Payloads.RESUME()
+ if self.resume
+ else self.Payloads.IDENTIFY()
+ )
+ )
+ elif opcode == discord.GatewayOpCodes.RECONNECT:
+ self.logger.info("Received RECONNECT.")
+
+ self.resume = True
+ await self.websocket.close()
+ elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
+ self.logger.info("Received INVALID_SESSION.")
+
+ self.resume = False
+ await self.websocket.close()
+ elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
+ # NOP
+ pass
+ else:
+ self.logger.info(
+ "Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}"
+ )
def handle_otype(self, data: dict, otype: str) -> None:
- if otype == "MESSAGE_CREATE" or otype == "MESSAGE_UPDATE":
+ if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"):
obj = discord.Message(data)
- elif otype == "MESSAGE_DELETE":
- obj = dict_cls(data, discord.DeletedMessage)
elif otype == "TYPING_START":
obj = dict_cls(data, discord.Typing)
- elif otype == "GUILD_MEMBERS_CHUNK":
- self.query_handler(data)
- return
+ elif otype == "GUILD_CREATE":
+ obj = discord.Guild(data)
+ elif otype == "GUILD_MEMBER_UPDATE":
+ obj = discord.GuildMemberUpdate(data)
+ elif otype == "GUILD_EMOJIS_UPDATE":
+ obj = discord.GuildEmojisUpdate(data)
else:
self.logger.info(f"Unknown OTYPE: {otype}")
return
@@ -90,119 +131,20 @@ class Gateway(object):
try:
func(obj)
except Exception:
- self.logger.exception(f"Ignoring exception in {func}:")
+ self.logger.exception(f"Ignoring exception in '{func.__name__}':")
async def gateway_handler(self, gateway_url: str) -> None:
async with websockets.connect(
f"{gateway_url}/?v=8&encoding=json"
) as websocket:
self.websocket = websocket
+
async for message in websocket:
- data = json.loads(message)
- data_dict = data.get("d")
-
- opcode = data.get("op")
-
- seq = data.get("s")
- if seq:
- self.Payloads.seq = seq
-
- if opcode == discord.GatewayOpCodes.DISPATCH:
- otype = data.get("t")
-
- if otype == "READY":
- self.Payloads.session = data_dict["session_id"]
-
- self.logger.info("READY")
-
- else:
- self.handle_otype(data_dict, otype)
-
- elif opcode == discord.GatewayOpCodes.HELLO:
- heartbeat_interval = data_dict.get("heartbeat_interval")
-
- self.logger.info(
- f"Heartbeat Interval: {heartbeat_interval}"
- )
-
- # Send periodic hearbeats to gateway.
- self.heartbeat_task = asyncio.ensure_future(
- self.heartbeat_handler(heartbeat_interval)
- )
-
- await websocket.send(
- json.dumps(
- self.Payloads.RESUME()
- if self.resume
- else self.Payloads.IDENTIFY()
- )
- )
-
- elif opcode == discord.GatewayOpCodes.RECONNECT:
- self.logger.info("Received RECONNECT.")
-
- self.resume = True
- await websocket.close()
-
- elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
- self.logger.info("Received INVALID_SESSION.")
-
- self.resume = False
- await websocket.close()
-
- elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
- # NOP
- pass
-
- else:
- self.logger.info(
- f"Unknown OP code {opcode}:\n"
- f"{json.dumps(data, indent=4)}"
- )
-
- @wrap_async
- async def query_member(self, guild_id: str, name: str) -> discord.User:
- """
- Query the members for a given guild and return the first match.
- """
-
- self.query_ev.clear()
-
- def query():
- if not self.query_cache.get(guild_id):
- self.query_cache[guild_id] = []
-
- user = [
- user
- for user in self.query_cache[guild_id]
- if name.lower() in user["username"].lower()
- ]
-
- return None if not user else discord.User(user[0])
-
- user = query()
-
- if user:
- return user
-
- if not self.websocket or self.websocket.closed:
- self.logger.warning("Not fetching members, websocket closed.")
- return
-
- await self.websocket.send(
- json.dumps(self.Payloads.QUERY(guild_id, name))
- )
-
- # TODO clean this mess.
-
- # Wait for our websocket to receive the chunk.
- await asyncio.wait_for(self.query_ev.wait(), timeout=5)
-
- return query()
+ await self.handle_resp(json.loads(message))
def get_channel(self, channel_id: str) -> discord.Channel:
"""
- Get the channel object for a given channel ID.
+ Get the channel for a given channel ID.
"""
resp = self.send("GET", f"/channels/{channel_id}")
@@ -259,9 +201,17 @@ class Gateway(object):
f"{message_id}",
)
- def send_webhook(self, webhook: discord.Webhook, **kwargs) -> str:
- content = {
- **kwargs,
+ def send_webhook(
+ self,
+ webhook: discord.Webhook,
+ avatar_url: str,
+ content: str,
+ username: str,
+ ) -> discord.Message:
+ payload = {
+ "avatar_url": avatar_url,
+ "content": content,
+ "username": username,
# Disable 'everyone' and 'role' mentions.
"allowed_mentions": {"parse": ["users"]},
}
@@ -269,11 +219,11 @@ class Gateway(object):
resp = self.send(
"POST",
f"/webhooks/{webhook.id}/{webhook.token}",
- content,
+ payload,
{"wait": True},
)
- return resp["id"]
+ return discord.Message(resp)
def send_message(self, message: str, channel_id: str) -> None:
self.send(
@@ -294,8 +244,8 @@ class Gateway(object):
}
# 'body' being an empty dict breaks "GET" requests.
- content = json.dumps(content) if content else None
+ payload = json.dumps(content) if content else None
return self.http.request(
- method, endpoint, body=content, headers=headers
+ method, endpoint, body=payload, headers=headers
)
diff --git a/appservice/main.py b/appservice/main.py
index bbdec7a..b2b6c75 100644
--- a/appservice/main.py
+++ b/appservice/main.py
@@ -5,21 +5,19 @@ import os
import re
import sys
import threading
-from typing import Dict, Tuple, Union
+from typing import Dict, List, Tuple
import urllib3
import discord
import matrix
from appservice import AppService
+from cache import Cache
from db import DataBase
from errors import RequestError
from gateway import Gateway
from misc import dict_cls, except_deleted, hash_str
-# TODO should this be cleared periodically ?
-message_cache: Dict[str, Union[discord.Webhook, str]] = {}
-
class MatrixClient(AppService):
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
@@ -27,9 +25,11 @@ class MatrixClient(AppService):
self.db = DataBase(config["database"])
self.discord = DiscordClient(self, config, http)
- self.emote_cache: Dict[str, str] = {}
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
+ for k in ("m_emotes", "m_members", "m_messages"):
+ Cache.cache[k] = {}
+
def handle_bridge(self, message: matrix.Event) -> None:
# Ignore events that aren't for us.
if message.sender.split(":")[
@@ -37,6 +37,7 @@ class MatrixClient(AppService):
] != self.server_name or not message.body.startswith("!bridge"):
return
+ # Get the channel ID.
try:
channel = message.body.split()[1]
except IndexError:
@@ -46,7 +47,7 @@ class MatrixClient(AppService):
try:
channel = self.discord.get_channel(channel)
except RequestError as e:
- # The channel can be invalid or we may not have permission.
+ # The channel can be invalid or we may not have permissions.
self.logger.warning(f"Failed to fetch channel {channel}: {e}")
return
@@ -61,7 +62,15 @@ class MatrixClient(AppService):
self.create_room(channel, message.sender)
def on_member(self, event: matrix.Event) -> None:
- # Ignore events that aren't for us.
+ with Cache.lock:
+ # Just lazily clear the whole member cache on
+ # membership update events.
+ if event.room_id in Cache.cache["m_members"]:
+ self.logger.info(
+ f"Clearing member cache for room '{event.room_id}'."
+ )
+ del Cache.cache["m_members"][event.room_id]
+
if (
event.sender.split(":")[-1] != self.server_name
or event.state_key != self.user_id
@@ -70,7 +79,7 @@ class MatrixClient(AppService):
return
# Join the direct message room.
- self.logger.info(f"Joining direct message room {event.room_id}.")
+ self.logger.info(f"Joining direct message room '{event.room_id}'.")
self.join_room(event.room_id)
def on_message(self, message: matrix.Event) -> None:
@@ -88,51 +97,78 @@ class MatrixClient(AppService):
if not channel_id:
return
- webhook = self.discord.get_webhook(channel_id, "matrix_bridge")
+ author = self.get_members(message.room_id)[message.sender]
+
+ if not author.display_name:
+ author.display_name = message.sender
+
+ webhook = self.discord.get_webhook(
+ channel_id, self.discord.webhook_name
+ )
if message.relates_to and message.reltype == "m.replace":
- relation = message_cache.get(message.relates_to)
+ with Cache.lock:
+ message_id = Cache.cache["m_messages"].get(message.relates_to)
- if not message.new_body or not relation:
+ if not message_id or not message.new_body:
return
- message.new_body = self.process_message(
- channel_id, message.new_body
- )
+ message.new_body = self.process_message(message)
except_deleted(self.discord.edit_webhook)(
- message.new_body, relation["message_id"], webhook
+ message.new_body, message_id, webhook
)
-
else:
message.body = (
f"`{message.body}`: {self.mxc_url(message.attachment)}"
if message.attachment
- else self.process_message(channel_id, message.body)
+ else self.process_message(message)
)
- message_cache[message.event_id] = {
- "message_id": self.discord.send_webhook(
- webhook,
- avatar_url=message.author.avatar_url,
- content=message.body,
- username=message.author.displayname,
- ),
- "webhook": webhook,
- }
+ message_id = self.discord.send_webhook(
+ webhook,
+ self.mxc_url(author.avatar_url),
+ message.body,
+ author.display_name,
+ ).id
- @except_deleted
- def on_redaction(self, event: dict) -> None:
- redacts = event["redacts"]
+ with Cache.lock:
+ Cache.cache["m_messages"][message.id] = message_id
- event = message_cache.get(redacts)
+ def on_redaction(self, event: matrix.Event) -> None:
+ with Cache.lock:
+ message_id = Cache.cache["m_messages"].get(event.redacts)
- if not event:
+ if not message_id:
return
- self.discord.delete_webhook(event["message_id"], event["webhook"])
+ webhook = self.discord.get_webhook(
+ self.db.get_channel(event.room_id), self.discord.webhook_name
+ )
- message_cache.pop(redacts)
+ except_deleted(self.discord.delete_webhook)(message_id, webhook)
+
+ with Cache.lock:
+ del Cache.cache["m_messages"][event.redacts]
+
+ def get_members(self, room_id: str) -> Dict[str, matrix.User]:
+ with Cache.lock:
+ cached = Cache.cache["m_members"].get(room_id)
+
+ if cached:
+ return cached
+
+ resp = self.send("GET", f"/rooms/{room_id}/joined_members")
+
+ joined = resp["joined"]
+
+ for k, v in joined.items():
+ joined[k] = dict_cls(v, matrix.User)
+
+ with Cache.lock:
+ Cache.cache["m_members"][room_id] = joined
+
+ return joined
def create_room(self, channel: discord.Channel, sender: str) -> None:
"""
@@ -164,7 +200,11 @@ class MatrixClient(AppService):
self.db.add_room(resp["room_id"], channel.id)
def create_message_event(
- self, message: str, emotes: dict, edit: str = "", reply: str = ""
+ self,
+ message: str,
+ emotes: dict,
+ edit: str = "",
+ reference: discord.MessageReference = None,
) -> dict:
content = {
"body": message,
@@ -173,20 +213,39 @@ class MatrixClient(AppService):
"formatted_body": self.get_fmt(message, emotes),
}
- event = message_cache.get(reply)
+ if reference:
+ # Reply to a Discord message.
+ with Cache.lock:
+ event_id = Cache.cache["d_messages"].get(reference.message_id)
- if event:
- content = {
- **content,
- "m.relates_to": {
- "m.in_reply_to": {"event_id": event["event_id"]}
- },
- "formatted_body": f"""\
-\
-In reply to\
-{event["mxid"]}
{event["body"]}
\
+ # Reply to a Matrix message. (maybe)
+ if not event_id:
+ with Cache.lock:
+ event_id = [
+ k
+ for k, v in Cache.cache["m_messages"].items()
+ if v == reference.message_id
+ ]
+ event_id = next(iter(event_id), "")
+
+ if reference and event_id:
+ event = except_deleted(self.get_event)(
+ event_id,
+ self.get_room_id(self.discord.matrixify(reference.channel_id)),
+ )
+ if event:
+ content = {
+ **content,
+ "body": (
+ f"> <{event.sender}> {event.body}\n{content['body']}"
+ ),
+ "m.relates_to": {"m.in_reply_to": {"event_id": event.id}},
+ "formatted_body": f"""\
+\
+In reply to\
+{event.sender}
{event.formatted_body}
\
{content["formatted_body"]}""",
- }
+ }
if edit:
content = {
@@ -227,63 +286,67 @@ In reply to\
for emote in emotes
]
- [thread.start() for thread in upload_threads]
- [thread.join() for thread in upload_threads]
+ # Acquire the lock before starting the threads to avoid resource
+ # contention by tens of threads at once.
+ with Cache.lock:
+ for thread in upload_threads:
+ thread.start()
+ for thread in upload_threads:
+ thread.join()
- for emote in emotes:
- emote_ = self.emote_cache.get(emote)
+ with Cache.lock:
+ for emote in emotes:
+ emote_ = Cache.cache["m_emotes"].get(emote)
- if emote_:
- emote = f":{emote}:"
- message = message.replace(
- emote,
- f"""""",
- )
+ )
return message
- def process_message(self, channel_id: str, message: str) -> str:
+ def process_message(self, event: matrix.Event) -> str:
+ message = event.new_body if event.new_body else event.body
+
message = message[:2000] # Discord limit.
+ id_regex = f"[0-9]{{{discord.ID_LEN}}}"
+
emotes = re.findall(r":(\w*):", message)
- mentions = re.findall(r"(@(\w*))", message)
+ mentions = re.findall(
+ f"@{self.format}{id_regex}:{re.escape(self.server_name)}",
+ event.formatted_body,
+ )
- # Remove the puppet user's username from replies.
- message = re.sub(f"<@{self.format}.+?>", "", message)
-
- added_emotes = []
- for emote in emotes:
- # Don't replace emote names with IDs multiple times.
- if emote not in added_emotes:
- added_emotes.append(emote)
- emote_ = self.discord.emote_cache.get(emote)
+ with Cache.lock:
+ for emote in set(emotes):
+ emote_ = Cache.cache["d_emotes"].get(emote)
if emote_:
message = message.replace(f":{emote}:", emote_)
- # Don't unnecessarily fetch the channel.
- if mentions:
- guild_id = self.discord.get_channel(channel_id).guild_id
+ for mention in set(mentions):
+ username = self.db.fetch_user(mention).get("username")
+ if username:
+ match = re.search(id_regex, mention)
- # TODO this can block for too long if a long list is to be fetched.
- for mention in mentions:
- if not mention[1]:
- continue
-
- try:
- member = self.discord.query_member(guild_id, mention[1])
- except (asyncio.TimeoutError, RuntimeError):
- continue
-
- if member:
- message = message.replace(mention[0], member.mention)
+ if match:
+ # Replace the 'mention' so that the user is tagged
+ # in the case of replies aswell.
+ # '> <@_discord_1234:localhost> Message'
+ for replace in (mention, username):
+ message = message.replace(
+ replace, f"<@{match.group()}>"
+ )
return message
def upload_emote(self, emote_name: str, emote_id: str) -> None:
# There won't be a race condition here, since only a unique
# set of emotes are uploaded at a time.
- if emote_name in self.emote_cache:
+ if emote_name in Cache.cache["m_emotes"]:
return
emote_url = f"{discord.CDN_URL}/emojis/{emote_id}"
@@ -291,7 +354,8 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
# We don't want the message to be dropped entirely if an emote
# fails to upload for some reason.
try:
- self.emote_cache[emote_name] = self.upload(emote_url)
+ # TODO This is not thread safe, but we're protected by the GIL.
+ Cache.cache["m_emotes"][emote_name] = self.upload(emote_url)
except RequestError as e:
self.logger.warning(f"Failed to upload emote {emote_id}: {e}")
@@ -340,66 +404,19 @@ class DiscordClient(Gateway):
super().__init__(http, config["discord_token"])
self.app = appservice
- self.emote_cache: Dict[str, str] = {}
- self.webhook_cache: Dict[str, discord.Webhook] = {}
+ self.webhook_name = "matrix_bridge"
- async def sync(self) -> None:
- """
- Periodically compare the usernames and avatar URLs with Discord
- and update if they differ. Also synchronise emotes.
- """
-
- # TODO use websocket events and requests.
-
- def sync_emotes(guilds: set):
- emotes = []
-
- for guild in guilds:
- [emotes.append(emote) for emote in (self.get_emotes(guild))]
-
- self.emote_cache.clear() # Clear deleted/renamed emotes.
-
- for emote in emotes:
- self.emote_cache[f"{emote.name}"] = (
- f"<{'a' if emote.animated else ''}:"
- f"{emote.name}:{emote.id}>"
- )
-
- def sync_users(guilds: set):
- for guild in guilds:
- [
- self.sync_profile(user, self.matrixify(user.id, user=True))
- for user in self.get_members(guild)
- ]
-
- while True:
- guilds = set() # Avoid duplicates.
-
- try:
- for channel in self.app.db.list_channels():
- guilds.add(self.get_channel(channel).guild_id)
-
- sync_emotes(guilds)
- sync_users(guilds)
- # Don't let the background task die.
- except RequestError:
- self.logger.exception(
- "Ignoring exception during background sync:"
- )
-
- await asyncio.sleep(120) # Check every 2 minutes.
-
- async def start(self) -> None:
- asyncio.ensure_future(self.sync())
-
- await self.run()
+ for k in ("d_emotes", "d_messages", "d_webhooks"):
+ Cache.cache[k] = {}
def to_return(self, message: discord.Message) -> bool:
+ with Cache.lock:
+ hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()]
+
return (
message.channel_id not in self.app.db.list_channels()
or not message.author # Embeds can be weird sometimes.
- or message.webhook_id
- in [hook.id for hook in self.webhook_cache.values()]
+ or message.webhook_id in hook_ids
)
def matrixify(self, id: str, user: bool = False) -> str:
@@ -408,11 +425,13 @@ class DiscordClient(Gateway):
f"{self.app.server_name}"
)
- def sync_profile(self, user: discord.User, mxid: str) -> None:
+ def sync_profile(self, user: discord.User) -> None:
"""
Sync the avatar and username for a puppeted user.
"""
+ mxid = self.matrixify(user.id, user=True)
+
profile = self.app.db.fetch_user(mxid)
# User doesn't exist.
@@ -422,10 +441,10 @@ class DiscordClient(Gateway):
username = f"{user.username}#{user.discriminator}"
if user.avatar_url != profile["avatar_url"]:
- self.logger.info(f"Updating avatar for Discord user {user.id}")
+ self.logger.info(f"Updating avatar for Discord user '{user.id}'")
self.app.set_avatar(user.avatar_url, mxid)
if username != profile["username"]:
- self.logger.info(f"Updating username for Discord user {user.id}")
+ self.logger.info(f"Updating username for Discord user '{user.id}'")
self.app.set_nick(username, mxid)
def wrap(self, message: discord.Message) -> Tuple[str, str]:
@@ -457,62 +476,95 @@ class DiscordClient(Gateway):
self.app.set_avatar(message.author.avatar_url, mxid)
if mxid not in self.app.get_members(room_id):
- self.logger.info(f"Inviting user {mxid} to room {room_id}.")
+ self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.")
self.app.send_invite(room_id, mxid)
self.app.join_room(room_id, mxid)
if message.webhook_id:
# Sync webhooks here as they can't be accessed like guild members.
- self.sync_profile(message.author, mxid)
+ self.sync_profile(message.author)
return mxid, room_id
+ def cache_emotes(self, emotes: List[discord.Emote]):
+ # TODO maybe "namespace" emotes by guild in the cache ?
+ with Cache.lock:
+ for emote in emotes:
+ Cache.cache["d_emotes"][emote.name] = (
+ f"<{'a' if emote.animated else ''}:"
+ f"{emote.name}:{emote.id}>"
+ )
+
+ def on_guild_create(self, guild: discord.Guild) -> None:
+ for member in guild.members:
+ self.sync_profile(member)
+
+ self.cache_emotes(guild.emojis)
+
+ def on_guild_emojis_update(
+ self, update: discord.GuildEmojisUpdate
+ ) -> None:
+ self.cache_emotes(update.emojis)
+
+ def on_guild_member_update(
+ self, update: discord.GuildMemberUpdate
+ ) -> None:
+ self.sync_profile(update.user)
+
def on_message_create(self, message: discord.Message) -> None:
if self.to_return(message):
return
mxid, room_id = self.wrap(message)
- content, emotes = self.process_message(message)
+ content_, emotes = self.process_message(message)
content = self.app.create_message_event(
- content, emotes, reply=message.reference
+ content_, emotes, reference=message.reference
)
- message_cache[message.id] = {
- "body": content["body"],
- "event_id": self.app.send_message(room_id, content, mxid),
- "mxid": mxid,
- "room_id": room_id,
- }
+ with Cache.lock:
+ Cache.cache["d_messages"][message.id] = self.app.send_message(
+ room_id, content, mxid
+ )
- def on_message_delete(self, message: discord.DeletedMessage) -> None:
- event = message_cache.get(message.id)
+ def on_message_delete(self, message: discord.Message) -> None:
+ with Cache.lock:
+ event_id = Cache.cache["d_messages"].get(message.id)
- if not event:
+ if not event_id:
return
- self.app.redact(event["event_id"], event["room_id"], event["mxid"])
+ room_id = self.app.get_room_id(self.matrixify(message.channel_id))
+ event = except_deleted(self.app.get_event)(event_id, room_id)
- message_cache.pop(message.id)
+ if event:
+ self.app.redact(event.id, event.room_id, event.sender)
+
+ with Cache.lock:
+ del Cache.cache["d_messages"][message.id]
def on_message_update(self, message: discord.Message) -> None:
if self.to_return(message):
return
- event = message_cache.get(message.id)
+ with Cache.lock:
+ event_id = Cache.cache["d_messages"].get(message.id)
- if not event:
+ if not event_id:
return
- content, emotes = self.process_message(message)
+ room_id = self.app.get_room_id(self.matrixify(message.channel_id))
+ mxid = self.matrixify(message.author.id, user=True)
+
+ content_, emotes = self.process_message(message)
content = self.app.create_message_event(
- content, emotes, edit=event["event_id"]
+ content_, emotes, edit=event_id
)
- self.app.send_message(event["room_id"], content, event["mxid"])
+ self.app.send_message(room_id, content, mxid)
def on_typing_start(self, typing: discord.Typing) -> None:
if typing.channel_id not in self.app.db.list_channels():
@@ -533,7 +585,8 @@ class DiscordClient(Gateway):
"""
# Check the cache first.
- webhook = self.webhook_cache.get(channel_id)
+ with Cache.lock:
+ webhook = Cache.cache["d_webhooks"].get(channel_id)
if webhook:
return webhook
@@ -551,24 +604,26 @@ class DiscordClient(Gateway):
if not webhook:
webhook = self.create_webhook(channel_id, name)
- self.webhook_cache[channel_id] = webhook
+ with Cache.lock:
+ Cache.cache["d_webhooks"][channel_id] = webhook
return webhook
- def process_message(self, message: discord.Message) -> Tuple[str, str]:
+ def process_message(self, message: discord.Message) -> Tuple[str, Dict]:
content = message.content
emotes = {}
regex = r""
# Mentions can either be in the form of `<@1234>` or `<@!1234>`.
- for char in ("", "!"):
- for member in message.mentions:
+ for member in message.mentions:
+ for char in ("", "!"):
content = content.replace(
f"<@{char}{member.id}>", f"@{member.username}"
)
# `except_deleted` for invalid channels.
- for channel in re.findall(r"<#([0-9]+)>", content):
+ # TODO can this block for too long ?
+ for channel in re.findall(r"<#([0-9]{{{discord.ID_LEN}}})>", content):
channel_ = except_deleted(self.get_channel)(channel)
content = content.replace(
f"<#{channel}>",
@@ -613,6 +668,16 @@ def config_gen(basedir: str, config_file: str) -> dict:
return json.loads(f.read())
+def excepthook(exc_type, exc_value, exc_traceback):
+ if issubclass(exc_type, KeyboardInterrupt):
+ sys.__excepthook__(exc_type, exc_value, exc_traceback)
+ return
+
+ logging.critical(
+ "Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback)
+ )
+
+
def main() -> None:
try:
basedir = sys.argv[1]
@@ -634,9 +699,9 @@ def main() -> None:
],
)
- http = urllib3.PoolManager(maxsize=10)
+ sys.excepthook = excepthook
- app = MatrixClient(config, http)
+ app = MatrixClient(config, urllib3.PoolManager(maxsize=10))
# Start the bottle app in a separate thread.
app_thread = threading.Thread(
@@ -645,7 +710,7 @@ def main() -> None:
app_thread.start()
try:
- asyncio.run(app.discord.start())
+ asyncio.run(app.discord.run())
except KeyboardInterrupt:
sys.exit()
diff --git a/appservice/matrix.py b/appservice/matrix.py
index 5b84060..ab10124 100644
--- a/appservice/matrix.py
+++ b/appservice/matrix.py
@@ -2,20 +2,21 @@ from dataclasses import dataclass
@dataclass
-class User(object):
+class User:
avatar_url: str = ""
- displayname: str = ""
+ display_name: str = ""
-class Event(object):
+class Event:
def __init__(self, event: dict):
- content = event["content"]
+ content = event.get("content", {})
self.attachment = content.get("url")
- self.author = event["author"]
self.body = content.get("body", "").strip()
- self.event_id = event["event_id"]
+ self.formatted_body = content.get("formatted_body", "")
+ self.id = event["event_id"]
self.is_direct = content.get("is_direct", False)
+ self.redacts = event.get("redacts", "")
self.room_id = event["room_id"]
self.sender = event["sender"]
self.state_key = event.get("state_key", "")
diff --git a/appservice/misc.py b/appservice/misc.py
index 3befa4a..7c69459 100644
--- a/appservice/misc.py
+++ b/appservice/misc.py
@@ -1,4 +1,3 @@
-import asyncio
import json
from dataclasses import fields
from typing import Any
@@ -8,13 +7,13 @@ import urllib3
from errors import RequestError
-def dict_cls(dict_var: dict, cls: Any) -> Any:
+def dict_cls(d: dict, cls: Any) -> Any:
"""
Create a dataclass from a dictionary.
"""
field_names = set(f.name for f in fields(cls))
- filtered_dict = {k: v for k, v in dict_var.items() if k in field_names}
+ filtered_dict = {k: v for k, v in d.items() if k in field_names}
return cls(**filtered_dict)
@@ -34,22 +33,6 @@ def log_except(fn):
return wrapper
-def wrap_async(fn):
- """
- Call an asynchronous function from a synchronous one.
- """
-
- def wrapper(self, *args, **kwargs):
- if not self.loop:
- raise RuntimeError("loop is None.")
-
- return asyncio.run_coroutine_threadsafe(
- fn(self, *args, **kwargs), loop=self.loop
- ).result()
-
- return wrapper
-
-
def request(fn):
"""
Either return json data or raise a `RequestError` if the request was
@@ -75,8 +58,7 @@ def request(fn):
def except_deleted(fn):
"""
- Ignore the `RequestError` on 404s, the message might have been
- deleted by someone else already.
+ Ignore the `RequestError` on 404s, the content might have been removed.
"""
def wrapper(*args, **kwargs):
diff --git a/appservice/requirements.txt b/appservice/requirements.txt
index 10a331b..7db4be2 100644
--- a/appservice/requirements.txt
+++ b/appservice/requirements.txt
@@ -1,3 +1,3 @@
-bottle==0.12.19
-urllib3==1.26.3
-websockets==8.1
+bottle
+urllib3
+websockets