Appservice (#4)

This commit is contained in:
git-bruh 2021-04-17 10:15:51 +05:30 committed by GitHub
parent 4b876734bb
commit d6740b4bd3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 1663 additions and 0 deletions

66
appservice/README.md Normal file
View file

@ -0,0 +1,66 @@
## Installation
`pip install -r requirements.txt`
## Usage
* Run `main.py` to generate `appservice.json`
* Edit `appservice.json`:
```
{
"as_token": "my-secret-as-token",
"hs_token": "my-secret-hs-token",
"user_id": "appservice-discord",
# Homeserver running on the same machine, listening on port 8008.
"homeserver": "http://127.0.0.1:8008",
# Change "localhost" to your server_name.
# Eg. "kde.org" is the server_name in "@testuser:kde.org".
"server_name": "localhost",
"discord_token": "my-secret-discord-token",
"port": 5000, # Port to run the bottle app on.
"database": "/path/to/bridge.db"
}
```
* Create `appservice.yaml` and add it to your homeserver configuration:
```
id: "discord"
url: "http://127.0.0.1:5000"
as_token: "my-secret-as-token"
hs_token: "my-secret-hs-token"
sender_localpart: "appservice-discord"
namespaces:
users:
- exclusive: true
regex: "@_discord.*"
# Work around for temporary bug in dendrite.
- regex: "@appservice-discord"
aliases:
- exclusive: false
regex: "#_discord.*"
rooms: []
```
A path can optionally be passed as the first argument to `main.py`. This path will be used as the base directory for the database and log file.
Eg. Running `python3 main.py /path/to/my/dir` will store the database and logs in `/path/to/my/dir`.
`$PWD` is used by default if no path is specified.
This bridge is written with:
* `bottle`: Receiving events from the homeserver.
* `urllib3`: Sending requests, thread safety.
* `websockets`: Connecting to Discord. (Big thanks to an anonymous person "nesslersreagent" for figuring out the initial connection mess.)
## NOTES
* A basic sqlite database is used for keeping track of bridged rooms.
* Logs are saved to the `appservice.log` file in `$PWD` or the specified directory.
* For avatars to show up on Discord, you must have a [reverse proxy](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) set up on your homeserver as the bridge does not specify the homeserver port when passing the avatar url.
* It is not possible to add normal Discord bot functionality like commands as this bridge does not use `discord.py`.
NOTE: [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) must be enabled for your Discord bot.

212
appservice/appservice.py Normal file
View file

@ -0,0 +1,212 @@
import json
import logging
import urllib.parse
import uuid
from typing import List, Union
import bottle
import urllib3
import matrix
from misc import dict_cls, log_except, request
class AppService(bottle.Bottle):
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
super(AppService, self).__init__()
self.as_token = config["as_token"]
self.hs_token = config["hs_token"]
self.base_url = config["homeserver"]
self.server_name = config["server_name"]
self.user_id = f"@{config['user_id']}:{self.server_name}"
self.http = http
self.logger = logging.getLogger("appservice")
# TODO better method.
# Map events to functions.
self.mapping = {
"m.room.member": "on_member",
"m.room.message": "on_message",
"m.room.redaction": "on_redaction",
}
# Add route for bottle.
self.route(
"/transactions/<transaction>",
callback=self.receive_event,
method="PUT",
)
def handle_event(self, event: dict) -> None:
event_type = event.get("type")
if event_type == "m.room.member" or event_type == "m.room.message":
obj = self.get_event_object(event)
elif event_type == "m.room.redaction":
obj = event
else:
self.logger.info(f"Unknown event type: {event_type}")
return
func = getattr(self, self.mapping[event_type], None)
if not func:
self.logger.warning(
f"Function '{func}' not defined, ignoring event."
)
return
# We don't catch exceptions here as the homeserver will re-send us
# the event in case of a failure.
func(obj)
@log_except
def receive_event(self, transaction: str) -> dict:
"""
Verify the homeserver's token and handle events.
"""
hs_token = bottle.request.query.getone("access_token")
if not hs_token:
bottle.response.status = 401
return {"errcode": "APPSERVICE_UNAUTHORIZED"}
if hs_token != self.hs_token:
bottle.response.status = 403
return {"errcode": "APPSERVICE_FORBIDDEN"}
events = bottle.request.json.get("events")
for event in events:
self.handle_event(event)
return {}
def get_event_object(self, event: dict) -> matrix.Event:
event["author"] = dict_cls(
self.get_profile(event["sender"]), matrix.User
)
return matrix.Event(event)
def join_room(self, room_id: str, mxid: str = "") -> None:
self.send(
"POST",
f"/join/{room_id}",
params={"user_id": mxid} if mxid else {},
)
def redact(self, event_id: str, room_id: str, mxid: str = "") -> None:
self.send(
"PUT",
f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}",
params={"user_id": mxid} if mxid else {},
)
def get_profile(self, mxid: str) -> dict:
# TODO handle failure, avoid querying this endpoint repeatedly.
resp = self.send("GET", f"/profile/{mxid}")
avatar_url = resp.get("avatar_url", "")[6:].split("/")
avatar_url = (
(
f"https://{self.server_name}/_matrix/media/r0/download/"
f"{avatar_url[0]}/{avatar_url[1]}"
)
if len(avatar_url) > 1
else None
)
return {
"avatar_url": avatar_url,
"displayname": resp.get("displayname"),
}
def get_members(self, room_id: str) -> List[str]:
resp = self.send(
"GET",
f"/rooms/{room_id}/members",
params={"membership": "join", "not_membership": "leave"},
)
return [
content["sender"]
for content in resp["chunk"]
if content["content"]["membership"] == "join"
]
def get_room_id(self, alias: str) -> str:
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
# TODO cache ?
return resp["room_id"]
def upload(self, url: str) -> str:
"""
Upload a file to the homeserver and get the MXC url.
"""
resp = self.http.request("GET", url)
resp = self.send(
"POST",
content=resp.data,
content_type=resp.headers.get("Content-Type"),
params={"filename": f"{uuid.uuid4()}"},
endpoint="/_matrix/media/r0/upload",
)
return resp["content_uri"]
def send_message(
self,
room_id: str,
content: dict,
mxid: str = "",
) -> str:
resp = self.send(
"PUT",
f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}",
content,
{"user_id": mxid} if mxid else {},
)
return resp["event_id"]
def send_typing(
self, room_id: str, mxid: str = "", timeout: int = 8000
) -> None:
self.send(
"PUT",
f"/rooms/{room_id}/typing/{mxid}",
{"typing": True, "timeout": timeout},
{"user_id": mxid} if mxid else {},
)
def send_invite(self, room_id: str, mxid: str) -> None:
self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid})
@request
def send(
self,
method: str,
path: str = "",
content: Union[bytes, dict] = {},
params: dict = {},
content_type: str = "application/json",
endpoint: str = "/_matrix/client/r0",
) -> dict:
params["access_token"] = self.as_token
headers = {"Content-Type": content_type}
content = json.dumps(content) if isinstance(content, dict) else content
endpoint = (
f"{self.base_url}{endpoint}{path}?"
f"{urllib.parse.urlencode(params)}"
)
return self.http.request(
method, endpoint, body=content, headers=headers
)

132
appservice/db.py Normal file
View file

@ -0,0 +1,132 @@
import os
import sqlite3
import threading
from typing import List
class DataBase(object):
def __init__(self, db_file) -> None:
self.create(db_file)
# The database is accessed via both the threads.
self.lock = threading.Lock()
def create(self, db_file) -> None:
"""
Create a database with the relevant tables if it doesn't already exist.
"""
exists = os.path.exists(db_file)
self.conn = sqlite3.connect(db_file, check_same_thread=False)
self.conn.row_factory = self.dict_factory
self.cur = self.conn.cursor()
if exists:
return
self.cur.execute(
"CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);"
)
self.cur.execute(
"CREATE TABLE users(mxid TEXT PRIMARY KEY, "
"avatar_url TEXT, username TEXT);"
)
self.conn.commit()
def dict_factory(self, cursor, row):
"""
https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory
"""
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
def add_room(self, room_id: str, channel_id: str) -> None:
"""
Add a bridged room to the database.
"""
with self.lock:
self.cur.execute(
"INSERT INTO bridge (room_id, channel_id) "
f"VALUES ('{room_id}', '{channel_id}')"
)
self.conn.commit()
def add_user(self, mxid: str) -> None:
with self.lock:
self.cur.execute(f"INSERT INTO users (mxid) VALUES ('{mxid}')")
self.conn.commit()
def add_avatar(self, avatar_url: str, mxid: str) -> None:
with self.lock:
self.cur.execute(
f"UPDATE users SET avatar_url = '{avatar_url}'"
f"WHERE mxid = '{mxid}'"
)
self.conn.commit()
def add_username(self, username: str, mxid: str) -> None:
with self.lock:
self.cur.execute(
f"UPDATE users SET username = '{username}'"
f"WHERE mxid = '{mxid}'"
)
self.conn.commit()
def get_channel(self, room_id: str) -> str:
"""
Get the corresponding channel ID for a given room ID.
"""
with self.lock:
self.cur.execute(
"SELECT channel_id FROM bridge WHERE room_id = ?", [room_id]
)
room = self.cur.fetchone()
# Return an empty string if nothing is bridged.
return "" if not room else room["channel_id"]
def list_channels(self) -> List[str]:
"""
Get a list of all the bridged channels.
"""
with self.lock:
self.cur.execute("SELECT channel_id FROM bridge")
channels = self.cur.fetchall()
return [channel["channel_id"] for channel in channels]
def list_users(self) -> List[dict]:
"""
Get a dictionary of all the puppeted users.
"""
with self.lock:
self.cur.execute("SELECT * FROM users")
users = self.cur.fetchall()
return users
def query_user(self, mxid: str) -> bool:
"""
Check whether a puppet user has already been created for a given mxid.
"""
with self.lock:
self.cur.execute("SELECT mxid FROM users")
users = self.cur.fetchall()
return next((True for user in users if user["mxid"] == mxid), False)

175
appservice/discord.py Normal file
View file

@ -0,0 +1,175 @@
from dataclasses import dataclass
CDN_URL = "https://cdn.discordapp.com"
@dataclass
class Channel(object):
id: str
type: str
guild_id: str = ""
name: str = ""
topic: str = ""
@dataclass
class Emote(object):
animated: bool
id: str
name: str
class User(object):
def __init__(self, user: dict) -> None:
self.discriminator = user["discriminator"]
self.id = user["id"]
self.mention = f"<@{self.id}>"
self.username = user["username"]
avatar = user["avatar"]
if not avatar:
# https://discord.com/developers/docs/reference#image-formatting
self.avatar_url = (
f"{CDN_URL}/embed/avatars/{int(self.discriminator) % 5}.png"
)
else:
ext = "gif" if avatar.startswith("a_") else "png"
self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}"
class Message(object):
def __init__(self, message: dict) -> None:
self.attachments = message.get("attachments", [])
self.channel_id = message["channel_id"]
self.content = message.get("content", "")
self.id = message["id"]
self.reference = message.get("message_reference", {}).get(
"message_id", ""
)
self.webhook_id = message.get("webhook_id", "")
self.mentions = [
User(mention) for mention in message.get("mentions", [])
]
author = message.get("author")
self.author = User(author) if author else None
@dataclass
class DeletedMessage(object):
channel_id: str
id: str
@dataclass
class Typing(object):
user_id: str
channel_id: str
@dataclass
class Webhook(object):
id: str
token: str
class ChannelType(object):
GUILD_TEXT = 0
DM = 1
GUILD_VOICE = 2
GROUP_DM = 3
GUILD_CATEGORY = 4
GUILD_NEWS = 5
GUILD_STORE = 6
class InteractionResponseType(object):
PONG = 0
ACKNOWLEDGE = 1
CHANNEL_MESSAGE = 2
CHANNEL_MESSAGE_WITH_SOURCE = 4
ACKNOWLEDGE_WITH_SOURCE = 5
class GatewayIntents(object):
def bitmask(bit: int) -> int:
return 1 << bit
GUILDS = bitmask(0)
GUILD_MEMBERS = bitmask(1)
GUILD_BANS = bitmask(2)
GUILD_EMOJIS = bitmask(3)
GUILD_INTEGRATIONS = bitmask(4)
GUILD_WEBHOOKS = bitmask(5)
GUILD_INVITES = bitmask(6)
GUILD_VOICE_STATES = bitmask(7)
GUILD_PRESENCES = bitmask(8)
GUILD_MESSAGES = bitmask(9)
GUILD_MESSAGE_REACTIONS = bitmask(10)
GUILD_MESSAGE_TYPING = bitmask(11)
DIRECT_MESSAGES = bitmask(12)
DIRECT_MESSAGE_REACTIONS = bitmask(13)
DIRECT_MESSAGE_TYPING = bitmask(14)
class GatewayOpCodes(object):
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
PRESENCE_UPDATE = 3
VOICE_STATE_UPDATE = 4
RESUME = 6
RECONNECT = 7
REQUEST_GUILD_MEMBERS = 8
INVALID_SESSION = 9
HELLO = 10
HEARTBEAT_ACK = 11
class Payloads(object):
def __init__(self, token: str) -> None:
self.seq = self.session = None
self.token = token
def HEARTBEAT(self) -> dict:
return {"op": GatewayOpCodes.HEARTBEAT, "d": self.seq}
def IDENTIFY(self) -> dict:
return {
"op": GatewayOpCodes.IDENTIFY,
"d": {
"token": self.token,
"intents": GatewayIntents.GUILDS
| GatewayIntents.GUILD_MESSAGES
| GatewayIntents.GUILD_MESSAGE_TYPING,
"properties": {
"$os": "discord",
"$browser": "discord",
"$device": "discord",
},
},
}
def QUERY(self, guild_id: str, query: str, limit: int = 1) -> dict:
"""
Return the Payload to query a member from a guild ID.
Return only a single match if `limit` isn't specified.
"""
return {
"op": GatewayOpCodes.REQUEST_GUILD_MEMBERS,
"d": {"guild_id": guild_id, "query": query, "limit": limit},
}
def RESUME(self) -> dict:
return {
"op": GatewayOpCodes.RESUME,
"d": {
"token": self.token,
"session_id": self.session,
"seq": self.seq,
},
}

5
appservice/errors.py Normal file
View file

@ -0,0 +1,5 @@
class RequestError(Exception):
def __init__(self, status: int, *args):
super().__init__(*args)
self.status = status

299
appservice/gateway.py Normal file
View file

@ -0,0 +1,299 @@
import asyncio
import json
import logging
import urllib.parse
from typing import List
import urllib3
import websockets
import discord
from misc import dict_cls, log_except, request, wrap_async
class Gateway(object):
def __init__(self, http: urllib3.PoolManager, token: str):
self.http = http
self.token = token
self.logger = logging.getLogger("discord")
self.cdn_url = "https://cdn.discordapp.com"
self.Payloads = discord.Payloads(self.token)
self.loop = self.websocket = None
self.query_cache = {}
@log_except
async def run(self) -> None:
self.loop = asyncio.get_running_loop()
self.query_ev = asyncio.Event()
self.heartbeat_task = None
self.resume = False
while True:
try:
await self.gateway_handler(self.get_gateway_url())
except websockets.ConnectionClosedError:
# TODO reconnect ?
self.logger.exception("Quitting, connection lost.")
break
# Stop sending heartbeats until we reconnect.
if self.heartbeat_task and not self.heartbeat_task.cancelled():
self.heartbeat_task.cancel()
def get_gateway_url(self) -> str:
resp = self.send("GET", "/gateway")
return resp["url"]
async def heartbeat_handler(self, interval_ms: int) -> None:
while True:
await asyncio.sleep(interval_ms / 1000)
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
def query_handler(self, data: dict) -> None:
members = data["members"]
guild_id = data["guild_id"]
for member in members:
user = member["user"]
self.query_cache[guild_id].append(user)
self.query_ev.set()
def handle_otype(self, data: dict, otype: str) -> None:
if data.get("embeds"):
return # TODO embeds
if otype == "MESSAGE_CREATE" or otype == "MESSAGE_UPDATE":
obj = discord.Message(data)
elif otype == "MESSAGE_DELETE":
obj = dict_cls(data, discord.DeletedMessage)
elif otype == "TYPING_START":
obj = dict_cls(data, discord.Typing)
elif otype == "GUILD_MEMBERS_CHUNK":
self.query_handler(data)
return
else:
self.logger.info(f"Unknown OTYPE: {otype}")
return
func = getattr(self, f"on_{otype.lower()}", None)
if not func:
self.logger.warning(
f"Function '{func}' not defined, ignoring message."
)
return
try:
func(obj)
except Exception:
self.logger.exception(f"Ignoring exception in {func}:")
async def gateway_handler(self, gateway_url: str) -> None:
async with websockets.connect(
f"{gateway_url}/?v=8&encoding=json"
) as websocket:
self.websocket = websocket
async for message in websocket:
data = json.loads(message)
data_dict = data.get("d")
opcode = data.get("op")
seq = data.get("s")
if seq:
self.Payloads.seq = seq
if opcode == discord.GatewayOpCodes.DISPATCH:
otype = data.get("t")
if otype == "READY":
self.Payloads.session = data_dict["session_id"]
self.logger.info("READY")
else:
self.handle_otype(data_dict, otype)
elif opcode == discord.GatewayOpCodes.HELLO:
heartbeat_interval = data_dict.get("heartbeat_interval")
self.logger.info(
f"Heartbeat Interval: {heartbeat_interval}"
)
# Send periodic hearbeats to gateway.
self.heartbeat_task = asyncio.ensure_future(
self.heartbeat_handler(heartbeat_interval)
)
await websocket.send(
json.dumps(
self.Payloads.RESUME()
if self.resume
else self.Payloads.IDENTIFY()
)
)
elif opcode == discord.GatewayOpCodes.RECONNECT:
self.logger.info("Received RECONNECT.")
self.resume = True
await websocket.close()
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
self.logger.info("Received INVALID_SESSION.")
self.resume = False
await websocket.close()
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
# NOP
pass
else:
self.logger.info(
f"Unknown OP code {opcode}:\n"
f"{json.dumps(data, indent=4)}"
)
@wrap_async
async def query_member(self, guild_id: str, name: str) -> discord.User:
"""
Query the members for a given guild and return the first match.
"""
self.query_ev.clear()
def query():
if not self.query_cache.get(guild_id):
self.query_cache[guild_id] = []
user = [
user
for user in self.query_cache[guild_id]
if name.lower() in user["username"].lower()
]
return None if not user else discord.User(user[0])
user = query()
if user:
return user
if not self.websocket or self.websocket.closed:
self.logger.warning("Not fetching members, websocket closed.")
return
await self.websocket.send(
json.dumps(self.Payloads.QUERY(guild_id, name))
)
# Wait for our websocket to receive the chunk.
await asyncio.wait_for(self.query_ev.wait(), timeout=5)
return query()
def get_channel(self, channel_id: str) -> discord.Channel:
"""
Get the channel object for a given channel ID.
"""
resp = self.send("GET", f"/channels/{channel_id}")
return dict_cls(resp, discord.Channel)
def get_emotes(self, guild_id: str) -> List[discord.Emote]:
"""
Get all the emotes for a given guild.
"""
resp = self.send("GET", f"/guilds/{guild_id}/emojis")
return [dict_cls(emote, discord.Emote) for emote in resp]
def get_members(self, guild_id: str) -> List[discord.User]:
"""
Get all the members for a given guild.
"""
resp = self.send(
"GET", f"/guilds/{guild_id}/members", params={"limit": 1000}
)
return [discord.User(member["user"]) for member in resp]
def create_webhook(self, channel_id: str, name: str) -> discord.Webhook:
"""
Create a webhook with the specified name in a given channel.
"""
resp = self.send(
"POST", f"/channels/{channel_id}/webhooks", {"name": name}
)
return dict_cls(resp, discord.Webhook)
def edit_webhook(
self, content: str, message_id: str, webhook: discord.Webhook
) -> None:
self.send(
"PATCH",
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
f"{message_id}",
{"content": content},
)
def delete_webhook(
self, message_id: str, webhook: discord.Webhook
) -> None:
self.send(
"DELETE",
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
f"{message_id}",
)
def send_webhook(self, webhook: discord.Webhook, **kwargs) -> str:
content = {
**kwargs,
# Disable 'everyone' and 'role' mentions.
"allowed_mentions": {"parse": ["users"]},
}
resp = self.send(
"POST",
f"/webhooks/{webhook.id}/{webhook.token}",
content,
{"wait": True},
)
return resp["id"]
def send_message(self, message: str, channel_id: str) -> None:
self.send(
"POST", f"/channels/{channel_id}/messages", {"content": message}
)
@request
def send(
self, method: str, path: str, content: dict = {}, params: dict = {}
) -> dict:
endpoint = (
f"https://discord.com/api/v8{path}?"
f"{urllib.parse.urlencode(params)}"
)
headers = {
"Authorization": f"Bot {self.token}",
"Content-Type": "application/json",
}
# 'body' being an empty dict breaks "GET" requests.
content = json.dumps(content) if content else None
return self.http.request(
method, endpoint, body=content, headers=headers
)

656
appservice/main.py Normal file
View file

@ -0,0 +1,656 @@
import asyncio
import json
import logging
import os
import re
import sys
import threading
from typing import Dict, Tuple, Union
import urllib3
import discord
import matrix
from appservice import AppService
from db import DataBase
from errors import RequestError
from gateway import Gateway
from misc import dict_cls, except_deleted
# TODO should this be cleared periodically ?
message_cache: Dict[str, Union[discord.Webhook, str]] = {}
class MatrixClient(AppService):
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
super().__init__(config, http)
self.db = DataBase(config["database"])
self.discord = DiscordClient(self, config, http)
self.emote_cache: Dict[str, str] = {}
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
def to_return(self, event: matrix.Event) -> bool:
return event.sender.startswith(("@_discord", self.user_id))
def handle_bridge(self, message: matrix.Event) -> None:
# Ignore events that aren't for us.
if message.sender.split(":")[
-1
] != self.server_name or not message.body.startswith("!bridge"):
return
try:
channel = message.body.split()[1]
except IndexError:
return
# Check if the given channel is valid.
try:
channel = self.discord.get_channel(channel)
except RequestError as e:
# The channel can be invalid or we may not have permission.
self.logger.warning(f"Failed to fetch channel {channel}: {e}")
return
if (
channel.type != discord.ChannelType.GUILD_TEXT
or channel.id in self.db.list_channels()
):
return
self.logger.info(f"Creating bridged room for channel {channel.id}.")
self.create_room(channel, message.sender)
def on_member(self, event: matrix.Event) -> None:
# Ignore events that aren't for us.
if (
event.sender.split(":")[-1] != self.server_name
or event.state_key != self.user_id
or not event.is_direct
):
return
# Join the direct message room.
self.logger.info(f"Joining direct message room {event.room_id}.")
self.join_room(event.room_id)
def on_message(self, message: matrix.Event) -> None:
if self.to_return(message):
return
# Handle bridging commands.
self.handle_bridge(message)
channel_id = self.db.get_channel(message.room_id)
if not channel_id:
return
webhook = self.discord.get_webhook(channel_id, "matrix_bridge")
if message.relates_to and message.reltype == "m.replace":
relation = message_cache.get(message.relates_to)
if not message.new_body or not relation:
return
message.new_body = self.process_message(
channel_id, message.new_body
)
except_deleted(self.discord.edit_webhook)(
message.new_body, relation["message_id"], webhook
)
else:
if not message.body:
return
message.body = self.process_message(channel_id, message.body)
message_cache[message.event_id] = {
"message_id": self.discord.send_webhook(
webhook,
avatar_url=message.author.avatar_url,
content=message.body,
username=message.author.displayname,
),
"webhook": webhook,
}
@except_deleted
def on_redaction(self, event: dict) -> None:
redacts = event["redacts"]
event = message_cache.get(redacts)
if not event:
return
self.discord.delete_webhook(event["message_id"], event["webhook"])
message_cache.pop(redacts)
def create_room(self, channel: discord.Channel, sender: str) -> None:
"""
Create a bridged room and invite the person who invoked the command.
"""
content = {
"room_alias_name": f"{self.format}{channel.id}",
"name": channel.name,
"topic": channel.topic,
"visibility": "private",
"invite": [sender],
"creation_content": {"m.federate": True},
"initial_state": [
{
"type": "m.room.join_rules",
"content": {"join_rule": "public"},
},
{
"type": "m.room.history_visibility",
"content": {"history_visibility": "shared"},
},
],
"power_level_content_override": {"users": {sender: 100}},
}
resp = self.send("POST", "/createRoom", content)
self.db.add_room(resp["room_id"], channel.id)
def create_message_event(
self, message: str, emotes: dict, edit: str = "", reply: str = ""
) -> dict:
content = {
"body": message,
"format": "org.matrix.custom.html",
"msgtype": "m.text",
"formatted_body": self.get_fmt(message, emotes),
}
event = message_cache.get(reply)
if event:
content = {
**content,
"m.relates_to": {
"m.in_reply_to": {"event_id": event["event_id"]}
},
"formatted_body": f"""<mx-reply><blockquote>\
<a href='https://matrix.to/#/{event["room_id"]}/{event["event_id"]}'>\
In reply to</a><a href='https://matrix.to/#/{event["mxid"]}'>\
{event["mxid"]}</a><br>{event["body"]}</blockquote></mx-reply>\
{content["formatted_body"]}""",
}
if edit:
content = {
**content,
"body": f" * {content['body']}",
"formatted_body": f" * {content['formatted_body']}",
"m.relates_to": {"event_id": edit, "rel_type": "m.replace"},
"m.new_content": {**content},
}
return content
def get_fmt(self, message: str, emotes: dict) -> str:
replace = [
# Bold.
("**", "<strong>", "</strong>"),
# Code blocks.
("```", "<pre><code>", "</code></pre>"),
# Spoilers.
("||", "<span data-mx-spoiler>", "</span>"),
# Strikethrough.
("~~", "<del>", "</del>"),
]
for replace_ in replace:
for i in range(1, message.count(replace_[0]) + 1):
if i % 2:
message = message.replace(replace_[0], replace_[1], 1)
else:
message = message.replace(replace_[0], replace_[2], 1)
# Upload emotes in multiple threads so that we don't
# block the Discord bot for too long.
upload_threads = [
threading.Thread(
target=self.upload_emote, args=(emote, emotes[emote])
)
for emote in emotes
]
[thread.start() for thread in upload_threads]
[thread.join() for thread in upload_threads]
for emote in emotes:
emote_ = self.emote_cache.get(emote)
if emote_:
emote = f":{emote}:"
message = message.replace(
emote,
f"""<img alt=\"{emote}\" title=\"{emote}\" \
height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
)
return message
def process_message(self, channel_id: str, message: str) -> str:
message = message[:2000] # Discord limit.
emotes = re.findall(r":(\w*):", message)
mentions = re.findall(r"(@(\w*))", message)
# Remove the puppet user's username from replies.
message = re.sub(f"<@{self.format}.+?>", "", message)
added_emotes = []
for emote in emotes:
# Don't replace emote names with IDs multiple times.
if emote not in added_emotes:
added_emotes.append(emote)
emote_ = self.discord.emote_cache.get(emote)
if emote_:
message = message.replace(f":{emote}:", emote_)
# Don't unnecessarily fetch the channel.
if mentions:
guild_id = self.discord.get_channel(channel_id).guild_id
for mention in mentions:
if not mention[1]:
continue
try:
member = self.discord.query_member(guild_id, mention[1])
except (asyncio.TimeoutError, RuntimeError):
continue
if member:
message = message.replace(mention[0], member.mention)
return message
def upload_emote(self, emote_name: str, emote_id: str) -> None:
# There won't be a race condition here, since only a unique
# set of emotes are uploaded at a time.
if emote_name in self.emote_cache:
return
emote_url = f"{discord.CDN_URL}/emojis/{emote_id}"
# We don't want the message to be dropped entirely if an emote
# fails to upload for some reason.
try:
self.emote_cache[emote_name] = self.upload(emote_url)
except RequestError as e:
self.logger.warning(f"Failed to upload emote {emote_id}: {e}")
def register(self, mxid: str) -> None:
"""
Register a dummy user on the homeserver.
"""
content = {
"type": "m.login.application_service",
# "@test:localhost" -> "test" (Can't register with a full mxid.)
"username": mxid[1:].split(":")[0],
}
resp = self.send("POST", "/register", content)
self.db.add_user(resp["user_id"])
def set_avatar(self, avatar_url: str, mxid: str) -> None:
avatar_uri = self.upload(avatar_url)
self.send(
"PUT",
f"/profile/{mxid}/avatar_url",
{"avatar_url": avatar_uri},
params={"user_id": mxid},
)
self.db.add_avatar(avatar_url, mxid)
def set_nick(self, username: str, mxid: str) -> None:
self.send(
"PUT",
f"/profile/{mxid}/displayname",
{"displayname": username},
params={"user_id": mxid},
)
self.db.add_username(username, mxid)
class DiscordClient(Gateway):
def __init__(
self, appservice: MatrixClient, config: dict, http: urllib3.PoolManager
) -> None:
super().__init__(http, config["discord_token"])
self.app = appservice
self.emote_cache: Dict[str, str] = {}
self.webhook_cache: Dict[str, discord.Webhook] = {}
async def sync(self) -> None:
"""
Periodically compare the usernames and avatar URLs with Discord
and update if they differ. Also synchronise emotes.
"""
def sync_emotes(guilds: set):
# We could store the emotes once and update according
# to gateway events but we're too lazy for that.
emotes = []
for guild in guilds:
[emotes.append(emote) for emote in (self.get_emotes(guild))]
self.emote_cache.clear() # Clears deleted/renamed emotes.
for emote in emotes:
self.emote_cache[f"{emote.name}"] = (
f"<{'a' if emote.animated else ''}:"
f"{emote.name}:{emote.id}>"
)
def sync_users(guilds: set):
# TODO use websockets for this, using IDs from database.
users = []
for guild in guilds:
[users.append(member) for member in self.get_members(guild)]
db_users = self.app.db.list_users()
# Convert a list of dicts:
# [ { "avatar_url": ... } ]
# to a dict that is indexable by Discord IDs:
# { "discord_id": { "avatar_url": ... } }
users_ = {}
for user in db_users:
users_[user["mxid"].split("_")[-1].split(":")[0]] = {**user}
for user in users:
user_ = users_.get(user.id)
if not user_:
continue
mxid = user_["mxid"]
username = f"{user.username}#{user.discriminator}"
if user.avatar_url != user_["avatar_url"]:
self.logger.info(
f"Updating avatar for Discord user {user.id}."
)
self.app.set_avatar(user.avatar_url, mxid)
if username != user_["username"]:
self.logger.info(
f"Updating username for Discord user {user.id}."
)
self.app.set_nick(username, mxid)
while True:
guilds = set() # Avoid duplicates.
try:
for channel in self.app.db.list_channels():
guilds.add(self.get_channel(channel).guild_id)
sync_emotes(guilds)
sync_users(guilds)
# Don't let the background task die.
except RequestError:
self.logger.exception(
"Ignoring exception during background sync:"
)
await asyncio.sleep(120) # Check every 2 minutes.
async def start(self) -> None:
asyncio.ensure_future(self.sync())
await self.run()
def to_return(self, message: discord.Message) -> bool:
return (
message.channel_id not in self.app.db.list_channels()
or not message.author
or message.author.discriminator == "0000"
)
def matrixify(self, id: str, user: bool = False) -> str:
return (
f"{'@' if user else '#'}{self.app.format}{id}:"
f"{self.app.server_name}"
)
def wrap(self, message: discord.Message) -> Tuple[str, str]:
"""
Get the room ID and the puppet's mxid for a given channel ID and a
Discord user.
"""
mxid = self.matrixify(message.author.id, user=True)
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
if not self.app.db.query_user(mxid):
self.logger.info(
f"Creating dummy user for Discord user {message.author.id}."
)
self.app.register(mxid)
self.app.set_nick(
f"{message.author.username}#"
f"{message.author.discriminator}",
mxid,
)
if message.author.avatar_url:
self.app.set_avatar(message.author.avatar_url, mxid)
if mxid not in self.app.get_members(room_id):
self.logger.info(f"Inviting user {mxid} to room {room_id}.")
self.app.send_invite(room_id, mxid)
self.app.join_room(room_id, mxid)
return mxid, room_id
def on_message_create(self, message: discord.Message) -> None:
if self.to_return(message):
return
mxid, room_id = self.wrap(message)
content, emotes = self.process_message(message)
content = self.app.create_message_event(
content, emotes, reply=message.reference
)
message_cache[message.id] = {
"body": content["body"],
"event_id": self.app.send_message(room_id, content, mxid),
"mxid": mxid,
"room_id": room_id,
}
def on_message_delete(self, message: discord.DeletedMessage) -> None:
event = message_cache.get(message.id)
if not event:
return
self.app.redact(event["event_id"], event["room_id"], event["mxid"])
message_cache.pop(message.id)
def on_message_update(self, message: dict) -> None:
if self.to_return(message):
return
event = message_cache.get(message.id)
if not event:
return
content, emotes = self.process_message(message)
content = self.app.create_message_event(
content, emotes, edit=event["event_id"]
)
self.app.send_message(event["room_id"], content, event["mxid"])
def on_typing_start(self, typing: discord.Typing) -> None:
if typing.channel_id not in self.app.db.list_channels():
return
mxid = self.matrixify(typing.user_id, user=True)
room_id = self.app.get_room_id(self.matrixify(typing.channel_id))
if mxid not in self.app.get_members(room_id):
return
self.app.send_typing(room_id, mxid)
def get_webhook(self, channel_id: str, name: str) -> discord.Webhook:
"""
Get the webhook object for the first webhook that matches the specified
name in a given channel, create the webhook if it doesn't exist.
"""
# Check the cache first.
webhook = self.webhook_cache.get(channel_id)
if webhook:
return webhook
webhooks = self.send("GET", f"/channels/{channel_id}/webhooks")
webhook = next(
(
dict_cls(webhook, discord.Webhook)
for webhook in webhooks
if webhook["name"] == name
),
None,
)
if not webhook:
webhook = self.create_webhook(channel_id, name)
self.webhook_cache[channel_id] = webhook
return webhook
def process_message(self, message: discord.Message) -> Tuple[str, str]:
content = message.content
emotes = {}
regex = r"<a?:(\w+):(\d+)>"
# Mentions can either be in the form of `<@1234>` or `<@!1234>`.
for char in ("", "!"):
for member in message.mentions:
content = content.replace(
f"<@{char}{member.id}>", f"@{member.username}"
)
# `except_deleted` for invalid channels.
for channel in re.findall(r"<#([0-9]+)>", content):
channel_ = except_deleted(self.get_channel)(channel)
content = content.replace(
f"<#{channel}>",
f"#{channel_.name}" if channel_ else "deleted-channel",
)
# { "emote_name": "emote_id" }
for emote in re.findall(regex, content):
emotes[emote[0]] = emote[1]
# Replace emote IDs with names.
content = re.sub(regex, r":\g<1>:", content)
# Append attachments to message.
for attachment in message.attachments:
content += f"\n{attachment['url']}"
return content, emotes
def config_gen(basedir: str, config_file: str) -> dict:
config_file = f"{basedir}/{config_file}"
config_dict = {
"as_token": "my-secret-as-token",
"hs_token": "my-secret-hs-token",
"user_id": "appservice-discord",
"homeserver": "http://127.0.0.1:8008",
"server_name": "localhost",
"discord_token": "my-secret-discord-token",
"port": 5000,
"database": f"{basedir}/bridge.db",
}
if not os.path.exists(config_file):
with open(config_file, "w") as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration dumped to '{config_file}'")
sys.exit()
with open(config_file, "r") as f:
return json.loads(f.read())
def main() -> None:
try:
basedir = sys.argv[1]
if not os.path.exists(basedir):
print(f"Path '{basedir}' does not exist!")
sys.exit(1)
basedir = os.path.abspath(basedir)
except IndexError:
basedir = os.getcwd()
config = config_gen(basedir, "appservice.json")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[
logging.FileHandler(f"{basedir}/appservice.log"),
],
)
http = urllib3.PoolManager(maxsize=10)
app = MatrixClient(config, http)
# Start the bottle app in a separate thread.
app_thread = threading.Thread(
target=app.run, kwargs={"port": int(config["port"])}, daemon=True
)
app_thread.start()
try:
asyncio.run(app.discord.start())
except KeyboardInterrupt:
sys.exit()
if __name__ == "__main__":
main()

26
appservice/matrix.py Normal file
View file

@ -0,0 +1,26 @@
from dataclasses import dataclass
@dataclass
class User(object):
avatar_url: str
displayname: str
class Event(object):
def __init__(self, event: dict):
content = event["content"]
self.author = event["author"]
self.body = content.get("body", "")
self.event_id = event["event_id"]
self.is_direct = content.get("is_direct", False)
self.room_id = event["room_id"]
self.sender = event["sender"]
self.state_key = event.get("state_key", "")
rel = content.get("m.relates_to", {})
self.relates_to = rel.get("event_id")
self.reltype = rel.get("rel_type")
self.new_body = content.get("m.new_content", {}).get("body", "")

89
appservice/misc.py Normal file
View file

@ -0,0 +1,89 @@
import asyncio
import json
from dataclasses import fields
from typing import Any
import urllib3
from errors import RequestError
def dict_cls(dict_var: dict, cls: Any) -> Any:
"""
Create a dataclass from a dictionary.
"""
field_names = set(f.name for f in fields(cls))
filtered_dict = {k: v for k, v in dict_var.items() if k in field_names}
return cls(**filtered_dict)
def log_except(fn):
"""
Log unhandled exceptions to a logger instead of `stderr`.
"""
def wrapper(self, *args, **kwargs):
try:
return fn(self, *args, **kwargs)
except Exception:
self.logger.exception(f"Exception in '{fn.__name__}':")
raise
return wrapper
def wrap_async(fn):
"""
Call an asynchronous function from a synchronous one.
"""
def wrapper(self, *args, **kwargs):
if not self.loop:
raise RuntimeError("loop is None.")
return asyncio.run_coroutine_threadsafe(
fn(self, *args, **kwargs), loop=self.loop
).result()
return wrapper
def request(fn):
"""
Either return json data or raise a `RequestError` if the request was
unsuccessful.
"""
def wrapper(*args, **kwargs):
try:
resp = fn(*args, **kwargs)
except urllib3.exceptions.HTTPError as e:
raise RequestError(None, f"Failed to connect: {e}") from None
if resp.status < 200 or resp.status >= 300:
raise RequestError(
resp.status,
f"Failed to get response from '{resp.geturl()}':\n{resp.data}",
)
return {} if resp.status == 204 else json.loads(resp.data)
return wrapper
def except_deleted(fn):
"""
Ignore the `RequestError` on 404s, the message might have been
deleted by someone else already.
"""
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except RequestError as e:
if e.status != 404:
raise
return wrapper

View file

@ -0,0 +1,3 @@
bottle==0.12.19
urllib3==1.26.3
websockets==8.1