chore: remove unneeded files

This commit is contained in:
əlemi 2024-01-29 16:52:14 +01:00
parent 1f04e0a88f
commit 25bcf38492
Signed by: alemi
GPG key ID: A4895B84D311642C
7 changed files with 0 additions and 1478 deletions

View file

@ -1,203 +0,0 @@
import json
import logging
import urllib.parse
import uuid
from typing import Union
import bottle
import urllib3
import matrix
from cache import Cache
from misc import log_except, request
class AppService(bottle.Bottle):
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
super(AppService, self).__init__()
self.as_token = config["as_token"]
self.hs_token = config["hs_token"]
self.base_url = config["homeserver"]
self.server_name = config["server_name"]
self.user_id = f"@{config['user_id']}:{self.server_name}"
self.http = http
self.logger = logging.getLogger("appservice")
# Map events to functions.
self.mapping = {
"m.room.member": "on_member",
"m.room.message": "on_message",
"m.room.redaction": "on_redaction",
}
# Add route for bottle.
self.route(
"/transactions/<transaction>",
callback=self.receive_event,
method="PUT",
)
Cache.cache["m_rooms"] = {}
def handle_event(self, event: dict) -> None:
event_type = event.get("type")
if event_type in (
"m.room.member",
"m.room.message",
"m.room.redaction",
):
obj = matrix.Event(event)
else:
self.logger.info(f"Unknown event type: {event_type}")
return
func = getattr(self, self.mapping[event_type], None)
if not func:
self.logger.warning(
f"Function '{func}' not defined, ignoring event."
)
return
# We don't catch exceptions here as the homeserver will re-send us
# the event in case of a failure.
func(obj)
@log_except
def receive_event(self, transaction: str) -> dict:
"""
Verify the homeserver's token and handle events.
"""
hs_token = bottle.request.query.getone("access_token")
if not hs_token:
bottle.response.status = 401
return {"errcode": "APPSERVICE_UNAUTHORIZED"}
if hs_token != self.hs_token:
bottle.response.status = 403
return {"errcode": "APPSERVICE_FORBIDDEN"}
events = bottle.request.json.get("events")
for event in events:
self.handle_event(event)
return {}
def mxc_url(self, mxc: str) -> str:
try:
homeserver, media_id = mxc.replace("mxc://", "").split("/")
except ValueError:
return ""
return (
f"https://{self.server_name}/_matrix/media/r0/download/"
f"{homeserver}/{media_id}"
)
def join_room(self, room_id: str, mxid: str = "") -> None:
self.send(
"POST",
f"/join/{room_id}",
params={"user_id": mxid} if mxid else {},
)
def redact(self, event_id: str, room_id: str, mxid: str = "") -> None:
self.send(
"PUT",
f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}",
params={"user_id": mxid} if mxid else {},
)
def get_room_id(self, alias: str) -> str:
with Cache.lock:
room = Cache.cache["m_rooms"].get(alias)
if room:
return room
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
room_id = resp["room_id"]
with Cache.lock:
Cache.cache["m_rooms"][alias] = room_id
return room_id
def get_event(self, event_id: str, room_id: str) -> matrix.Event:
resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}")
return matrix.Event(resp)
def upload(self, url: str) -> str:
"""
Upload a file to the homeserver and get the MXC url.
"""
resp = self.http.request("GET", url)
resp = self.send(
"POST",
content=resp.data,
content_type=resp.headers.get("Content-Type"),
params={"filename": f"{uuid.uuid4()}"},
endpoint="/_matrix/media/r0/upload",
)
return resp["content_uri"]
def send_message(
self,
room_id: str,
content: dict,
mxid: str = "",
) -> str:
resp = self.send(
"PUT",
f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}",
content,
{"user_id": mxid} if mxid else {},
)
return resp["event_id"]
def send_typing(
self, room_id: str, mxid: str = "", timeout: int = 8000
) -> None:
self.send(
"PUT",
f"/rooms/{room_id}/typing/{mxid}",
{"typing": True, "timeout": timeout},
{"user_id": mxid} if mxid else {},
)
def send_invite(self, room_id: str, mxid: str) -> None:
self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid})
@request
def send(
self,
method: str,
path: str = "",
content: Union[bytes, dict] = {},
params: dict = {},
content_type: str = "application/json",
endpoint: str = "/_matrix/client/r0",
) -> dict:
headers = {
"Authorization": f"Bearer {self.as_token}",
"Content-Type": content_type,
}
payload = json.dumps(content) if isinstance(content, dict) else content
endpoint = (
f"{self.base_url}{endpoint}{path}?"
f"{urllib.parse.urlencode(params)}"
)
return self.http.request(
method, endpoint, body=payload, headers=headers
)

View file

@ -1,6 +0,0 @@
import threading
class Cache:
cache = {}
lock = threading.Lock()

View file

@ -1,120 +0,0 @@
import os
import sqlite3
import threading
from typing import List
class DataBase:
def __init__(self, db_file) -> None:
self.create(db_file)
# The database is accessed via multiple threads.
self.lock = threading.Lock()
def create(self, db_file) -> None:
"""
Create a database with the relevant tables if it doesn't already exist.
"""
exists = os.path.exists(db_file)
self.conn = sqlite3.connect(db_file, check_same_thread=False)
self.conn.row_factory = self.dict_factory
self.cur = self.conn.cursor()
if exists:
return
self.cur.execute(
"CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);"
)
self.cur.execute(
"CREATE TABLE users(mxid TEXT PRIMARY KEY, "
"avatar_url TEXT, username TEXT);"
)
self.conn.commit()
def dict_factory(self, cursor, row):
"""
https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory
"""
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
def add_room(self, room_id: str, channel_id: str) -> None:
"""
Add a bridged room to the database.
"""
with self.lock:
self.cur.execute(
"INSERT INTO bridge (room_id, channel_id) VALUES (?, ?)",
[room_id, channel_id],
)
self.conn.commit()
def add_user(self, mxid: str) -> None:
with self.lock:
self.cur.execute("INSERT INTO users (mxid) VALUES (?)", [mxid])
self.conn.commit()
def add_avatar(self, avatar_url: str, mxid: str) -> None:
with self.lock:
self.cur.execute(
"UPDATE users SET avatar_url = (?) WHERE mxid = (?)",
[avatar_url, mxid],
)
self.conn.commit()
def add_username(self, username: str, mxid: str) -> None:
with self.lock:
self.cur.execute(
"UPDATE users SET username = (?) WHERE mxid = (?)",
[username, mxid],
)
self.conn.commit()
def get_channel(self, room_id: str) -> str:
"""
Get the corresponding channel ID for a given room ID.
"""
with self.lock:
self.cur.execute(
"SELECT channel_id FROM bridge WHERE room_id = ?", [room_id]
)
room = self.cur.fetchone()
# Return an empty string if the channel is not bridged.
return "" if not room else room["channel_id"]
def list_channels(self) -> List[str]:
"""
Get a list of all the bridged channels.
"""
with self.lock:
self.cur.execute("SELECT channel_id FROM bridge")
channels = self.cur.fetchall()
return [channel["channel_id"] for channel in channels]
def fetch_user(self, mxid: str) -> dict:
"""
Fetch the profile for a bridged user.
"""
with self.lock:
self.cur.execute("SELECT * FROM users where mxid = ?", [mxid])
user = self.cur.fetchone()
return {} if not user else user

View file

@ -1,5 +0,0 @@
class RequestError(Exception):
def __init__(self, status: int, *args):
super().__init__(*args)
self.status = status

View file

@ -1,260 +0,0 @@
import asyncio
import json
import logging
import urllib.parse
from typing import Dict, List
import urllib3
import websockets
import discord
from misc import dict_cls, log_except, request
class Gateway:
def __init__(self, http: urllib3.PoolManager, token: str):
self.http = http
self.token = token
self.logger = logging.getLogger("discord")
self.Payloads = discord.Payloads(self.token)
self.websocket = None
@log_except
async def run(self) -> None:
self.heartbeat_task: asyncio.Future = None
self.resume = False
gateway_url = self.get_gateway_url()
while True:
try:
await self.gateway_handler(gateway_url)
except (
websockets.ConnectionClosedError,
websockets.InvalidMessage,
):
self.logger.exception("Connection lost, reconnecting.")
# Stop sending heartbeats until we reconnect.
if self.heartbeat_task and not self.heartbeat_task.cancelled():
self.heartbeat_task.cancel()
def get_gateway_url(self) -> str:
resp = self.send("GET", "/gateway")
return resp["url"]
async def heartbeat_handler(self, interval_ms: int) -> None:
while True:
await asyncio.sleep(interval_ms / 1000)
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
async def handle_resp(self, data: dict) -> None:
data_dict = data["d"]
opcode = data["op"]
seq = data["s"]
if seq:
self.Payloads.seq = seq
if opcode == discord.GatewayOpCodes.DISPATCH:
otype = data["t"]
if otype == "READY":
self.Payloads.session = data_dict["session_id"]
self.logger.info("READY")
else:
self.handle_otype(data_dict, otype)
elif opcode == discord.GatewayOpCodes.HELLO:
heartbeat_interval = data_dict.get("heartbeat_interval")
self.logger.info(f"Heartbeat Interval: {heartbeat_interval}")
# Send periodic hearbeats to gateway.
self.heartbeat_task = asyncio.ensure_future(
self.heartbeat_handler(heartbeat_interval)
)
await self.websocket.send(
json.dumps(
self.Payloads.RESUME()
if self.resume
else self.Payloads.IDENTIFY()
)
)
elif opcode == discord.GatewayOpCodes.RECONNECT:
self.logger.info("Received RECONNECT.")
self.resume = True
await self.websocket.close()
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
self.logger.info("Received INVALID_SESSION.")
self.resume = False
await self.websocket.close()
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
# NOP
pass
else:
self.logger.info(
"Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}"
)
def handle_otype(self, data: dict, otype: str) -> None:
if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"):
obj = discord.Message(data)
elif otype == "TYPING_START":
obj = dict_cls(data, discord.Typing)
elif otype == "GUILD_CREATE":
obj = discord.Guild(data)
elif otype == "GUILD_MEMBER_UPDATE":
obj = discord.GuildMemberUpdate(data)
elif otype == "GUILD_EMOJIS_UPDATE":
obj = discord.GuildEmojisUpdate(data)
else:
return
func = getattr(self, f"on_{otype.lower()}", None)
if not func:
self.logger.warning(
f"Function '{func}' not defined, ignoring message."
)
return
try:
func(obj)
except Exception:
self.logger.exception(f"Ignoring exception in '{func.__name__}':")
async def gateway_handler(self, gateway_url: str) -> None:
async with websockets.connect(
f"{gateway_url}/?v=8&encoding=json"
) as websocket:
self.websocket = websocket
async for message in websocket:
await self.handle_resp(json.loads(message))
def get_channel(self, channel_id: str) -> discord.Channel:
"""
Get the channel for a given channel ID.
"""
resp = self.send("GET", f"/channels/{channel_id}")
return dict_cls(resp, discord.Channel)
def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]:
"""
Get all channels for a given guild ID.
"""
resp = self.send("GET", f"/guilds/{guild_id}/channels")
return {
channel["id"]: dict_cls(channel, discord.Channel)
for channel in resp
}
def get_emotes(self, guild_id: str) -> List[discord.Emote]:
"""
Get all the emotes for a given guild.
"""
resp = self.send("GET", f"/guilds/{guild_id}/emojis")
return [dict_cls(emote, discord.Emote) for emote in resp]
def get_members(self, guild_id: str) -> List[discord.User]:
"""
Get all the members for a given guild.
"""
resp = self.send(
"GET", f"/guilds/{guild_id}/members", params={"limit": 1000}
)
return [discord.User(member["user"]) for member in resp]
def create_webhook(self, channel_id: str, name: str) -> discord.Webhook:
"""
Create a webhook with the specified name in a given channel.
"""
resp = self.send(
"POST", f"/channels/{channel_id}/webhooks", {"name": name}
)
return dict_cls(resp, discord.Webhook)
def edit_webhook(
self, content: str, message_id: str, webhook: discord.Webhook
) -> None:
self.send(
"PATCH",
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
f"{message_id}",
{"content": content},
)
def delete_webhook(
self, message_id: str, webhook: discord.Webhook
) -> None:
self.send(
"DELETE",
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
f"{message_id}",
)
def send_webhook(
self,
webhook: discord.Webhook,
avatar_url: str,
content: str,
username: str,
) -> discord.Message:
payload = {
"avatar_url": avatar_url,
"content": content,
"username": username,
# Disable 'everyone' and 'role' mentions.
"allowed_mentions": {"parse": ["users"]},
}
resp = self.send(
"POST",
f"/webhooks/{webhook.id}/{webhook.token}",
payload,
{"wait": True},
)
return discord.Message(resp)
def send_message(self, message: str, channel_id: str) -> None:
self.send(
"POST", f"/channels/{channel_id}/messages", {"content": message}
)
@request
def send(
self, method: str, path: str, content: dict = {}, params: dict = {}
) -> dict:
endpoint = (
f"https://discord.com/api/v8{path}?"
f"{urllib.parse.urlencode(params)}"
)
headers = {
"Authorization": f"Bot {self.token}",
"Content-Type": "application/json",
}
# 'body' being an empty dict breaks "GET" requests.
payload = json.dumps(content) if content else None
return self.http.request(
method, endpoint, body=payload, headers=headers
)

View file

@ -1,800 +0,0 @@
import asyncio
import json
import logging
import os
import re
import sys
import threading
import urllib.parse
from typing import Dict, List, Tuple
import markdown
import urllib3
import discord
import matrix
from appservice import AppService
from cache import Cache
from db import DataBase
from errors import RequestError
from gateway import Gateway
from misc import dict_cls, except_deleted, hash_str
class MatrixClient(AppService):
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
super().__init__(config, http)
self.db = DataBase(config["database"])
self.discord = DiscordClient(self, config, http)
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
self.id_regex = "[0-9]+" # Snowflakes may have variable length
# TODO Find a cleaner way to use these keys.
for k in ("m_emotes", "m_members", "m_messages"):
Cache.cache[k] = {}
def handle_bridge(self, message: matrix.Event) -> None:
# Ignore events that aren't for us.
if message.sender.split(":")[
-1
] != self.server_name or not message.body.startswith("!bridge"):
return
# Get the channel ID.
try:
channel = message.body.split()[1]
except IndexError:
return
# Check if the given channel is valid.
try:
channel = self.discord.get_channel(channel)
except RequestError as e:
# The channel can be invalid or we may not have permissions.
self.logger.warning(f"Failed to fetch channel {channel}: {e}")
return
if (
channel.type != discord.ChannelType.GUILD_TEXT
or channel.id in self.db.list_channels()
):
return
self.logger.info(f"Creating bridged room for channel {channel.id}.")
self.create_room(channel, message.sender)
def on_member(self, event: matrix.Event) -> None:
with Cache.lock:
# Just lazily clear the whole member cache on
# membership update events.
if event.room_id in Cache.cache["m_members"]:
self.logger.info(
f"Clearing member cache for room '{event.room_id}'."
)
del Cache.cache["m_members"][event.room_id]
if (
event.sender.split(":")[-1] != self.server_name
or event.state_key != self.user_id
or not event.is_direct
):
return
# Join the direct message room.
self.logger.info(f"Joining direct message room '{event.room_id}'.")
self.join_room(event.room_id)
def on_message(self, message: matrix.Event) -> None:
if (
message.sender.startswith((f"@{self.format}", self.user_id))
or not message.body
):
return
# Handle bridging commands.
self.handle_bridge(message)
channel_id = self.db.get_channel(message.room_id)
if not channel_id:
return
author = self.get_members(message.room_id)[message.sender]
webhook = self.discord.get_webhook(
channel_id, self.discord.webhook_name
)
if message.relates_to and message.reltype == "m.replace":
with Cache.lock:
message_id = Cache.cache["m_messages"].get(message.relates_to)
# TODO validate if the original author sent the edit.
if not message_id or not message.new_body:
return
message.new_body = self.process_message(message)
except_deleted(self.discord.edit_webhook)(
message.new_body, message_id, webhook
)
else:
message.body = (
f"`{message.body}`: {self.mxc_url(message.attachment)}"
if message.attachment
else self.process_message(message)
)
message_id = self.discord.send_webhook(
webhook,
self.mxc_url(author.avatar_url) if author.avatar_url else None,
message.body,
author.display_name if author.display_name else message.sender,
).id
with Cache.lock:
Cache.cache["m_messages"][message.id] = message_id
def on_redaction(self, event: matrix.Event) -> None:
with Cache.lock:
message_id = Cache.cache["m_messages"].get(event.redacts)
if not message_id:
return
webhook = self.discord.get_webhook(
self.db.get_channel(event.room_id), self.discord.webhook_name
)
except_deleted(self.discord.delete_webhook)(message_id, webhook)
with Cache.lock:
del Cache.cache["m_messages"][event.redacts]
def get_members(self, room_id: str) -> Dict[str, matrix.User]:
with Cache.lock:
cached = Cache.cache["m_members"].get(room_id)
if cached:
return cached
resp = self.send("GET", f"/rooms/{room_id}/joined_members")
joined = resp["joined"]
for k, v in joined.items():
joined[k] = dict_cls(v, matrix.User)
with Cache.lock:
Cache.cache["m_members"][room_id] = joined
return joined
def create_room(self, channel: discord.Channel, sender: str) -> None:
"""
Create a bridged room and invite the person who invoked the command.
"""
content = {
"room_alias_name": f"{self.format}{channel.id}",
"name": channel.name,
"topic": channel.topic if channel.topic else channel.name,
"visibility": "private",
"invite": [sender],
"creation_content": {"m.federate": True},
"initial_state": [
{
"type": "m.room.join_rules",
"content": {"join_rule": "public"},
},
{
"type": "m.room.history_visibility",
"content": {"history_visibility": "shared"},
},
],
"power_level_content_override": {
"users": {sender: 100, self.user_id: 100}
},
}
resp = self.send("POST", "/createRoom", content)
self.db.add_room(resp["room_id"], channel.id)
def create_message_event(
self,
message: str,
emotes: dict,
edit: str = "",
reference: discord.Message = None,
) -> dict:
content = {
"body": message,
"msgtype": "m.text",
}
fmt = self.get_fmt(message, emotes)
if fmt != message:
content = {
**content,
"format": "org.matrix.custom.html",
"formatted_body": fmt,
}
ref_id = None
if reference:
# Reply to a Discord message.
with Cache.lock:
ref_id = Cache.cache["d_messages"].get(reference.id)
# Reply to a Matrix message. (maybe)
if not ref_id:
with Cache.lock:
ref_id = [
k
for k, v in Cache.cache["m_messages"].items()
if v == reference.id
]
ref_id = next(iter(ref_id), "")
if ref_id:
event = except_deleted(self.get_event)(
ref_id,
self.get_room_id(self.discord.matrixify(reference.channel_id)),
)
if event:
# Content with the reply fallbacks stripped.
tmp = ""
# We don't want to strip lines starting with "> " after
# encountering a regular line, so we use this variable.
got_fallback = True
for line in event.body.split("\n"):
if not line.startswith("> "):
got_fallback = False
if not got_fallback:
tmp += line
event.body = tmp
event.formatted_body = (
# re.DOTALL allows the match to span newlines.
re.sub(
"<mx-reply.+?</mx-reply>",
"",
event.formatted_body,
flags=re.DOTALL,
)
if event.formatted_body
else event.body
)
content = {
**content,
"body": (
f"> <{event.sender}> {event.body}\n{content['body']}"
),
"m.relates_to": {"m.in_reply_to": {"event_id": event.id}},
"format": "org.matrix.custom.html",
"formatted_body": f"""<mx-reply><blockquote>\
<a href="https://matrix.to/#/{event.room_id}/{event.id}">In reply to</a>\
<a href="https://matrix.to/#/{event.sender}">{event.sender}</a>\
<br />{event.formatted_body if event.formatted_body else event.body}\
</blockquote></mx-reply>\
{content.get("formatted_body", content['body'])}""",
}
if edit:
content = {
**content,
"body": f" * {content['body']}",
"formatted_body": f" * {content.get('formatted_body', content['body'])}",
"m.relates_to": {"event_id": edit, "rel_type": "m.replace"},
"m.new_content": {**content},
}
return content
def get_fmt(self, message: str, emotes: dict) -> str:
message = (
markdown.markdown(message)
.replace("<p>", "")
.replace("</p>", "")
.replace("\n", "<br />")
)
# Upload emotes in multiple threads so that we don't
# block the Discord bot for too long.
upload_threads = [
threading.Thread(
target=self.upload_emote, args=(emote, emotes[emote])
)
for emote in emotes
]
# Acquire the lock before starting the threads to avoid resource
# contention by tens of threads at once.
with Cache.lock:
for thread in upload_threads:
thread.start()
for thread in upload_threads:
thread.join()
with Cache.lock:
for emote in emotes:
emote_ = Cache.cache["m_emotes"].get(emote)
if emote_:
emote = f":{emote}:"
message = message.replace(
emote,
f"""<img alt=\"{emote}\" title=\"{emote}\" \
height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
)
return message
def mention_regex(self, encode: bool, id_as_group: bool) -> str:
mention = "@"
colon = ":"
snowflake = self.id_regex
if encode:
mention = urllib.parse.quote(mention)
colon = urllib.parse.quote(colon)
if id_as_group:
snowflake = f"({snowflake})"
hashed = f"(?:-{snowflake})?"
return f"{mention}{self.format}{snowflake}{hashed}{colon}{re.escape(self.server_name)}"
def process_message(self, event: matrix.Event) -> str:
message = event.new_body if event.new_body else event.body
emotes = re.findall(r":(\w*):", message)
mentions = list(
re.finditer(
self.mention_regex(encode=False, id_as_group=True),
event.formatted_body,
)
)
# For clients that properly encode mentions.
# 'https://matrix.to/#/%40_discord_...%3Adomain.tld'
mentions.extend(
re.finditer(
self.mention_regex(encode=True, id_as_group=True),
event.formatted_body,
)
)
with Cache.lock:
for emote in set(emotes):
emote_ = Cache.cache["d_emotes"].get(emote)
if emote_:
message = message.replace(f":{emote}:", emote_)
for mention in set(mentions):
# Unquote just in-case we matched an encoded username.
username = self.db.fetch_user(
urllib.parse.unquote(mention.group(0))
).get("username")
if username:
if mention.group(2):
# Replace mention with plain text for hashed users (webhooks)
message = message.replace(mention.group(0), f"@{username}")
else:
# Replace the 'mention' so that the user is tagged
# in the case of replies aswell.
# '> <@_discord_1234:localhost> Message'
for replace in (mention.group(0), username):
message = message.replace(
replace, f"<@{mention.group(1)}>"
)
# We trim the message later as emotes take up extra characters too.
return message[: discord.MESSAGE_LIMIT]
def upload_emote(self, emote_name: str, emote_id: str) -> None:
# There won't be a race condition here, since only a unique
# set of emotes are uploaded at a time.
if emote_name in Cache.cache["m_emotes"]:
return
emote_url = f"{discord.CDN_URL}/emojis/{emote_id}"
# We don't want the message to be dropped entirely if an emote
# fails to upload for some reason.
try:
# TODO This is not thread safe, but we're protected by the GIL.
Cache.cache["m_emotes"][emote_name] = self.upload(emote_url)
except RequestError as e:
self.logger.warning(f"Failed to upload emote {emote_id}: {e}")
def register(self, mxid: str) -> None:
"""
Register a dummy user on the homeserver.
"""
content = {
"type": "m.login.application_service",
# "@test:localhost" -> "test" (Can't register with a full mxid.)
"username": mxid[1:].split(":")[0],
}
resp = self.send("POST", "/register", content)
self.db.add_user(resp["user_id"])
def set_avatar(self, avatar_url: str, mxid: str) -> None:
avatar_uri = self.upload(avatar_url)
self.send(
"PUT",
f"/profile/{mxid}/avatar_url",
{"avatar_url": avatar_uri},
params={"user_id": mxid},
)
self.db.add_avatar(avatar_url, mxid)
def set_nick(self, username: str, mxid: str) -> None:
self.send(
"PUT",
f"/profile/{mxid}/displayname",
{"displayname": username},
params={"user_id": mxid},
)
self.db.add_username(username, mxid)
class DiscordClient(Gateway):
def __init__(
self, appservice: MatrixClient, config: dict, http: urllib3.PoolManager
) -> None:
super().__init__(http, config["discord_token"])
self.app = appservice
self.webhook_name = "matrix_bridge"
# TODO Find a cleaner way to use these keys.
for k in ("d_emotes", "d_messages", "d_webhooks"):
Cache.cache[k] = {}
def to_return(self, message: discord.Message) -> bool:
with Cache.lock:
hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()]
return (
message.channel_id not in self.app.db.list_channels()
or not message.author # Embeds can be weird sometimes.
or message.webhook_id in hook_ids
)
def matrixify(self, id: str, user: bool = False, hashed: str = "") -> str:
return (
f"{'@' if user else '#'}{self.app.format}"
f"{id}{'-' + hashed if hashed else ''}:"
f"{self.app.server_name}"
)
def sync_profile(self, user: discord.User, hashed: str = "") -> None:
"""
Sync the avatar and username for a puppeted user.
"""
mxid = self.matrixify(user.id, user=True, hashed=hashed)
profile = self.app.db.fetch_user(mxid)
# User doesn't exist.
if not profile:
return
username = f"{user.username}#{user.discriminator}"
if user.avatar_url != profile["avatar_url"]:
self.logger.info(f"Updating avatar for Discord user '{user.id}'")
self.app.set_avatar(user.avatar_url, mxid)
if username != profile["username"]:
self.logger.info(f"Updating username for Discord user '{user.id}'")
self.app.set_nick(username, mxid)
def wrap(self, message: discord.Message) -> Tuple[str, str]:
"""
Get the room ID and the puppet's mxid for a given channel ID and a
Discord user.
"""
hashed = ""
if message.webhook_id and message.webhook_id != message.application_id:
hashed = str(hash_str(message.author.username))
mxid = self.matrixify(message.author.id, user=True, hashed=hashed)
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
if not self.app.db.fetch_user(mxid):
self.logger.info(
f"Creating dummy user for Discord user {message.author.id}."
)
self.app.register(mxid)
self.app.set_nick(
f"{message.author.username}#"
f"{message.author.discriminator}",
mxid,
)
if message.author.avatar_url:
self.app.set_avatar(message.author.avatar_url, mxid)
if mxid not in self.app.get_members(room_id):
self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.")
self.app.send_invite(room_id, mxid)
self.app.join_room(room_id, mxid)
if message.webhook_id:
# Sync webhooks here as they can't be accessed like guild members.
self.sync_profile(message.author, hashed=hashed)
return mxid, room_id
def cache_emotes(self, emotes: List[discord.Emote]):
# TODO maybe "namespace" emotes by guild in the cache ?
with Cache.lock:
for emote in emotes:
Cache.cache["d_emotes"][emote.name] = (
f"<{'a' if emote.animated else ''}:"
f"{emote.name}:{emote.id}>"
)
def on_guild_create(self, guild: discord.Guild) -> None:
for member in guild.members:
self.sync_profile(member)
self.cache_emotes(guild.emojis)
def on_guild_emojis_update(
self, update: discord.GuildEmojisUpdate
) -> None:
self.cache_emotes(update.emojis)
def on_guild_member_update(
self, update: discord.GuildMemberUpdate
) -> None:
self.sync_profile(update.user)
def on_message_create(self, message: discord.Message) -> None:
if self.to_return(message):
return
mxid, room_id = self.wrap(message)
content_, emotes = self.process_message(message)
content = self.app.create_message_event(
content_, emotes, reference=message.referenced_message
)
with Cache.lock:
Cache.cache["d_messages"][message.id] = self.app.send_message(
room_id, content, mxid
)
def on_message_delete(self, message: discord.Message) -> None:
with Cache.lock:
event_id = Cache.cache["d_messages"].get(message.id)
if not event_id:
return
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
event = except_deleted(self.app.get_event)(event_id, room_id)
if event:
self.app.redact(event.id, event.room_id, event.sender)
with Cache.lock:
del Cache.cache["d_messages"][message.id]
def on_message_update(self, message: discord.Message) -> None:
if self.to_return(message):
return
with Cache.lock:
event_id = Cache.cache["d_messages"].get(message.id)
if not event_id:
return
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
mxid = self.matrixify(message.author.id, user=True)
# It is possible that a webhook edit's it's own old message
# after changing it's name, hence we generate a new mxid from
# the hashed username, but that mxid hasn't been registered before,
# so the request fails with:
# M_FORBIDDEN: Application service has not registered this user
if not self.app.db.fetch_user(mxid):
return
content_, emotes = self.process_message(message)
content = self.app.create_message_event(
content_, emotes, edit=event_id
)
self.app.send_message(room_id, content, mxid)
def on_typing_start(self, typing: discord.Typing) -> None:
if typing.channel_id not in self.app.db.list_channels():
return
mxid = self.matrixify(typing.user_id, user=True)
room_id = self.app.get_room_id(self.matrixify(typing.channel_id))
if mxid not in self.app.get_members(room_id):
return
self.app.send_typing(room_id, mxid)
def get_webhook(self, channel_id: str, name: str) -> discord.Webhook:
"""
Get the webhook object for the first webhook that matches the specified
name in a given channel, create the webhook if it doesn't exist.
"""
# Check the cache first.
with Cache.lock:
webhook = Cache.cache["d_webhooks"].get(channel_id)
if webhook:
return webhook
webhooks = self.send("GET", f"/channels/{channel_id}/webhooks")
webhook = next(
(
dict_cls(webhook, discord.Webhook)
for webhook in webhooks
if webhook["name"] == name
),
None,
)
if not webhook:
webhook = self.create_webhook(channel_id, name)
with Cache.lock:
Cache.cache["d_webhooks"][channel_id] = webhook
return webhook
def process_message(self, message: discord.Message) -> Tuple[str, Dict]:
content = message.content
emotes = {}
regex = r"<a?:(\w+):(\d+)>"
# Mentions can either be in the form of `<@1234>` or `<@!1234>`.
for member in message.mentions:
for char in ("", "!"):
content = content.replace(
f"<@{char}{member.id}>", f"@{member.username}"
)
# Replace channel IDs with names.
channels = re.findall("<#([0-9]+)>", content)
if channels:
if not message.guild_id:
self.logger.warning(
f"Message '{message.id}' in channel '{message.channel_id}' does not have a guild_id!"
)
else:
discord_channels = self.get_channels(message.guild_id)
for channel in channels:
discord_channel = discord_channels.get(channel)
name = (
discord_channel.name
if discord_channel
else "deleted-channel"
)
content = content.replace(f"<#{channel}>", f"#{name}")
# { "emote_name": "emote_id" }
for emote in re.findall(regex, content):
emotes[emote[0]] = emote[1]
# Replace emote IDs with names.
content = re.sub(regex, r":\g<1>:", content)
# Append attachments to message.
for attachment in message.attachments:
content += f"\n{attachment['url']}"
# Append stickers to message.
for sticker in message.stickers:
if sticker.format_type != 3: # 3 == Lottie format.
content += f"\n{discord.CDN_URL}/stickers/{sticker.id}.png"
return content, emotes
def config_gen(basedir: str, config_file: str) -> dict:
config_file = f"{basedir}/{config_file}"
config_dict = {
"as_token": "my-secret-as-token",
"hs_token": "my-secret-hs-token",
"user_id": "appservice-discord",
"homeserver": "http://127.0.0.1:8008",
"server_name": "localhost",
"discord_token": "my-secret-discord-token",
"port": 5000,
"database": f"{basedir}/bridge.db",
}
if not os.path.exists(config_file):
with open(config_file, "w") as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration dumped to '{config_file}'")
sys.exit()
with open(config_file, "r") as f:
return json.loads(f.read())
def excepthook(exc_type, exc_value, exc_traceback):
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logging.critical(
"Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback)
)
def main() -> None:
try:
basedir = sys.argv[1]
if not os.path.exists(basedir):
print(f"Path '{basedir}' does not exist!")
sys.exit(1)
basedir = os.path.abspath(basedir)
except IndexError:
basedir = os.getcwd()
config = config_gen(basedir, "appservice.json")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[
logging.FileHandler(f"{basedir}/appservice.log"),
],
)
sys.excepthook = excepthook
app = MatrixClient(config, urllib3.PoolManager(maxsize=10))
# Start the bottle app in a separate thread.
app_thread = threading.Thread(
target=app.run, kwargs={"port": int(config["port"])}, daemon=True
)
app_thread.start()
try:
asyncio.run(app.discord.run())
except KeyboardInterrupt:
sys.exit()
if __name__ == "__main__":
main()

View file

@ -1,84 +0,0 @@
import json
from dataclasses import fields
from typing import Any
import urllib3
from errors import RequestError
def dict_cls(d: dict, cls: Any) -> Any:
"""
Create a dataclass from a dictionary.
"""
field_names = set(f.name for f in fields(cls))
filtered_dict = {k: v for k, v in d.items() if k in field_names}
return cls(**filtered_dict)
def log_except(fn):
"""
Log unhandled exceptions to a logger instead of `stderr`.
"""
def wrapper(self, *args, **kwargs):
try:
return fn(self, *args, **kwargs)
except Exception:
self.logger.exception(f"Exception in '{fn.__name__}':")
raise
return wrapper
def request(fn):
"""
Either return json data or raise a `RequestError` if the request was
unsuccessful.
"""
def wrapper(*args, **kwargs):
try:
resp = fn(*args, **kwargs)
except urllib3.exceptions.HTTPError as e:
raise RequestError(None, f"Failed to connect: {e}") from None
if resp.status < 200 or resp.status >= 300:
raise RequestError(
resp.status,
f"Failed to get response from '{resp.geturl()}':\n{resp.data}",
)
return {} if resp.status == 204 else json.loads(resp.data)
return wrapper
def except_deleted(fn):
"""
Ignore the `RequestError` on 404s, the content might have been removed.
"""
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except RequestError as e:
if e.status != 404:
raise
return wrapper
def hash_str(string: str) -> int:
"""
Create the hash for a string
"""
hash = 5381
for ch in string:
hash = ((hash << 5) + hash) + ord(ch)
return hash & 0xFFFFFFFF