chore: remove unneeded files
This commit is contained in:
parent
1f04e0a88f
commit
25bcf38492
7 changed files with 0 additions and 1478 deletions
|
@ -1,203 +0,0 @@
|
|||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from typing import Union
|
||||
|
||||
import bottle
|
||||
import urllib3
|
||||
|
||||
import matrix
|
||||
from cache import Cache
|
||||
from misc import log_except, request
|
||||
|
||||
|
||||
class AppService(bottle.Bottle):
|
||||
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
|
||||
super(AppService, self).__init__()
|
||||
|
||||
self.as_token = config["as_token"]
|
||||
self.hs_token = config["hs_token"]
|
||||
self.base_url = config["homeserver"]
|
||||
self.server_name = config["server_name"]
|
||||
self.user_id = f"@{config['user_id']}:{self.server_name}"
|
||||
self.http = http
|
||||
self.logger = logging.getLogger("appservice")
|
||||
|
||||
# Map events to functions.
|
||||
self.mapping = {
|
||||
"m.room.member": "on_member",
|
||||
"m.room.message": "on_message",
|
||||
"m.room.redaction": "on_redaction",
|
||||
}
|
||||
|
||||
# Add route for bottle.
|
||||
self.route(
|
||||
"/transactions/<transaction>",
|
||||
callback=self.receive_event,
|
||||
method="PUT",
|
||||
)
|
||||
|
||||
Cache.cache["m_rooms"] = {}
|
||||
|
||||
def handle_event(self, event: dict) -> None:
|
||||
event_type = event.get("type")
|
||||
|
||||
if event_type in (
|
||||
"m.room.member",
|
||||
"m.room.message",
|
||||
"m.room.redaction",
|
||||
):
|
||||
obj = matrix.Event(event)
|
||||
else:
|
||||
self.logger.info(f"Unknown event type: {event_type}")
|
||||
return
|
||||
|
||||
func = getattr(self, self.mapping[event_type], None)
|
||||
|
||||
if not func:
|
||||
self.logger.warning(
|
||||
f"Function '{func}' not defined, ignoring event."
|
||||
)
|
||||
return
|
||||
|
||||
# We don't catch exceptions here as the homeserver will re-send us
|
||||
# the event in case of a failure.
|
||||
func(obj)
|
||||
|
||||
@log_except
|
||||
def receive_event(self, transaction: str) -> dict:
|
||||
"""
|
||||
Verify the homeserver's token and handle events.
|
||||
"""
|
||||
|
||||
hs_token = bottle.request.query.getone("access_token")
|
||||
|
||||
if not hs_token:
|
||||
bottle.response.status = 401
|
||||
return {"errcode": "APPSERVICE_UNAUTHORIZED"}
|
||||
|
||||
if hs_token != self.hs_token:
|
||||
bottle.response.status = 403
|
||||
return {"errcode": "APPSERVICE_FORBIDDEN"}
|
||||
|
||||
events = bottle.request.json.get("events")
|
||||
|
||||
for event in events:
|
||||
self.handle_event(event)
|
||||
|
||||
return {}
|
||||
|
||||
def mxc_url(self, mxc: str) -> str:
|
||||
try:
|
||||
homeserver, media_id = mxc.replace("mxc://", "").split("/")
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
return (
|
||||
f"https://{self.server_name}/_matrix/media/r0/download/"
|
||||
f"{homeserver}/{media_id}"
|
||||
)
|
||||
|
||||
def join_room(self, room_id: str, mxid: str = "") -> None:
|
||||
self.send(
|
||||
"POST",
|
||||
f"/join/{room_id}",
|
||||
params={"user_id": mxid} if mxid else {},
|
||||
)
|
||||
|
||||
def redact(self, event_id: str, room_id: str, mxid: str = "") -> None:
|
||||
self.send(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}",
|
||||
params={"user_id": mxid} if mxid else {},
|
||||
)
|
||||
|
||||
def get_room_id(self, alias: str) -> str:
|
||||
with Cache.lock:
|
||||
room = Cache.cache["m_rooms"].get(alias)
|
||||
if room:
|
||||
return room
|
||||
|
||||
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
|
||||
|
||||
room_id = resp["room_id"]
|
||||
|
||||
with Cache.lock:
|
||||
Cache.cache["m_rooms"][alias] = room_id
|
||||
|
||||
return room_id
|
||||
|
||||
def get_event(self, event_id: str, room_id: str) -> matrix.Event:
|
||||
resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}")
|
||||
|
||||
return matrix.Event(resp)
|
||||
|
||||
def upload(self, url: str) -> str:
|
||||
"""
|
||||
Upload a file to the homeserver and get the MXC url.
|
||||
"""
|
||||
|
||||
resp = self.http.request("GET", url)
|
||||
|
||||
resp = self.send(
|
||||
"POST",
|
||||
content=resp.data,
|
||||
content_type=resp.headers.get("Content-Type"),
|
||||
params={"filename": f"{uuid.uuid4()}"},
|
||||
endpoint="/_matrix/media/r0/upload",
|
||||
)
|
||||
|
||||
return resp["content_uri"]
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
room_id: str,
|
||||
content: dict,
|
||||
mxid: str = "",
|
||||
) -> str:
|
||||
resp = self.send(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}",
|
||||
content,
|
||||
{"user_id": mxid} if mxid else {},
|
||||
)
|
||||
|
||||
return resp["event_id"]
|
||||
|
||||
def send_typing(
|
||||
self, room_id: str, mxid: str = "", timeout: int = 8000
|
||||
) -> None:
|
||||
self.send(
|
||||
"PUT",
|
||||
f"/rooms/{room_id}/typing/{mxid}",
|
||||
{"typing": True, "timeout": timeout},
|
||||
{"user_id": mxid} if mxid else {},
|
||||
)
|
||||
|
||||
def send_invite(self, room_id: str, mxid: str) -> None:
|
||||
self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid})
|
||||
|
||||
@request
|
||||
def send(
|
||||
self,
|
||||
method: str,
|
||||
path: str = "",
|
||||
content: Union[bytes, dict] = {},
|
||||
params: dict = {},
|
||||
content_type: str = "application/json",
|
||||
endpoint: str = "/_matrix/client/r0",
|
||||
) -> dict:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.as_token}",
|
||||
"Content-Type": content_type,
|
||||
}
|
||||
payload = json.dumps(content) if isinstance(content, dict) else content
|
||||
endpoint = (
|
||||
f"{self.base_url}{endpoint}{path}?"
|
||||
f"{urllib.parse.urlencode(params)}"
|
||||
)
|
||||
|
||||
return self.http.request(
|
||||
method, endpoint, body=payload, headers=headers
|
||||
)
|
|
@ -1,6 +0,0 @@
|
|||
import threading
|
||||
|
||||
|
||||
class Cache:
|
||||
cache = {}
|
||||
lock = threading.Lock()
|
|
@ -1,120 +0,0 @@
|
|||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import List
|
||||
|
||||
|
||||
class DataBase:
|
||||
def __init__(self, db_file) -> None:
|
||||
self.create(db_file)
|
||||
|
||||
# The database is accessed via multiple threads.
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def create(self, db_file) -> None:
|
||||
"""
|
||||
Create a database with the relevant tables if it doesn't already exist.
|
||||
"""
|
||||
|
||||
exists = os.path.exists(db_file)
|
||||
|
||||
self.conn = sqlite3.connect(db_file, check_same_thread=False)
|
||||
self.conn.row_factory = self.dict_factory
|
||||
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
if exists:
|
||||
return
|
||||
|
||||
self.cur.execute(
|
||||
"CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);"
|
||||
)
|
||||
|
||||
self.cur.execute(
|
||||
"CREATE TABLE users(mxid TEXT PRIMARY KEY, "
|
||||
"avatar_url TEXT, username TEXT);"
|
||||
)
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def dict_factory(self, cursor, row):
|
||||
"""
|
||||
https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory
|
||||
"""
|
||||
|
||||
d = {}
|
||||
for idx, col in enumerate(cursor.description):
|
||||
d[col[0]] = row[idx]
|
||||
return d
|
||||
|
||||
def add_room(self, room_id: str, channel_id: str) -> None:
|
||||
"""
|
||||
Add a bridged room to the database.
|
||||
"""
|
||||
|
||||
with self.lock:
|
||||
self.cur.execute(
|
||||
"INSERT INTO bridge (room_id, channel_id) VALUES (?, ?)",
|
||||
[room_id, channel_id],
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def add_user(self, mxid: str) -> None:
|
||||
with self.lock:
|
||||
self.cur.execute("INSERT INTO users (mxid) VALUES (?)", [mxid])
|
||||
self.conn.commit()
|
||||
|
||||
def add_avatar(self, avatar_url: str, mxid: str) -> None:
|
||||
with self.lock:
|
||||
self.cur.execute(
|
||||
"UPDATE users SET avatar_url = (?) WHERE mxid = (?)",
|
||||
[avatar_url, mxid],
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def add_username(self, username: str, mxid: str) -> None:
|
||||
with self.lock:
|
||||
self.cur.execute(
|
||||
"UPDATE users SET username = (?) WHERE mxid = (?)",
|
||||
[username, mxid],
|
||||
)
|
||||
self.conn.commit()
|
||||
|
||||
def get_channel(self, room_id: str) -> str:
|
||||
"""
|
||||
Get the corresponding channel ID for a given room ID.
|
||||
"""
|
||||
|
||||
with self.lock:
|
||||
self.cur.execute(
|
||||
"SELECT channel_id FROM bridge WHERE room_id = ?", [room_id]
|
||||
)
|
||||
|
||||
room = self.cur.fetchone()
|
||||
|
||||
# Return an empty string if the channel is not bridged.
|
||||
return "" if not room else room["channel_id"]
|
||||
|
||||
def list_channels(self) -> List[str]:
|
||||
"""
|
||||
Get a list of all the bridged channels.
|
||||
"""
|
||||
|
||||
with self.lock:
|
||||
self.cur.execute("SELECT channel_id FROM bridge")
|
||||
|
||||
channels = self.cur.fetchall()
|
||||
|
||||
return [channel["channel_id"] for channel in channels]
|
||||
|
||||
def fetch_user(self, mxid: str) -> dict:
|
||||
"""
|
||||
Fetch the profile for a bridged user.
|
||||
"""
|
||||
|
||||
with self.lock:
|
||||
self.cur.execute("SELECT * FROM users where mxid = ?", [mxid])
|
||||
|
||||
user = self.cur.fetchone()
|
||||
|
||||
return {} if not user else user
|
|
@ -1,5 +0,0 @@
|
|||
class RequestError(Exception):
|
||||
def __init__(self, status: int, *args):
|
||||
super().__init__(*args)
|
||||
|
||||
self.status = status
|
|
@ -1,260 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import Dict, List
|
||||
|
||||
import urllib3
|
||||
import websockets
|
||||
|
||||
import discord
|
||||
from misc import dict_cls, log_except, request
|
||||
|
||||
|
||||
class Gateway:
|
||||
def __init__(self, http: urllib3.PoolManager, token: str):
|
||||
self.http = http
|
||||
self.token = token
|
||||
self.logger = logging.getLogger("discord")
|
||||
self.Payloads = discord.Payloads(self.token)
|
||||
self.websocket = None
|
||||
|
||||
@log_except
|
||||
async def run(self) -> None:
|
||||
self.heartbeat_task: asyncio.Future = None
|
||||
self.resume = False
|
||||
|
||||
gateway_url = self.get_gateway_url()
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self.gateway_handler(gateway_url)
|
||||
except (
|
||||
websockets.ConnectionClosedError,
|
||||
websockets.InvalidMessage,
|
||||
):
|
||||
self.logger.exception("Connection lost, reconnecting.")
|
||||
|
||||
# Stop sending heartbeats until we reconnect.
|
||||
if self.heartbeat_task and not self.heartbeat_task.cancelled():
|
||||
self.heartbeat_task.cancel()
|
||||
|
||||
def get_gateway_url(self) -> str:
|
||||
resp = self.send("GET", "/gateway")
|
||||
|
||||
return resp["url"]
|
||||
|
||||
async def heartbeat_handler(self, interval_ms: int) -> None:
|
||||
while True:
|
||||
await asyncio.sleep(interval_ms / 1000)
|
||||
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
|
||||
|
||||
async def handle_resp(self, data: dict) -> None:
|
||||
data_dict = data["d"]
|
||||
|
||||
opcode = data["op"]
|
||||
|
||||
seq = data["s"]
|
||||
|
||||
if seq:
|
||||
self.Payloads.seq = seq
|
||||
|
||||
if opcode == discord.GatewayOpCodes.DISPATCH:
|
||||
otype = data["t"]
|
||||
|
||||
if otype == "READY":
|
||||
self.Payloads.session = data_dict["session_id"]
|
||||
|
||||
self.logger.info("READY")
|
||||
else:
|
||||
self.handle_otype(data_dict, otype)
|
||||
elif opcode == discord.GatewayOpCodes.HELLO:
|
||||
heartbeat_interval = data_dict.get("heartbeat_interval")
|
||||
|
||||
self.logger.info(f"Heartbeat Interval: {heartbeat_interval}")
|
||||
|
||||
# Send periodic hearbeats to gateway.
|
||||
self.heartbeat_task = asyncio.ensure_future(
|
||||
self.heartbeat_handler(heartbeat_interval)
|
||||
)
|
||||
|
||||
await self.websocket.send(
|
||||
json.dumps(
|
||||
self.Payloads.RESUME()
|
||||
if self.resume
|
||||
else self.Payloads.IDENTIFY()
|
||||
)
|
||||
)
|
||||
elif opcode == discord.GatewayOpCodes.RECONNECT:
|
||||
self.logger.info("Received RECONNECT.")
|
||||
|
||||
self.resume = True
|
||||
await self.websocket.close()
|
||||
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
|
||||
self.logger.info("Received INVALID_SESSION.")
|
||||
|
||||
self.resume = False
|
||||
await self.websocket.close()
|
||||
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
|
||||
# NOP
|
||||
pass
|
||||
else:
|
||||
self.logger.info(
|
||||
"Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}"
|
||||
)
|
||||
|
||||
def handle_otype(self, data: dict, otype: str) -> None:
|
||||
if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"):
|
||||
obj = discord.Message(data)
|
||||
elif otype == "TYPING_START":
|
||||
obj = dict_cls(data, discord.Typing)
|
||||
elif otype == "GUILD_CREATE":
|
||||
obj = discord.Guild(data)
|
||||
elif otype == "GUILD_MEMBER_UPDATE":
|
||||
obj = discord.GuildMemberUpdate(data)
|
||||
elif otype == "GUILD_EMOJIS_UPDATE":
|
||||
obj = discord.GuildEmojisUpdate(data)
|
||||
else:
|
||||
return
|
||||
|
||||
func = getattr(self, f"on_{otype.lower()}", None)
|
||||
|
||||
if not func:
|
||||
self.logger.warning(
|
||||
f"Function '{func}' not defined, ignoring message."
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
func(obj)
|
||||
except Exception:
|
||||
self.logger.exception(f"Ignoring exception in '{func.__name__}':")
|
||||
|
||||
async def gateway_handler(self, gateway_url: str) -> None:
|
||||
async with websockets.connect(
|
||||
f"{gateway_url}/?v=8&encoding=json"
|
||||
) as websocket:
|
||||
self.websocket = websocket
|
||||
|
||||
async for message in websocket:
|
||||
await self.handle_resp(json.loads(message))
|
||||
|
||||
def get_channel(self, channel_id: str) -> discord.Channel:
|
||||
"""
|
||||
Get the channel for a given channel ID.
|
||||
"""
|
||||
|
||||
resp = self.send("GET", f"/channels/{channel_id}")
|
||||
|
||||
return dict_cls(resp, discord.Channel)
|
||||
|
||||
def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]:
|
||||
"""
|
||||
Get all channels for a given guild ID.
|
||||
"""
|
||||
|
||||
resp = self.send("GET", f"/guilds/{guild_id}/channels")
|
||||
|
||||
return {
|
||||
channel["id"]: dict_cls(channel, discord.Channel)
|
||||
for channel in resp
|
||||
}
|
||||
|
||||
def get_emotes(self, guild_id: str) -> List[discord.Emote]:
|
||||
"""
|
||||
Get all the emotes for a given guild.
|
||||
"""
|
||||
|
||||
resp = self.send("GET", f"/guilds/{guild_id}/emojis")
|
||||
|
||||
return [dict_cls(emote, discord.Emote) for emote in resp]
|
||||
|
||||
def get_members(self, guild_id: str) -> List[discord.User]:
|
||||
"""
|
||||
Get all the members for a given guild.
|
||||
"""
|
||||
|
||||
resp = self.send(
|
||||
"GET", f"/guilds/{guild_id}/members", params={"limit": 1000}
|
||||
)
|
||||
|
||||
return [discord.User(member["user"]) for member in resp]
|
||||
|
||||
def create_webhook(self, channel_id: str, name: str) -> discord.Webhook:
|
||||
"""
|
||||
Create a webhook with the specified name in a given channel.
|
||||
"""
|
||||
|
||||
resp = self.send(
|
||||
"POST", f"/channels/{channel_id}/webhooks", {"name": name}
|
||||
)
|
||||
|
||||
return dict_cls(resp, discord.Webhook)
|
||||
|
||||
def edit_webhook(
|
||||
self, content: str, message_id: str, webhook: discord.Webhook
|
||||
) -> None:
|
||||
self.send(
|
||||
"PATCH",
|
||||
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
|
||||
f"{message_id}",
|
||||
{"content": content},
|
||||
)
|
||||
|
||||
def delete_webhook(
|
||||
self, message_id: str, webhook: discord.Webhook
|
||||
) -> None:
|
||||
self.send(
|
||||
"DELETE",
|
||||
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
|
||||
f"{message_id}",
|
||||
)
|
||||
|
||||
def send_webhook(
|
||||
self,
|
||||
webhook: discord.Webhook,
|
||||
avatar_url: str,
|
||||
content: str,
|
||||
username: str,
|
||||
) -> discord.Message:
|
||||
payload = {
|
||||
"avatar_url": avatar_url,
|
||||
"content": content,
|
||||
"username": username,
|
||||
# Disable 'everyone' and 'role' mentions.
|
||||
"allowed_mentions": {"parse": ["users"]},
|
||||
}
|
||||
|
||||
resp = self.send(
|
||||
"POST",
|
||||
f"/webhooks/{webhook.id}/{webhook.token}",
|
||||
payload,
|
||||
{"wait": True},
|
||||
)
|
||||
|
||||
return discord.Message(resp)
|
||||
|
||||
def send_message(self, message: str, channel_id: str) -> None:
|
||||
self.send(
|
||||
"POST", f"/channels/{channel_id}/messages", {"content": message}
|
||||
)
|
||||
|
||||
@request
|
||||
def send(
|
||||
self, method: str, path: str, content: dict = {}, params: dict = {}
|
||||
) -> dict:
|
||||
endpoint = (
|
||||
f"https://discord.com/api/v8{path}?"
|
||||
f"{urllib.parse.urlencode(params)}"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bot {self.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# 'body' being an empty dict breaks "GET" requests.
|
||||
payload = json.dumps(content) if content else None
|
||||
|
||||
return self.http.request(
|
||||
method, endpoint, body=payload, headers=headers
|
||||
)
|
|
@ -1,800 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import urllib.parse
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import markdown
|
||||
import urllib3
|
||||
|
||||
import discord
|
||||
import matrix
|
||||
from appservice import AppService
|
||||
from cache import Cache
|
||||
from db import DataBase
|
||||
from errors import RequestError
|
||||
from gateway import Gateway
|
||||
from misc import dict_cls, except_deleted, hash_str
|
||||
|
||||
|
||||
class MatrixClient(AppService):
|
||||
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
|
||||
super().__init__(config, http)
|
||||
|
||||
self.db = DataBase(config["database"])
|
||||
self.discord = DiscordClient(self, config, http)
|
||||
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
|
||||
self.id_regex = "[0-9]+" # Snowflakes may have variable length
|
||||
|
||||
# TODO Find a cleaner way to use these keys.
|
||||
for k in ("m_emotes", "m_members", "m_messages"):
|
||||
Cache.cache[k] = {}
|
||||
|
||||
def handle_bridge(self, message: matrix.Event) -> None:
|
||||
# Ignore events that aren't for us.
|
||||
if message.sender.split(":")[
|
||||
-1
|
||||
] != self.server_name or not message.body.startswith("!bridge"):
|
||||
return
|
||||
|
||||
# Get the channel ID.
|
||||
try:
|
||||
channel = message.body.split()[1]
|
||||
except IndexError:
|
||||
return
|
||||
|
||||
# Check if the given channel is valid.
|
||||
try:
|
||||
channel = self.discord.get_channel(channel)
|
||||
except RequestError as e:
|
||||
# The channel can be invalid or we may not have permissions.
|
||||
self.logger.warning(f"Failed to fetch channel {channel}: {e}")
|
||||
return
|
||||
|
||||
if (
|
||||
channel.type != discord.ChannelType.GUILD_TEXT
|
||||
or channel.id in self.db.list_channels()
|
||||
):
|
||||
return
|
||||
|
||||
self.logger.info(f"Creating bridged room for channel {channel.id}.")
|
||||
|
||||
self.create_room(channel, message.sender)
|
||||
|
||||
def on_member(self, event: matrix.Event) -> None:
|
||||
with Cache.lock:
|
||||
# Just lazily clear the whole member cache on
|
||||
# membership update events.
|
||||
if event.room_id in Cache.cache["m_members"]:
|
||||
self.logger.info(
|
||||
f"Clearing member cache for room '{event.room_id}'."
|
||||
)
|
||||
del Cache.cache["m_members"][event.room_id]
|
||||
|
||||
if (
|
||||
event.sender.split(":")[-1] != self.server_name
|
||||
or event.state_key != self.user_id
|
||||
or not event.is_direct
|
||||
):
|
||||
return
|
||||
|
||||
# Join the direct message room.
|
||||
self.logger.info(f"Joining direct message room '{event.room_id}'.")
|
||||
self.join_room(event.room_id)
|
||||
|
||||
def on_message(self, message: matrix.Event) -> None:
|
||||
if (
|
||||
message.sender.startswith((f"@{self.format}", self.user_id))
|
||||
or not message.body
|
||||
):
|
||||
return
|
||||
|
||||
# Handle bridging commands.
|
||||
self.handle_bridge(message)
|
||||
|
||||
channel_id = self.db.get_channel(message.room_id)
|
||||
|
||||
if not channel_id:
|
||||
return
|
||||
|
||||
author = self.get_members(message.room_id)[message.sender]
|
||||
|
||||
webhook = self.discord.get_webhook(
|
||||
channel_id, self.discord.webhook_name
|
||||
)
|
||||
|
||||
if message.relates_to and message.reltype == "m.replace":
|
||||
with Cache.lock:
|
||||
message_id = Cache.cache["m_messages"].get(message.relates_to)
|
||||
|
||||
# TODO validate if the original author sent the edit.
|
||||
|
||||
if not message_id or not message.new_body:
|
||||
return
|
||||
|
||||
message.new_body = self.process_message(message)
|
||||
|
||||
except_deleted(self.discord.edit_webhook)(
|
||||
message.new_body, message_id, webhook
|
||||
)
|
||||
else:
|
||||
message.body = (
|
||||
f"`{message.body}`: {self.mxc_url(message.attachment)}"
|
||||
if message.attachment
|
||||
else self.process_message(message)
|
||||
)
|
||||
|
||||
message_id = self.discord.send_webhook(
|
||||
webhook,
|
||||
self.mxc_url(author.avatar_url) if author.avatar_url else None,
|
||||
message.body,
|
||||
author.display_name if author.display_name else message.sender,
|
||||
).id
|
||||
|
||||
with Cache.lock:
|
||||
Cache.cache["m_messages"][message.id] = message_id
|
||||
|
||||
def on_redaction(self, event: matrix.Event) -> None:
|
||||
with Cache.lock:
|
||||
message_id = Cache.cache["m_messages"].get(event.redacts)
|
||||
|
||||
if not message_id:
|
||||
return
|
||||
|
||||
webhook = self.discord.get_webhook(
|
||||
self.db.get_channel(event.room_id), self.discord.webhook_name
|
||||
)
|
||||
|
||||
except_deleted(self.discord.delete_webhook)(message_id, webhook)
|
||||
|
||||
with Cache.lock:
|
||||
del Cache.cache["m_messages"][event.redacts]
|
||||
|
||||
def get_members(self, room_id: str) -> Dict[str, matrix.User]:
|
||||
with Cache.lock:
|
||||
cached = Cache.cache["m_members"].get(room_id)
|
||||
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
resp = self.send("GET", f"/rooms/{room_id}/joined_members")
|
||||
|
||||
joined = resp["joined"]
|
||||
|
||||
for k, v in joined.items():
|
||||
joined[k] = dict_cls(v, matrix.User)
|
||||
|
||||
with Cache.lock:
|
||||
Cache.cache["m_members"][room_id] = joined
|
||||
|
||||
return joined
|
||||
|
||||
def create_room(self, channel: discord.Channel, sender: str) -> None:
|
||||
"""
|
||||
Create a bridged room and invite the person who invoked the command.
|
||||
"""
|
||||
|
||||
content = {
|
||||
"room_alias_name": f"{self.format}{channel.id}",
|
||||
"name": channel.name,
|
||||
"topic": channel.topic if channel.topic else channel.name,
|
||||
"visibility": "private",
|
||||
"invite": [sender],
|
||||
"creation_content": {"m.federate": True},
|
||||
"initial_state": [
|
||||
{
|
||||
"type": "m.room.join_rules",
|
||||
"content": {"join_rule": "public"},
|
||||
},
|
||||
{
|
||||
"type": "m.room.history_visibility",
|
||||
"content": {"history_visibility": "shared"},
|
||||
},
|
||||
],
|
||||
"power_level_content_override": {
|
||||
"users": {sender: 100, self.user_id: 100}
|
||||
},
|
||||
}
|
||||
|
||||
resp = self.send("POST", "/createRoom", content)
|
||||
|
||||
self.db.add_room(resp["room_id"], channel.id)
|
||||
|
||||
def create_message_event(
|
||||
self,
|
||||
message: str,
|
||||
emotes: dict,
|
||||
edit: str = "",
|
||||
reference: discord.Message = None,
|
||||
) -> dict:
|
||||
content = {
|
||||
"body": message,
|
||||
"msgtype": "m.text",
|
||||
}
|
||||
|
||||
fmt = self.get_fmt(message, emotes)
|
||||
|
||||
if fmt != message:
|
||||
content = {
|
||||
**content,
|
||||
"format": "org.matrix.custom.html",
|
||||
"formatted_body": fmt,
|
||||
}
|
||||
|
||||
ref_id = None
|
||||
|
||||
if reference:
|
||||
# Reply to a Discord message.
|
||||
with Cache.lock:
|
||||
ref_id = Cache.cache["d_messages"].get(reference.id)
|
||||
|
||||
# Reply to a Matrix message. (maybe)
|
||||
if not ref_id:
|
||||
with Cache.lock:
|
||||
ref_id = [
|
||||
k
|
||||
for k, v in Cache.cache["m_messages"].items()
|
||||
if v == reference.id
|
||||
]
|
||||
ref_id = next(iter(ref_id), "")
|
||||
|
||||
if ref_id:
|
||||
event = except_deleted(self.get_event)(
|
||||
ref_id,
|
||||
self.get_room_id(self.discord.matrixify(reference.channel_id)),
|
||||
)
|
||||
if event:
|
||||
# Content with the reply fallbacks stripped.
|
||||
tmp = ""
|
||||
# We don't want to strip lines starting with "> " after
|
||||
# encountering a regular line, so we use this variable.
|
||||
got_fallback = True
|
||||
for line in event.body.split("\n"):
|
||||
if not line.startswith("> "):
|
||||
got_fallback = False
|
||||
if not got_fallback:
|
||||
tmp += line
|
||||
|
||||
event.body = tmp
|
||||
event.formatted_body = (
|
||||
# re.DOTALL allows the match to span newlines.
|
||||
re.sub(
|
||||
"<mx-reply.+?</mx-reply>",
|
||||
"",
|
||||
event.formatted_body,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
if event.formatted_body
|
||||
else event.body
|
||||
)
|
||||
|
||||
content = {
|
||||
**content,
|
||||
"body": (
|
||||
f"> <{event.sender}> {event.body}\n{content['body']}"
|
||||
),
|
||||
"m.relates_to": {"m.in_reply_to": {"event_id": event.id}},
|
||||
"format": "org.matrix.custom.html",
|
||||
"formatted_body": f"""<mx-reply><blockquote>\
|
||||
<a href="https://matrix.to/#/{event.room_id}/{event.id}">In reply to</a>\
|
||||
<a href="https://matrix.to/#/{event.sender}">{event.sender}</a>\
|
||||
<br />{event.formatted_body if event.formatted_body else event.body}\
|
||||
</blockquote></mx-reply>\
|
||||
{content.get("formatted_body", content['body'])}""",
|
||||
}
|
||||
|
||||
if edit:
|
||||
content = {
|
||||
**content,
|
||||
"body": f" * {content['body']}",
|
||||
"formatted_body": f" * {content.get('formatted_body', content['body'])}",
|
||||
"m.relates_to": {"event_id": edit, "rel_type": "m.replace"},
|
||||
"m.new_content": {**content},
|
||||
}
|
||||
|
||||
return content
|
||||
|
||||
def get_fmt(self, message: str, emotes: dict) -> str:
|
||||
message = (
|
||||
markdown.markdown(message)
|
||||
.replace("<p>", "")
|
||||
.replace("</p>", "")
|
||||
.replace("\n", "<br />")
|
||||
)
|
||||
|
||||
# Upload emotes in multiple threads so that we don't
|
||||
# block the Discord bot for too long.
|
||||
upload_threads = [
|
||||
threading.Thread(
|
||||
target=self.upload_emote, args=(emote, emotes[emote])
|
||||
)
|
||||
for emote in emotes
|
||||
]
|
||||
|
||||
# Acquire the lock before starting the threads to avoid resource
|
||||
# contention by tens of threads at once.
|
||||
with Cache.lock:
|
||||
for thread in upload_threads:
|
||||
thread.start()
|
||||
for thread in upload_threads:
|
||||
thread.join()
|
||||
|
||||
with Cache.lock:
|
||||
for emote in emotes:
|
||||
emote_ = Cache.cache["m_emotes"].get(emote)
|
||||
|
||||
if emote_:
|
||||
emote = f":{emote}:"
|
||||
message = message.replace(
|
||||
emote,
|
||||
f"""<img alt=\"{emote}\" title=\"{emote}\" \
|
||||
height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
def mention_regex(self, encode: bool, id_as_group: bool) -> str:
|
||||
mention = "@"
|
||||
colon = ":"
|
||||
snowflake = self.id_regex
|
||||
|
||||
if encode:
|
||||
mention = urllib.parse.quote(mention)
|
||||
colon = urllib.parse.quote(colon)
|
||||
|
||||
if id_as_group:
|
||||
snowflake = f"({snowflake})"
|
||||
|
||||
hashed = f"(?:-{snowflake})?"
|
||||
|
||||
return f"{mention}{self.format}{snowflake}{hashed}{colon}{re.escape(self.server_name)}"
|
||||
|
||||
def process_message(self, event: matrix.Event) -> str:
|
||||
message = event.new_body if event.new_body else event.body
|
||||
|
||||
emotes = re.findall(r":(\w*):", message)
|
||||
|
||||
mentions = list(
|
||||
re.finditer(
|
||||
self.mention_regex(encode=False, id_as_group=True),
|
||||
event.formatted_body,
|
||||
)
|
||||
)
|
||||
# For clients that properly encode mentions.
|
||||
# 'https://matrix.to/#/%40_discord_...%3Adomain.tld'
|
||||
mentions.extend(
|
||||
re.finditer(
|
||||
self.mention_regex(encode=True, id_as_group=True),
|
||||
event.formatted_body,
|
||||
)
|
||||
)
|
||||
|
||||
with Cache.lock:
|
||||
for emote in set(emotes):
|
||||
emote_ = Cache.cache["d_emotes"].get(emote)
|
||||
if emote_:
|
||||
message = message.replace(f":{emote}:", emote_)
|
||||
|
||||
for mention in set(mentions):
|
||||
# Unquote just in-case we matched an encoded username.
|
||||
username = self.db.fetch_user(
|
||||
urllib.parse.unquote(mention.group(0))
|
||||
).get("username")
|
||||
if username:
|
||||
if mention.group(2):
|
||||
# Replace mention with plain text for hashed users (webhooks)
|
||||
message = message.replace(mention.group(0), f"@{username}")
|
||||
else:
|
||||
# Replace the 'mention' so that the user is tagged
|
||||
# in the case of replies aswell.
|
||||
# '> <@_discord_1234:localhost> Message'
|
||||
for replace in (mention.group(0), username):
|
||||
message = message.replace(
|
||||
replace, f"<@{mention.group(1)}>"
|
||||
)
|
||||
|
||||
# We trim the message later as emotes take up extra characters too.
|
||||
return message[: discord.MESSAGE_LIMIT]
|
||||
|
||||
def upload_emote(self, emote_name: str, emote_id: str) -> None:
|
||||
# There won't be a race condition here, since only a unique
|
||||
# set of emotes are uploaded at a time.
|
||||
if emote_name in Cache.cache["m_emotes"]:
|
||||
return
|
||||
|
||||
emote_url = f"{discord.CDN_URL}/emojis/{emote_id}"
|
||||
|
||||
# We don't want the message to be dropped entirely if an emote
|
||||
# fails to upload for some reason.
|
||||
try:
|
||||
# TODO This is not thread safe, but we're protected by the GIL.
|
||||
Cache.cache["m_emotes"][emote_name] = self.upload(emote_url)
|
||||
except RequestError as e:
|
||||
self.logger.warning(f"Failed to upload emote {emote_id}: {e}")
|
||||
|
||||
def register(self, mxid: str) -> None:
|
||||
"""
|
||||
Register a dummy user on the homeserver.
|
||||
"""
|
||||
|
||||
content = {
|
||||
"type": "m.login.application_service",
|
||||
# "@test:localhost" -> "test" (Can't register with a full mxid.)
|
||||
"username": mxid[1:].split(":")[0],
|
||||
}
|
||||
|
||||
resp = self.send("POST", "/register", content)
|
||||
|
||||
self.db.add_user(resp["user_id"])
|
||||
|
||||
def set_avatar(self, avatar_url: str, mxid: str) -> None:
|
||||
avatar_uri = self.upload(avatar_url)
|
||||
|
||||
self.send(
|
||||
"PUT",
|
||||
f"/profile/{mxid}/avatar_url",
|
||||
{"avatar_url": avatar_uri},
|
||||
params={"user_id": mxid},
|
||||
)
|
||||
|
||||
self.db.add_avatar(avatar_url, mxid)
|
||||
|
||||
def set_nick(self, username: str, mxid: str) -> None:
|
||||
self.send(
|
||||
"PUT",
|
||||
f"/profile/{mxid}/displayname",
|
||||
{"displayname": username},
|
||||
params={"user_id": mxid},
|
||||
)
|
||||
|
||||
self.db.add_username(username, mxid)
|
||||
|
||||
|
||||
class DiscordClient(Gateway):
|
||||
def __init__(
|
||||
self, appservice: MatrixClient, config: dict, http: urllib3.PoolManager
|
||||
) -> None:
|
||||
super().__init__(http, config["discord_token"])
|
||||
|
||||
self.app = appservice
|
||||
self.webhook_name = "matrix_bridge"
|
||||
|
||||
# TODO Find a cleaner way to use these keys.
|
||||
for k in ("d_emotes", "d_messages", "d_webhooks"):
|
||||
Cache.cache[k] = {}
|
||||
|
||||
def to_return(self, message: discord.Message) -> bool:
|
||||
with Cache.lock:
|
||||
hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()]
|
||||
|
||||
return (
|
||||
message.channel_id not in self.app.db.list_channels()
|
||||
or not message.author # Embeds can be weird sometimes.
|
||||
or message.webhook_id in hook_ids
|
||||
)
|
||||
|
||||
def matrixify(self, id: str, user: bool = False, hashed: str = "") -> str:
|
||||
return (
|
||||
f"{'@' if user else '#'}{self.app.format}"
|
||||
f"{id}{'-' + hashed if hashed else ''}:"
|
||||
f"{self.app.server_name}"
|
||||
)
|
||||
|
||||
def sync_profile(self, user: discord.User, hashed: str = "") -> None:
|
||||
"""
|
||||
Sync the avatar and username for a puppeted user.
|
||||
"""
|
||||
|
||||
mxid = self.matrixify(user.id, user=True, hashed=hashed)
|
||||
|
||||
profile = self.app.db.fetch_user(mxid)
|
||||
|
||||
# User doesn't exist.
|
||||
if not profile:
|
||||
return
|
||||
|
||||
username = f"{user.username}#{user.discriminator}"
|
||||
|
||||
if user.avatar_url != profile["avatar_url"]:
|
||||
self.logger.info(f"Updating avatar for Discord user '{user.id}'")
|
||||
self.app.set_avatar(user.avatar_url, mxid)
|
||||
if username != profile["username"]:
|
||||
self.logger.info(f"Updating username for Discord user '{user.id}'")
|
||||
self.app.set_nick(username, mxid)
|
||||
|
||||
def wrap(self, message: discord.Message) -> Tuple[str, str]:
|
||||
"""
|
||||
Get the room ID and the puppet's mxid for a given channel ID and a
|
||||
Discord user.
|
||||
"""
|
||||
|
||||
hashed = ""
|
||||
if message.webhook_id and message.webhook_id != message.application_id:
|
||||
hashed = str(hash_str(message.author.username))
|
||||
|
||||
mxid = self.matrixify(message.author.id, user=True, hashed=hashed)
|
||||
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
|
||||
|
||||
if not self.app.db.fetch_user(mxid):
|
||||
self.logger.info(
|
||||
f"Creating dummy user for Discord user {message.author.id}."
|
||||
)
|
||||
self.app.register(mxid)
|
||||
|
||||
self.app.set_nick(
|
||||
f"{message.author.username}#"
|
||||
f"{message.author.discriminator}",
|
||||
mxid,
|
||||
)
|
||||
|
||||
if message.author.avatar_url:
|
||||
self.app.set_avatar(message.author.avatar_url, mxid)
|
||||
|
||||
if mxid not in self.app.get_members(room_id):
|
||||
self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.")
|
||||
|
||||
self.app.send_invite(room_id, mxid)
|
||||
self.app.join_room(room_id, mxid)
|
||||
|
||||
if message.webhook_id:
|
||||
# Sync webhooks here as they can't be accessed like guild members.
|
||||
self.sync_profile(message.author, hashed=hashed)
|
||||
|
||||
return mxid, room_id
|
||||
|
||||
def cache_emotes(self, emotes: List[discord.Emote]):
|
||||
# TODO maybe "namespace" emotes by guild in the cache ?
|
||||
with Cache.lock:
|
||||
for emote in emotes:
|
||||
Cache.cache["d_emotes"][emote.name] = (
|
||||
f"<{'a' if emote.animated else ''}:"
|
||||
f"{emote.name}:{emote.id}>"
|
||||
)
|
||||
|
||||
def on_guild_create(self, guild: discord.Guild) -> None:
|
||||
for member in guild.members:
|
||||
self.sync_profile(member)
|
||||
|
||||
self.cache_emotes(guild.emojis)
|
||||
|
||||
def on_guild_emojis_update(
|
||||
self, update: discord.GuildEmojisUpdate
|
||||
) -> None:
|
||||
self.cache_emotes(update.emojis)
|
||||
|
||||
def on_guild_member_update(
|
||||
self, update: discord.GuildMemberUpdate
|
||||
) -> None:
|
||||
self.sync_profile(update.user)
|
||||
|
||||
def on_message_create(self, message: discord.Message) -> None:
|
||||
if self.to_return(message):
|
||||
return
|
||||
|
||||
mxid, room_id = self.wrap(message)
|
||||
|
||||
content_, emotes = self.process_message(message)
|
||||
|
||||
content = self.app.create_message_event(
|
||||
content_, emotes, reference=message.referenced_message
|
||||
)
|
||||
|
||||
with Cache.lock:
|
||||
Cache.cache["d_messages"][message.id] = self.app.send_message(
|
||||
room_id, content, mxid
|
||||
)
|
||||
|
||||
def on_message_delete(self, message: discord.Message) -> None:
|
||||
with Cache.lock:
|
||||
event_id = Cache.cache["d_messages"].get(message.id)
|
||||
|
||||
if not event_id:
|
||||
return
|
||||
|
||||
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
|
||||
event = except_deleted(self.app.get_event)(event_id, room_id)
|
||||
|
||||
if event:
|
||||
self.app.redact(event.id, event.room_id, event.sender)
|
||||
|
||||
with Cache.lock:
|
||||
del Cache.cache["d_messages"][message.id]
|
||||
|
||||
def on_message_update(self, message: discord.Message) -> None:
|
||||
if self.to_return(message):
|
||||
return
|
||||
|
||||
with Cache.lock:
|
||||
event_id = Cache.cache["d_messages"].get(message.id)
|
||||
|
||||
if not event_id:
|
||||
return
|
||||
|
||||
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
|
||||
mxid = self.matrixify(message.author.id, user=True)
|
||||
|
||||
# It is possible that a webhook edit's it's own old message
|
||||
# after changing it's name, hence we generate a new mxid from
|
||||
# the hashed username, but that mxid hasn't been registered before,
|
||||
# so the request fails with:
|
||||
# M_FORBIDDEN: Application service has not registered this user
|
||||
if not self.app.db.fetch_user(mxid):
|
||||
return
|
||||
|
||||
content_, emotes = self.process_message(message)
|
||||
|
||||
content = self.app.create_message_event(
|
||||
content_, emotes, edit=event_id
|
||||
)
|
||||
|
||||
self.app.send_message(room_id, content, mxid)
|
||||
|
||||
def on_typing_start(self, typing: discord.Typing) -> None:
|
||||
if typing.channel_id not in self.app.db.list_channels():
|
||||
return
|
||||
|
||||
mxid = self.matrixify(typing.user_id, user=True)
|
||||
room_id = self.app.get_room_id(self.matrixify(typing.channel_id))
|
||||
|
||||
if mxid not in self.app.get_members(room_id):
|
||||
return
|
||||
|
||||
self.app.send_typing(room_id, mxid)
|
||||
|
||||
def get_webhook(self, channel_id: str, name: str) -> discord.Webhook:
|
||||
"""
|
||||
Get the webhook object for the first webhook that matches the specified
|
||||
name in a given channel, create the webhook if it doesn't exist.
|
||||
"""
|
||||
|
||||
# Check the cache first.
|
||||
with Cache.lock:
|
||||
webhook = Cache.cache["d_webhooks"].get(channel_id)
|
||||
|
||||
if webhook:
|
||||
return webhook
|
||||
|
||||
webhooks = self.send("GET", f"/channels/{channel_id}/webhooks")
|
||||
webhook = next(
|
||||
(
|
||||
dict_cls(webhook, discord.Webhook)
|
||||
for webhook in webhooks
|
||||
if webhook["name"] == name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not webhook:
|
||||
webhook = self.create_webhook(channel_id, name)
|
||||
|
||||
with Cache.lock:
|
||||
Cache.cache["d_webhooks"][channel_id] = webhook
|
||||
|
||||
return webhook
|
||||
|
||||
def process_message(self, message: discord.Message) -> Tuple[str, Dict]:
|
||||
content = message.content
|
||||
emotes = {}
|
||||
regex = r"<a?:(\w+):(\d+)>"
|
||||
|
||||
# Mentions can either be in the form of `<@1234>` or `<@!1234>`.
|
||||
for member in message.mentions:
|
||||
for char in ("", "!"):
|
||||
content = content.replace(
|
||||
f"<@{char}{member.id}>", f"@{member.username}"
|
||||
)
|
||||
|
||||
# Replace channel IDs with names.
|
||||
channels = re.findall("<#([0-9]+)>", content)
|
||||
if channels:
|
||||
if not message.guild_id:
|
||||
self.logger.warning(
|
||||
f"Message '{message.id}' in channel '{message.channel_id}' does not have a guild_id!"
|
||||
)
|
||||
else:
|
||||
discord_channels = self.get_channels(message.guild_id)
|
||||
for channel in channels:
|
||||
discord_channel = discord_channels.get(channel)
|
||||
name = (
|
||||
discord_channel.name
|
||||
if discord_channel
|
||||
else "deleted-channel"
|
||||
)
|
||||
content = content.replace(f"<#{channel}>", f"#{name}")
|
||||
|
||||
# { "emote_name": "emote_id" }
|
||||
for emote in re.findall(regex, content):
|
||||
emotes[emote[0]] = emote[1]
|
||||
|
||||
# Replace emote IDs with names.
|
||||
content = re.sub(regex, r":\g<1>:", content)
|
||||
|
||||
# Append attachments to message.
|
||||
for attachment in message.attachments:
|
||||
content += f"\n{attachment['url']}"
|
||||
|
||||
# Append stickers to message.
|
||||
for sticker in message.stickers:
|
||||
if sticker.format_type != 3: # 3 == Lottie format.
|
||||
content += f"\n{discord.CDN_URL}/stickers/{sticker.id}.png"
|
||||
|
||||
return content, emotes
|
||||
|
||||
|
||||
def config_gen(basedir: str, config_file: str) -> dict:
|
||||
config_file = f"{basedir}/{config_file}"
|
||||
|
||||
config_dict = {
|
||||
"as_token": "my-secret-as-token",
|
||||
"hs_token": "my-secret-hs-token",
|
||||
"user_id": "appservice-discord",
|
||||
"homeserver": "http://127.0.0.1:8008",
|
||||
"server_name": "localhost",
|
||||
"discord_token": "my-secret-discord-token",
|
||||
"port": 5000,
|
||||
"database": f"{basedir}/bridge.db",
|
||||
}
|
||||
|
||||
if not os.path.exists(config_file):
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
print(f"Configuration dumped to '{config_file}'")
|
||||
sys.exit()
|
||||
|
||||
with open(config_file, "r") as f:
|
||||
return json.loads(f.read())
|
||||
|
||||
|
||||
def excepthook(exc_type, exc_value, exc_traceback):
|
||||
if issubclass(exc_type, KeyboardInterrupt):
|
||||
sys.__excepthook__(exc_type, exc_value, exc_traceback)
|
||||
return
|
||||
|
||||
logging.critical(
|
||||
"Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback)
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
basedir = sys.argv[1]
|
||||
if not os.path.exists(basedir):
|
||||
print(f"Path '{basedir}' does not exist!")
|
||||
sys.exit(1)
|
||||
basedir = os.path.abspath(basedir)
|
||||
except IndexError:
|
||||
basedir = os.getcwd()
|
||||
|
||||
config = config_gen(basedir, "appservice.json")
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(name)s:%(levelname)s:%(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[
|
||||
logging.FileHandler(f"{basedir}/appservice.log"),
|
||||
],
|
||||
)
|
||||
|
||||
sys.excepthook = excepthook
|
||||
|
||||
app = MatrixClient(config, urllib3.PoolManager(maxsize=10))
|
||||
|
||||
# Start the bottle app in a separate thread.
|
||||
app_thread = threading.Thread(
|
||||
target=app.run, kwargs={"port": int(config["port"])}, daemon=True
|
||||
)
|
||||
app_thread.start()
|
||||
|
||||
try:
|
||||
asyncio.run(app.discord.run())
|
||||
except KeyboardInterrupt:
|
||||
sys.exit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,84 +0,0 @@
|
|||
import json
|
||||
from dataclasses import fields
|
||||
from typing import Any
|
||||
|
||||
import urllib3
|
||||
|
||||
from errors import RequestError
|
||||
|
||||
|
||||
def dict_cls(d: dict, cls: Any) -> Any:
|
||||
"""
|
||||
Create a dataclass from a dictionary.
|
||||
"""
|
||||
|
||||
field_names = set(f.name for f in fields(cls))
|
||||
filtered_dict = {k: v for k, v in d.items() if k in field_names}
|
||||
|
||||
return cls(**filtered_dict)
|
||||
|
||||
|
||||
def log_except(fn):
|
||||
"""
|
||||
Log unhandled exceptions to a logger instead of `stderr`.
|
||||
"""
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return fn(self, *args, **kwargs)
|
||||
except Exception:
|
||||
self.logger.exception(f"Exception in '{fn.__name__}':")
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def request(fn):
|
||||
"""
|
||||
Either return json data or raise a `RequestError` if the request was
|
||||
unsuccessful.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
resp = fn(*args, **kwargs)
|
||||
except urllib3.exceptions.HTTPError as e:
|
||||
raise RequestError(None, f"Failed to connect: {e}") from None
|
||||
|
||||
if resp.status < 200 or resp.status >= 300:
|
||||
raise RequestError(
|
||||
resp.status,
|
||||
f"Failed to get response from '{resp.geturl()}':\n{resp.data}",
|
||||
)
|
||||
|
||||
return {} if resp.status == 204 else json.loads(resp.data)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def except_deleted(fn):
|
||||
"""
|
||||
Ignore the `RequestError` on 404s, the content might have been removed.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except RequestError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def hash_str(string: str) -> int:
|
||||
"""
|
||||
Create the hash for a string
|
||||
"""
|
||||
|
||||
hash = 5381
|
||||
|
||||
for ch in string:
|
||||
hash = ((hash << 5) + hash) + ord(ch)
|
||||
|
||||
return hash & 0xFFFFFFFF
|
Loading…
Reference in a new issue