Appservice (#4)
This commit is contained in:
parent
4b876734bb
commit
d6740b4bd3
10 changed files with 1663 additions and 0 deletions
66
appservice/README.md
Normal file
66
appservice/README.md
Normal file
|
@ -0,0 +1,66 @@
|
|||
## Installation
|
||||
|
||||
`pip install -r requirements.txt`
|
||||
|
||||
## Usage
|
||||
|
||||
* Run `main.py` to generate `appservice.json`
|
||||
|
||||
* Edit `appservice.json`:
|
||||
|
||||
```
|
||||
{
|
||||
"as_token": "my-secret-as-token",
|
||||
"hs_token": "my-secret-hs-token",
|
||||
"user_id": "appservice-discord",
|
||||
# Homeserver running on the same machine, listening on port 8008.
|
||||
"homeserver": "http://127.0.0.1:8008",
|
||||
# Change "localhost" to your server_name.
|
||||
# Eg. "kde.org" is the server_name in "@testuser:kde.org".
|
||||
"server_name": "localhost",
|
||||
"discord_token": "my-secret-discord-token",
|
||||
"port": 5000, # Port to run the bottle app on.
|
||||
"database": "/path/to/bridge.db"
|
||||
}
|
||||
```
|
||||
|
||||
* Create `appservice.yaml` and add it to your homeserver configuration:
|
||||
|
||||
```
|
||||
id: "discord"
|
||||
url: "http://127.0.0.1:5000"
|
||||
as_token: "my-secret-as-token"
|
||||
hs_token: "my-secret-hs-token"
|
||||
sender_localpart: "appservice-discord"
|
||||
namespaces:
|
||||
users:
|
||||
- exclusive: true
|
||||
regex: "@_discord.*"
|
||||
# Work around for temporary bug in dendrite.
|
||||
- regex: "@appservice-discord"
|
||||
aliases:
|
||||
- exclusive: false
|
||||
regex: "#_discord.*"
|
||||
rooms: []
|
||||
```
|
||||
|
||||
A path can optionally be passed as the first argument to `main.py`. This path will be used as the base directory for the database and log file.
|
||||
|
||||
Eg. Running `python3 main.py /path/to/my/dir` will store the database and logs in `/path/to/my/dir`.
|
||||
`$PWD` is used by default if no path is specified.
|
||||
|
||||
This bridge is written with:
|
||||
* `bottle`: Receiving events from the homeserver.
|
||||
* `urllib3`: Sending requests, thread safety.
|
||||
* `websockets`: Connecting to Discord. (Big thanks to an anonymous person "nesslersreagent" for figuring out the initial connection mess.)
|
||||
|
||||
## NOTES
|
||||
* A basic sqlite database is used for keeping track of bridged rooms.
|
||||
|
||||
* Logs are saved to the `appservice.log` file in `$PWD` or the specified directory.
|
||||
|
||||
* For avatars to show up on Discord, you must have a [reverse proxy](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) set up on your homeserver as the bridge does not specify the homeserver port when passing the avatar url.
|
||||
|
||||
* It is not possible to add normal Discord bot functionality like commands as this bridge does not use `discord.py`.
|
||||
|
||||
NOTE: [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) must be enabled for your Discord bot.
|
212
appservice/appservice.py
Normal file
212
appservice/appservice.py
Normal file
|
@ -0,0 +1,212 @@
|
|||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from typing import List, Union
|
||||
|
||||
import bottle
|
||||
import urllib3
|
||||
|
||||
import matrix
|
||||
from misc import dict_cls, log_except, request
|
||||
|
||||
|
||||
class AppService(bottle.Bottle):
|
||||
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
|
||||
super(AppService, self).__init__()
|
||||
|
||||
self.as_token = config["as_token"]
|
||||
self.hs_token = config["hs_token"]
|
||||
self.base_url = config["homeserver"]
|
||||
self.server_name = config["server_name"]
|
||||
self.user_id = f"@{config['user_id']}:{self.server_name}"
|
||||
self.http = http
|
||||
self.logger = logging.getLogger("appservice")
|
||||
|
||||
# TODO better method.
|
||||
# Map events to functions.
|
||||
self.mapping = {
|
||||
"m.room.member": "on_member",
|
||||
"m.room.message": "on_message",
|
||||
"m.room.redaction": "on_redaction",
|
||||
}
|
||||
|
||||
# Add route for bottle.
|
||||
self.route(
|
||||
"/transactions/<transaction>",
|
||||
callback=self.receive_event,
|
||||
method="PUT",
|
||||
)
|
||||
|
||||
def handle_event(self, event: dict) -> None:
|
||||
event_type = event.get("type")
|
||||
|
||||
if event_type == "m.room.member" or event_type == "m.room.message":
|
||||
obj = self.get_event_object(event)
|
||||
elif event_type == "m.room.redaction":
|
||||
obj = event
|
||||
else:
|
||||
self.logger.info(f"Unknown event type: {event_type}")
|
||||
return
|
||||
|
||||
func = getattr(self, self.mapping[event_type], None)
|
||||
|
||||
if not func:
|
||||
self.logger.warning(
|
||||
f"Function '{func}' not defined, ignoring event."
|
||||
)
|
||||
return
|
||||
|
||||
# We don't catch exceptions here as the homeserver will re-send us
|
||||
# the event in case of a failure.
|
||||
func(obj)
|
||||
|
||||
@log_except
|
||||
def receive_event(self, transaction: str) -> dict:
|
||||
"""
|
||||
Verify the homeserver's token and handle events.
|
||||
"""
|
||||
|
||||
hs_token = bottle.request.query.getone("access_token")
|
||||
|
||||
if not hs_token:
|
||||
bottle.response.status = 401
|
||||
return {"errcode": "APPSERVICE_UNAUTHORIZED"}
|
||||
|
||||
if hs_token != self.hs_token:
|
||||
bottle.response.status = 403
|
||||
return {"errcode": "APPSERVICE_FORBIDDEN"}
|
||||
|
||||
events = bottle.request.json.get("events")
|
||||
|
||||
for event in events:
|
||||
self.handle_event(event)
|
||||
|
||||
return {}
|
||||
|
||||
def get_event_object(self, event: dict) -> matrix.Event:
|
||||
event["author"] = dict_cls(
|
||||
self.get_profile(event["sender"]), matrix.User
|
||||
)
|
||||
|
||||
return matrix.Event(event)
|
||||
|
||||
def join_room(self, room_id: str, mxid: str = "") -> None:
|
||||
self.send(
|
||||
"POST",
|
||||
f"/join/{room_id}",
|
||||
params={"user_id": mxid} if mxid else {},
|
||||
)
|
||||
|
||||
def redact(self, event_id: str, room_id: str, mxid: str = "") -> None:
|
||||
self.send(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}",
|
||||
params={"user_id": mxid} if mxid else {},
|
||||
)
|
||||
|
||||
def get_profile(self, mxid: str) -> dict:
|
||||
# TODO handle failure, avoid querying this endpoint repeatedly.
|
||||
resp = self.send("GET", f"/profile/{mxid}")
|
||||
|
||||
avatar_url = resp.get("avatar_url", "")[6:].split("/")
|
||||
avatar_url = (
|
||||
(
|
||||
f"https://{self.server_name}/_matrix/media/r0/download/"
|
||||
f"{avatar_url[0]}/{avatar_url[1]}"
|
||||
)
|
||||
if len(avatar_url) > 1
|
||||
else None
|
||||
)
|
||||
|
||||
return {
|
||||
"avatar_url": avatar_url,
|
||||
"displayname": resp.get("displayname"),
|
||||
}
|
||||
|
||||
def get_members(self, room_id: str) -> List[str]:
|
||||
resp = self.send(
|
||||
"GET",
|
||||
f"/rooms/{room_id}/members",
|
||||
params={"membership": "join", "not_membership": "leave"},
|
||||
)
|
||||
|
||||
return [
|
||||
content["sender"]
|
||||
for content in resp["chunk"]
|
||||
if content["content"]["membership"] == "join"
|
||||
]
|
||||
|
||||
def get_room_id(self, alias: str) -> str:
|
||||
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
|
||||
|
||||
# TODO cache ?
|
||||
|
||||
return resp["room_id"]
|
||||
|
||||
def upload(self, url: str) -> str:
|
||||
"""
|
||||
Upload a file to the homeserver and get the MXC url.
|
||||
"""
|
||||
|
||||
resp = self.http.request("GET", url)
|
||||
|
||||
resp = self.send(
|
||||
"POST",
|
||||
content=resp.data,
|
||||
content_type=resp.headers.get("Content-Type"),
|
||||
params={"filename": f"{uuid.uuid4()}"},
|
||||
endpoint="/_matrix/media/r0/upload",
|
||||
)
|
||||
|
||||
return resp["content_uri"]
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
room_id: str,
|
||||
content: dict,
|
||||
mxid: str = "",
|
||||
) -> str:
|
||||
resp = self.send(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}",
|
||||
content,
|
||||
{"user_id": mxid} if mxid else {},
|
||||
)
|
||||
|
||||
return resp["event_id"]
|
||||
|
||||
def send_typing(
|
||||
self, room_id: str, mxid: str = "", timeout: int = 8000
|
||||
) -> None:
|
||||
self.send(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/typing/{mxid}",
|
||||
{"typing": True, "timeout": timeout},
|
||||
{"user_id": mxid} if mxid else {},
|
||||
)
|
||||
|
||||
def send_invite(self, room_id: str, mxid: str) -> None:
|
||||
self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid})
|
||||
|
||||
@request
|
||||
def send(
|
||||
self,
|
||||
method: str,
|
||||
path: str = "",
|
||||
content: Union[bytes, dict] = {},
|
||||
params: dict = {},
|
||||
content_type: str = "application/json",
|
||||
endpoint: str = "/_matrix/client/r0",
|
||||
) -> dict:
|
||||
params["access_token"] = self.as_token
|
||||
headers = {"Content-Type": content_type}
|
||||
content = json.dumps(content) if isinstance(content, dict) else content
|
||||
endpoint = (
|
||||
f"{self.base_url}{endpoint}{path}?"
|
||||
f"{urllib.parse.urlencode(params)}"
|
||||
)
|
||||
|
||||
return self.http.request(
|
||||
method, endpoint, body=content, headers=headers
|
||||
)
|
132
appservice/db.py
Normal file
132
appservice/db.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import List
|
||||
|
||||
|
||||
class DataBase(object):
|
||||
def __init__(self, db_file) -> None:
|
||||
self.create(db_file)
|
||||
|
||||
# The database is accessed via both the threads.
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def create(self, db_file) -> None:
|
||||
"""
|
||||
Create a database with the relevant tables if it doesn't already exist.
|
||||
"""
|
||||
|
||||
exists = os.path.exists(db_file)
|
||||
|
||||
self.conn = sqlite3.connect(db_file, check_same_thread=False)
|
||||
self.conn.row_factory = self.dict_factory
|
||||
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
if exists:
|
||||
return
|
||||
|
||||
self.cur.execute(
|
||||
"CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);"
|
||||
)
|
||||
|
||||
self.cur.execute(
|
||||
"CREATE TABLE users(mxid TEXT PRIMARY KEY, "
|
||||
"avatar_url TEXT, username TEXT);"
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def dict_factory(self, cursor, row):
|
||||
"""
|
||||
https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory
|
||||
"""
|
||||
|
||||
d = {}
|
||||
for idx, col in enumerate(cursor.description):
|
||||
d[col[0]] = row[idx]
|
||||
return d
|
||||
|
||||
def add_room(self, room_id: str, channel_id: str) -> None:
|
||||
"""
|
||||
Add a bridged room to the database.
|
||||
"""
|
||||
|
||||
with self.lock:
|
||||
self.cur.execute(
|
||||
"INSERT INTO bridge (room_id, channel_id) "
|
||||
f"VALUES ('{room_id}', '{channel_id}')"
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def add_user(self, mxid: str) -> None:
|
||||
with self.lock:
|
||||
self.cur.execute(f"INSERT INTO users (mxid) VALUES ('{mxid}')")
|
||||
self.conn.commit()
|
||||
|
||||
def add_avatar(self, avatar_url: str, mxid: str) -> None:
|
||||
with self.lock:
|
||||
self.cur.execute(
|
||||
f"UPDATE users SET avatar_url = '{avatar_url}'"
|
||||
f"WHERE mxid = '{mxid}'"
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def add_username(self, username: str, mxid: str) -> None:
|
||||
with self.lock:
|
||||
self.cur.execute(
|
||||
f"UPDATE users SET username = '{username}'"
|
||||
f"WHERE mxid = '{mxid}'"
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def get_channel(self, room_id: str) -> str:
|
||||
"""
|
||||
Get the corresponding channel ID for a given room ID.
|
||||
"""
|
||||
|
||||
with self.lock:
|
||||
self.cur.execute(
|
||||
"SELECT channel_id FROM bridge WHERE room_id = ?", [room_id]
|
||||
)
|
||||
|
||||
room = self.cur.fetchone()
|
||||
|
||||
# Return an empty string if nothing is bridged.
|
||||
return "" if not room else room["channel_id"]
|
||||
|
||||
def list_channels(self) -> List[str]:
|
||||
"""
|
||||
Get a list of all the bridged channels.
|
||||
"""
|
||||
|
||||
with self.lock:
|
||||
self.cur.execute("SELECT channel_id FROM bridge")
|
||||
|
||||
channels = self.cur.fetchall()
|
||||
|
||||
return [channel["channel_id"] for channel in channels]
|
||||
|
||||
def list_users(self) -> List[dict]:
|
||||
"""
|
||||
Get a dictionary of all the puppeted users.
|
||||
"""
|
||||
|
||||
with self.lock:
|
||||
self.cur.execute("SELECT * FROM users")
|
||||
|
||||
users = self.cur.fetchall()
|
||||
|
||||
return users
|
||||
|
||||
def query_user(self, mxid: str) -> bool:
|
||||
"""
|
||||
Check whether a puppet user has already been created for a given mxid.
|
||||
"""
|
||||
|
||||
with self.lock:
|
||||
self.cur.execute("SELECT mxid FROM users")
|
||||
|
||||
users = self.cur.fetchall()
|
||||
|
||||
return next((True for user in users if user["mxid"] == mxid), False)
|
175
appservice/discord.py
Normal file
175
appservice/discord.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
CDN_URL = "https://cdn.discordapp.com"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Channel(object):
|
||||
id: str
|
||||
type: str
|
||||
guild_id: str = ""
|
||||
name: str = ""
|
||||
topic: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Emote(object):
|
||||
animated: bool
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class User(object):
|
||||
def __init__(self, user: dict) -> None:
|
||||
self.discriminator = user["discriminator"]
|
||||
self.id = user["id"]
|
||||
self.mention = f"<@{self.id}>"
|
||||
self.username = user["username"]
|
||||
|
||||
avatar = user["avatar"]
|
||||
|
||||
if not avatar:
|
||||
# https://discord.com/developers/docs/reference#image-formatting
|
||||
self.avatar_url = (
|
||||
f"{CDN_URL}/embed/avatars/{int(self.discriminator) % 5}.png"
|
||||
)
|
||||
else:
|
||||
ext = "gif" if avatar.startswith("a_") else "png"
|
||||
self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}"
|
||||
|
||||
|
||||
class Message(object):
|
||||
def __init__(self, message: dict) -> None:
|
||||
self.attachments = message.get("attachments", [])
|
||||
self.channel_id = message["channel_id"]
|
||||
self.content = message.get("content", "")
|
||||
self.id = message["id"]
|
||||
self.reference = message.get("message_reference", {}).get(
|
||||
"message_id", ""
|
||||
)
|
||||
self.webhook_id = message.get("webhook_id", "")
|
||||
|
||||
self.mentions = [
|
||||
User(mention) for mention in message.get("mentions", [])
|
||||
]
|
||||
|
||||
author = message.get("author")
|
||||
|
||||
self.author = User(author) if author else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeletedMessage(object):
|
||||
channel_id: str
|
||||
id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Typing(object):
|
||||
user_id: str
|
||||
channel_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Webhook(object):
|
||||
id: str
|
||||
token: str
|
||||
|
||||
|
||||
class ChannelType(object):
|
||||
GUILD_TEXT = 0
|
||||
DM = 1
|
||||
GUILD_VOICE = 2
|
||||
GROUP_DM = 3
|
||||
GUILD_CATEGORY = 4
|
||||
GUILD_NEWS = 5
|
||||
GUILD_STORE = 6
|
||||
|
||||
|
||||
class InteractionResponseType(object):
|
||||
PONG = 0
|
||||
ACKNOWLEDGE = 1
|
||||
CHANNEL_MESSAGE = 2
|
||||
CHANNEL_MESSAGE_WITH_SOURCE = 4
|
||||
ACKNOWLEDGE_WITH_SOURCE = 5
|
||||
|
||||
|
||||
class GatewayIntents(object):
|
||||
def bitmask(bit: int) -> int:
|
||||
return 1 << bit
|
||||
|
||||
GUILDS = bitmask(0)
|
||||
GUILD_MEMBERS = bitmask(1)
|
||||
GUILD_BANS = bitmask(2)
|
||||
GUILD_EMOJIS = bitmask(3)
|
||||
GUILD_INTEGRATIONS = bitmask(4)
|
||||
GUILD_WEBHOOKS = bitmask(5)
|
||||
GUILD_INVITES = bitmask(6)
|
||||
GUILD_VOICE_STATES = bitmask(7)
|
||||
GUILD_PRESENCES = bitmask(8)
|
||||
GUILD_MESSAGES = bitmask(9)
|
||||
GUILD_MESSAGE_REACTIONS = bitmask(10)
|
||||
GUILD_MESSAGE_TYPING = bitmask(11)
|
||||
DIRECT_MESSAGES = bitmask(12)
|
||||
DIRECT_MESSAGE_REACTIONS = bitmask(13)
|
||||
DIRECT_MESSAGE_TYPING = bitmask(14)
|
||||
|
||||
|
||||
class GatewayOpCodes(object):
|
||||
DISPATCH = 0
|
||||
HEARTBEAT = 1
|
||||
IDENTIFY = 2
|
||||
PRESENCE_UPDATE = 3
|
||||
VOICE_STATE_UPDATE = 4
|
||||
RESUME = 6
|
||||
RECONNECT = 7
|
||||
REQUEST_GUILD_MEMBERS = 8
|
||||
INVALID_SESSION = 9
|
||||
HELLO = 10
|
||||
HEARTBEAT_ACK = 11
|
||||
|
||||
|
||||
class Payloads(object):
|
||||
def __init__(self, token: str) -> None:
|
||||
self.seq = self.session = None
|
||||
self.token = token
|
||||
|
||||
def HEARTBEAT(self) -> dict:
|
||||
return {"op": GatewayOpCodes.HEARTBEAT, "d": self.seq}
|
||||
|
||||
def IDENTIFY(self) -> dict:
|
||||
return {
|
||||
"op": GatewayOpCodes.IDENTIFY,
|
||||
"d": {
|
||||
"token": self.token,
|
||||
"intents": GatewayIntents.GUILDS
|
||||
| GatewayIntents.GUILD_MESSAGES
|
||||
| GatewayIntents.GUILD_MESSAGE_TYPING,
|
||||
"properties": {
|
||||
"$os": "discord",
|
||||
"$browser": "discord",
|
||||
"$device": "discord",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def QUERY(self, guild_id: str, query: str, limit: int = 1) -> dict:
|
||||
"""
|
||||
Return the Payload to query a member from a guild ID.
|
||||
Return only a single match if `limit` isn't specified.
|
||||
"""
|
||||
|
||||
return {
|
||||
"op": GatewayOpCodes.REQUEST_GUILD_MEMBERS,
|
||||
"d": {"guild_id": guild_id, "query": query, "limit": limit},
|
||||
}
|
||||
|
||||
def RESUME(self) -> dict:
|
||||
return {
|
||||
"op": GatewayOpCodes.RESUME,
|
||||
"d": {
|
||||
"token": self.token,
|
||||
"session_id": self.session,
|
||||
"seq": self.seq,
|
||||
},
|
||||
}
|
5
appservice/errors.py
Normal file
5
appservice/errors.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
class RequestError(Exception):
|
||||
def __init__(self, status: int, *args):
|
||||
super().__init__(*args)
|
||||
|
||||
self.status = status
|
299
appservice/gateway.py
Normal file
299
appservice/gateway.py
Normal file
|
@ -0,0 +1,299 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import List
|
||||
|
||||
import urllib3
|
||||
import websockets
|
||||
|
||||
import discord
|
||||
from misc import dict_cls, log_except, request, wrap_async
|
||||
|
||||
|
||||
class Gateway(object):
|
||||
def __init__(self, http: urllib3.PoolManager, token: str):
|
||||
self.http = http
|
||||
self.token = token
|
||||
self.logger = logging.getLogger("discord")
|
||||
self.cdn_url = "https://cdn.discordapp.com"
|
||||
self.Payloads = discord.Payloads(self.token)
|
||||
self.loop = self.websocket = None
|
||||
|
||||
self.query_cache = {}
|
||||
|
||||
@log_except
|
||||
async def run(self) -> None:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.query_ev = asyncio.Event()
|
||||
|
||||
self.heartbeat_task = None
|
||||
self.resume = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self.gateway_handler(self.get_gateway_url())
|
||||
except websockets.ConnectionClosedError:
|
||||
# TODO reconnect ?
|
||||
self.logger.exception("Quitting, connection lost.")
|
||||
break
|
||||
|
||||
# Stop sending heartbeats until we reconnect.
|
||||
if self.heartbeat_task and not self.heartbeat_task.cancelled():
|
||||
self.heartbeat_task.cancel()
|
||||
|
||||
def get_gateway_url(self) -> str:
|
||||
resp = self.send("GET", "/gateway")
|
||||
|
||||
return resp["url"]
|
||||
|
||||
async def heartbeat_handler(self, interval_ms: int) -> None:
|
||||
while True:
|
||||
await asyncio.sleep(interval_ms / 1000)
|
||||
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
|
||||
|
||||
def query_handler(self, data: dict) -> None:
|
||||
members = data["members"]
|
||||
guild_id = data["guild_id"]
|
||||
|
||||
for member in members:
|
||||
user = member["user"]
|
||||
self.query_cache[guild_id].append(user)
|
||||
|
||||
self.query_ev.set()
|
||||
|
||||
def handle_otype(self, data: dict, otype: str) -> None:
|
||||
if data.get("embeds"):
|
||||
return # TODO embeds
|
||||
|
||||
if otype == "MESSAGE_CREATE" or otype == "MESSAGE_UPDATE":
|
||||
obj = discord.Message(data)
|
||||
elif otype == "MESSAGE_DELETE":
|
||||
obj = dict_cls(data, discord.DeletedMessage)
|
||||
elif otype == "TYPING_START":
|
||||
obj = dict_cls(data, discord.Typing)
|
||||
elif otype == "GUILD_MEMBERS_CHUNK":
|
||||
self.query_handler(data)
|
||||
return
|
||||
else:
|
||||
self.logger.info(f"Unknown OTYPE: {otype}")
|
||||
return
|
||||
|
||||
func = getattr(self, f"on_{otype.lower()}", None)
|
||||
|
||||
if not func:
|
||||
self.logger.warning(
|
||||
f"Function '{func}' not defined, ignoring message."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
func(obj)
|
||||
except Exception:
|
||||
self.logger.exception(f"Ignoring exception in {func}:")
|
||||
|
||||
async def gateway_handler(self, gateway_url: str) -> None:
|
||||
async with websockets.connect(
|
||||
f"{gateway_url}/?v=8&encoding=json"
|
||||
) as websocket:
|
||||
self.websocket = websocket
|
||||
async for message in websocket:
|
||||
data = json.loads(message)
|
||||
data_dict = data.get("d")
|
||||
|
||||
opcode = data.get("op")
|
||||
|
||||
seq = data.get("s")
|
||||
if seq:
|
||||
self.Payloads.seq = seq
|
||||
|
||||
if opcode == discord.GatewayOpCodes.DISPATCH:
|
||||
otype = data.get("t")
|
||||
|
||||
if otype == "READY":
|
||||
self.Payloads.session = data_dict["session_id"]
|
||||
|
||||
self.logger.info("READY")
|
||||
|
||||
else:
|
||||
self.handle_otype(data_dict, otype)
|
||||
|
||||
elif opcode == discord.GatewayOpCodes.HELLO:
|
||||
heartbeat_interval = data_dict.get("heartbeat_interval")
|
||||
|
||||
self.logger.info(
|
||||
f"Heartbeat Interval: {heartbeat_interval}"
|
||||
)
|
||||
|
||||
# Send periodic hearbeats to gateway.
|
||||
self.heartbeat_task = asyncio.ensure_future(
|
||||
self.heartbeat_handler(heartbeat_interval)
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
self.Payloads.RESUME()
|
||||
if self.resume
|
||||
else self.Payloads.IDENTIFY()
|
||||
)
|
||||
)
|
||||
|
||||
elif opcode == discord.GatewayOpCodes.RECONNECT:
|
||||
self.logger.info("Received RECONNECT.")
|
||||
|
||||
self.resume = True
|
||||
await websocket.close()
|
||||
|
||||
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
|
||||
self.logger.info("Received INVALID_SESSION.")
|
||||
|
||||
self.resume = False
|
||||
await websocket.close()
|
||||
|
||||
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
|
||||
# NOP
|
||||
pass
|
||||
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Unknown OP code {opcode}:\n"
|
||||
f"{json.dumps(data, indent=4)}"
|
||||
)
|
||||
|
||||
@wrap_async
|
||||
async def query_member(self, guild_id: str, name: str) -> discord.User:
|
||||
"""
|
||||
Query the members for a given guild and return the first match.
|
||||
"""
|
||||
|
||||
self.query_ev.clear()
|
||||
|
||||
def query():
|
||||
if not self.query_cache.get(guild_id):
|
||||
self.query_cache[guild_id] = []
|
||||
|
||||
user = [
|
||||
user
|
||||
for user in self.query_cache[guild_id]
|
||||
if name.lower() in user["username"].lower()
|
||||
]
|
||||
|
||||
return None if not user else discord.User(user[0])
|
||||
|
||||
user = query()
|
||||
|
||||
if user:
|
||||
return user
|
||||
|
||||
if not self.websocket or self.websocket.closed:
|
||||
self.logger.warning("Not fetching members, websocket closed.")
|
||||
return
|
||||
|
||||
await self.websocket.send(
|
||||
json.dumps(self.Payloads.QUERY(guild_id, name))
|
||||
)
|
||||
|
||||
# Wait for our websocket to receive the chunk.
|
||||
await asyncio.wait_for(self.query_ev.wait(), timeout=5)
|
||||
|
||||
return query()
|
||||
|
||||
def get_channel(self, channel_id: str) -> discord.Channel:
|
||||
"""
|
||||
Get the channel object for a given channel ID.
|
||||
"""
|
||||
|
||||
resp = self.send("GET", f"/channels/{channel_id}")
|
||||
|
||||
return dict_cls(resp, discord.Channel)
|
||||
|
||||
def get_emotes(self, guild_id: str) -> List[discord.Emote]:
|
||||
"""
|
||||
Get all the emotes for a given guild.
|
||||
"""
|
||||
|
||||
resp = self.send("GET", f"/guilds/{guild_id}/emojis")
|
||||
|
||||
return [dict_cls(emote, discord.Emote) for emote in resp]
|
||||
|
||||
def get_members(self, guild_id: str) -> List[discord.User]:
|
||||
"""
|
||||
Get all the members for a given guild.
|
||||
"""
|
||||
|
||||
resp = self.send(
|
||||
"GET", f"/guilds/{guild_id}/members", params={"limit": 1000}
|
||||
)
|
||||
|
||||
return [discord.User(member["user"]) for member in resp]
|
||||
|
||||
def create_webhook(self, channel_id: str, name: str) -> discord.Webhook:
|
||||
"""
|
||||
Create a webhook with the specified name in a given channel.
|
||||
"""
|
||||
|
||||
resp = self.send(
|
||||
"POST", f"/channels/{channel_id}/webhooks", {"name": name}
|
||||
)
|
||||
|
||||
return dict_cls(resp, discord.Webhook)
|
||||
|
||||
def edit_webhook(
|
||||
self, content: str, message_id: str, webhook: discord.Webhook
|
||||
) -> None:
|
||||
self.send(
|
||||
"PATCH",
|
||||
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
|
||||
f"{message_id}",
|
||||
{"content": content},
|
||||
)
|
||||
|
||||
def delete_webhook(
|
||||
self, message_id: str, webhook: discord.Webhook
|
||||
) -> None:
|
||||
self.send(
|
||||
"DELETE",
|
||||
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
|
||||
f"{message_id}",
|
||||
)
|
||||
|
||||
def send_webhook(self, webhook: discord.Webhook, **kwargs) -> str:
|
||||
content = {
|
||||
**kwargs,
|
||||
# Disable 'everyone' and 'role' mentions.
|
||||
"allowed_mentions": {"parse": ["users"]},
|
||||
}
|
||||
|
||||
resp = self.send(
|
||||
"POST",
|
||||
f"/webhooks/{webhook.id}/{webhook.token}",
|
||||
content,
|
||||
{"wait": True},
|
||||
)
|
||||
|
||||
return resp["id"]
|
||||
|
||||
def send_message(self, message: str, channel_id: str) -> None:
|
||||
self.send(
|
||||
"POST", f"/channels/{channel_id}/messages", {"content": message}
|
||||
)
|
||||
|
||||
@request
|
||||
def send(
|
||||
self, method: str, path: str, content: dict = {}, params: dict = {}
|
||||
) -> dict:
|
||||
endpoint = (
|
||||
f"https://discord.com/api/v8{path}?"
|
||||
f"{urllib.parse.urlencode(params)}"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bot {self.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# 'body' being an empty dict breaks "GET" requests.
|
||||
content = json.dumps(content) if content else None
|
||||
|
||||
return self.http.request(
|
||||
method, endpoint, body=content, headers=headers
|
||||
)
|
656
appservice/main.py
Normal file
656
appservice/main.py
Normal file
|
@ -0,0 +1,656 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import urllib3
|
||||
|
||||
import discord
|
||||
import matrix
|
||||
from appservice import AppService
|
||||
from db import DataBase
|
||||
from errors import RequestError
|
||||
from gateway import Gateway
|
||||
from misc import dict_cls, except_deleted
|
||||
|
||||
# TODO should this be cleared periodically ?
|
||||
message_cache: Dict[str, Union[discord.Webhook, str]] = {}
|
||||
|
||||
|
||||
class MatrixClient(AppService):
|
||||
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
|
||||
super().__init__(config, http)
|
||||
|
||||
self.db = DataBase(config["database"])
|
||||
self.discord = DiscordClient(self, config, http)
|
||||
self.emote_cache: Dict[str, str] = {}
|
||||
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
|
||||
|
||||
def to_return(self, event: matrix.Event) -> bool:
|
||||
return event.sender.startswith(("@_discord", self.user_id))
|
||||
|
||||
def handle_bridge(self, message: matrix.Event) -> None:
|
||||
# Ignore events that aren't for us.
|
||||
if message.sender.split(":")[
|
||||
-1
|
||||
] != self.server_name or not message.body.startswith("!bridge"):
|
||||
return
|
||||
|
||||
try:
|
||||
channel = message.body.split()[1]
|
||||
except IndexError:
|
||||
return
|
||||
|
||||
# Check if the given channel is valid.
|
||||
try:
|
||||
channel = self.discord.get_channel(channel)
|
||||
except RequestError as e:
|
||||
# The channel can be invalid or we may not have permission.
|
||||
self.logger.warning(f"Failed to fetch channel {channel}: {e}")
|
||||
return
|
||||
|
||||
if (
|
||||
channel.type != discord.ChannelType.GUILD_TEXT
|
||||
or channel.id in self.db.list_channels()
|
||||
):
|
||||
return
|
||||
|
||||
self.logger.info(f"Creating bridged room for channel {channel.id}.")
|
||||
|
||||
self.create_room(channel, message.sender)
|
||||
|
||||
def on_member(self, event: matrix.Event) -> None:
|
||||
# Ignore events that aren't for us.
|
||||
if (
|
||||
event.sender.split(":")[-1] != self.server_name
|
||||
or event.state_key != self.user_id
|
||||
or not event.is_direct
|
||||
):
|
||||
return
|
||||
|
||||
# Join the direct message room.
|
||||
self.logger.info(f"Joining direct message room {event.room_id}.")
|
||||
self.join_room(event.room_id)
|
||||
|
||||
def on_message(self, message: matrix.Event) -> None:
|
||||
if self.to_return(message):
|
||||
return
|
||||
|
||||
# Handle bridging commands.
|
||||
self.handle_bridge(message)
|
||||
|
||||
channel_id = self.db.get_channel(message.room_id)
|
||||
|
||||
if not channel_id:
|
||||
return
|
||||
|
||||
webhook = self.discord.get_webhook(channel_id, "matrix_bridge")
|
||||
|
||||
if message.relates_to and message.reltype == "m.replace":
|
||||
|
||||
relation = message_cache.get(message.relates_to)
|
||||
|
||||
if not message.new_body or not relation:
|
||||
return
|
||||
|
||||
message.new_body = self.process_message(
|
||||
channel_id, message.new_body
|
||||
)
|
||||
|
||||
except_deleted(self.discord.edit_webhook)(
|
||||
message.new_body, relation["message_id"], webhook
|
||||
)
|
||||
|
||||
else:
|
||||
if not message.body:
|
||||
return
|
||||
|
||||
message.body = self.process_message(channel_id, message.body)
|
||||
message_cache[message.event_id] = {
|
||||
"message_id": self.discord.send_webhook(
|
||||
webhook,
|
||||
avatar_url=message.author.avatar_url,
|
||||
content=message.body,
|
||||
username=message.author.displayname,
|
||||
),
|
||||
"webhook": webhook,
|
||||
}
|
||||
|
||||
@except_deleted
|
||||
def on_redaction(self, event: dict) -> None:
|
||||
redacts = event["redacts"]
|
||||
|
||||
event = message_cache.get(redacts)
|
||||
|
||||
if not event:
|
||||
return
|
||||
|
||||
self.discord.delete_webhook(event["message_id"], event["webhook"])
|
||||
|
||||
message_cache.pop(redacts)
|
||||
|
||||
def create_room(self, channel: discord.Channel, sender: str) -> None:
|
||||
"""
|
||||
Create a bridged room and invite the person who invoked the command.
|
||||
"""
|
||||
|
||||
content = {
|
||||
"room_alias_name": f"{self.format}{channel.id}",
|
||||
"name": channel.name,
|
||||
"topic": channel.topic,
|
||||
"visibility": "private",
|
||||
"invite": [sender],
|
||||
"creation_content": {"m.federate": True},
|
||||
"initial_state": [
|
||||
{
|
||||
"type": "m.room.join_rules",
|
||||
"content": {"join_rule": "public"},
|
||||
},
|
||||
{
|
||||
"type": "m.room.history_visibility",
|
||||
"content": {"history_visibility": "shared"},
|
||||
},
|
||||
],
|
||||
"power_level_content_override": {"users": {sender: 100}},
|
||||
}
|
||||
|
||||
resp = self.send("POST", "/createRoom", content)
|
||||
|
||||
self.db.add_room(resp["room_id"], channel.id)
|
||||
|
||||
def create_message_event(
|
||||
self, message: str, emotes: dict, edit: str = "", reply: str = ""
|
||||
) -> dict:
|
||||
content = {
|
||||
"body": message,
|
||||
"format": "org.matrix.custom.html",
|
||||
"msgtype": "m.text",
|
||||
"formatted_body": self.get_fmt(message, emotes),
|
||||
}
|
||||
|
||||
event = message_cache.get(reply)
|
||||
|
||||
if event:
|
||||
content = {
|
||||
**content,
|
||||
"m.relates_to": {
|
||||
"m.in_reply_to": {"event_id": event["event_id"]}
|
||||
},
|
||||
"formatted_body": f"""<mx-reply><blockquote>\
|
||||
<a href='https://matrix.to/#/{event["room_id"]}/{event["event_id"]}'>\
|
||||
In reply to</a><a href='https://matrix.to/#/{event["mxid"]}'>\
|
||||
{event["mxid"]}</a><br>{event["body"]}</blockquote></mx-reply>\
|
||||
{content["formatted_body"]}""",
|
||||
}
|
||||
|
||||
if edit:
|
||||
content = {
|
||||
**content,
|
||||
"body": f" * {content['body']}",
|
||||
"formatted_body": f" * {content['formatted_body']}",
|
||||
"m.relates_to": {"event_id": edit, "rel_type": "m.replace"},
|
||||
"m.new_content": {**content},
|
||||
}
|
||||
|
||||
return content
|
||||
|
||||
def get_fmt(self, message: str, emotes: dict) -> str:
|
||||
replace = [
|
||||
# Bold.
|
||||
("**", "<strong>", "</strong>"),
|
||||
# Code blocks.
|
||||
("```", "<pre><code>", "</code></pre>"),
|
||||
# Spoilers.
|
||||
("||", "<span data-mx-spoiler>", "</span>"),
|
||||
# Strikethrough.
|
||||
("~~", "<del>", "</del>"),
|
||||
]
|
||||
|
||||
for replace_ in replace:
|
||||
for i in range(1, message.count(replace_[0]) + 1):
|
||||
if i % 2:
|
||||
message = message.replace(replace_[0], replace_[1], 1)
|
||||
else:
|
||||
message = message.replace(replace_[0], replace_[2], 1)
|
||||
|
||||
# Upload emotes in multiple threads so that we don't
|
||||
# block the Discord bot for too long.
|
||||
upload_threads = [
|
||||
threading.Thread(
|
||||
target=self.upload_emote, args=(emote, emotes[emote])
|
||||
)
|
||||
for emote in emotes
|
||||
]
|
||||
|
||||
[thread.start() for thread in upload_threads]
|
||||
[thread.join() for thread in upload_threads]
|
||||
|
||||
for emote in emotes:
|
||||
emote_ = self.emote_cache.get(emote)
|
||||
|
||||
if emote_:
|
||||
emote = f":{emote}:"
|
||||
message = message.replace(
|
||||
emote,
|
||||
f"""<img alt=\"{emote}\" title=\"{emote}\" \
|
||||
height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
def process_message(self, channel_id: str, message: str) -> str:
|
||||
message = message[:2000] # Discord limit.
|
||||
|
||||
emotes = re.findall(r":(\w*):", message)
|
||||
mentions = re.findall(r"(@(\w*))", message)
|
||||
|
||||
# Remove the puppet user's username from replies.
|
||||
message = re.sub(f"<@{self.format}.+?>", "", message)
|
||||
|
||||
added_emotes = []
|
||||
for emote in emotes:
|
||||
# Don't replace emote names with IDs multiple times.
|
||||
if emote not in added_emotes:
|
||||
added_emotes.append(emote)
|
||||
emote_ = self.discord.emote_cache.get(emote)
|
||||
if emote_:
|
||||
message = message.replace(f":{emote}:", emote_)
|
||||
|
||||
# Don't unnecessarily fetch the channel.
|
||||
if mentions:
|
||||
guild_id = self.discord.get_channel(channel_id).guild_id
|
||||
|
||||
for mention in mentions:
|
||||
if not mention[1]:
|
||||
continue
|
||||
|
||||
try:
|
||||
member = self.discord.query_member(guild_id, mention[1])
|
||||
except (asyncio.TimeoutError, RuntimeError):
|
||||
continue
|
||||
|
||||
if member:
|
||||
message = message.replace(mention[0], member.mention)
|
||||
|
||||
return message
|
||||
|
||||
def upload_emote(self, emote_name: str, emote_id: str) -> None:
|
||||
# There won't be a race condition here, since only a unique
|
||||
# set of emotes are uploaded at a time.
|
||||
if emote_name in self.emote_cache:
|
||||
return
|
||||
|
||||
emote_url = f"{discord.CDN_URL}/emojis/{emote_id}"
|
||||
|
||||
# We don't want the message to be dropped entirely if an emote
|
||||
# fails to upload for some reason.
|
||||
try:
|
||||
self.emote_cache[emote_name] = self.upload(emote_url)
|
||||
except RequestError as e:
|
||||
self.logger.warning(f"Failed to upload emote {emote_id}: {e}")
|
||||
|
||||
def register(self, mxid: str) -> None:
|
||||
"""
|
||||
Register a dummy user on the homeserver.
|
||||
"""
|
||||
|
||||
content = {
|
||||
"type": "m.login.application_service",
|
||||
# "@test:localhost" -> "test" (Can't register with a full mxid.)
|
||||
"username": mxid[1:].split(":")[0],
|
||||
}
|
||||
|
||||
resp = self.send("POST", "/register", content)
|
||||
|
||||
self.db.add_user(resp["user_id"])
|
||||
|
||||
def set_avatar(self, avatar_url: str, mxid: str) -> None:
|
||||
avatar_uri = self.upload(avatar_url)
|
||||
|
||||
self.send(
|
||||
"PUT",
|
||||
f"/profile/{mxid}/avatar_url",
|
||||
{"avatar_url": avatar_uri},
|
||||
params={"user_id": mxid},
|
||||
)
|
||||
|
||||
self.db.add_avatar(avatar_url, mxid)
|
||||
|
||||
def set_nick(self, username: str, mxid: str) -> None:
|
||||
self.send(
|
||||
"PUT",
|
||||
f"/profile/{mxid}/displayname",
|
||||
{"displayname": username},
|
||||
params={"user_id": mxid},
|
||||
)
|
||||
|
||||
self.db.add_username(username, mxid)
|
||||
|
||||
|
||||
class DiscordClient(Gateway):
|
||||
def __init__(
|
||||
self, appservice: MatrixClient, config: dict, http: urllib3.PoolManager
|
||||
) -> None:
|
||||
super().__init__(http, config["discord_token"])
|
||||
|
||||
self.app = appservice
|
||||
self.emote_cache: Dict[str, str] = {}
|
||||
self.webhook_cache: Dict[str, discord.Webhook] = {}
|
||||
|
||||
async def sync(self) -> None:
|
||||
"""
|
||||
Periodically compare the usernames and avatar URLs with Discord
|
||||
and update if they differ. Also synchronise emotes.
|
||||
"""
|
||||
|
||||
def sync_emotes(guilds: set):
|
||||
# We could store the emotes once and update according
|
||||
# to gateway events but we're too lazy for that.
|
||||
emotes = []
|
||||
|
||||
for guild in guilds:
|
||||
[emotes.append(emote) for emote in (self.get_emotes(guild))]
|
||||
|
||||
self.emote_cache.clear() # Clears deleted/renamed emotes.
|
||||
|
||||
for emote in emotes:
|
||||
self.emote_cache[f"{emote.name}"] = (
|
||||
f"<{'a' if emote.animated else ''}:"
|
||||
f"{emote.name}:{emote.id}>"
|
||||
)
|
||||
|
||||
def sync_users(guilds: set):
|
||||
# TODO use websockets for this, using IDs from database.
|
||||
|
||||
users = []
|
||||
|
||||
for guild in guilds:
|
||||
[users.append(member) for member in self.get_members(guild)]
|
||||
|
||||
db_users = self.app.db.list_users()
|
||||
|
||||
# Convert a list of dicts:
|
||||
# [ { "avatar_url": ... } ]
|
||||
# to a dict that is indexable by Discord IDs:
|
||||
# { "discord_id": { "avatar_url": ... } }
|
||||
users_ = {}
|
||||
|
||||
for user in db_users:
|
||||
users_[user["mxid"].split("_")[-1].split(":")[0]] = {**user}
|
||||
|
||||
for user in users:
|
||||
user_ = users_.get(user.id)
|
||||
|
||||
if not user_:
|
||||
continue
|
||||
|
||||
mxid = user_["mxid"]
|
||||
username = f"{user.username}#{user.discriminator}"
|
||||
|
||||
if user.avatar_url != user_["avatar_url"]:
|
||||
self.logger.info(
|
||||
f"Updating avatar for Discord user {user.id}."
|
||||
)
|
||||
self.app.set_avatar(user.avatar_url, mxid)
|
||||
|
||||
if username != user_["username"]:
|
||||
self.logger.info(
|
||||
f"Updating username for Discord user {user.id}."
|
||||
)
|
||||
self.app.set_nick(username, mxid)
|
||||
|
||||
while True:
|
||||
guilds = set() # Avoid duplicates.
|
||||
|
||||
try:
|
||||
for channel in self.app.db.list_channels():
|
||||
guilds.add(self.get_channel(channel).guild_id)
|
||||
|
||||
sync_emotes(guilds)
|
||||
sync_users(guilds)
|
||||
# Don't let the background task die.
|
||||
except RequestError:
|
||||
self.logger.exception(
|
||||
"Ignoring exception during background sync:"
|
||||
)
|
||||
|
||||
await asyncio.sleep(120) # Check every 2 minutes.
|
||||
|
||||
async def start(self) -> None:
|
||||
asyncio.ensure_future(self.sync())
|
||||
|
||||
await self.run()
|
||||
|
||||
def to_return(self, message: discord.Message) -> bool:
|
||||
return (
|
||||
message.channel_id not in self.app.db.list_channels()
|
||||
or not message.author
|
||||
or message.author.discriminator == "0000"
|
||||
)
|
||||
|
||||
def matrixify(self, id: str, user: bool = False) -> str:
|
||||
return (
|
||||
f"{'@' if user else '#'}{self.app.format}{id}:"
|
||||
f"{self.app.server_name}"
|
||||
)
|
||||
|
||||
def wrap(self, message: discord.Message) -> Tuple[str, str]:
|
||||
"""
|
||||
Get the room ID and the puppet's mxid for a given channel ID and a
|
||||
Discord user.
|
||||
"""
|
||||
|
||||
mxid = self.matrixify(message.author.id, user=True)
|
||||
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
|
||||
|
||||
if not self.app.db.query_user(mxid):
|
||||
self.logger.info(
|
||||
f"Creating dummy user for Discord user {message.author.id}."
|
||||
)
|
||||
self.app.register(mxid)
|
||||
|
||||
self.app.set_nick(
|
||||
f"{message.author.username}#"
|
||||
f"{message.author.discriminator}",
|
||||
mxid,
|
||||
)
|
||||
|
||||
if message.author.avatar_url:
|
||||
self.app.set_avatar(message.author.avatar_url, mxid)
|
||||
|
||||
if mxid not in self.app.get_members(room_id):
|
||||
self.logger.info(f"Inviting user {mxid} to room {room_id}.")
|
||||
|
||||
self.app.send_invite(room_id, mxid)
|
||||
self.app.join_room(room_id, mxid)
|
||||
|
||||
return mxid, room_id
|
||||
|
||||
def on_message_create(self, message: discord.Message) -> None:
|
||||
if self.to_return(message):
|
||||
return
|
||||
|
||||
mxid, room_id = self.wrap(message)
|
||||
|
||||
content, emotes = self.process_message(message)
|
||||
|
||||
content = self.app.create_message_event(
|
||||
content, emotes, reply=message.reference
|
||||
)
|
||||
|
||||
message_cache[message.id] = {
|
||||
"body": content["body"],
|
||||
"event_id": self.app.send_message(room_id, content, mxid),
|
||||
"mxid": mxid,
|
||||
"room_id": room_id,
|
||||
}
|
||||
|
||||
def on_message_delete(self, message: discord.DeletedMessage) -> None:
|
||||
event = message_cache.get(message.id)
|
||||
|
||||
if not event:
|
||||
return
|
||||
|
||||
self.app.redact(event["event_id"], event["room_id"], event["mxid"])
|
||||
|
||||
message_cache.pop(message.id)
|
||||
|
||||
def on_message_update(self, message: dict) -> None:
|
||||
if self.to_return(message):
|
||||
return
|
||||
|
||||
event = message_cache.get(message.id)
|
||||
|
||||
if not event:
|
||||
return
|
||||
|
||||
content, emotes = self.process_message(message)
|
||||
|
||||
content = self.app.create_message_event(
|
||||
content, emotes, edit=event["event_id"]
|
||||
)
|
||||
|
||||
self.app.send_message(event["room_id"], content, event["mxid"])
|
||||
|
||||
def on_typing_start(self, typing: discord.Typing) -> None:
|
||||
if typing.channel_id not in self.app.db.list_channels():
|
||||
return
|
||||
|
||||
mxid = self.matrixify(typing.user_id, user=True)
|
||||
room_id = self.app.get_room_id(self.matrixify(typing.channel_id))
|
||||
|
||||
if mxid not in self.app.get_members(room_id):
|
||||
return
|
||||
|
||||
self.app.send_typing(room_id, mxid)
|
||||
|
||||
def get_webhook(self, channel_id: str, name: str) -> discord.Webhook:
|
||||
"""
|
||||
Get the webhook object for the first webhook that matches the specified
|
||||
name in a given channel, create the webhook if it doesn't exist.
|
||||
"""
|
||||
|
||||
# Check the cache first.
|
||||
webhook = self.webhook_cache.get(channel_id)
|
||||
|
||||
if webhook:
|
||||
return webhook
|
||||
|
||||
webhooks = self.send("GET", f"/channels/{channel_id}/webhooks")
|
||||
webhook = next(
|
||||
(
|
||||
dict_cls(webhook, discord.Webhook)
|
||||
for webhook in webhooks
|
||||
if webhook["name"] == name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not webhook:
|
||||
webhook = self.create_webhook(channel_id, name)
|
||||
|
||||
self.webhook_cache[channel_id] = webhook
|
||||
|
||||
return webhook
|
||||
|
||||
def process_message(self, message: discord.Message) -> Tuple[str, str]:
|
||||
content = message.content
|
||||
emotes = {}
|
||||
regex = r"<a?:(\w+):(\d+)>"
|
||||
|
||||
# Mentions can either be in the form of `<@1234>` or `<@!1234>`.
|
||||
for char in ("", "!"):
|
||||
for member in message.mentions:
|
||||
content = content.replace(
|
||||
f"<@{char}{member.id}>", f"@{member.username}"
|
||||
)
|
||||
|
||||
# `except_deleted` for invalid channels.
|
||||
for channel in re.findall(r"<#([0-9]+)>", content):
|
||||
channel_ = except_deleted(self.get_channel)(channel)
|
||||
content = content.replace(
|
||||
f"<#{channel}>",
|
||||
f"#{channel_.name}" if channel_ else "deleted-channel",
|
||||
)
|
||||
|
||||
# { "emote_name": "emote_id" }
|
||||
for emote in re.findall(regex, content):
|
||||
emotes[emote[0]] = emote[1]
|
||||
|
||||
# Replace emote IDs with names.
|
||||
content = re.sub(regex, r":\g<1>:", content)
|
||||
|
||||
# Append attachments to message.
|
||||
for attachment in message.attachments:
|
||||
content += f"\n{attachment['url']}"
|
||||
|
||||
return content, emotes
|
||||
|
||||
|
||||
def config_gen(basedir: str, config_file: str) -> dict:
|
||||
config_file = f"{basedir}/{config_file}"
|
||||
|
||||
config_dict = {
|
||||
"as_token": "my-secret-as-token",
|
||||
"hs_token": "my-secret-hs-token",
|
||||
"user_id": "appservice-discord",
|
||||
"homeserver": "http://127.0.0.1:8008",
|
||||
"server_name": "localhost",
|
||||
"discord_token": "my-secret-discord-token",
|
||||
"port": 5000,
|
||||
"database": f"{basedir}/bridge.db",
|
||||
}
|
||||
|
||||
if not os.path.exists(config_file):
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
print(f"Configuration dumped to '{config_file}'")
|
||||
sys.exit()
|
||||
|
||||
with open(config_file, "r") as f:
|
||||
return json.loads(f.read())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
basedir = sys.argv[1]
|
||||
if not os.path.exists(basedir):
|
||||
print(f"Path '{basedir}' does not exist!")
|
||||
sys.exit(1)
|
||||
basedir = os.path.abspath(basedir)
|
||||
except IndexError:
|
||||
basedir = os.getcwd()
|
||||
|
||||
config = config_gen(basedir, "appservice.json")
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[
|
||||
logging.FileHandler(f"{basedir}/appservice.log"),
|
||||
],
|
||||
)
|
||||
|
||||
http = urllib3.PoolManager(maxsize=10)
|
||||
|
||||
app = MatrixClient(config, http)
|
||||
|
||||
# Start the bottle app in a separate thread.
|
||||
app_thread = threading.Thread(
|
||||
target=app.run, kwargs={"port": int(config["port"])}, daemon=True
|
||||
)
|
||||
app_thread.start()
|
||||
|
||||
try:
|
||||
asyncio.run(app.discord.start())
|
||||
except KeyboardInterrupt:
|
||||
sys.exit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
26
appservice/matrix.py
Normal file
26
appservice/matrix.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class User(object):
|
||||
avatar_url: str
|
||||
displayname: str
|
||||
|
||||
|
||||
class Event(object):
|
||||
def __init__(self, event: dict):
|
||||
content = event["content"]
|
||||
|
||||
self.author = event["author"]
|
||||
self.body = content.get("body", "")
|
||||
self.event_id = event["event_id"]
|
||||
self.is_direct = content.get("is_direct", False)
|
||||
self.room_id = event["room_id"]
|
||||
self.sender = event["sender"]
|
||||
self.state_key = event.get("state_key", "")
|
||||
|
||||
rel = content.get("m.relates_to", {})
|
||||
|
||||
self.relates_to = rel.get("event_id")
|
||||
self.reltype = rel.get("rel_type")
|
||||
self.new_body = content.get("m.new_content", {}).get("body", "")
|
89
appservice/misc.py
Normal file
89
appservice/misc.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
import asyncio
|
||||
import json
|
||||
from dataclasses import fields
|
||||
from typing import Any
|
||||
|
||||
import urllib3
|
||||
|
||||
from errors import RequestError
|
||||
|
||||
|
||||
def dict_cls(dict_var: dict, cls: Any) -> Any:
|
||||
"""
|
||||
Create a dataclass from a dictionary.
|
||||
"""
|
||||
|
||||
field_names = set(f.name for f in fields(cls))
|
||||
filtered_dict = {k: v for k, v in dict_var.items() if k in field_names}
|
||||
|
||||
return cls(**filtered_dict)
|
||||
|
||||
|
||||
def log_except(fn):
|
||||
"""
|
||||
Log unhandled exceptions to a logger instead of `stderr`.
|
||||
"""
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return fn(self, *args, **kwargs)
|
||||
except Exception:
|
||||
self.logger.exception(f"Exception in '{fn.__name__}':")
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def wrap_async(fn):
|
||||
"""
|
||||
Call an asynchronous function from a synchronous one.
|
||||
"""
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.loop:
|
||||
raise RuntimeError("loop is None.")
|
||||
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
fn(self, *args, **kwargs), loop=self.loop
|
||||
).result()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def request(fn):
|
||||
"""
|
||||
Either return json data or raise a `RequestError` if the request was
|
||||
unsuccessful.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
resp = fn(*args, **kwargs)
|
||||
except urllib3.exceptions.HTTPError as e:
|
||||
raise RequestError(None, f"Failed to connect: {e}") from None
|
||||
|
||||
if resp.status < 200 or resp.status >= 300:
|
||||
raise RequestError(
|
||||
resp.status,
|
||||
f"Failed to get response from '{resp.geturl()}':\n{resp.data}",
|
||||
)
|
||||
|
||||
return {} if resp.status == 204 else json.loads(resp.data)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def except_deleted(fn):
|
||||
"""
|
||||
Ignore the `RequestError` on 404s, the message might have been
|
||||
deleted by someone else already.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except RequestError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
|
||||
return wrapper
|
3
appservice/requirements.txt
Normal file
3
appservice/requirements.txt
Normal file
|
@ -0,0 +1,3 @@
|
|||
bottle==0.12.19
|
||||
urllib3==1.26.3
|
||||
websockets==8.1
|
Loading…
Reference in a new issue