* Use websocket events to sync usernames/avatars instead of periodic syncing.

* Use caches for fetched room_ids, room state (for usernames and avatars). Also switch to using a single cache with locks.

* Don't store full message objects in cache, just store the relation of matrix event IDs to discord message IDs and vice-versa. Content can be fetched from the server instead.

* Don't rely on websocket events for mentioning Discord users, mentions are now done by mentioning the dummy matrix user. The ID to be mentioned is extracted from the MXID instead.

* General clean-ups.
This commit is contained in:
git-bruh 2021-05-09 15:46:00 +05:30
parent b21c82ccd0
commit 4713a00016
No known key found for this signature in database
GPG key ID: E1475C50075ADCE6
10 changed files with 454 additions and 440 deletions

View file

@ -49,20 +49,23 @@ A path can optionally be passed as the first argument to `main.py`. This path wi
Eg. Running `python3 main.py /path/to/my/dir` will store the database and logs in `/path/to/my/dir`.
`$PWD` is used by default if no path is specified.
After setting up the bridge, send a direct message to `@appservice-discord:domain.tld` containing the channel ID to be bridged (`!bridge 123456`).
This bridge is written with:
* `bottle`: Receiving events from the homeserver.
* `urllib3`: Sending requests, thread safety.
* `websockets`: Connecting to Discord. (Big thanks to an anonymous person "nesslersreagent" for figuring out the initial connection mess.)
## NOTES
* A basic sqlite database is used for keeping track of bridged rooms.
* Logs are saved to the `appservice.log` file in `$PWD` or the specified directory.
* For avatars to show up on Discord, you must have a [reverse proxy](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) set up on your homeserver as the bridge does not specify the homeserver port when passing the avatar url.
* It is not possible to add normal Discord bot functionality like commands as this bridge does not use `discord.py`.
* It is not possible to add "normal" Discord bot functionality like commands as this bridge does not use `discord.py`.
* [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) must be enabled for your Discord bot.
* [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) for members and presence must be enabled for your Discord bot.
* This Appservice might not work well for bridging a large number of rooms since it is mostly synchronous. However, it wouldn't take much effort to port it to `asyncio` and `aiohttp` if desired.

View file

@ -2,13 +2,14 @@ import json
import logging
import urllib.parse
import uuid
from typing import List, Union
from typing import Union
import bottle
import urllib3
import matrix
from misc import dict_cls, except_deleted, log_except, request
from cache import Cache
from misc import log_except, request
class AppService(bottle.Bottle):
@ -37,13 +38,17 @@ class AppService(bottle.Bottle):
method="PUT",
)
Cache.cache["m_rooms"] = {}
def handle_event(self, event: dict) -> None:
event_type = event.get("type")
if event_type == "m.room.member" or event_type == "m.room.message":
obj = self.get_event_object(event)
elif event_type == "m.room.redaction":
obj = event
if event_type in (
"m.room.member",
"m.room.message",
"m.room.redaction",
):
obj = matrix.Event(event)
else:
self.logger.info(f"Unknown event type: {event_type}")
return
@ -86,23 +91,14 @@ class AppService(bottle.Bottle):
def mxc_url(self, mxc: str) -> str:
try:
homeserver, media_id = mxc.replace("mxc://", "").split("/")
converted = (
f"https://{self.server_name}/_matrix/media/r0/download/"
f"{homeserver}/{media_id}"
)
except ValueError:
converted = ""
return ""
return converted
def get_event_object(self, event: dict) -> matrix.Event:
# TODO use caching and invalidate old cache on member events.
event["author"] = dict_cls(
self.get_profile(event["sender"]), matrix.User
return (
f"https://{self.server_name}/_matrix/media/r0/download/"
f"{homeserver}/{media_id}"
)
return matrix.Event(event)
def join_room(self, room_id: str, mxid: str = "") -> None:
self.send(
"POST",
@ -117,42 +113,25 @@ class AppService(bottle.Bottle):
params={"user_id": mxid} if mxid else {},
)
def get_profile(self, mxid: str) -> dict:
resp = except_deleted(self.send)("GET", f"/profile/{mxid}")
# No profile exists for the user.
if not resp:
return {}
avatar_url = resp.get("avatar_url")
if avatar_url:
avatar_url = self.mxc_url(avatar_url)
return {
"avatar_url": avatar_url,
"displayname": resp.get("displayname"),
}
def get_members(self, room_id: str) -> List[str]:
resp = self.send(
"GET",
f"/rooms/{room_id}/members",
params={"membership": "join", "not_membership": "leave"},
)
return [
content["sender"]
for content in resp["chunk"]
if content["content"]["membership"] == "join"
]
def get_room_id(self, alias: str) -> str:
with Cache.lock:
room = Cache.cache["m_rooms"].get(alias)
if room:
return room
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
# TODO cache ?
room_id = resp["room_id"]
return resp["room_id"]
with Cache.lock:
Cache.cache["m_rooms"][alias] = room_id
return room_id
def get_event(self, event_id: str, room_id: str) -> matrix.Event:
resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}")
return matrix.Event(resp)
def upload(self, url: str) -> str:
"""
@ -211,12 +190,12 @@ class AppService(bottle.Bottle):
) -> dict:
params["access_token"] = self.as_token
headers = {"Content-Type": content_type}
content = json.dumps(content) if isinstance(content, dict) else content
payload = json.dumps(content) if isinstance(content, dict) else content
endpoint = (
f"{self.base_url}{endpoint}{path}?"
f"{urllib.parse.urlencode(params)}"
)
return self.http.request(
method, endpoint, body=content, headers=headers
method, endpoint, body=payload, headers=headers
)

6
appservice/cache.py Normal file
View file

@ -0,0 +1,6 @@
import threading
class Cache:
cache = {}
lock = threading.Lock()

View file

@ -4,11 +4,11 @@ import threading
from typing import List
class DataBase(object):
class DataBase:
def __init__(self, db_file) -> None:
self.create(db_file)
# The database is accessed via both the threads.
# The database is accessed via multiple threads.
self.lock = threading.Lock()
def create(self, db_file) -> None:
@ -92,7 +92,7 @@ class DataBase(object):
room = self.cur.fetchone()
# Return an empty string if nothing is bridged.
# Return an empty string if the channel is not bridged.
return "" if not room else room["channel_id"]
def list_channels(self) -> List[str]:
@ -116,6 +116,8 @@ class DataBase(object):
self.cur.execute("SELECT * FROM users")
users = self.cur.fetchall()
user = [user for user in users if user["mxid"] == mxid]
user: dict = next(
iter([user for user in users if user["mxid"] == mxid]), {}
)
return user[0] if user else {}
return user

View file

@ -1,10 +1,16 @@
from dataclasses import dataclass
from misc import dict_cls
CDN_URL = "https://cdn.discordapp.com"
ID_LEN = 18
def bitmask(bit: int) -> int:
return 1 << bit
@dataclass
class Channel(object):
class Channel:
id: str
type: str
guild_id: str = ""
@ -13,13 +19,32 @@ class Channel(object):
@dataclass
class Emote(object):
class Emote:
animated: bool
id: str
name: str
class User(object):
@dataclass
class MessageReference:
message_id: str
channel_id: str
guild_id: str
@dataclass
class Typing:
user_id: str
channel_id: str
@dataclass
class Webhook:
id: str
token: str
class User:
def __init__(self, user: dict) -> None:
self.discriminator = user["discriminator"]
self.id = user["id"]
@ -38,45 +63,57 @@ class User(object):
self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}"
class Message(object):
class Guild:
def __init__(self, guild: dict) -> None:
self.guild_id = guild["id"]
self.channels = [dict_cls(c, Channel) for c in guild["channels"]]
self.emojis = [dict_cls(e, Emote) for e in guild["emojis"]]
members = [member["user"] for member in guild["members"]]
self.members = [User(m) for m in members]
class GuildEmojisUpdate:
def __init__(self, update: dict) -> None:
self.guild_id = update["guild_id"]
self.emojis = [dict_cls(e, Emote) for e in update["emojis"]]
class GuildMembersChunk:
def __init__(self, chunk: dict) -> None:
self.chunk_index = chunk["chunk_index"]
self.chunk_count = chunk["chunk_count"]
self.guild_id = chunk["guild_id"]
self.members = [User(m) for m in chunk["members"]]
class GuildMemberUpdate:
def __init__(self, update: dict) -> None:
self.guild_id = update["guild_id"]
self.user = User(update["user"])
class Message:
def __init__(self, message: dict) -> None:
self.attachments = message.get("attachments", [])
self.channel_id = message["channel_id"]
self.content = message.get("content", "")
self.id = message["id"]
self.reference = message.get("message_reference", {}).get(
"message_id", ""
)
self.webhook_id = message.get("webhook_id", "")
self.mentions = [
User(mention) for mention in message.get("mentions", [])
]
ref = message.get("message_reference")
self.reference = dict_cls(ref, MessageReference) if ref else None
author = message.get("author")
self.author = User(author) if author else None
@dataclass
class DeletedMessage(object):
channel_id: str
id: str
@dataclass
class Typing(object):
user_id: str
channel_id: str
@dataclass
class Webhook(object):
id: str
token: str
class ChannelType(object):
class ChannelType:
GUILD_TEXT = 0
DM = 1
GUILD_VOICE = 2
@ -86,7 +123,7 @@ class ChannelType(object):
GUILD_STORE = 6
class InteractionResponseType(object):
class InteractionResponseType:
PONG = 0
ACKNOWLEDGE = 1
CHANNEL_MESSAGE = 2
@ -94,10 +131,7 @@ class InteractionResponseType(object):
ACKNOWLEDGE_WITH_SOURCE = 5
class GatewayIntents(object):
def bitmask(bit: int) -> int:
return 1 << bit
class GatewayIntents:
GUILDS = bitmask(0)
GUILD_MEMBERS = bitmask(1)
GUILD_BANS = bitmask(2)
@ -115,7 +149,7 @@ class GatewayIntents(object):
DIRECT_MESSAGE_TYPING = bitmask(14)
class GatewayOpCodes(object):
class GatewayOpCodes:
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
@ -129,7 +163,7 @@ class GatewayOpCodes(object):
HEARTBEAT_ACK = 11
class Payloads(object):
class Payloads:
def __init__(self, token: str) -> None:
self.seq = self.session = None
self.token = token
@ -143,27 +177,19 @@ class Payloads(object):
"d": {
"token": self.token,
"intents": GatewayIntents.GUILDS
| GatewayIntents.GUILD_EMOJIS
| GatewayIntents.GUILD_MEMBERS
| GatewayIntents.GUILD_MESSAGES
| GatewayIntents.GUILD_MESSAGE_TYPING,
| GatewayIntents.GUILD_MESSAGE_TYPING
| GatewayIntents.GUILD_PRESENCES,
"properties": {
"$os": "discord",
"$browser": "discord",
"$browser": "Discord Client",
"$device": "discord",
},
},
}
def QUERY(self, guild_id: str, query: str, limit: int = 1) -> dict:
"""
Return the Payload to query a member from a guild ID.
Return only a single match if `limit` isn't specified.
"""
return {
"op": GatewayOpCodes.REQUEST_GUILD_MEMBERS,
"d": {"guild_id": guild_id, "query": query, "limit": limit},
}
def RESUME(self) -> dict:
return {
"op": GatewayOpCodes.RESUME,

View file

@ -8,31 +8,27 @@ import urllib3
import websockets
import discord
from misc import dict_cls, log_except, request, wrap_async
from misc import dict_cls, log_except, request
class Gateway(object):
class Gateway:
def __init__(self, http: urllib3.PoolManager, token: str):
self.http = http
self.token = token
self.logger = logging.getLogger("discord")
self.cdn_url = "https://cdn.discordapp.com"
self.Payloads = discord.Payloads(self.token)
self.loop = self.websocket = None
self.query_cache = {}
self.websocket = None
@log_except
async def run(self) -> None:
self.loop = asyncio.get_running_loop()
self.query_ev = asyncio.Event()
self.heartbeat_task = None
self.heartbeat_task: asyncio.Future = None
self.resume = False
gateway_url = self.get_gateway_url()
while True:
try:
await self.gateway_handler(self.get_gateway_url())
await self.gateway_handler(gateway_url)
except (
websockets.ConnectionClosedError,
websockets.InvalidMessage,
@ -55,26 +51,71 @@ class Gateway(object):
await asyncio.sleep(interval_ms / 1000)
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
def query_handler(self, data: dict) -> None:
members = data["members"]
guild_id = data["guild_id"]
async def handle_resp(self, data: dict) -> None:
data_dict = data["d"]
for member in members:
user = member["user"]
self.query_cache[guild_id].append(user)
opcode = data["op"]
self.query_ev.set()
seq = data["s"]
if seq:
self.Payloads.seq = seq
if opcode == discord.GatewayOpCodes.DISPATCH:
otype = data["t"]
if otype == "READY":
self.Payloads.session = data_dict["session_id"]
self.logger.info("READY")
else:
self.handle_otype(data_dict, otype)
elif opcode == discord.GatewayOpCodes.HELLO:
heartbeat_interval = data_dict.get("heartbeat_interval")
self.logger.info(f"Heartbeat Interval: {heartbeat_interval}")
# Send periodic hearbeats to gateway.
self.heartbeat_task = asyncio.ensure_future(
self.heartbeat_handler(heartbeat_interval)
)
await self.websocket.send(
json.dumps(
self.Payloads.RESUME()
if self.resume
else self.Payloads.IDENTIFY()
)
)
elif opcode == discord.GatewayOpCodes.RECONNECT:
self.logger.info("Received RECONNECT.")
self.resume = True
await self.websocket.close()
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
self.logger.info("Received INVALID_SESSION.")
self.resume = False
await self.websocket.close()
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
# NOP
pass
else:
self.logger.info(
"Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}"
)
def handle_otype(self, data: dict, otype: str) -> None:
if otype == "MESSAGE_CREATE" or otype == "MESSAGE_UPDATE":
if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"):
obj = discord.Message(data)
elif otype == "MESSAGE_DELETE":
obj = dict_cls(data, discord.DeletedMessage)
elif otype == "TYPING_START":
obj = dict_cls(data, discord.Typing)
elif otype == "GUILD_MEMBERS_CHUNK":
self.query_handler(data)
return
elif otype == "GUILD_CREATE":
obj = discord.Guild(data)
elif otype == "GUILD_MEMBER_UPDATE":
obj = discord.GuildMemberUpdate(data)
elif otype == "GUILD_EMOJIS_UPDATE":
obj = discord.GuildEmojisUpdate(data)
else:
self.logger.info(f"Unknown OTYPE: {otype}")
return
@ -90,119 +131,20 @@ class Gateway(object):
try:
func(obj)
except Exception:
self.logger.exception(f"Ignoring exception in {func}:")
self.logger.exception(f"Ignoring exception in '{func.__name__}':")
async def gateway_handler(self, gateway_url: str) -> None:
async with websockets.connect(
f"{gateway_url}/?v=8&encoding=json"
) as websocket:
self.websocket = websocket
async for message in websocket:
data = json.loads(message)
data_dict = data.get("d")
opcode = data.get("op")
seq = data.get("s")
if seq:
self.Payloads.seq = seq
if opcode == discord.GatewayOpCodes.DISPATCH:
otype = data.get("t")
if otype == "READY":
self.Payloads.session = data_dict["session_id"]
self.logger.info("READY")
else:
self.handle_otype(data_dict, otype)
elif opcode == discord.GatewayOpCodes.HELLO:
heartbeat_interval = data_dict.get("heartbeat_interval")
self.logger.info(
f"Heartbeat Interval: {heartbeat_interval}"
)
# Send periodic hearbeats to gateway.
self.heartbeat_task = asyncio.ensure_future(
self.heartbeat_handler(heartbeat_interval)
)
await websocket.send(
json.dumps(
self.Payloads.RESUME()
if self.resume
else self.Payloads.IDENTIFY()
)
)
elif opcode == discord.GatewayOpCodes.RECONNECT:
self.logger.info("Received RECONNECT.")
self.resume = True
await websocket.close()
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
self.logger.info("Received INVALID_SESSION.")
self.resume = False
await websocket.close()
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
# NOP
pass
else:
self.logger.info(
f"Unknown OP code {opcode}:\n"
f"{json.dumps(data, indent=4)}"
)
@wrap_async
async def query_member(self, guild_id: str, name: str) -> discord.User:
"""
Query the members for a given guild and return the first match.
"""
self.query_ev.clear()
def query():
if not self.query_cache.get(guild_id):
self.query_cache[guild_id] = []
user = [
user
for user in self.query_cache[guild_id]
if name.lower() in user["username"].lower()
]
return None if not user else discord.User(user[0])
user = query()
if user:
return user
if not self.websocket or self.websocket.closed:
self.logger.warning("Not fetching members, websocket closed.")
return
await self.websocket.send(
json.dumps(self.Payloads.QUERY(guild_id, name))
)
# TODO clean this mess.
# Wait for our websocket to receive the chunk.
await asyncio.wait_for(self.query_ev.wait(), timeout=5)
return query()
await self.handle_resp(json.loads(message))
def get_channel(self, channel_id: str) -> discord.Channel:
"""
Get the channel object for a given channel ID.
Get the channel for a given channel ID.
"""
resp = self.send("GET", f"/channels/{channel_id}")
@ -259,9 +201,17 @@ class Gateway(object):
f"{message_id}",
)
def send_webhook(self, webhook: discord.Webhook, **kwargs) -> str:
content = {
**kwargs,
def send_webhook(
self,
webhook: discord.Webhook,
avatar_url: str,
content: str,
username: str,
) -> discord.Message:
payload = {
"avatar_url": avatar_url,
"content": content,
"username": username,
# Disable 'everyone' and 'role' mentions.
"allowed_mentions": {"parse": ["users"]},
}
@ -269,11 +219,11 @@ class Gateway(object):
resp = self.send(
"POST",
f"/webhooks/{webhook.id}/{webhook.token}",
content,
payload,
{"wait": True},
)
return resp["id"]
return discord.Message(resp)
def send_message(self, message: str, channel_id: str) -> None:
self.send(
@ -294,8 +244,8 @@ class Gateway(object):
}
# 'body' being an empty dict breaks "GET" requests.
content = json.dumps(content) if content else None
payload = json.dumps(content) if content else None
return self.http.request(
method, endpoint, body=content, headers=headers
method, endpoint, body=payload, headers=headers
)

View file

@ -5,21 +5,19 @@ import os
import re
import sys
import threading
from typing import Dict, Tuple, Union
from typing import Dict, List, Tuple
import urllib3
import discord
import matrix
from appservice import AppService
from cache import Cache
from db import DataBase
from errors import RequestError
from gateway import Gateway
from misc import dict_cls, except_deleted, hash_str
# TODO should this be cleared periodically ?
message_cache: Dict[str, Union[discord.Webhook, str]] = {}
class MatrixClient(AppService):
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
@ -27,9 +25,11 @@ class MatrixClient(AppService):
self.db = DataBase(config["database"])
self.discord = DiscordClient(self, config, http)
self.emote_cache: Dict[str, str] = {}
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
for k in ("m_emotes", "m_members", "m_messages"):
Cache.cache[k] = {}
def handle_bridge(self, message: matrix.Event) -> None:
# Ignore events that aren't for us.
if message.sender.split(":")[
@ -37,6 +37,7 @@ class MatrixClient(AppService):
] != self.server_name or not message.body.startswith("!bridge"):
return
# Get the channel ID.
try:
channel = message.body.split()[1]
except IndexError:
@ -46,7 +47,7 @@ class MatrixClient(AppService):
try:
channel = self.discord.get_channel(channel)
except RequestError as e:
# The channel can be invalid or we may not have permission.
# The channel can be invalid or we may not have permissions.
self.logger.warning(f"Failed to fetch channel {channel}: {e}")
return
@ -61,7 +62,15 @@ class MatrixClient(AppService):
self.create_room(channel, message.sender)
def on_member(self, event: matrix.Event) -> None:
# Ignore events that aren't for us.
with Cache.lock:
# Just lazily clear the whole member cache on
# membership update events.
if event.room_id in Cache.cache["m_members"]:
self.logger.info(
f"Clearing member cache for room '{event.room_id}'."
)
del Cache.cache["m_members"][event.room_id]
if (
event.sender.split(":")[-1] != self.server_name
or event.state_key != self.user_id
@ -70,7 +79,7 @@ class MatrixClient(AppService):
return
# Join the direct message room.
self.logger.info(f"Joining direct message room {event.room_id}.")
self.logger.info(f"Joining direct message room '{event.room_id}'.")
self.join_room(event.room_id)
def on_message(self, message: matrix.Event) -> None:
@ -88,51 +97,78 @@ class MatrixClient(AppService):
if not channel_id:
return
webhook = self.discord.get_webhook(channel_id, "matrix_bridge")
author = self.get_members(message.room_id)[message.sender]
if not author.display_name:
author.display_name = message.sender
webhook = self.discord.get_webhook(
channel_id, self.discord.webhook_name
)
if message.relates_to and message.reltype == "m.replace":
relation = message_cache.get(message.relates_to)
with Cache.lock:
message_id = Cache.cache["m_messages"].get(message.relates_to)
if not message.new_body or not relation:
if not message_id or not message.new_body:
return
message.new_body = self.process_message(
channel_id, message.new_body
)
message.new_body = self.process_message(message)
except_deleted(self.discord.edit_webhook)(
message.new_body, relation["message_id"], webhook
message.new_body, message_id, webhook
)
else:
message.body = (
f"`{message.body}`: {self.mxc_url(message.attachment)}"
if message.attachment
else self.process_message(channel_id, message.body)
else self.process_message(message)
)
message_cache[message.event_id] = {
"message_id": self.discord.send_webhook(
webhook,
avatar_url=message.author.avatar_url,
content=message.body,
username=message.author.displayname,
),
"webhook": webhook,
}
message_id = self.discord.send_webhook(
webhook,
self.mxc_url(author.avatar_url),
message.body,
author.display_name,
).id
@except_deleted
def on_redaction(self, event: dict) -> None:
redacts = event["redacts"]
with Cache.lock:
Cache.cache["m_messages"][message.id] = message_id
event = message_cache.get(redacts)
def on_redaction(self, event: matrix.Event) -> None:
with Cache.lock:
message_id = Cache.cache["m_messages"].get(event.redacts)
if not event:
if not message_id:
return
self.discord.delete_webhook(event["message_id"], event["webhook"])
webhook = self.discord.get_webhook(
self.db.get_channel(event.room_id), self.discord.webhook_name
)
message_cache.pop(redacts)
except_deleted(self.discord.delete_webhook)(message_id, webhook)
with Cache.lock:
del Cache.cache["m_messages"][event.redacts]
def get_members(self, room_id: str) -> Dict[str, matrix.User]:
with Cache.lock:
cached = Cache.cache["m_members"].get(room_id)
if cached:
return cached
resp = self.send("GET", f"/rooms/{room_id}/joined_members")
joined = resp["joined"]
for k, v in joined.items():
joined[k] = dict_cls(v, matrix.User)
with Cache.lock:
Cache.cache["m_members"][room_id] = joined
return joined
def create_room(self, channel: discord.Channel, sender: str) -> None:
"""
@ -164,7 +200,11 @@ class MatrixClient(AppService):
self.db.add_room(resp["room_id"], channel.id)
def create_message_event(
self, message: str, emotes: dict, edit: str = "", reply: str = ""
self,
message: str,
emotes: dict,
edit: str = "",
reference: discord.MessageReference = None,
) -> dict:
content = {
"body": message,
@ -173,20 +213,39 @@ class MatrixClient(AppService):
"formatted_body": self.get_fmt(message, emotes),
}
event = message_cache.get(reply)
if reference:
# Reply to a Discord message.
with Cache.lock:
event_id = Cache.cache["d_messages"].get(reference.message_id)
if event:
content = {
**content,
"m.relates_to": {
"m.in_reply_to": {"event_id": event["event_id"]}
},
"formatted_body": f"""<mx-reply><blockquote>\
<a href='https://matrix.to/#/{event["room_id"]}/{event["event_id"]}'>\
In reply to</a><a href='https://matrix.to/#/{event["mxid"]}'>\
{event["mxid"]}</a><br>{event["body"]}</blockquote></mx-reply>\
# Reply to a Matrix message. (maybe)
if not event_id:
with Cache.lock:
event_id = [
k
for k, v in Cache.cache["m_messages"].items()
if v == reference.message_id
]
event_id = next(iter(event_id), "")
if reference and event_id:
event = except_deleted(self.get_event)(
event_id,
self.get_room_id(self.discord.matrixify(reference.channel_id)),
)
if event:
content = {
**content,
"body": (
f"> <{event.sender}> {event.body}\n{content['body']}"
),
"m.relates_to": {"m.in_reply_to": {"event_id": event.id}},
"formatted_body": f"""<mx-reply><blockquote>\
<a href='https://matrix.to/#/{event.room_id}/{event.id}'>\
In reply to</a><a href='https://matrix.to/#/{event.sender}'>\
{event.sender}</a><br>{event.formatted_body}</blockquote></mx-reply>\
{content["formatted_body"]}""",
}
}
if edit:
content = {
@ -227,63 +286,67 @@ In reply to</a><a href='https://matrix.to/#/{event["mxid"]}'>\
for emote in emotes
]
[thread.start() for thread in upload_threads]
[thread.join() for thread in upload_threads]
# Acquire the lock before starting the threads to avoid resource
# contention by tens of threads at once.
with Cache.lock:
for thread in upload_threads:
thread.start()
for thread in upload_threads:
thread.join()
for emote in emotes:
emote_ = self.emote_cache.get(emote)
with Cache.lock:
for emote in emotes:
emote_ = Cache.cache["m_emotes"].get(emote)
if emote_:
emote = f":{emote}:"
message = message.replace(
emote,
f"""<img alt=\"{emote}\" title=\"{emote}\" \
if emote_:
emote = f":{emote}:"
message = message.replace(
emote,
f"""<img alt=\"{emote}\" title=\"{emote}\" \
height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
)
)
return message
def process_message(self, channel_id: str, message: str) -> str:
def process_message(self, event: matrix.Event) -> str:
message = event.new_body if event.new_body else event.body
message = message[:2000] # Discord limit.
id_regex = f"[0-9]{{{discord.ID_LEN}}}"
emotes = re.findall(r":(\w*):", message)
mentions = re.findall(r"(@(\w*))", message)
mentions = re.findall(
f"@{self.format}{id_regex}:{re.escape(self.server_name)}",
event.formatted_body,
)
# Remove the puppet user's username from replies.
message = re.sub(f"<@{self.format}.+?>", "", message)
added_emotes = []
for emote in emotes:
# Don't replace emote names with IDs multiple times.
if emote not in added_emotes:
added_emotes.append(emote)
emote_ = self.discord.emote_cache.get(emote)
with Cache.lock:
for emote in set(emotes):
emote_ = Cache.cache["d_emotes"].get(emote)
if emote_:
message = message.replace(f":{emote}:", emote_)
# Don't unnecessarily fetch the channel.
if mentions:
guild_id = self.discord.get_channel(channel_id).guild_id
for mention in set(mentions):
username = self.db.fetch_user(mention).get("username")
if username:
match = re.search(id_regex, mention)
# TODO this can block for too long if a long list is to be fetched.
for mention in mentions:
if not mention[1]:
continue
try:
member = self.discord.query_member(guild_id, mention[1])
except (asyncio.TimeoutError, RuntimeError):
continue
if member:
message = message.replace(mention[0], member.mention)
if match:
# Replace the 'mention' so that the user is tagged
# in the case of replies aswell.
# '> <@_discord_1234:localhost> Message'
for replace in (mention, username):
message = message.replace(
replace, f"<@{match.group()}>"
)
return message
def upload_emote(self, emote_name: str, emote_id: str) -> None:
# There won't be a race condition here, since only a unique
# set of emotes are uploaded at a time.
if emote_name in self.emote_cache:
if emote_name in Cache.cache["m_emotes"]:
return
emote_url = f"{discord.CDN_URL}/emojis/{emote_id}"
@ -291,7 +354,8 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
# We don't want the message to be dropped entirely if an emote
# fails to upload for some reason.
try:
self.emote_cache[emote_name] = self.upload(emote_url)
# TODO This is not thread safe, but we're protected by the GIL.
Cache.cache["m_emotes"][emote_name] = self.upload(emote_url)
except RequestError as e:
self.logger.warning(f"Failed to upload emote {emote_id}: {e}")
@ -340,66 +404,19 @@ class DiscordClient(Gateway):
super().__init__(http, config["discord_token"])
self.app = appservice
self.emote_cache: Dict[str, str] = {}
self.webhook_cache: Dict[str, discord.Webhook] = {}
self.webhook_name = "matrix_bridge"
async def sync(self) -> None:
"""
Periodically compare the usernames and avatar URLs with Discord
and update if they differ. Also synchronise emotes.
"""
# TODO use websocket events and requests.
def sync_emotes(guilds: set):
emotes = []
for guild in guilds:
[emotes.append(emote) for emote in (self.get_emotes(guild))]
self.emote_cache.clear() # Clear deleted/renamed emotes.
for emote in emotes:
self.emote_cache[f"{emote.name}"] = (
f"<{'a' if emote.animated else ''}:"
f"{emote.name}:{emote.id}>"
)
def sync_users(guilds: set):
for guild in guilds:
[
self.sync_profile(user, self.matrixify(user.id, user=True))
for user in self.get_members(guild)
]
while True:
guilds = set() # Avoid duplicates.
try:
for channel in self.app.db.list_channels():
guilds.add(self.get_channel(channel).guild_id)
sync_emotes(guilds)
sync_users(guilds)
# Don't let the background task die.
except RequestError:
self.logger.exception(
"Ignoring exception during background sync:"
)
await asyncio.sleep(120) # Check every 2 minutes.
async def start(self) -> None:
asyncio.ensure_future(self.sync())
await self.run()
for k in ("d_emotes", "d_messages", "d_webhooks"):
Cache.cache[k] = {}
def to_return(self, message: discord.Message) -> bool:
with Cache.lock:
hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()]
return (
message.channel_id not in self.app.db.list_channels()
or not message.author # Embeds can be weird sometimes.
or message.webhook_id
in [hook.id for hook in self.webhook_cache.values()]
or message.webhook_id in hook_ids
)
def matrixify(self, id: str, user: bool = False) -> str:
@ -408,11 +425,13 @@ class DiscordClient(Gateway):
f"{self.app.server_name}"
)
def sync_profile(self, user: discord.User, mxid: str) -> None:
def sync_profile(self, user: discord.User) -> None:
"""
Sync the avatar and username for a puppeted user.
"""
mxid = self.matrixify(user.id, user=True)
profile = self.app.db.fetch_user(mxid)
# User doesn't exist.
@ -422,10 +441,10 @@ class DiscordClient(Gateway):
username = f"{user.username}#{user.discriminator}"
if user.avatar_url != profile["avatar_url"]:
self.logger.info(f"Updating avatar for Discord user {user.id}")
self.logger.info(f"Updating avatar for Discord user '{user.id}'")
self.app.set_avatar(user.avatar_url, mxid)
if username != profile["username"]:
self.logger.info(f"Updating username for Discord user {user.id}")
self.logger.info(f"Updating username for Discord user '{user.id}'")
self.app.set_nick(username, mxid)
def wrap(self, message: discord.Message) -> Tuple[str, str]:
@ -457,62 +476,95 @@ class DiscordClient(Gateway):
self.app.set_avatar(message.author.avatar_url, mxid)
if mxid not in self.app.get_members(room_id):
self.logger.info(f"Inviting user {mxid} to room {room_id}.")
self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.")
self.app.send_invite(room_id, mxid)
self.app.join_room(room_id, mxid)
if message.webhook_id:
# Sync webhooks here as they can't be accessed like guild members.
self.sync_profile(message.author, mxid)
self.sync_profile(message.author)
return mxid, room_id
def cache_emotes(self, emotes: List[discord.Emote]):
# TODO maybe "namespace" emotes by guild in the cache ?
with Cache.lock:
for emote in emotes:
Cache.cache["d_emotes"][emote.name] = (
f"<{'a' if emote.animated else ''}:"
f"{emote.name}:{emote.id}>"
)
def on_guild_create(self, guild: discord.Guild) -> None:
for member in guild.members:
self.sync_profile(member)
self.cache_emotes(guild.emojis)
def on_guild_emojis_update(
self, update: discord.GuildEmojisUpdate
) -> None:
self.cache_emotes(update.emojis)
def on_guild_member_update(
self, update: discord.GuildMemberUpdate
) -> None:
self.sync_profile(update.user)
def on_message_create(self, message: discord.Message) -> None:
if self.to_return(message):
return
mxid, room_id = self.wrap(message)
content, emotes = self.process_message(message)
content_, emotes = self.process_message(message)
content = self.app.create_message_event(
content, emotes, reply=message.reference
content_, emotes, reference=message.reference
)
message_cache[message.id] = {
"body": content["body"],
"event_id": self.app.send_message(room_id, content, mxid),
"mxid": mxid,
"room_id": room_id,
}
with Cache.lock:
Cache.cache["d_messages"][message.id] = self.app.send_message(
room_id, content, mxid
)
def on_message_delete(self, message: discord.DeletedMessage) -> None:
event = message_cache.get(message.id)
def on_message_delete(self, message: discord.Message) -> None:
with Cache.lock:
event_id = Cache.cache["d_messages"].get(message.id)
if not event:
if not event_id:
return
self.app.redact(event["event_id"], event["room_id"], event["mxid"])
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
event = except_deleted(self.app.get_event)(event_id, room_id)
message_cache.pop(message.id)
if event:
self.app.redact(event.id, event.room_id, event.sender)
with Cache.lock:
del Cache.cache["d_messages"][message.id]
def on_message_update(self, message: discord.Message) -> None:
if self.to_return(message):
return
event = message_cache.get(message.id)
with Cache.lock:
event_id = Cache.cache["d_messages"].get(message.id)
if not event:
if not event_id:
return
content, emotes = self.process_message(message)
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
mxid = self.matrixify(message.author.id, user=True)
content_, emotes = self.process_message(message)
content = self.app.create_message_event(
content, emotes, edit=event["event_id"]
content_, emotes, edit=event_id
)
self.app.send_message(event["room_id"], content, event["mxid"])
self.app.send_message(room_id, content, mxid)
def on_typing_start(self, typing: discord.Typing) -> None:
if typing.channel_id not in self.app.db.list_channels():
@ -533,7 +585,8 @@ class DiscordClient(Gateway):
"""
# Check the cache first.
webhook = self.webhook_cache.get(channel_id)
with Cache.lock:
webhook = Cache.cache["d_webhooks"].get(channel_id)
if webhook:
return webhook
@ -551,24 +604,26 @@ class DiscordClient(Gateway):
if not webhook:
webhook = self.create_webhook(channel_id, name)
self.webhook_cache[channel_id] = webhook
with Cache.lock:
Cache.cache["d_webhooks"][channel_id] = webhook
return webhook
def process_message(self, message: discord.Message) -> Tuple[str, str]:
def process_message(self, message: discord.Message) -> Tuple[str, Dict]:
content = message.content
emotes = {}
regex = r"<a?:(\w+):(\d+)>"
# Mentions can either be in the form of `<@1234>` or `<@!1234>`.
for char in ("", "!"):
for member in message.mentions:
for member in message.mentions:
for char in ("", "!"):
content = content.replace(
f"<@{char}{member.id}>", f"@{member.username}"
)
# `except_deleted` for invalid channels.
for channel in re.findall(r"<#([0-9]+)>", content):
# TODO can this block for too long ?
for channel in re.findall(r"<#([0-9]{{{discord.ID_LEN}}})>", content):
channel_ = except_deleted(self.get_channel)(channel)
content = content.replace(
f"<#{channel}>",
@ -613,6 +668,16 @@ def config_gen(basedir: str, config_file: str) -> dict:
return json.loads(f.read())
def excepthook(exc_type, exc_value, exc_traceback):
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logging.critical(
"Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback)
)
def main() -> None:
try:
basedir = sys.argv[1]
@ -634,9 +699,9 @@ def main() -> None:
],
)
http = urllib3.PoolManager(maxsize=10)
sys.excepthook = excepthook
app = MatrixClient(config, http)
app = MatrixClient(config, urllib3.PoolManager(maxsize=10))
# Start the bottle app in a separate thread.
app_thread = threading.Thread(
@ -645,7 +710,7 @@ def main() -> None:
app_thread.start()
try:
asyncio.run(app.discord.start())
asyncio.run(app.discord.run())
except KeyboardInterrupt:
sys.exit()

View file

@ -2,20 +2,21 @@ from dataclasses import dataclass
@dataclass
class User(object):
class User:
avatar_url: str = ""
displayname: str = ""
display_name: str = ""
class Event(object):
class Event:
def __init__(self, event: dict):
content = event["content"]
content = event.get("content", {})
self.attachment = content.get("url")
self.author = event["author"]
self.body = content.get("body", "").strip()
self.event_id = event["event_id"]
self.formatted_body = content.get("formatted_body", "")
self.id = event["event_id"]
self.is_direct = content.get("is_direct", False)
self.redacts = event.get("redacts", "")
self.room_id = event["room_id"]
self.sender = event["sender"]
self.state_key = event.get("state_key", "")

View file

@ -1,4 +1,3 @@
import asyncio
import json
from dataclasses import fields
from typing import Any
@ -8,13 +7,13 @@ import urllib3
from errors import RequestError
def dict_cls(dict_var: dict, cls: Any) -> Any:
def dict_cls(d: dict, cls: Any) -> Any:
"""
Create a dataclass from a dictionary.
"""
field_names = set(f.name for f in fields(cls))
filtered_dict = {k: v for k, v in dict_var.items() if k in field_names}
filtered_dict = {k: v for k, v in d.items() if k in field_names}
return cls(**filtered_dict)
@ -34,22 +33,6 @@ def log_except(fn):
return wrapper
def wrap_async(fn):
"""
Call an asynchronous function from a synchronous one.
"""
def wrapper(self, *args, **kwargs):
if not self.loop:
raise RuntimeError("loop is None.")
return asyncio.run_coroutine_threadsafe(
fn(self, *args, **kwargs), loop=self.loop
).result()
return wrapper
def request(fn):
"""
Either return json data or raise a `RequestError` if the request was
@ -75,8 +58,7 @@ def request(fn):
def except_deleted(fn):
"""
Ignore the `RequestError` on 404s, the message might have been
deleted by someone else already.
Ignore the `RequestError` on 404s, the content might have been removed.
"""
def wrapper(*args, **kwargs):

View file

@ -1,3 +1,3 @@
bottle==0.12.19
urllib3==1.26.3
websockets==8.1
bottle
urllib3
websockets