aioappsrv/appservice/main.py

735 lines
23 KiB
Python
Raw Normal View History

2021-04-17 06:45:51 +02:00
import asyncio
import json
import logging
import os
import re
import sys
import threading
from typing import Dict, List, Tuple
2021-04-17 06:45:51 +02:00
2021-06-28 11:48:59 +02:00
import markdown
2021-04-17 06:45:51 +02:00
import urllib3
import discord
import matrix
from appservice import AppService
from cache import Cache
2021-04-17 06:45:51 +02:00
from db import DataBase
from errors import RequestError
from gateway import Gateway
from misc import dict_cls, except_deleted, hash_str
2021-04-17 06:45:51 +02:00
class MatrixClient(AppService):
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
super().__init__(config, http)
self.db = DataBase(config["database"])
self.discord = DiscordClient(self, config, http)
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
2021-06-06 11:56:05 +02:00
# TODO Find a cleaner way to use these keys.
for k in ("m_emotes", "m_members", "m_messages"):
Cache.cache[k] = {}
2021-04-17 06:45:51 +02:00
def handle_bridge(self, message: matrix.Event) -> None:
# Ignore events that aren't for us.
if message.sender.split(":")[
-1
] != self.server_name or not message.body.startswith("!bridge"):
return
# Get the channel ID.
2021-04-17 06:45:51 +02:00
try:
channel = message.body.split()[1]
except IndexError:
return
# Check if the given channel is valid.
try:
channel = self.discord.get_channel(channel)
except RequestError as e:
# The channel can be invalid or we may not have permissions.
2021-04-17 06:45:51 +02:00
self.logger.warning(f"Failed to fetch channel {channel}: {e}")
return
if (
channel.type != discord.ChannelType.GUILD_TEXT
or channel.id in self.db.list_channels()
):
return
self.logger.info(f"Creating bridged room for channel {channel.id}.")
self.create_room(channel, message.sender)
def on_member(self, event: matrix.Event) -> None:
with Cache.lock:
# Just lazily clear the whole member cache on
# membership update events.
if event.room_id in Cache.cache["m_members"]:
self.logger.info(
f"Clearing member cache for room '{event.room_id}'."
)
del Cache.cache["m_members"][event.room_id]
2021-04-17 06:45:51 +02:00
if (
event.sender.split(":")[-1] != self.server_name
or event.state_key != self.user_id
or not event.is_direct
):
return
# Join the direct message room.
self.logger.info(f"Joining direct message room '{event.room_id}'.")
2021-04-17 06:45:51 +02:00
self.join_room(event.room_id)
def on_message(self, message: matrix.Event) -> None:
2021-04-18 09:53:21 +02:00
if (
message.sender.startswith((f"@{self.format}", self.user_id))
or not message.body
):
2021-04-17 06:45:51 +02:00
return
# Handle bridging commands.
self.handle_bridge(message)
channel_id = self.db.get_channel(message.room_id)
if not channel_id:
return
author = self.get_members(message.room_id)[message.sender]
webhook = self.discord.get_webhook(
channel_id, self.discord.webhook_name
)
2021-04-17 06:45:51 +02:00
if message.relates_to and message.reltype == "m.replace":
with Cache.lock:
message_id = Cache.cache["m_messages"].get(message.relates_to)
2021-04-17 06:45:51 +02:00
2021-06-06 11:56:05 +02:00
# TODO validate if the original author sent the edit.
if not message_id or not message.new_body:
2021-04-17 06:45:51 +02:00
return
message.new_body = self.process_message(message)
2021-04-17 06:45:51 +02:00
except_deleted(self.discord.edit_webhook)(
message.new_body, message_id, webhook
2021-04-17 06:45:51 +02:00
)
else:
message.body = (
f"`{message.body}`: {self.mxc_url(message.attachment)}"
if message.attachment
else self.process_message(message)
)
message_id = self.discord.send_webhook(
webhook,
2021-06-05 16:44:44 +02:00
self.mxc_url(author.avatar_url) if author.avatar_url else None,
message.body,
2021-05-18 10:04:42 +02:00
author.display_name if author.display_name else message.sender,
).id
2021-04-17 06:45:51 +02:00
with Cache.lock:
Cache.cache["m_messages"][message.id] = message_id
2021-04-17 06:45:51 +02:00
def on_redaction(self, event: matrix.Event) -> None:
with Cache.lock:
message_id = Cache.cache["m_messages"].get(event.redacts)
2021-04-17 06:45:51 +02:00
if not message_id:
2021-04-17 06:45:51 +02:00
return
webhook = self.discord.get_webhook(
self.db.get_channel(event.room_id), self.discord.webhook_name
)
except_deleted(self.discord.delete_webhook)(message_id, webhook)
with Cache.lock:
del Cache.cache["m_messages"][event.redacts]
def get_members(self, room_id: str) -> Dict[str, matrix.User]:
with Cache.lock:
cached = Cache.cache["m_members"].get(room_id)
if cached:
return cached
resp = self.send("GET", f"/rooms/{room_id}/joined_members")
joined = resp["joined"]
for k, v in joined.items():
joined[k] = dict_cls(v, matrix.User)
with Cache.lock:
Cache.cache["m_members"][room_id] = joined
2021-04-17 06:45:51 +02:00
return joined
2021-04-17 06:45:51 +02:00
def create_room(self, channel: discord.Channel, sender: str) -> None:
"""
Create a bridged room and invite the person who invoked the command.
"""
content = {
"room_alias_name": f"{self.format}{channel.id}",
"name": channel.name,
2021-06-05 14:48:31 +02:00
"topic": channel.topic if channel.topic else channel.name,
2021-04-17 06:45:51 +02:00
"visibility": "private",
"invite": [sender],
"creation_content": {"m.federate": True},
"initial_state": [
{
"type": "m.room.join_rules",
"content": {"join_rule": "public"},
},
{
"type": "m.room.history_visibility",
"content": {"history_visibility": "shared"},
},
],
2021-06-05 14:48:31 +02:00
"power_level_content_override": {
"users": {sender: 100, self.user_id: 100}
},
2021-04-17 06:45:51 +02:00
}
resp = self.send("POST", "/createRoom", content)
self.db.add_room(resp["room_id"], channel.id)
def create_message_event(
self,
message: str,
emotes: dict,
edit: str = "",
2021-05-18 11:30:48 +02:00
reference: discord.Message = None,
2021-04-17 06:45:51 +02:00
) -> dict:
content = {
2021-05-18 10:04:42 +02:00
"body": message,
2021-04-17 06:45:51 +02:00
"format": "org.matrix.custom.html",
"msgtype": "m.text",
"formatted_body": self.get_fmt(message, emotes),
}
2021-05-18 11:30:48 +02:00
ref_id = None
if reference:
# Reply to a Discord message.
with Cache.lock:
2021-05-18 11:30:48 +02:00
ref_id = Cache.cache["d_messages"].get(reference.id)
# Reply to a Matrix message. (maybe)
2021-05-18 11:30:48 +02:00
if not ref_id:
with Cache.lock:
2021-05-18 11:30:48 +02:00
ref_id = [
k
for k, v in Cache.cache["m_messages"].items()
2021-05-18 11:30:48 +02:00
if v == reference.id
]
2021-05-18 11:30:48 +02:00
ref_id = next(iter(ref_id), "")
2021-05-18 11:30:48 +02:00
if ref_id:
event = except_deleted(self.get_event)(
2021-05-18 11:30:48 +02:00
ref_id,
self.get_room_id(self.discord.matrixify(reference.channel_id)),
)
if event:
# Content with the reply fallbacks stripped.
tmp = ""
# We don't want to strip lines starting with "> " after
# encountering a regular line, so we use this variable.
got_fallback = True
for line in event.body.split("\n"):
if not line.startswith("> "):
got_fallback = False
if not got_fallback:
tmp += line
event.body = tmp
event.formatted_body = (
2021-07-04 09:09:05 +02:00
re.sub("<mx-reply>.*</mx-reply>", "", event.formatted_body)
if event.formatted_body
else event.body
)
content = {
**content,
"body": (
f"> <{event.sender}> {event.body}\n{content['body']}"
),
"m.relates_to": {"m.in_reply_to": {"event_id": event.id}},
"formatted_body": f"""<mx-reply><blockquote>\
<a href="https://matrix.to/#/{event.room_id}/{event.id}">\
In reply to</a><a href="https://matrix.to/#/{event.sender}">\
{event.sender}</a><br>{event.formatted_body}</blockquote></mx-reply>\
2021-04-17 06:45:51 +02:00
{content["formatted_body"]}""",
}
2021-04-17 06:45:51 +02:00
if edit:
content = {
**content,
"body": f" * {content['body']}",
"formatted_body": f" * {content['formatted_body']}",
"m.relates_to": {"event_id": edit, "rel_type": "m.replace"},
"m.new_content": {**content},
}
return content
def get_fmt(self, message: str, emotes: dict) -> str:
2021-06-28 11:48:59 +02:00
message = markdown.markdown(message)
2021-04-17 06:45:51 +02:00
# Upload emotes in multiple threads so that we don't
# block the Discord bot for too long.
upload_threads = [
threading.Thread(
target=self.upload_emote, args=(emote, emotes[emote])
)
for emote in emotes
]
# Acquire the lock before starting the threads to avoid resource
# contention by tens of threads at once.
with Cache.lock:
for thread in upload_threads:
thread.start()
for thread in upload_threads:
thread.join()
2021-04-17 06:45:51 +02:00
with Cache.lock:
for emote in emotes:
emote_ = Cache.cache["m_emotes"].get(emote)
2021-04-17 06:45:51 +02:00
if emote_:
emote = f":{emote}:"
message = message.replace(
emote,
f"""<img alt=\"{emote}\" title=\"{emote}\" \
2021-04-17 06:45:51 +02:00
height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
)
2021-04-17 06:45:51 +02:00
return message
def process_message(self, event: matrix.Event) -> str:
message = event.new_body if event.new_body else event.body
id_regex = f"[0-9]{{{discord.ID_LEN}}}"
2021-04-17 06:45:51 +02:00
emotes = re.findall(r":(\w*):", message)
mentions = re.findall(
f"@{self.format}{id_regex}:{re.escape(self.server_name)}",
event.formatted_body,
)
2021-04-17 06:45:51 +02:00
with Cache.lock:
for emote in set(emotes):
emote_ = Cache.cache["d_emotes"].get(emote)
2021-04-17 06:45:51 +02:00
if emote_:
message = message.replace(f":{emote}:", emote_)
for mention in set(mentions):
username = self.db.fetch_user(mention).get("username")
if username:
match = re.search(id_regex, mention)
2021-04-17 06:45:51 +02:00
if match:
# Replace the 'mention' so that the user is tagged
# in the case of replies aswell.
# '> <@_discord_1234:localhost> Message'
for replace in (mention, username):
message = message.replace(
replace, f"<@{match.group()}>"
)
2021-04-17 06:45:51 +02:00
# We trim the message later as emotes take up extra characters too.
return message[: discord.MESSAGE_LIMIT]
2021-04-17 06:45:51 +02:00
def upload_emote(self, emote_name: str, emote_id: str) -> None:
# There won't be a race condition here, since only a unique
# set of emotes are uploaded at a time.
if emote_name in Cache.cache["m_emotes"]:
2021-04-17 06:45:51 +02:00
return
emote_url = f"{discord.CDN_URL}/emojis/{emote_id}"
# We don't want the message to be dropped entirely if an emote
# fails to upload for some reason.
try:
# TODO This is not thread safe, but we're protected by the GIL.
Cache.cache["m_emotes"][emote_name] = self.upload(emote_url)
2021-04-17 06:45:51 +02:00
except RequestError as e:
self.logger.warning(f"Failed to upload emote {emote_id}: {e}")
def register(self, mxid: str) -> None:
"""
Register a dummy user on the homeserver.
"""
content = {
"type": "m.login.application_service",
# "@test:localhost" -> "test" (Can't register with a full mxid.)
"username": mxid[1:].split(":")[0],
}
resp = self.send("POST", "/register", content)
self.db.add_user(resp["user_id"])
def set_avatar(self, avatar_url: str, mxid: str) -> None:
avatar_uri = self.upload(avatar_url)
self.send(
"PUT",
f"/profile/{mxid}/avatar_url",
{"avatar_url": avatar_uri},
params={"user_id": mxid},
)
self.db.add_avatar(avatar_url, mxid)
def set_nick(self, username: str, mxid: str) -> None:
self.send(
"PUT",
f"/profile/{mxid}/displayname",
{"displayname": username},
params={"user_id": mxid},
)
self.db.add_username(username, mxid)
class DiscordClient(Gateway):
def __init__(
self, appservice: MatrixClient, config: dict, http: urllib3.PoolManager
) -> None:
super().__init__(http, config["discord_token"])
self.app = appservice
self.webhook_name = "matrix_bridge"
2021-04-17 06:45:51 +02:00
2021-06-06 11:56:05 +02:00
# TODO Find a cleaner way to use these keys.
for k in ("d_emotes", "d_messages", "d_webhooks"):
Cache.cache[k] = {}
2021-04-17 06:45:51 +02:00
def to_return(self, message: discord.Message) -> bool:
with Cache.lock:
hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()]
2021-04-17 06:45:51 +02:00
return (
message.channel_id not in self.app.db.list_channels()
or not message.author # Embeds can be weird sometimes.
or message.webhook_id in hook_ids
2021-04-17 06:45:51 +02:00
)
def matrixify(self, id: str, user: bool = False) -> str:
return (
f"{'@' if user else '#'}{self.app.format}{id}:"
f"{self.app.server_name}"
)
def sync_profile(self, user: discord.User) -> None:
"""
Sync the avatar and username for a puppeted user.
"""
mxid = self.matrixify(user.id, user=True)
profile = self.app.db.fetch_user(mxid)
# User doesn't exist.
if not profile:
return
username = f"{user.username}#{user.discriminator}"
if user.avatar_url != profile["avatar_url"]:
self.logger.info(f"Updating avatar for Discord user '{user.id}'")
self.app.set_avatar(user.avatar_url, mxid)
if username != profile["username"]:
self.logger.info(f"Updating username for Discord user '{user.id}'")
self.app.set_nick(username, mxid)
2021-04-17 06:45:51 +02:00
def wrap(self, message: discord.Message) -> Tuple[str, str]:
"""
Get the room ID and the puppet's mxid for a given channel ID and a
Discord user.
"""
if message.webhook_id:
hashed = hash_str(message.author.username)
message.author.id = str(int(message.author.id) + hashed)
2021-04-17 06:45:51 +02:00
mxid = self.matrixify(message.author.id, user=True)
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
if not self.app.db.fetch_user(mxid):
2021-04-17 06:45:51 +02:00
self.logger.info(
f"Creating dummy user for Discord user {message.author.id}."
)
self.app.register(mxid)
self.app.set_nick(
f"{message.author.username}#"
f"{message.author.discriminator}",
mxid,
)
if message.author.avatar_url:
self.app.set_avatar(message.author.avatar_url, mxid)
if mxid not in self.app.get_members(room_id):
self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.")
2021-04-17 06:45:51 +02:00
self.app.send_invite(room_id, mxid)
self.app.join_room(room_id, mxid)
if message.webhook_id:
# Sync webhooks here as they can't be accessed like guild members.
self.sync_profile(message.author)
2021-04-17 06:45:51 +02:00
return mxid, room_id
def cache_emotes(self, emotes: List[discord.Emote]):
# TODO maybe "namespace" emotes by guild in the cache ?
with Cache.lock:
for emote in emotes:
Cache.cache["d_emotes"][emote.name] = (
f"<{'a' if emote.animated else ''}:"
f"{emote.name}:{emote.id}>"
)
def on_guild_create(self, guild: discord.Guild) -> None:
for member in guild.members:
self.sync_profile(member)
self.cache_emotes(guild.emojis)
def on_guild_emojis_update(
self, update: discord.GuildEmojisUpdate
) -> None:
self.cache_emotes(update.emojis)
def on_guild_member_update(
self, update: discord.GuildMemberUpdate
) -> None:
self.sync_profile(update.user)
2021-04-17 06:45:51 +02:00
def on_message_create(self, message: discord.Message) -> None:
if self.to_return(message):
return
mxid, room_id = self.wrap(message)
content_, emotes = self.process_message(message)
2021-04-17 06:45:51 +02:00
content = self.app.create_message_event(
2021-05-18 11:30:48 +02:00
content_, emotes, reference=message.referenced_message
2021-04-17 06:45:51 +02:00
)
with Cache.lock:
Cache.cache["d_messages"][message.id] = self.app.send_message(
room_id, content, mxid
)
2021-04-17 06:45:51 +02:00
def on_message_delete(self, message: discord.Message) -> None:
with Cache.lock:
event_id = Cache.cache["d_messages"].get(message.id)
2021-04-17 06:45:51 +02:00
if not event_id:
2021-04-17 06:45:51 +02:00
return
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
event = except_deleted(self.app.get_event)(event_id, room_id)
2021-04-17 06:45:51 +02:00
if event:
self.app.redact(event.id, event.room_id, event.sender)
with Cache.lock:
del Cache.cache["d_messages"][message.id]
2021-04-17 06:45:51 +02:00
def on_message_update(self, message: discord.Message) -> None:
2021-04-17 06:45:51 +02:00
if self.to_return(message):
return
with Cache.lock:
event_id = Cache.cache["d_messages"].get(message.id)
2021-04-17 06:45:51 +02:00
if not event_id:
2021-04-17 06:45:51 +02:00
return
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
mxid = self.matrixify(message.author.id, user=True)
2021-06-06 11:56:05 +02:00
# It is possible that a webhook edit's it's own old message
# after changing it's name, hence we generate a new mxid from
# the hashed username, but that mxid hasn't been registered before,
# so the request fails with:
# M_FORBIDDEN: Application service has not registered this user
if not self.app.db.fetch_user(mxid):
return
content_, emotes = self.process_message(message)
2021-04-17 06:45:51 +02:00
content = self.app.create_message_event(
content_, emotes, edit=event_id
2021-04-17 06:45:51 +02:00
)
self.app.send_message(room_id, content, mxid)
2021-04-17 06:45:51 +02:00
def on_typing_start(self, typing: discord.Typing) -> None:
if typing.channel_id not in self.app.db.list_channels():
return
mxid = self.matrixify(typing.user_id, user=True)
room_id = self.app.get_room_id(self.matrixify(typing.channel_id))
if mxid not in self.app.get_members(room_id):
return
self.app.send_typing(room_id, mxid)
def get_webhook(self, channel_id: str, name: str) -> discord.Webhook:
"""
Get the webhook object for the first webhook that matches the specified
name in a given channel, create the webhook if it doesn't exist.
"""
# Check the cache first.
with Cache.lock:
webhook = Cache.cache["d_webhooks"].get(channel_id)
2021-04-17 06:45:51 +02:00
if webhook:
return webhook
webhooks = self.send("GET", f"/channels/{channel_id}/webhooks")
webhook = next(
(
dict_cls(webhook, discord.Webhook)
for webhook in webhooks
if webhook["name"] == name
),
None,
)
if not webhook:
webhook = self.create_webhook(channel_id, name)
with Cache.lock:
Cache.cache["d_webhooks"][channel_id] = webhook
2021-04-17 06:45:51 +02:00
return webhook
def process_message(self, message: discord.Message) -> Tuple[str, Dict]:
2021-04-17 06:45:51 +02:00
content = message.content
emotes = {}
regex = r"<a?:(\w+):(\d+)>"
# Mentions can either be in the form of `<@1234>` or `<@!1234>`.
for member in message.mentions:
for char in ("", "!"):
2021-04-17 06:45:51 +02:00
content = content.replace(
f"<@{char}{member.id}>", f"@{member.username}"
)
# `except_deleted` for invalid channels.
# TODO can this block for too long ?
for channel in re.findall(f"<#([0-9]{{{discord.ID_LEN}}})>", content):
discord_channel = except_deleted(self.get_channel)(channel)
name = (
discord_channel.name if discord_channel else "deleted-channel"
2021-04-17 06:45:51 +02:00
)
content = content.replace(f"<#{channel}>", f"#{name}")
2021-04-17 06:45:51 +02:00
# { "emote_name": "emote_id" }
for emote in re.findall(regex, content):
emotes[emote[0]] = emote[1]
# Replace emote IDs with names.
content = re.sub(regex, r":\g<1>:", content)
# Append attachments to message.
for attachment in message.attachments:
content += f"\n{attachment['url']}"
return content, emotes
def config_gen(basedir: str, config_file: str) -> dict:
config_file = f"{basedir}/{config_file}"
config_dict = {
"as_token": "my-secret-as-token",
"hs_token": "my-secret-hs-token",
"user_id": "appservice-discord",
"homeserver": "http://127.0.0.1:8008",
"server_name": "localhost",
"discord_token": "my-secret-discord-token",
"port": 5000,
"database": f"{basedir}/bridge.db",
}
if not os.path.exists(config_file):
with open(config_file, "w") as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration dumped to '{config_file}'")
sys.exit()
with open(config_file, "r") as f:
return json.loads(f.read())
def excepthook(exc_type, exc_value, exc_traceback):
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logging.critical(
"Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback)
)
2021-04-17 06:45:51 +02:00
def main() -> None:
try:
basedir = sys.argv[1]
if not os.path.exists(basedir):
print(f"Path '{basedir}' does not exist!")
sys.exit(1)
basedir = os.path.abspath(basedir)
except IndexError:
basedir = os.getcwd()
config = config_gen(basedir, "appservice.json")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[
logging.FileHandler(f"{basedir}/appservice.log"),
],
)
sys.excepthook = excepthook
2021-04-17 06:45:51 +02:00
app = MatrixClient(config, urllib3.PoolManager(maxsize=10))
2021-04-17 06:45:51 +02:00
# Start the bottle app in a separate thread.
app_thread = threading.Thread(
target=app.run, kwargs={"port": int(config["port"])}, daemon=True
)
app_thread.start()
try:
asyncio.run(app.discord.run())
2021-04-17 06:45:51 +02:00
except KeyboardInterrupt:
sys.exit()
if __name__ == "__main__":
main()