cleanup
* Use websocket events to sync usernames/avatars instead of periodic syncing. * Use caches for fetched room_ids, room state (for usernames and avatars). Also switch to using a single cache with locks. * Don't store full message objects in cache, just store the relation of matrix event IDs to discord message IDs and vice-versa. Content can be fetched from the server instead. * Don't rely on websocket events for mentioning Discord users, mentions are now done by mentioning the dummy matrix user. The ID to be mentioned is extracted from the MXID instead. * General clean-ups.
This commit is contained in:
parent
b21c82ccd0
commit
4713a00016
10 changed files with 454 additions and 440 deletions
|
@ -49,20 +49,23 @@ A path can optionally be passed as the first argument to `main.py`. This path wi
|
|||
Eg. Running `python3 main.py /path/to/my/dir` will store the database and logs in `/path/to/my/dir`.
|
||||
`$PWD` is used by default if no path is specified.
|
||||
|
||||
After setting up the bridge, send a direct message to `@appservice-discord:domain.tld` containing the channel ID to be bridged (`!bridge 123456`).
|
||||
|
||||
This bridge is written with:
|
||||
* `bottle`: Receiving events from the homeserver.
|
||||
* `urllib3`: Sending requests, thread safety.
|
||||
* `websockets`: Connecting to Discord. (Big thanks to an anonymous person "nesslersreagent" for figuring out the initial connection mess.)
|
||||
|
||||
## NOTES
|
||||
|
||||
* A basic sqlite database is used for keeping track of bridged rooms.
|
||||
|
||||
* Logs are saved to the `appservice.log` file in `$PWD` or the specified directory.
|
||||
|
||||
* For avatars to show up on Discord, you must have a [reverse proxy](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) set up on your homeserver as the bridge does not specify the homeserver port when passing the avatar url.
|
||||
|
||||
* It is not possible to add normal Discord bot functionality like commands as this bridge does not use `discord.py`.
|
||||
* It is not possible to add "normal" Discord bot functionality like commands as this bridge does not use `discord.py`.
|
||||
|
||||
* [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) must be enabled for your Discord bot.
|
||||
* [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) for members and presence must be enabled for your Discord bot.
|
||||
|
||||
* This Appservice might not work well for bridging a large number of rooms since it is mostly synchronous. However, it wouldn't take much effort to port it to `asyncio` and `aiohttp` if desired.
|
||||
|
|
|
@ -2,13 +2,14 @@ import json
|
|||
import logging
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
import bottle
|
||||
import urllib3
|
||||
|
||||
import matrix
|
||||
from misc import dict_cls, except_deleted, log_except, request
|
||||
from cache import Cache
|
||||
from misc import log_except, request
|
||||
|
||||
|
||||
class AppService(bottle.Bottle):
|
||||
|
@ -37,13 +38,17 @@ class AppService(bottle.Bottle):
|
|||
method="PUT",
|
||||
)
|
||||
|
||||
Cache.cache["m_rooms"] = {}
|
||||
|
||||
def handle_event(self, event: dict) -> None:
|
||||
event_type = event.get("type")
|
||||
|
||||
if event_type == "m.room.member" or event_type == "m.room.message":
|
||||
obj = self.get_event_object(event)
|
||||
elif event_type == "m.room.redaction":
|
||||
obj = event
|
||||
if event_type in (
|
||||
"m.room.member",
|
||||
"m.room.message",
|
||||
"m.room.redaction",
|
||||
):
|
||||
obj = matrix.Event(event)
|
||||
else:
|
||||
self.logger.info(f"Unknown event type: {event_type}")
|
||||
return
|
||||
|
@ -86,22 +91,13 @@ class AppService(bottle.Bottle):
|
|||
def mxc_url(self, mxc: str) -> str:
|
||||
try:
|
||||
homeserver, media_id = mxc.replace("mxc://", "").split("/")
|
||||
converted = (
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
return (
|
||||
f"https://{self.server_name}/_matrix/media/r0/download/"
|
||||
f"{homeserver}/{media_id}"
|
||||
)
|
||||
except ValueError:
|
||||
converted = ""
|
||||
|
||||
return converted
|
||||
|
||||
def get_event_object(self, event: dict) -> matrix.Event:
|
||||
# TODO use caching and invalidate old cache on member events.
|
||||
event["author"] = dict_cls(
|
||||
self.get_profile(event["sender"]), matrix.User
|
||||
)
|
||||
|
||||
return matrix.Event(event)
|
||||
|
||||
def join_room(self, room_id: str, mxid: str = "") -> None:
|
||||
self.send(
|
||||
|
@ -117,42 +113,25 @@ class AppService(bottle.Bottle):
|
|||
params={"user_id": mxid} if mxid else {},
|
||||
)
|
||||
|
||||
def get_profile(self, mxid: str) -> dict:
|
||||
resp = except_deleted(self.send)("GET", f"/profile/{mxid}")
|
||||
|
||||
# No profile exists for the user.
|
||||
if not resp:
|
||||
return {}
|
||||
|
||||
avatar_url = resp.get("avatar_url")
|
||||
|
||||
if avatar_url:
|
||||
avatar_url = self.mxc_url(avatar_url)
|
||||
|
||||
return {
|
||||
"avatar_url": avatar_url,
|
||||
"displayname": resp.get("displayname"),
|
||||
}
|
||||
|
||||
def get_members(self, room_id: str) -> List[str]:
|
||||
resp = self.send(
|
||||
"GET",
|
||||
f"/rooms/{room_id}/members",
|
||||
params={"membership": "join", "not_membership": "leave"},
|
||||
)
|
||||
|
||||
return [
|
||||
content["sender"]
|
||||
for content in resp["chunk"]
|
||||
if content["content"]["membership"] == "join"
|
||||
]
|
||||
|
||||
def get_room_id(self, alias: str) -> str:
|
||||
with Cache.lock:
|
||||
room = Cache.cache["m_rooms"].get(alias)
|
||||
if room:
|
||||
return room
|
||||
|
||||
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
|
||||
|
||||
# TODO cache ?
|
||||
room_id = resp["room_id"]
|
||||
|
||||
return resp["room_id"]
|
||||
with Cache.lock:
|
||||
Cache.cache["m_rooms"][alias] = room_id
|
||||
|
||||
return room_id
|
||||
|
||||
def get_event(self, event_id: str, room_id: str) -> matrix.Event:
|
||||
resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}")
|
||||
|
||||
return matrix.Event(resp)
|
||||
|
||||
def upload(self, url: str) -> str:
|
||||
"""
|
||||
|
@ -211,12 +190,12 @@ class AppService(bottle.Bottle):
|
|||
) -> dict:
|
||||
params["access_token"] = self.as_token
|
||||
headers = {"Content-Type": content_type}
|
||||
content = json.dumps(content) if isinstance(content, dict) else content
|
||||
payload = json.dumps(content) if isinstance(content, dict) else content
|
||||
endpoint = (
|
||||
f"{self.base_url}{endpoint}{path}?"
|
||||
f"{urllib.parse.urlencode(params)}"
|
||||
)
|
||||
|
||||
return self.http.request(
|
||||
method, endpoint, body=content, headers=headers
|
||||
method, endpoint, body=payload, headers=headers
|
||||
)
|
||||
|
|
6
appservice/cache.py
Normal file
6
appservice/cache.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
import threading
|
||||
|
||||
|
||||
class Cache:
|
||||
cache = {}
|
||||
lock = threading.Lock()
|
|
@ -4,11 +4,11 @@ import threading
|
|||
from typing import List
|
||||
|
||||
|
||||
class DataBase(object):
|
||||
class DataBase:
|
||||
def __init__(self, db_file) -> None:
|
||||
self.create(db_file)
|
||||
|
||||
# The database is accessed via both the threads.
|
||||
# The database is accessed via multiple threads.
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def create(self, db_file) -> None:
|
||||
|
@ -92,7 +92,7 @@ class DataBase(object):
|
|||
|
||||
room = self.cur.fetchone()
|
||||
|
||||
# Return an empty string if nothing is bridged.
|
||||
# Return an empty string if the channel is not bridged.
|
||||
return "" if not room else room["channel_id"]
|
||||
|
||||
def list_channels(self) -> List[str]:
|
||||
|
@ -116,6 +116,8 @@ class DataBase(object):
|
|||
self.cur.execute("SELECT * FROM users")
|
||||
users = self.cur.fetchall()
|
||||
|
||||
user = [user for user in users if user["mxid"] == mxid]
|
||||
user: dict = next(
|
||||
iter([user for user in users if user["mxid"] == mxid]), {}
|
||||
)
|
||||
|
||||
return user[0] if user else {}
|
||||
return user
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
from dataclasses import dataclass
|
||||
from misc import dict_cls
|
||||
|
||||
CDN_URL = "https://cdn.discordapp.com"
|
||||
ID_LEN = 18
|
||||
|
||||
|
||||
def bitmask(bit: int) -> int:
|
||||
return 1 << bit
|
||||
|
||||
|
||||
@dataclass
|
||||
class Channel(object):
|
||||
class Channel:
|
||||
id: str
|
||||
type: str
|
||||
guild_id: str = ""
|
||||
|
@ -13,13 +19,32 @@ class Channel(object):
|
|||
|
||||
|
||||
@dataclass
|
||||
class Emote(object):
|
||||
class Emote:
|
||||
animated: bool
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class User(object):
|
||||
@dataclass
|
||||
class MessageReference:
|
||||
message_id: str
|
||||
channel_id: str
|
||||
guild_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Typing:
|
||||
user_id: str
|
||||
channel_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Webhook:
|
||||
id: str
|
||||
token: str
|
||||
|
||||
|
||||
class User:
|
||||
def __init__(self, user: dict) -> None:
|
||||
self.discriminator = user["discriminator"]
|
||||
self.id = user["id"]
|
||||
|
@ -38,45 +63,57 @@ class User(object):
|
|||
self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}"
|
||||
|
||||
|
||||
class Message(object):
|
||||
class Guild:
|
||||
def __init__(self, guild: dict) -> None:
|
||||
self.guild_id = guild["id"]
|
||||
self.channels = [dict_cls(c, Channel) for c in guild["channels"]]
|
||||
self.emojis = [dict_cls(e, Emote) for e in guild["emojis"]]
|
||||
members = [member["user"] for member in guild["members"]]
|
||||
self.members = [User(m) for m in members]
|
||||
|
||||
|
||||
class GuildEmojisUpdate:
|
||||
def __init__(self, update: dict) -> None:
|
||||
self.guild_id = update["guild_id"]
|
||||
self.emojis = [dict_cls(e, Emote) for e in update["emojis"]]
|
||||
|
||||
|
||||
class GuildMembersChunk:
|
||||
def __init__(self, chunk: dict) -> None:
|
||||
self.chunk_index = chunk["chunk_index"]
|
||||
self.chunk_count = chunk["chunk_count"]
|
||||
self.guild_id = chunk["guild_id"]
|
||||
self.members = [User(m) for m in chunk["members"]]
|
||||
|
||||
|
||||
class GuildMemberUpdate:
|
||||
def __init__(self, update: dict) -> None:
|
||||
self.guild_id = update["guild_id"]
|
||||
self.user = User(update["user"])
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, message: dict) -> None:
|
||||
self.attachments = message.get("attachments", [])
|
||||
self.channel_id = message["channel_id"]
|
||||
self.content = message.get("content", "")
|
||||
self.id = message["id"]
|
||||
self.reference = message.get("message_reference", {}).get(
|
||||
"message_id", ""
|
||||
)
|
||||
self.webhook_id = message.get("webhook_id", "")
|
||||
|
||||
self.mentions = [
|
||||
User(mention) for mention in message.get("mentions", [])
|
||||
]
|
||||
|
||||
ref = message.get("message_reference")
|
||||
|
||||
self.reference = dict_cls(ref, MessageReference) if ref else None
|
||||
|
||||
author = message.get("author")
|
||||
|
||||
self.author = User(author) if author else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeletedMessage(object):
|
||||
channel_id: str
|
||||
id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Typing(object):
|
||||
user_id: str
|
||||
channel_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Webhook(object):
|
||||
id: str
|
||||
token: str
|
||||
|
||||
|
||||
class ChannelType(object):
|
||||
class ChannelType:
|
||||
GUILD_TEXT = 0
|
||||
DM = 1
|
||||
GUILD_VOICE = 2
|
||||
|
@ -86,7 +123,7 @@ class ChannelType(object):
|
|||
GUILD_STORE = 6
|
||||
|
||||
|
||||
class InteractionResponseType(object):
|
||||
class InteractionResponseType:
|
||||
PONG = 0
|
||||
ACKNOWLEDGE = 1
|
||||
CHANNEL_MESSAGE = 2
|
||||
|
@ -94,10 +131,7 @@ class InteractionResponseType(object):
|
|||
ACKNOWLEDGE_WITH_SOURCE = 5
|
||||
|
||||
|
||||
class GatewayIntents(object):
|
||||
def bitmask(bit: int) -> int:
|
||||
return 1 << bit
|
||||
|
||||
class GatewayIntents:
|
||||
GUILDS = bitmask(0)
|
||||
GUILD_MEMBERS = bitmask(1)
|
||||
GUILD_BANS = bitmask(2)
|
||||
|
@ -115,7 +149,7 @@ class GatewayIntents(object):
|
|||
DIRECT_MESSAGE_TYPING = bitmask(14)
|
||||
|
||||
|
||||
class GatewayOpCodes(object):
|
||||
class GatewayOpCodes:
|
||||
DISPATCH = 0
|
||||
HEARTBEAT = 1
|
||||
IDENTIFY = 2
|
||||
|
@ -129,7 +163,7 @@ class GatewayOpCodes(object):
|
|||
HEARTBEAT_ACK = 11
|
||||
|
||||
|
||||
class Payloads(object):
|
||||
class Payloads:
|
||||
def __init__(self, token: str) -> None:
|
||||
self.seq = self.session = None
|
||||
self.token = token
|
||||
|
@ -143,27 +177,19 @@ class Payloads(object):
|
|||
"d": {
|
||||
"token": self.token,
|
||||
"intents": GatewayIntents.GUILDS
|
||||
| GatewayIntents.GUILD_EMOJIS
|
||||
| GatewayIntents.GUILD_MEMBERS
|
||||
| GatewayIntents.GUILD_MESSAGES
|
||||
| GatewayIntents.GUILD_MESSAGE_TYPING,
|
||||
| GatewayIntents.GUILD_MESSAGE_TYPING
|
||||
| GatewayIntents.GUILD_PRESENCES,
|
||||
"properties": {
|
||||
"$os": "discord",
|
||||
"$browser": "discord",
|
||||
"$browser": "Discord Client",
|
||||
"$device": "discord",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def QUERY(self, guild_id: str, query: str, limit: int = 1) -> dict:
|
||||
"""
|
||||
Return the Payload to query a member from a guild ID.
|
||||
Return only a single match if `limit` isn't specified.
|
||||
"""
|
||||
|
||||
return {
|
||||
"op": GatewayOpCodes.REQUEST_GUILD_MEMBERS,
|
||||
"d": {"guild_id": guild_id, "query": query, "limit": limit},
|
||||
}
|
||||
|
||||
def RESUME(self) -> dict:
|
||||
return {
|
||||
"op": GatewayOpCodes.RESUME,
|
||||
|
|
|
@ -8,31 +8,27 @@ import urllib3
|
|||
import websockets
|
||||
|
||||
import discord
|
||||
from misc import dict_cls, log_except, request, wrap_async
|
||||
from misc import dict_cls, log_except, request
|
||||
|
||||
|
||||
class Gateway(object):
|
||||
class Gateway:
|
||||
def __init__(self, http: urllib3.PoolManager, token: str):
|
||||
self.http = http
|
||||
self.token = token
|
||||
self.logger = logging.getLogger("discord")
|
||||
self.cdn_url = "https://cdn.discordapp.com"
|
||||
self.Payloads = discord.Payloads(self.token)
|
||||
self.loop = self.websocket = None
|
||||
|
||||
self.query_cache = {}
|
||||
self.websocket = None
|
||||
|
||||
@log_except
|
||||
async def run(self) -> None:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.query_ev = asyncio.Event()
|
||||
|
||||
self.heartbeat_task = None
|
||||
self.heartbeat_task: asyncio.Future = None
|
||||
self.resume = False
|
||||
|
||||
gateway_url = self.get_gateway_url()
|
||||
|
||||
while True:
|
||||
try:
|
||||
await self.gateway_handler(self.get_gateway_url())
|
||||
await self.gateway_handler(gateway_url)
|
||||
except (
|
||||
websockets.ConnectionClosedError,
|
||||
websockets.InvalidMessage,
|
||||
|
@ -55,26 +51,71 @@ class Gateway(object):
|
|||
await asyncio.sleep(interval_ms / 1000)
|
||||
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
|
||||
|
||||
def query_handler(self, data: dict) -> None:
|
||||
members = data["members"]
|
||||
guild_id = data["guild_id"]
|
||||
async def handle_resp(self, data: dict) -> None:
|
||||
data_dict = data["d"]
|
||||
|
||||
for member in members:
|
||||
user = member["user"]
|
||||
self.query_cache[guild_id].append(user)
|
||||
opcode = data["op"]
|
||||
|
||||
self.query_ev.set()
|
||||
seq = data["s"]
|
||||
|
||||
if seq:
|
||||
self.Payloads.seq = seq
|
||||
|
||||
if opcode == discord.GatewayOpCodes.DISPATCH:
|
||||
otype = data["t"]
|
||||
|
||||
if otype == "READY":
|
||||
self.Payloads.session = data_dict["session_id"]
|
||||
|
||||
self.logger.info("READY")
|
||||
else:
|
||||
self.handle_otype(data_dict, otype)
|
||||
elif opcode == discord.GatewayOpCodes.HELLO:
|
||||
heartbeat_interval = data_dict.get("heartbeat_interval")
|
||||
|
||||
self.logger.info(f"Heartbeat Interval: {heartbeat_interval}")
|
||||
|
||||
# Send periodic hearbeats to gateway.
|
||||
self.heartbeat_task = asyncio.ensure_future(
|
||||
self.heartbeat_handler(heartbeat_interval)
|
||||
)
|
||||
|
||||
await self.websocket.send(
|
||||
json.dumps(
|
||||
self.Payloads.RESUME()
|
||||
if self.resume
|
||||
else self.Payloads.IDENTIFY()
|
||||
)
|
||||
)
|
||||
elif opcode == discord.GatewayOpCodes.RECONNECT:
|
||||
self.logger.info("Received RECONNECT.")
|
||||
|
||||
self.resume = True
|
||||
await self.websocket.close()
|
||||
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
|
||||
self.logger.info("Received INVALID_SESSION.")
|
||||
|
||||
self.resume = False
|
||||
await self.websocket.close()
|
||||
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
|
||||
# NOP
|
||||
pass
|
||||
else:
|
||||
self.logger.info(
|
||||
"Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}"
|
||||
)
|
||||
|
||||
def handle_otype(self, data: dict, otype: str) -> None:
|
||||
if otype == "MESSAGE_CREATE" or otype == "MESSAGE_UPDATE":
|
||||
if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"):
|
||||
obj = discord.Message(data)
|
||||
elif otype == "MESSAGE_DELETE":
|
||||
obj = dict_cls(data, discord.DeletedMessage)
|
||||
elif otype == "TYPING_START":
|
||||
obj = dict_cls(data, discord.Typing)
|
||||
elif otype == "GUILD_MEMBERS_CHUNK":
|
||||
self.query_handler(data)
|
||||
return
|
||||
elif otype == "GUILD_CREATE":
|
||||
obj = discord.Guild(data)
|
||||
elif otype == "GUILD_MEMBER_UPDATE":
|
||||
obj = discord.GuildMemberUpdate(data)
|
||||
elif otype == "GUILD_EMOJIS_UPDATE":
|
||||
obj = discord.GuildEmojisUpdate(data)
|
||||
else:
|
||||
self.logger.info(f"Unknown OTYPE: {otype}")
|
||||
return
|
||||
|
@ -90,119 +131,20 @@ class Gateway(object):
|
|||
try:
|
||||
func(obj)
|
||||
except Exception:
|
||||
self.logger.exception(f"Ignoring exception in {func}:")
|
||||
self.logger.exception(f"Ignoring exception in '{func.__name__}':")
|
||||
|
||||
async def gateway_handler(self, gateway_url: str) -> None:
|
||||
async with websockets.connect(
|
||||
f"{gateway_url}/?v=8&encoding=json"
|
||||
) as websocket:
|
||||
self.websocket = websocket
|
||||
|
||||
async for message in websocket:
|
||||
data = json.loads(message)
|
||||
data_dict = data.get("d")
|
||||
|
||||
opcode = data.get("op")
|
||||
|
||||
seq = data.get("s")
|
||||
if seq:
|
||||
self.Payloads.seq = seq
|
||||
|
||||
if opcode == discord.GatewayOpCodes.DISPATCH:
|
||||
otype = data.get("t")
|
||||
|
||||
if otype == "READY":
|
||||
self.Payloads.session = data_dict["session_id"]
|
||||
|
||||
self.logger.info("READY")
|
||||
|
||||
else:
|
||||
self.handle_otype(data_dict, otype)
|
||||
|
||||
elif opcode == discord.GatewayOpCodes.HELLO:
|
||||
heartbeat_interval = data_dict.get("heartbeat_interval")
|
||||
|
||||
self.logger.info(
|
||||
f"Heartbeat Interval: {heartbeat_interval}"
|
||||
)
|
||||
|
||||
# Send periodic hearbeats to gateway.
|
||||
self.heartbeat_task = asyncio.ensure_future(
|
||||
self.heartbeat_handler(heartbeat_interval)
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
self.Payloads.RESUME()
|
||||
if self.resume
|
||||
else self.Payloads.IDENTIFY()
|
||||
)
|
||||
)
|
||||
|
||||
elif opcode == discord.GatewayOpCodes.RECONNECT:
|
||||
self.logger.info("Received RECONNECT.")
|
||||
|
||||
self.resume = True
|
||||
await websocket.close()
|
||||
|
||||
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
|
||||
self.logger.info("Received INVALID_SESSION.")
|
||||
|
||||
self.resume = False
|
||||
await websocket.close()
|
||||
|
||||
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
|
||||
# NOP
|
||||
pass
|
||||
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Unknown OP code {opcode}:\n"
|
||||
f"{json.dumps(data, indent=4)}"
|
||||
)
|
||||
|
||||
@wrap_async
|
||||
async def query_member(self, guild_id: str, name: str) -> discord.User:
|
||||
"""
|
||||
Query the members for a given guild and return the first match.
|
||||
"""
|
||||
|
||||
self.query_ev.clear()
|
||||
|
||||
def query():
|
||||
if not self.query_cache.get(guild_id):
|
||||
self.query_cache[guild_id] = []
|
||||
|
||||
user = [
|
||||
user
|
||||
for user in self.query_cache[guild_id]
|
||||
if name.lower() in user["username"].lower()
|
||||
]
|
||||
|
||||
return None if not user else discord.User(user[0])
|
||||
|
||||
user = query()
|
||||
|
||||
if user:
|
||||
return user
|
||||
|
||||
if not self.websocket or self.websocket.closed:
|
||||
self.logger.warning("Not fetching members, websocket closed.")
|
||||
return
|
||||
|
||||
await self.websocket.send(
|
||||
json.dumps(self.Payloads.QUERY(guild_id, name))
|
||||
)
|
||||
|
||||
# TODO clean this mess.
|
||||
|
||||
# Wait for our websocket to receive the chunk.
|
||||
await asyncio.wait_for(self.query_ev.wait(), timeout=5)
|
||||
|
||||
return query()
|
||||
await self.handle_resp(json.loads(message))
|
||||
|
||||
def get_channel(self, channel_id: str) -> discord.Channel:
|
||||
"""
|
||||
Get the channel object for a given channel ID.
|
||||
Get the channel for a given channel ID.
|
||||
"""
|
||||
|
||||
resp = self.send("GET", f"/channels/{channel_id}")
|
||||
|
@ -259,9 +201,17 @@ class Gateway(object):
|
|||
f"{message_id}",
|
||||
)
|
||||
|
||||
def send_webhook(self, webhook: discord.Webhook, **kwargs) -> str:
|
||||
content = {
|
||||
**kwargs,
|
||||
def send_webhook(
|
||||
self,
|
||||
webhook: discord.Webhook,
|
||||
avatar_url: str,
|
||||
content: str,
|
||||
username: str,
|
||||
) -> discord.Message:
|
||||
payload = {
|
||||
"avatar_url": avatar_url,
|
||||
"content": content,
|
||||
"username": username,
|
||||
# Disable 'everyone' and 'role' mentions.
|
||||
"allowed_mentions": {"parse": ["users"]},
|
||||
}
|
||||
|
@ -269,11 +219,11 @@ class Gateway(object):
|
|||
resp = self.send(
|
||||
"POST",
|
||||
f"/webhooks/{webhook.id}/{webhook.token}",
|
||||
content,
|
||||
payload,
|
||||
{"wait": True},
|
||||
)
|
||||
|
||||
return resp["id"]
|
||||
return discord.Message(resp)
|
||||
|
||||
def send_message(self, message: str, channel_id: str) -> None:
|
||||
self.send(
|
||||
|
@ -294,8 +244,8 @@ class Gateway(object):
|
|||
}
|
||||
|
||||
# 'body' being an empty dict breaks "GET" requests.
|
||||
content = json.dumps(content) if content else None
|
||||
payload = json.dumps(content) if content else None
|
||||
|
||||
return self.http.request(
|
||||
method, endpoint, body=content, headers=headers
|
||||
method, endpoint, body=payload, headers=headers
|
||||
)
|
||||
|
|
|
@ -5,21 +5,19 @@ import os
|
|||
import re
|
||||
import sys
|
||||
import threading
|
||||
from typing import Dict, Tuple, Union
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import urllib3
|
||||
|
||||
import discord
|
||||
import matrix
|
||||
from appservice import AppService
|
||||
from cache import Cache
|
||||
from db import DataBase
|
||||
from errors import RequestError
|
||||
from gateway import Gateway
|
||||
from misc import dict_cls, except_deleted, hash_str
|
||||
|
||||
# TODO should this be cleared periodically ?
|
||||
message_cache: Dict[str, Union[discord.Webhook, str]] = {}
|
||||
|
||||
|
||||
class MatrixClient(AppService):
|
||||
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
|
||||
|
@ -27,9 +25,11 @@ class MatrixClient(AppService):
|
|||
|
||||
self.db = DataBase(config["database"])
|
||||
self.discord = DiscordClient(self, config, http)
|
||||
self.emote_cache: Dict[str, str] = {}
|
||||
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
|
||||
|
||||
for k in ("m_emotes", "m_members", "m_messages"):
|
||||
Cache.cache[k] = {}
|
||||
|
||||
def handle_bridge(self, message: matrix.Event) -> None:
|
||||
# Ignore events that aren't for us.
|
||||
if message.sender.split(":")[
|
||||
|
@ -37,6 +37,7 @@ class MatrixClient(AppService):
|
|||
] != self.server_name or not message.body.startswith("!bridge"):
|
||||
return
|
||||
|
||||
# Get the channel ID.
|
||||
try:
|
||||
channel = message.body.split()[1]
|
||||
except IndexError:
|
||||
|
@ -46,7 +47,7 @@ class MatrixClient(AppService):
|
|||
try:
|
||||
channel = self.discord.get_channel(channel)
|
||||
except RequestError as e:
|
||||
# The channel can be invalid or we may not have permission.
|
||||
# The channel can be invalid or we may not have permissions.
|
||||
self.logger.warning(f"Failed to fetch channel {channel}: {e}")
|
||||
return
|
||||
|
||||
|
@ -61,7 +62,15 @@ class MatrixClient(AppService):
|
|||
self.create_room(channel, message.sender)
|
||||
|
||||
def on_member(self, event: matrix.Event) -> None:
|
||||
# Ignore events that aren't for us.
|
||||
with Cache.lock:
|
||||
# Just lazily clear the whole member cache on
|
||||
# membership update events.
|
||||
if event.room_id in Cache.cache["m_members"]:
|
||||
self.logger.info(
|
||||
f"Clearing member cache for room '{event.room_id}'."
|
||||
)
|
||||
del Cache.cache["m_members"][event.room_id]
|
||||
|
||||
if (
|
||||
event.sender.split(":")[-1] != self.server_name
|
||||
or event.state_key != self.user_id
|
||||
|
@ -70,7 +79,7 @@ class MatrixClient(AppService):
|
|||
return
|
||||
|
||||
# Join the direct message room.
|
||||
self.logger.info(f"Joining direct message room {event.room_id}.")
|
||||
self.logger.info(f"Joining direct message room '{event.room_id}'.")
|
||||
self.join_room(event.room_id)
|
||||
|
||||
def on_message(self, message: matrix.Event) -> None:
|
||||
|
@ -88,51 +97,78 @@ class MatrixClient(AppService):
|
|||
if not channel_id:
|
||||
return
|
||||
|
||||
webhook = self.discord.get_webhook(channel_id, "matrix_bridge")
|
||||
author = self.get_members(message.room_id)[message.sender]
|
||||
|
||||
if not author.display_name:
|
||||
author.display_name = message.sender
|
||||
|
||||
webhook = self.discord.get_webhook(
|
||||
channel_id, self.discord.webhook_name
|
||||
)
|
||||
|
||||
if message.relates_to and message.reltype == "m.replace":
|
||||
relation = message_cache.get(message.relates_to)
|
||||
with Cache.lock:
|
||||
message_id = Cache.cache["m_messages"].get(message.relates_to)
|
||||
|
||||
if not message.new_body or not relation:
|
||||
if not message_id or not message.new_body:
|
||||
return
|
||||
|
||||
message.new_body = self.process_message(
|
||||
channel_id, message.new_body
|
||||
)
|
||||
message.new_body = self.process_message(message)
|
||||
|
||||
except_deleted(self.discord.edit_webhook)(
|
||||
message.new_body, relation["message_id"], webhook
|
||||
message.new_body, message_id, webhook
|
||||
)
|
||||
|
||||
else:
|
||||
message.body = (
|
||||
f"`{message.body}`: {self.mxc_url(message.attachment)}"
|
||||
if message.attachment
|
||||
else self.process_message(channel_id, message.body)
|
||||
else self.process_message(message)
|
||||
)
|
||||
|
||||
message_cache[message.event_id] = {
|
||||
"message_id": self.discord.send_webhook(
|
||||
message_id = self.discord.send_webhook(
|
||||
webhook,
|
||||
avatar_url=message.author.avatar_url,
|
||||
content=message.body,
|
||||
username=message.author.displayname,
|
||||
),
|
||||
"webhook": webhook,
|
||||
}
|
||||
self.mxc_url(author.avatar_url),
|
||||
message.body,
|
||||
author.display_name,
|
||||
).id
|
||||
|
||||
@except_deleted
|
||||
def on_redaction(self, event: dict) -> None:
|
||||
redacts = event["redacts"]
|
||||
with Cache.lock:
|
||||
Cache.cache["m_messages"][message.id] = message_id
|
||||
|
||||
event = message_cache.get(redacts)
|
||||
def on_redaction(self, event: matrix.Event) -> None:
|
||||
with Cache.lock:
|
||||
message_id = Cache.cache["m_messages"].get(event.redacts)
|
||||
|
||||
if not event:
|
||||
if not message_id:
|
||||
return
|
||||
|
||||
self.discord.delete_webhook(event["message_id"], event["webhook"])
|
||||
webhook = self.discord.get_webhook(
|
||||
self.db.get_channel(event.room_id), self.discord.webhook_name
|
||||
)
|
||||
|
||||
message_cache.pop(redacts)
|
||||
except_deleted(self.discord.delete_webhook)(message_id, webhook)
|
||||
|
||||
with Cache.lock:
|
||||
del Cache.cache["m_messages"][event.redacts]
|
||||
|
||||
def get_members(self, room_id: str) -> Dict[str, matrix.User]:
|
||||
with Cache.lock:
|
||||
cached = Cache.cache["m_members"].get(room_id)
|
||||
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
resp = self.send("GET", f"/rooms/{room_id}/joined_members")
|
||||
|
||||
joined = resp["joined"]
|
||||
|
||||
for k, v in joined.items():
|
||||
joined[k] = dict_cls(v, matrix.User)
|
||||
|
||||
with Cache.lock:
|
||||
Cache.cache["m_members"][room_id] = joined
|
||||
|
||||
return joined
|
||||
|
||||
def create_room(self, channel: discord.Channel, sender: str) -> None:
|
||||
"""
|
||||
|
@ -164,7 +200,11 @@ class MatrixClient(AppService):
|
|||
self.db.add_room(resp["room_id"], channel.id)
|
||||
|
||||
def create_message_event(
|
||||
self, message: str, emotes: dict, edit: str = "", reply: str = ""
|
||||
self,
|
||||
message: str,
|
||||
emotes: dict,
|
||||
edit: str = "",
|
||||
reference: discord.MessageReference = None,
|
||||
) -> dict:
|
||||
content = {
|
||||
"body": message,
|
||||
|
@ -173,18 +213,37 @@ class MatrixClient(AppService):
|
|||
"formatted_body": self.get_fmt(message, emotes),
|
||||
}
|
||||
|
||||
event = message_cache.get(reply)
|
||||
if reference:
|
||||
# Reply to a Discord message.
|
||||
with Cache.lock:
|
||||
event_id = Cache.cache["d_messages"].get(reference.message_id)
|
||||
|
||||
# Reply to a Matrix message. (maybe)
|
||||
if not event_id:
|
||||
with Cache.lock:
|
||||
event_id = [
|
||||
k
|
||||
for k, v in Cache.cache["m_messages"].items()
|
||||
if v == reference.message_id
|
||||
]
|
||||
event_id = next(iter(event_id), "")
|
||||
|
||||
if reference and event_id:
|
||||
event = except_deleted(self.get_event)(
|
||||
event_id,
|
||||
self.get_room_id(self.discord.matrixify(reference.channel_id)),
|
||||
)
|
||||
if event:
|
||||
content = {
|
||||
**content,
|
||||
"m.relates_to": {
|
||||
"m.in_reply_to": {"event_id": event["event_id"]}
|
||||
},
|
||||
"body": (
|
||||
f"> <{event.sender}> {event.body}\n{content['body']}"
|
||||
),
|
||||
"m.relates_to": {"m.in_reply_to": {"event_id": event.id}},
|
||||
"formatted_body": f"""<mx-reply><blockquote>\
|
||||
<a href='https://matrix.to/#/{event["room_id"]}/{event["event_id"]}'>\
|
||||
In reply to</a><a href='https://matrix.to/#/{event["mxid"]}'>\
|
||||
{event["mxid"]}</a><br>{event["body"]}</blockquote></mx-reply>\
|
||||
<a href='https://matrix.to/#/{event.room_id}/{event.id}'>\
|
||||
In reply to</a><a href='https://matrix.to/#/{event.sender}'>\
|
||||
{event.sender}</a><br>{event.formatted_body}</blockquote></mx-reply>\
|
||||
{content["formatted_body"]}""",
|
||||
}
|
||||
|
||||
|
@ -227,11 +286,17 @@ In reply to</a><a href='https://matrix.to/#/{event["mxid"]}'>\
|
|||
for emote in emotes
|
||||
]
|
||||
|
||||
[thread.start() for thread in upload_threads]
|
||||
[thread.join() for thread in upload_threads]
|
||||
# Acquire the lock before starting the threads to avoid resource
|
||||
# contention by tens of threads at once.
|
||||
with Cache.lock:
|
||||
for thread in upload_threads:
|
||||
thread.start()
|
||||
for thread in upload_threads:
|
||||
thread.join()
|
||||
|
||||
with Cache.lock:
|
||||
for emote in emotes:
|
||||
emote_ = self.emote_cache.get(emote)
|
||||
emote_ = Cache.cache["m_emotes"].get(emote)
|
||||
|
||||
if emote_:
|
||||
emote = f":{emote}:"
|
||||
|
@ -243,47 +308,45 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
|
|||
|
||||
return message
|
||||
|
||||
def process_message(self, channel_id: str, message: str) -> str:
|
||||
def process_message(self, event: matrix.Event) -> str:
|
||||
message = event.new_body if event.new_body else event.body
|
||||
|
||||
message = message[:2000] # Discord limit.
|
||||
|
||||
id_regex = f"[0-9]{{{discord.ID_LEN}}}"
|
||||
|
||||
emotes = re.findall(r":(\w*):", message)
|
||||
mentions = re.findall(r"(@(\w*))", message)
|
||||
mentions = re.findall(
|
||||
f"@{self.format}{id_regex}:{re.escape(self.server_name)}",
|
||||
event.formatted_body,
|
||||
)
|
||||
|
||||
# Remove the puppet user's username from replies.
|
||||
message = re.sub(f"<@{self.format}.+?>", "", message)
|
||||
|
||||
added_emotes = []
|
||||
for emote in emotes:
|
||||
# Don't replace emote names with IDs multiple times.
|
||||
if emote not in added_emotes:
|
||||
added_emotes.append(emote)
|
||||
emote_ = self.discord.emote_cache.get(emote)
|
||||
with Cache.lock:
|
||||
for emote in set(emotes):
|
||||
emote_ = Cache.cache["d_emotes"].get(emote)
|
||||
if emote_:
|
||||
message = message.replace(f":{emote}:", emote_)
|
||||
|
||||
# Don't unnecessarily fetch the channel.
|
||||
if mentions:
|
||||
guild_id = self.discord.get_channel(channel_id).guild_id
|
||||
for mention in set(mentions):
|
||||
username = self.db.fetch_user(mention).get("username")
|
||||
if username:
|
||||
match = re.search(id_regex, mention)
|
||||
|
||||
# TODO this can block for too long if a long list is to be fetched.
|
||||
for mention in mentions:
|
||||
if not mention[1]:
|
||||
continue
|
||||
|
||||
try:
|
||||
member = self.discord.query_member(guild_id, mention[1])
|
||||
except (asyncio.TimeoutError, RuntimeError):
|
||||
continue
|
||||
|
||||
if member:
|
||||
message = message.replace(mention[0], member.mention)
|
||||
if match:
|
||||
# Replace the 'mention' so that the user is tagged
|
||||
# in the case of replies aswell.
|
||||
# '> <@_discord_1234:localhost> Message'
|
||||
for replace in (mention, username):
|
||||
message = message.replace(
|
||||
replace, f"<@{match.group()}>"
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
def upload_emote(self, emote_name: str, emote_id: str) -> None:
|
||||
# There won't be a race condition here, since only a unique
|
||||
# set of emotes are uploaded at a time.
|
||||
if emote_name in self.emote_cache:
|
||||
if emote_name in Cache.cache["m_emotes"]:
|
||||
return
|
||||
|
||||
emote_url = f"{discord.CDN_URL}/emojis/{emote_id}"
|
||||
|
@ -291,7 +354,8 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
|
|||
# We don't want the message to be dropped entirely if an emote
|
||||
# fails to upload for some reason.
|
||||
try:
|
||||
self.emote_cache[emote_name] = self.upload(emote_url)
|
||||
# TODO This is not thread safe, but we're protected by the GIL.
|
||||
Cache.cache["m_emotes"][emote_name] = self.upload(emote_url)
|
||||
except RequestError as e:
|
||||
self.logger.warning(f"Failed to upload emote {emote_id}: {e}")
|
||||
|
||||
|
@ -340,66 +404,19 @@ class DiscordClient(Gateway):
|
|||
super().__init__(http, config["discord_token"])
|
||||
|
||||
self.app = appservice
|
||||
self.emote_cache: Dict[str, str] = {}
|
||||
self.webhook_cache: Dict[str, discord.Webhook] = {}
|
||||
self.webhook_name = "matrix_bridge"
|
||||
|
||||
async def sync(self) -> None:
|
||||
"""
|
||||
Periodically compare the usernames and avatar URLs with Discord
|
||||
and update if they differ. Also synchronise emotes.
|
||||
"""
|
||||
|
||||
# TODO use websocket events and requests.
|
||||
|
||||
def sync_emotes(guilds: set):
|
||||
emotes = []
|
||||
|
||||
for guild in guilds:
|
||||
[emotes.append(emote) for emote in (self.get_emotes(guild))]
|
||||
|
||||
self.emote_cache.clear() # Clear deleted/renamed emotes.
|
||||
|
||||
for emote in emotes:
|
||||
self.emote_cache[f"{emote.name}"] = (
|
||||
f"<{'a' if emote.animated else ''}:"
|
||||
f"{emote.name}:{emote.id}>"
|
||||
)
|
||||
|
||||
def sync_users(guilds: set):
|
||||
for guild in guilds:
|
||||
[
|
||||
self.sync_profile(user, self.matrixify(user.id, user=True))
|
||||
for user in self.get_members(guild)
|
||||
]
|
||||
|
||||
while True:
|
||||
guilds = set() # Avoid duplicates.
|
||||
|
||||
try:
|
||||
for channel in self.app.db.list_channels():
|
||||
guilds.add(self.get_channel(channel).guild_id)
|
||||
|
||||
sync_emotes(guilds)
|
||||
sync_users(guilds)
|
||||
# Don't let the background task die.
|
||||
except RequestError:
|
||||
self.logger.exception(
|
||||
"Ignoring exception during background sync:"
|
||||
)
|
||||
|
||||
await asyncio.sleep(120) # Check every 2 minutes.
|
||||
|
||||
async def start(self) -> None:
|
||||
asyncio.ensure_future(self.sync())
|
||||
|
||||
await self.run()
|
||||
for k in ("d_emotes", "d_messages", "d_webhooks"):
|
||||
Cache.cache[k] = {}
|
||||
|
||||
def to_return(self, message: discord.Message) -> bool:
|
||||
with Cache.lock:
|
||||
hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()]
|
||||
|
||||
return (
|
||||
message.channel_id not in self.app.db.list_channels()
|
||||
or not message.author # Embeds can be weird sometimes.
|
||||
or message.webhook_id
|
||||
in [hook.id for hook in self.webhook_cache.values()]
|
||||
or message.webhook_id in hook_ids
|
||||
)
|
||||
|
||||
def matrixify(self, id: str, user: bool = False) -> str:
|
||||
|
@ -408,11 +425,13 @@ class DiscordClient(Gateway):
|
|||
f"{self.app.server_name}"
|
||||
)
|
||||
|
||||
def sync_profile(self, user: discord.User, mxid: str) -> None:
|
||||
def sync_profile(self, user: discord.User) -> None:
|
||||
"""
|
||||
Sync the avatar and username for a puppeted user.
|
||||
"""
|
||||
|
||||
mxid = self.matrixify(user.id, user=True)
|
||||
|
||||
profile = self.app.db.fetch_user(mxid)
|
||||
|
||||
# User doesn't exist.
|
||||
|
@ -422,10 +441,10 @@ class DiscordClient(Gateway):
|
|||
username = f"{user.username}#{user.discriminator}"
|
||||
|
||||
if user.avatar_url != profile["avatar_url"]:
|
||||
self.logger.info(f"Updating avatar for Discord user {user.id}")
|
||||
self.logger.info(f"Updating avatar for Discord user '{user.id}'")
|
||||
self.app.set_avatar(user.avatar_url, mxid)
|
||||
if username != profile["username"]:
|
||||
self.logger.info(f"Updating username for Discord user {user.id}")
|
||||
self.logger.info(f"Updating username for Discord user '{user.id}'")
|
||||
self.app.set_nick(username, mxid)
|
||||
|
||||
def wrap(self, message: discord.Message) -> Tuple[str, str]:
|
||||
|
@ -457,62 +476,95 @@ class DiscordClient(Gateway):
|
|||
self.app.set_avatar(message.author.avatar_url, mxid)
|
||||
|
||||
if mxid not in self.app.get_members(room_id):
|
||||
self.logger.info(f"Inviting user {mxid} to room {room_id}.")
|
||||
self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.")
|
||||
|
||||
self.app.send_invite(room_id, mxid)
|
||||
self.app.join_room(room_id, mxid)
|
||||
|
||||
if message.webhook_id:
|
||||
# Sync webhooks here as they can't be accessed like guild members.
|
||||
self.sync_profile(message.author, mxid)
|
||||
self.sync_profile(message.author)
|
||||
|
||||
return mxid, room_id
|
||||
|
||||
def cache_emotes(self, emotes: List[discord.Emote]):
|
||||
# TODO maybe "namespace" emotes by guild in the cache ?
|
||||
with Cache.lock:
|
||||
for emote in emotes:
|
||||
Cache.cache["d_emotes"][emote.name] = (
|
||||
f"<{'a' if emote.animated else ''}:"
|
||||
f"{emote.name}:{emote.id}>"
|
||||
)
|
||||
|
||||
def on_guild_create(self, guild: discord.Guild) -> None:
|
||||
for member in guild.members:
|
||||
self.sync_profile(member)
|
||||
|
||||
self.cache_emotes(guild.emojis)
|
||||
|
||||
def on_guild_emojis_update(
|
||||
self, update: discord.GuildEmojisUpdate
|
||||
) -> None:
|
||||
self.cache_emotes(update.emojis)
|
||||
|
||||
def on_guild_member_update(
|
||||
self, update: discord.GuildMemberUpdate
|
||||
) -> None:
|
||||
self.sync_profile(update.user)
|
||||
|
||||
def on_message_create(self, message: discord.Message) -> None:
|
||||
if self.to_return(message):
|
||||
return
|
||||
|
||||
mxid, room_id = self.wrap(message)
|
||||
|
||||
content, emotes = self.process_message(message)
|
||||
content_, emotes = self.process_message(message)
|
||||
|
||||
content = self.app.create_message_event(
|
||||
content, emotes, reply=message.reference
|
||||
content_, emotes, reference=message.reference
|
||||
)
|
||||
|
||||
message_cache[message.id] = {
|
||||
"body": content["body"],
|
||||
"event_id": self.app.send_message(room_id, content, mxid),
|
||||
"mxid": mxid,
|
||||
"room_id": room_id,
|
||||
}
|
||||
with Cache.lock:
|
||||
Cache.cache["d_messages"][message.id] = self.app.send_message(
|
||||
room_id, content, mxid
|
||||
)
|
||||
|
||||
def on_message_delete(self, message: discord.DeletedMessage) -> None:
|
||||
event = message_cache.get(message.id)
|
||||
def on_message_delete(self, message: discord.Message) -> None:
|
||||
with Cache.lock:
|
||||
event_id = Cache.cache["d_messages"].get(message.id)
|
||||
|
||||
if not event:
|
||||
if not event_id:
|
||||
return
|
||||
|
||||
self.app.redact(event["event_id"], event["room_id"], event["mxid"])
|
||||
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
|
||||
event = except_deleted(self.app.get_event)(event_id, room_id)
|
||||
|
||||
message_cache.pop(message.id)
|
||||
if event:
|
||||
self.app.redact(event.id, event.room_id, event.sender)
|
||||
|
||||
with Cache.lock:
|
||||
del Cache.cache["d_messages"][message.id]
|
||||
|
||||
def on_message_update(self, message: discord.Message) -> None:
|
||||
if self.to_return(message):
|
||||
return
|
||||
|
||||
event = message_cache.get(message.id)
|
||||
with Cache.lock:
|
||||
event_id = Cache.cache["d_messages"].get(message.id)
|
||||
|
||||
if not event:
|
||||
if not event_id:
|
||||
return
|
||||
|
||||
content, emotes = self.process_message(message)
|
||||
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
|
||||
mxid = self.matrixify(message.author.id, user=True)
|
||||
|
||||
content_, emotes = self.process_message(message)
|
||||
|
||||
content = self.app.create_message_event(
|
||||
content, emotes, edit=event["event_id"]
|
||||
content_, emotes, edit=event_id
|
||||
)
|
||||
|
||||
self.app.send_message(event["room_id"], content, event["mxid"])
|
||||
self.app.send_message(room_id, content, mxid)
|
||||
|
||||
def on_typing_start(self, typing: discord.Typing) -> None:
|
||||
if typing.channel_id not in self.app.db.list_channels():
|
||||
|
@ -533,7 +585,8 @@ class DiscordClient(Gateway):
|
|||
"""
|
||||
|
||||
# Check the cache first.
|
||||
webhook = self.webhook_cache.get(channel_id)
|
||||
with Cache.lock:
|
||||
webhook = Cache.cache["d_webhooks"].get(channel_id)
|
||||
|
||||
if webhook:
|
||||
return webhook
|
||||
|
@ -551,24 +604,26 @@ class DiscordClient(Gateway):
|
|||
if not webhook:
|
||||
webhook = self.create_webhook(channel_id, name)
|
||||
|
||||
self.webhook_cache[channel_id] = webhook
|
||||
with Cache.lock:
|
||||
Cache.cache["d_webhooks"][channel_id] = webhook
|
||||
|
||||
return webhook
|
||||
|
||||
def process_message(self, message: discord.Message) -> Tuple[str, str]:
|
||||
def process_message(self, message: discord.Message) -> Tuple[str, Dict]:
|
||||
content = message.content
|
||||
emotes = {}
|
||||
regex = r"<a?:(\w+):(\d+)>"
|
||||
|
||||
# Mentions can either be in the form of `<@1234>` or `<@!1234>`.
|
||||
for char in ("", "!"):
|
||||
for member in message.mentions:
|
||||
for char in ("", "!"):
|
||||
content = content.replace(
|
||||
f"<@{char}{member.id}>", f"@{member.username}"
|
||||
)
|
||||
|
||||
# `except_deleted` for invalid channels.
|
||||
for channel in re.findall(r"<#([0-9]+)>", content):
|
||||
# TODO can this block for too long ?
|
||||
for channel in re.findall(r"<#([0-9]{{{discord.ID_LEN}}})>", content):
|
||||
channel_ = except_deleted(self.get_channel)(channel)
|
||||
content = content.replace(
|
||||
f"<#{channel}>",
|
||||
|
@ -613,6 +668,16 @@ def config_gen(basedir: str, config_file: str) -> dict:
|
|||
return json.loads(f.read())
|
||||
|
||||
|
||||
def excepthook(exc_type, exc_value, exc_traceback):
|
||||
if issubclass(exc_type, KeyboardInterrupt):
|
||||
sys.__excepthook__(exc_type, exc_value, exc_traceback)
|
||||
return
|
||||
|
||||
logging.critical(
|
||||
"Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback)
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
basedir = sys.argv[1]
|
||||
|
@ -634,9 +699,9 @@ def main() -> None:
|
|||
],
|
||||
)
|
||||
|
||||
http = urllib3.PoolManager(maxsize=10)
|
||||
sys.excepthook = excepthook
|
||||
|
||||
app = MatrixClient(config, http)
|
||||
app = MatrixClient(config, urllib3.PoolManager(maxsize=10))
|
||||
|
||||
# Start the bottle app in a separate thread.
|
||||
app_thread = threading.Thread(
|
||||
|
@ -645,7 +710,7 @@ def main() -> None:
|
|||
app_thread.start()
|
||||
|
||||
try:
|
||||
asyncio.run(app.discord.start())
|
||||
asyncio.run(app.discord.run())
|
||||
except KeyboardInterrupt:
|
||||
sys.exit()
|
||||
|
||||
|
|
|
@ -2,20 +2,21 @@ from dataclasses import dataclass
|
|||
|
||||
|
||||
@dataclass
|
||||
class User(object):
|
||||
class User:
|
||||
avatar_url: str = ""
|
||||
displayname: str = ""
|
||||
display_name: str = ""
|
||||
|
||||
|
||||
class Event(object):
|
||||
class Event:
|
||||
def __init__(self, event: dict):
|
||||
content = event["content"]
|
||||
content = event.get("content", {})
|
||||
|
||||
self.attachment = content.get("url")
|
||||
self.author = event["author"]
|
||||
self.body = content.get("body", "").strip()
|
||||
self.event_id = event["event_id"]
|
||||
self.formatted_body = content.get("formatted_body", "")
|
||||
self.id = event["event_id"]
|
||||
self.is_direct = content.get("is_direct", False)
|
||||
self.redacts = event.get("redacts", "")
|
||||
self.room_id = event["room_id"]
|
||||
self.sender = event["sender"]
|
||||
self.state_key = event.get("state_key", "")
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import asyncio
|
||||
import json
|
||||
from dataclasses import fields
|
||||
from typing import Any
|
||||
|
@ -8,13 +7,13 @@ import urllib3
|
|||
from errors import RequestError
|
||||
|
||||
|
||||
def dict_cls(dict_var: dict, cls: Any) -> Any:
|
||||
def dict_cls(d: dict, cls: Any) -> Any:
|
||||
"""
|
||||
Create a dataclass from a dictionary.
|
||||
"""
|
||||
|
||||
field_names = set(f.name for f in fields(cls))
|
||||
filtered_dict = {k: v for k, v in dict_var.items() if k in field_names}
|
||||
filtered_dict = {k: v for k, v in d.items() if k in field_names}
|
||||
|
||||
return cls(**filtered_dict)
|
||||
|
||||
|
@ -34,22 +33,6 @@ def log_except(fn):
|
|||
return wrapper
|
||||
|
||||
|
||||
def wrap_async(fn):
|
||||
"""
|
||||
Call an asynchronous function from a synchronous one.
|
||||
"""
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.loop:
|
||||
raise RuntimeError("loop is None.")
|
||||
|
||||
return asyncio.run_coroutine_threadsafe(
|
||||
fn(self, *args, **kwargs), loop=self.loop
|
||||
).result()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def request(fn):
|
||||
"""
|
||||
Either return json data or raise a `RequestError` if the request was
|
||||
|
@ -75,8 +58,7 @@ def request(fn):
|
|||
|
||||
def except_deleted(fn):
|
||||
"""
|
||||
Ignore the `RequestError` on 404s, the message might have been
|
||||
deleted by someone else already.
|
||||
Ignore the `RequestError` on 404s, the content might have been removed.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
bottle==0.12.19
|
||||
urllib3==1.26.3
|
||||
websockets==8.1
|
||||
bottle
|
||||
urllib3
|
||||
websockets
|
||||
|
|
Loading…
Reference in a new issue