cleanup
* Use websocket events to sync usernames/avatars instead of periodic syncing. * Use caches for fetched room_ids, room state (for usernames and avatars). Also switch to using a single cache with locks. * Don't store full message objects in cache, just store the relation of matrix event IDs to discord message IDs and vice-versa. Content can be fetched from the server instead. * Don't rely on websocket events for mentioning Discord users, mentions are now done by mentioning the dummy matrix user. The ID to be mentioned is extracted from the MXID instead. * General clean-ups.
This commit is contained in:
parent
b21c82ccd0
commit
4713a00016
10 changed files with 454 additions and 440 deletions
|
@ -49,20 +49,23 @@ A path can optionally be passed as the first argument to `main.py`. This path wi
|
||||||
Eg. Running `python3 main.py /path/to/my/dir` will store the database and logs in `/path/to/my/dir`.
|
Eg. Running `python3 main.py /path/to/my/dir` will store the database and logs in `/path/to/my/dir`.
|
||||||
`$PWD` is used by default if no path is specified.
|
`$PWD` is used by default if no path is specified.
|
||||||
|
|
||||||
|
After setting up the bridge, send a direct message to `@appservice-discord:domain.tld` containing the channel ID to be bridged (`!bridge 123456`).
|
||||||
|
|
||||||
This bridge is written with:
|
This bridge is written with:
|
||||||
* `bottle`: Receiving events from the homeserver.
|
* `bottle`: Receiving events from the homeserver.
|
||||||
* `urllib3`: Sending requests, thread safety.
|
* `urllib3`: Sending requests, thread safety.
|
||||||
* `websockets`: Connecting to Discord. (Big thanks to an anonymous person "nesslersreagent" for figuring out the initial connection mess.)
|
* `websockets`: Connecting to Discord. (Big thanks to an anonymous person "nesslersreagent" for figuring out the initial connection mess.)
|
||||||
|
|
||||||
## NOTES
|
## NOTES
|
||||||
|
|
||||||
* A basic sqlite database is used for keeping track of bridged rooms.
|
* A basic sqlite database is used for keeping track of bridged rooms.
|
||||||
|
|
||||||
* Logs are saved to the `appservice.log` file in `$PWD` or the specified directory.
|
* Logs are saved to the `appservice.log` file in `$PWD` or the specified directory.
|
||||||
|
|
||||||
* For avatars to show up on Discord, you must have a [reverse proxy](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) set up on your homeserver as the bridge does not specify the homeserver port when passing the avatar url.
|
* For avatars to show up on Discord, you must have a [reverse proxy](https://github.com/matrix-org/dendrite/blob/master/docs/nginx/monolith-sample.conf) set up on your homeserver as the bridge does not specify the homeserver port when passing the avatar url.
|
||||||
|
|
||||||
* It is not possible to add normal Discord bot functionality like commands as this bridge does not use `discord.py`.
|
* It is not possible to add "normal" Discord bot functionality like commands as this bridge does not use `discord.py`.
|
||||||
|
|
||||||
* [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) must be enabled for your Discord bot.
|
* [Privileged Intents](https://discordpy.readthedocs.io/en/latest/intents.html#privileged-intents) for members and presence must be enabled for your Discord bot.
|
||||||
|
|
||||||
* This Appservice might not work well for bridging a large number of rooms since it is mostly synchronous. However, it wouldn't take much effort to port it to `asyncio` and `aiohttp` if desired.
|
* This Appservice might not work well for bridging a large number of rooms since it is mostly synchronous. However, it wouldn't take much effort to port it to `asyncio` and `aiohttp` if desired.
|
||||||
|
|
|
@ -2,13 +2,14 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Union
|
from typing import Union
|
||||||
|
|
||||||
import bottle
|
import bottle
|
||||||
import urllib3
|
import urllib3
|
||||||
|
|
||||||
import matrix
|
import matrix
|
||||||
from misc import dict_cls, except_deleted, log_except, request
|
from cache import Cache
|
||||||
|
from misc import log_except, request
|
||||||
|
|
||||||
|
|
||||||
class AppService(bottle.Bottle):
|
class AppService(bottle.Bottle):
|
||||||
|
@ -37,13 +38,17 @@ class AppService(bottle.Bottle):
|
||||||
method="PUT",
|
method="PUT",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Cache.cache["m_rooms"] = {}
|
||||||
|
|
||||||
def handle_event(self, event: dict) -> None:
|
def handle_event(self, event: dict) -> None:
|
||||||
event_type = event.get("type")
|
event_type = event.get("type")
|
||||||
|
|
||||||
if event_type == "m.room.member" or event_type == "m.room.message":
|
if event_type in (
|
||||||
obj = self.get_event_object(event)
|
"m.room.member",
|
||||||
elif event_type == "m.room.redaction":
|
"m.room.message",
|
||||||
obj = event
|
"m.room.redaction",
|
||||||
|
):
|
||||||
|
obj = matrix.Event(event)
|
||||||
else:
|
else:
|
||||||
self.logger.info(f"Unknown event type: {event_type}")
|
self.logger.info(f"Unknown event type: {event_type}")
|
||||||
return
|
return
|
||||||
|
@ -86,23 +91,14 @@ class AppService(bottle.Bottle):
|
||||||
def mxc_url(self, mxc: str) -> str:
|
def mxc_url(self, mxc: str) -> str:
|
||||||
try:
|
try:
|
||||||
homeserver, media_id = mxc.replace("mxc://", "").split("/")
|
homeserver, media_id = mxc.replace("mxc://", "").split("/")
|
||||||
converted = (
|
|
||||||
f"https://{self.server_name}/_matrix/media/r0/download/"
|
|
||||||
f"{homeserver}/{media_id}"
|
|
||||||
)
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
converted = ""
|
return ""
|
||||||
|
|
||||||
return converted
|
return (
|
||||||
|
f"https://{self.server_name}/_matrix/media/r0/download/"
|
||||||
def get_event_object(self, event: dict) -> matrix.Event:
|
f"{homeserver}/{media_id}"
|
||||||
# TODO use caching and invalidate old cache on member events.
|
|
||||||
event["author"] = dict_cls(
|
|
||||||
self.get_profile(event["sender"]), matrix.User
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return matrix.Event(event)
|
|
||||||
|
|
||||||
def join_room(self, room_id: str, mxid: str = "") -> None:
|
def join_room(self, room_id: str, mxid: str = "") -> None:
|
||||||
self.send(
|
self.send(
|
||||||
"POST",
|
"POST",
|
||||||
|
@ -117,42 +113,25 @@ class AppService(bottle.Bottle):
|
||||||
params={"user_id": mxid} if mxid else {},
|
params={"user_id": mxid} if mxid else {},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_profile(self, mxid: str) -> dict:
|
|
||||||
resp = except_deleted(self.send)("GET", f"/profile/{mxid}")
|
|
||||||
|
|
||||||
# No profile exists for the user.
|
|
||||||
if not resp:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
avatar_url = resp.get("avatar_url")
|
|
||||||
|
|
||||||
if avatar_url:
|
|
||||||
avatar_url = self.mxc_url(avatar_url)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"avatar_url": avatar_url,
|
|
||||||
"displayname": resp.get("displayname"),
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_members(self, room_id: str) -> List[str]:
|
|
||||||
resp = self.send(
|
|
||||||
"GET",
|
|
||||||
f"/rooms/{room_id}/members",
|
|
||||||
params={"membership": "join", "not_membership": "leave"},
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
|
||||||
content["sender"]
|
|
||||||
for content in resp["chunk"]
|
|
||||||
if content["content"]["membership"] == "join"
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_room_id(self, alias: str) -> str:
|
def get_room_id(self, alias: str) -> str:
|
||||||
|
with Cache.lock:
|
||||||
|
room = Cache.cache["m_rooms"].get(alias)
|
||||||
|
if room:
|
||||||
|
return room
|
||||||
|
|
||||||
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
|
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
|
||||||
|
|
||||||
# TODO cache ?
|
room_id = resp["room_id"]
|
||||||
|
|
||||||
return resp["room_id"]
|
with Cache.lock:
|
||||||
|
Cache.cache["m_rooms"][alias] = room_id
|
||||||
|
|
||||||
|
return room_id
|
||||||
|
|
||||||
|
def get_event(self, event_id: str, room_id: str) -> matrix.Event:
|
||||||
|
resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}")
|
||||||
|
|
||||||
|
return matrix.Event(resp)
|
||||||
|
|
||||||
def upload(self, url: str) -> str:
|
def upload(self, url: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -211,12 +190,12 @@ class AppService(bottle.Bottle):
|
||||||
) -> dict:
|
) -> dict:
|
||||||
params["access_token"] = self.as_token
|
params["access_token"] = self.as_token
|
||||||
headers = {"Content-Type": content_type}
|
headers = {"Content-Type": content_type}
|
||||||
content = json.dumps(content) if isinstance(content, dict) else content
|
payload = json.dumps(content) if isinstance(content, dict) else content
|
||||||
endpoint = (
|
endpoint = (
|
||||||
f"{self.base_url}{endpoint}{path}?"
|
f"{self.base_url}{endpoint}{path}?"
|
||||||
f"{urllib.parse.urlencode(params)}"
|
f"{urllib.parse.urlencode(params)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.http.request(
|
return self.http.request(
|
||||||
method, endpoint, body=content, headers=headers
|
method, endpoint, body=payload, headers=headers
|
||||||
)
|
)
|
||||||
|
|
6
appservice/cache.py
Normal file
6
appservice/cache.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class Cache:
|
||||||
|
cache = {}
|
||||||
|
lock = threading.Lock()
|
|
@ -4,11 +4,11 @@ import threading
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class DataBase(object):
|
class DataBase:
|
||||||
def __init__(self, db_file) -> None:
|
def __init__(self, db_file) -> None:
|
||||||
self.create(db_file)
|
self.create(db_file)
|
||||||
|
|
||||||
# The database is accessed via both the threads.
|
# The database is accessed via multiple threads.
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
def create(self, db_file) -> None:
|
def create(self, db_file) -> None:
|
||||||
|
@ -92,7 +92,7 @@ class DataBase(object):
|
||||||
|
|
||||||
room = self.cur.fetchone()
|
room = self.cur.fetchone()
|
||||||
|
|
||||||
# Return an empty string if nothing is bridged.
|
# Return an empty string if the channel is not bridged.
|
||||||
return "" if not room else room["channel_id"]
|
return "" if not room else room["channel_id"]
|
||||||
|
|
||||||
def list_channels(self) -> List[str]:
|
def list_channels(self) -> List[str]:
|
||||||
|
@ -116,6 +116,8 @@ class DataBase(object):
|
||||||
self.cur.execute("SELECT * FROM users")
|
self.cur.execute("SELECT * FROM users")
|
||||||
users = self.cur.fetchall()
|
users = self.cur.fetchall()
|
||||||
|
|
||||||
user = [user for user in users if user["mxid"] == mxid]
|
user: dict = next(
|
||||||
|
iter([user for user in users if user["mxid"] == mxid]), {}
|
||||||
|
)
|
||||||
|
|
||||||
return user[0] if user else {}
|
return user
|
||||||
|
|
|
@ -1,10 +1,16 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from misc import dict_cls
|
||||||
|
|
||||||
CDN_URL = "https://cdn.discordapp.com"
|
CDN_URL = "https://cdn.discordapp.com"
|
||||||
|
ID_LEN = 18
|
||||||
|
|
||||||
|
|
||||||
|
def bitmask(bit: int) -> int:
|
||||||
|
return 1 << bit
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Channel(object):
|
class Channel:
|
||||||
id: str
|
id: str
|
||||||
type: str
|
type: str
|
||||||
guild_id: str = ""
|
guild_id: str = ""
|
||||||
|
@ -13,13 +19,32 @@ class Channel(object):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Emote(object):
|
class Emote:
|
||||||
animated: bool
|
animated: bool
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class User(object):
|
@dataclass
|
||||||
|
class MessageReference:
|
||||||
|
message_id: str
|
||||||
|
channel_id: str
|
||||||
|
guild_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Typing:
|
||||||
|
user_id: str
|
||||||
|
channel_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Webhook:
|
||||||
|
id: str
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
|
class User:
|
||||||
def __init__(self, user: dict) -> None:
|
def __init__(self, user: dict) -> None:
|
||||||
self.discriminator = user["discriminator"]
|
self.discriminator = user["discriminator"]
|
||||||
self.id = user["id"]
|
self.id = user["id"]
|
||||||
|
@ -38,45 +63,57 @@ class User(object):
|
||||||
self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}"
|
self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}"
|
||||||
|
|
||||||
|
|
||||||
class Message(object):
|
class Guild:
|
||||||
|
def __init__(self, guild: dict) -> None:
|
||||||
|
self.guild_id = guild["id"]
|
||||||
|
self.channels = [dict_cls(c, Channel) for c in guild["channels"]]
|
||||||
|
self.emojis = [dict_cls(e, Emote) for e in guild["emojis"]]
|
||||||
|
members = [member["user"] for member in guild["members"]]
|
||||||
|
self.members = [User(m) for m in members]
|
||||||
|
|
||||||
|
|
||||||
|
class GuildEmojisUpdate:
|
||||||
|
def __init__(self, update: dict) -> None:
|
||||||
|
self.guild_id = update["guild_id"]
|
||||||
|
self.emojis = [dict_cls(e, Emote) for e in update["emojis"]]
|
||||||
|
|
||||||
|
|
||||||
|
class GuildMembersChunk:
|
||||||
|
def __init__(self, chunk: dict) -> None:
|
||||||
|
self.chunk_index = chunk["chunk_index"]
|
||||||
|
self.chunk_count = chunk["chunk_count"]
|
||||||
|
self.guild_id = chunk["guild_id"]
|
||||||
|
self.members = [User(m) for m in chunk["members"]]
|
||||||
|
|
||||||
|
|
||||||
|
class GuildMemberUpdate:
|
||||||
|
def __init__(self, update: dict) -> None:
|
||||||
|
self.guild_id = update["guild_id"]
|
||||||
|
self.user = User(update["user"])
|
||||||
|
|
||||||
|
|
||||||
|
class Message:
|
||||||
def __init__(self, message: dict) -> None:
|
def __init__(self, message: dict) -> None:
|
||||||
self.attachments = message.get("attachments", [])
|
self.attachments = message.get("attachments", [])
|
||||||
self.channel_id = message["channel_id"]
|
self.channel_id = message["channel_id"]
|
||||||
self.content = message.get("content", "")
|
self.content = message.get("content", "")
|
||||||
self.id = message["id"]
|
self.id = message["id"]
|
||||||
self.reference = message.get("message_reference", {}).get(
|
|
||||||
"message_id", ""
|
|
||||||
)
|
|
||||||
self.webhook_id = message.get("webhook_id", "")
|
self.webhook_id = message.get("webhook_id", "")
|
||||||
|
|
||||||
self.mentions = [
|
self.mentions = [
|
||||||
User(mention) for mention in message.get("mentions", [])
|
User(mention) for mention in message.get("mentions", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
ref = message.get("message_reference")
|
||||||
|
|
||||||
|
self.reference = dict_cls(ref, MessageReference) if ref else None
|
||||||
|
|
||||||
author = message.get("author")
|
author = message.get("author")
|
||||||
|
|
||||||
self.author = User(author) if author else None
|
self.author = User(author) if author else None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class ChannelType:
|
||||||
class DeletedMessage(object):
|
|
||||||
channel_id: str
|
|
||||||
id: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Typing(object):
|
|
||||||
user_id: str
|
|
||||||
channel_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Webhook(object):
|
|
||||||
id: str
|
|
||||||
token: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelType(object):
|
|
||||||
GUILD_TEXT = 0
|
GUILD_TEXT = 0
|
||||||
DM = 1
|
DM = 1
|
||||||
GUILD_VOICE = 2
|
GUILD_VOICE = 2
|
||||||
|
@ -86,7 +123,7 @@ class ChannelType(object):
|
||||||
GUILD_STORE = 6
|
GUILD_STORE = 6
|
||||||
|
|
||||||
|
|
||||||
class InteractionResponseType(object):
|
class InteractionResponseType:
|
||||||
PONG = 0
|
PONG = 0
|
||||||
ACKNOWLEDGE = 1
|
ACKNOWLEDGE = 1
|
||||||
CHANNEL_MESSAGE = 2
|
CHANNEL_MESSAGE = 2
|
||||||
|
@ -94,10 +131,7 @@ class InteractionResponseType(object):
|
||||||
ACKNOWLEDGE_WITH_SOURCE = 5
|
ACKNOWLEDGE_WITH_SOURCE = 5
|
||||||
|
|
||||||
|
|
||||||
class GatewayIntents(object):
|
class GatewayIntents:
|
||||||
def bitmask(bit: int) -> int:
|
|
||||||
return 1 << bit
|
|
||||||
|
|
||||||
GUILDS = bitmask(0)
|
GUILDS = bitmask(0)
|
||||||
GUILD_MEMBERS = bitmask(1)
|
GUILD_MEMBERS = bitmask(1)
|
||||||
GUILD_BANS = bitmask(2)
|
GUILD_BANS = bitmask(2)
|
||||||
|
@ -115,7 +149,7 @@ class GatewayIntents(object):
|
||||||
DIRECT_MESSAGE_TYPING = bitmask(14)
|
DIRECT_MESSAGE_TYPING = bitmask(14)
|
||||||
|
|
||||||
|
|
||||||
class GatewayOpCodes(object):
|
class GatewayOpCodes:
|
||||||
DISPATCH = 0
|
DISPATCH = 0
|
||||||
HEARTBEAT = 1
|
HEARTBEAT = 1
|
||||||
IDENTIFY = 2
|
IDENTIFY = 2
|
||||||
|
@ -129,7 +163,7 @@ class GatewayOpCodes(object):
|
||||||
HEARTBEAT_ACK = 11
|
HEARTBEAT_ACK = 11
|
||||||
|
|
||||||
|
|
||||||
class Payloads(object):
|
class Payloads:
|
||||||
def __init__(self, token: str) -> None:
|
def __init__(self, token: str) -> None:
|
||||||
self.seq = self.session = None
|
self.seq = self.session = None
|
||||||
self.token = token
|
self.token = token
|
||||||
|
@ -143,27 +177,19 @@ class Payloads(object):
|
||||||
"d": {
|
"d": {
|
||||||
"token": self.token,
|
"token": self.token,
|
||||||
"intents": GatewayIntents.GUILDS
|
"intents": GatewayIntents.GUILDS
|
||||||
|
| GatewayIntents.GUILD_EMOJIS
|
||||||
|
| GatewayIntents.GUILD_MEMBERS
|
||||||
| GatewayIntents.GUILD_MESSAGES
|
| GatewayIntents.GUILD_MESSAGES
|
||||||
| GatewayIntents.GUILD_MESSAGE_TYPING,
|
| GatewayIntents.GUILD_MESSAGE_TYPING
|
||||||
|
| GatewayIntents.GUILD_PRESENCES,
|
||||||
"properties": {
|
"properties": {
|
||||||
"$os": "discord",
|
"$os": "discord",
|
||||||
"$browser": "discord",
|
"$browser": "Discord Client",
|
||||||
"$device": "discord",
|
"$device": "discord",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def QUERY(self, guild_id: str, query: str, limit: int = 1) -> dict:
|
|
||||||
"""
|
|
||||||
Return the Payload to query a member from a guild ID.
|
|
||||||
Return only a single match if `limit` isn't specified.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return {
|
|
||||||
"op": GatewayOpCodes.REQUEST_GUILD_MEMBERS,
|
|
||||||
"d": {"guild_id": guild_id, "query": query, "limit": limit},
|
|
||||||
}
|
|
||||||
|
|
||||||
def RESUME(self) -> dict:
|
def RESUME(self) -> dict:
|
||||||
return {
|
return {
|
||||||
"op": GatewayOpCodes.RESUME,
|
"op": GatewayOpCodes.RESUME,
|
||||||
|
|
|
@ -8,31 +8,27 @@ import urllib3
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from misc import dict_cls, log_except, request, wrap_async
|
from misc import dict_cls, log_except, request
|
||||||
|
|
||||||
|
|
||||||
class Gateway(object):
|
class Gateway:
|
||||||
def __init__(self, http: urllib3.PoolManager, token: str):
|
def __init__(self, http: urllib3.PoolManager, token: str):
|
||||||
self.http = http
|
self.http = http
|
||||||
self.token = token
|
self.token = token
|
||||||
self.logger = logging.getLogger("discord")
|
self.logger = logging.getLogger("discord")
|
||||||
self.cdn_url = "https://cdn.discordapp.com"
|
|
||||||
self.Payloads = discord.Payloads(self.token)
|
self.Payloads = discord.Payloads(self.token)
|
||||||
self.loop = self.websocket = None
|
self.websocket = None
|
||||||
|
|
||||||
self.query_cache = {}
|
|
||||||
|
|
||||||
@log_except
|
@log_except
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
self.loop = asyncio.get_running_loop()
|
self.heartbeat_task: asyncio.Future = None
|
||||||
self.query_ev = asyncio.Event()
|
|
||||||
|
|
||||||
self.heartbeat_task = None
|
|
||||||
self.resume = False
|
self.resume = False
|
||||||
|
|
||||||
|
gateway_url = self.get_gateway_url()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await self.gateway_handler(self.get_gateway_url())
|
await self.gateway_handler(gateway_url)
|
||||||
except (
|
except (
|
||||||
websockets.ConnectionClosedError,
|
websockets.ConnectionClosedError,
|
||||||
websockets.InvalidMessage,
|
websockets.InvalidMessage,
|
||||||
|
@ -55,26 +51,71 @@ class Gateway(object):
|
||||||
await asyncio.sleep(interval_ms / 1000)
|
await asyncio.sleep(interval_ms / 1000)
|
||||||
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
|
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
|
||||||
|
|
||||||
def query_handler(self, data: dict) -> None:
|
async def handle_resp(self, data: dict) -> None:
|
||||||
members = data["members"]
|
data_dict = data["d"]
|
||||||
guild_id = data["guild_id"]
|
|
||||||
|
|
||||||
for member in members:
|
opcode = data["op"]
|
||||||
user = member["user"]
|
|
||||||
self.query_cache[guild_id].append(user)
|
|
||||||
|
|
||||||
self.query_ev.set()
|
seq = data["s"]
|
||||||
|
|
||||||
|
if seq:
|
||||||
|
self.Payloads.seq = seq
|
||||||
|
|
||||||
|
if opcode == discord.GatewayOpCodes.DISPATCH:
|
||||||
|
otype = data["t"]
|
||||||
|
|
||||||
|
if otype == "READY":
|
||||||
|
self.Payloads.session = data_dict["session_id"]
|
||||||
|
|
||||||
|
self.logger.info("READY")
|
||||||
|
else:
|
||||||
|
self.handle_otype(data_dict, otype)
|
||||||
|
elif opcode == discord.GatewayOpCodes.HELLO:
|
||||||
|
heartbeat_interval = data_dict.get("heartbeat_interval")
|
||||||
|
|
||||||
|
self.logger.info(f"Heartbeat Interval: {heartbeat_interval}")
|
||||||
|
|
||||||
|
# Send periodic hearbeats to gateway.
|
||||||
|
self.heartbeat_task = asyncio.ensure_future(
|
||||||
|
self.heartbeat_handler(heartbeat_interval)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.websocket.send(
|
||||||
|
json.dumps(
|
||||||
|
self.Payloads.RESUME()
|
||||||
|
if self.resume
|
||||||
|
else self.Payloads.IDENTIFY()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif opcode == discord.GatewayOpCodes.RECONNECT:
|
||||||
|
self.logger.info("Received RECONNECT.")
|
||||||
|
|
||||||
|
self.resume = True
|
||||||
|
await self.websocket.close()
|
||||||
|
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
|
||||||
|
self.logger.info("Received INVALID_SESSION.")
|
||||||
|
|
||||||
|
self.resume = False
|
||||||
|
await self.websocket.close()
|
||||||
|
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
|
||||||
|
# NOP
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.logger.info(
|
||||||
|
"Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}"
|
||||||
|
)
|
||||||
|
|
||||||
def handle_otype(self, data: dict, otype: str) -> None:
|
def handle_otype(self, data: dict, otype: str) -> None:
|
||||||
if otype == "MESSAGE_CREATE" or otype == "MESSAGE_UPDATE":
|
if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"):
|
||||||
obj = discord.Message(data)
|
obj = discord.Message(data)
|
||||||
elif otype == "MESSAGE_DELETE":
|
|
||||||
obj = dict_cls(data, discord.DeletedMessage)
|
|
||||||
elif otype == "TYPING_START":
|
elif otype == "TYPING_START":
|
||||||
obj = dict_cls(data, discord.Typing)
|
obj = dict_cls(data, discord.Typing)
|
||||||
elif otype == "GUILD_MEMBERS_CHUNK":
|
elif otype == "GUILD_CREATE":
|
||||||
self.query_handler(data)
|
obj = discord.Guild(data)
|
||||||
return
|
elif otype == "GUILD_MEMBER_UPDATE":
|
||||||
|
obj = discord.GuildMemberUpdate(data)
|
||||||
|
elif otype == "GUILD_EMOJIS_UPDATE":
|
||||||
|
obj = discord.GuildEmojisUpdate(data)
|
||||||
else:
|
else:
|
||||||
self.logger.info(f"Unknown OTYPE: {otype}")
|
self.logger.info(f"Unknown OTYPE: {otype}")
|
||||||
return
|
return
|
||||||
|
@ -90,119 +131,20 @@ class Gateway(object):
|
||||||
try:
|
try:
|
||||||
func(obj)
|
func(obj)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.logger.exception(f"Ignoring exception in {func}:")
|
self.logger.exception(f"Ignoring exception in '{func.__name__}':")
|
||||||
|
|
||||||
async def gateway_handler(self, gateway_url: str) -> None:
|
async def gateway_handler(self, gateway_url: str) -> None:
|
||||||
async with websockets.connect(
|
async with websockets.connect(
|
||||||
f"{gateway_url}/?v=8&encoding=json"
|
f"{gateway_url}/?v=8&encoding=json"
|
||||||
) as websocket:
|
) as websocket:
|
||||||
self.websocket = websocket
|
self.websocket = websocket
|
||||||
|
|
||||||
async for message in websocket:
|
async for message in websocket:
|
||||||
data = json.loads(message)
|
await self.handle_resp(json.loads(message))
|
||||||
data_dict = data.get("d")
|
|
||||||
|
|
||||||
opcode = data.get("op")
|
|
||||||
|
|
||||||
seq = data.get("s")
|
|
||||||
if seq:
|
|
||||||
self.Payloads.seq = seq
|
|
||||||
|
|
||||||
if opcode == discord.GatewayOpCodes.DISPATCH:
|
|
||||||
otype = data.get("t")
|
|
||||||
|
|
||||||
if otype == "READY":
|
|
||||||
self.Payloads.session = data_dict["session_id"]
|
|
||||||
|
|
||||||
self.logger.info("READY")
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.handle_otype(data_dict, otype)
|
|
||||||
|
|
||||||
elif opcode == discord.GatewayOpCodes.HELLO:
|
|
||||||
heartbeat_interval = data_dict.get("heartbeat_interval")
|
|
||||||
|
|
||||||
self.logger.info(
|
|
||||||
f"Heartbeat Interval: {heartbeat_interval}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Send periodic hearbeats to gateway.
|
|
||||||
self.heartbeat_task = asyncio.ensure_future(
|
|
||||||
self.heartbeat_handler(heartbeat_interval)
|
|
||||||
)
|
|
||||||
|
|
||||||
await websocket.send(
|
|
||||||
json.dumps(
|
|
||||||
self.Payloads.RESUME()
|
|
||||||
if self.resume
|
|
||||||
else self.Payloads.IDENTIFY()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif opcode == discord.GatewayOpCodes.RECONNECT:
|
|
||||||
self.logger.info("Received RECONNECT.")
|
|
||||||
|
|
||||||
self.resume = True
|
|
||||||
await websocket.close()
|
|
||||||
|
|
||||||
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
|
|
||||||
self.logger.info("Received INVALID_SESSION.")
|
|
||||||
|
|
||||||
self.resume = False
|
|
||||||
await websocket.close()
|
|
||||||
|
|
||||||
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
|
|
||||||
# NOP
|
|
||||||
pass
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.logger.info(
|
|
||||||
f"Unknown OP code {opcode}:\n"
|
|
||||||
f"{json.dumps(data, indent=4)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@wrap_async
|
|
||||||
async def query_member(self, guild_id: str, name: str) -> discord.User:
|
|
||||||
"""
|
|
||||||
Query the members for a given guild and return the first match.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.query_ev.clear()
|
|
||||||
|
|
||||||
def query():
|
|
||||||
if not self.query_cache.get(guild_id):
|
|
||||||
self.query_cache[guild_id] = []
|
|
||||||
|
|
||||||
user = [
|
|
||||||
user
|
|
||||||
for user in self.query_cache[guild_id]
|
|
||||||
if name.lower() in user["username"].lower()
|
|
||||||
]
|
|
||||||
|
|
||||||
return None if not user else discord.User(user[0])
|
|
||||||
|
|
||||||
user = query()
|
|
||||||
|
|
||||||
if user:
|
|
||||||
return user
|
|
||||||
|
|
||||||
if not self.websocket or self.websocket.closed:
|
|
||||||
self.logger.warning("Not fetching members, websocket closed.")
|
|
||||||
return
|
|
||||||
|
|
||||||
await self.websocket.send(
|
|
||||||
json.dumps(self.Payloads.QUERY(guild_id, name))
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO clean this mess.
|
|
||||||
|
|
||||||
# Wait for our websocket to receive the chunk.
|
|
||||||
await asyncio.wait_for(self.query_ev.wait(), timeout=5)
|
|
||||||
|
|
||||||
return query()
|
|
||||||
|
|
||||||
def get_channel(self, channel_id: str) -> discord.Channel:
|
def get_channel(self, channel_id: str) -> discord.Channel:
|
||||||
"""
|
"""
|
||||||
Get the channel object for a given channel ID.
|
Get the channel for a given channel ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
resp = self.send("GET", f"/channels/{channel_id}")
|
resp = self.send("GET", f"/channels/{channel_id}")
|
||||||
|
@ -259,9 +201,17 @@ class Gateway(object):
|
||||||
f"{message_id}",
|
f"{message_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_webhook(self, webhook: discord.Webhook, **kwargs) -> str:
|
def send_webhook(
|
||||||
content = {
|
self,
|
||||||
**kwargs,
|
webhook: discord.Webhook,
|
||||||
|
avatar_url: str,
|
||||||
|
content: str,
|
||||||
|
username: str,
|
||||||
|
) -> discord.Message:
|
||||||
|
payload = {
|
||||||
|
"avatar_url": avatar_url,
|
||||||
|
"content": content,
|
||||||
|
"username": username,
|
||||||
# Disable 'everyone' and 'role' mentions.
|
# Disable 'everyone' and 'role' mentions.
|
||||||
"allowed_mentions": {"parse": ["users"]},
|
"allowed_mentions": {"parse": ["users"]},
|
||||||
}
|
}
|
||||||
|
@ -269,11 +219,11 @@ class Gateway(object):
|
||||||
resp = self.send(
|
resp = self.send(
|
||||||
"POST",
|
"POST",
|
||||||
f"/webhooks/{webhook.id}/{webhook.token}",
|
f"/webhooks/{webhook.id}/{webhook.token}",
|
||||||
content,
|
payload,
|
||||||
{"wait": True},
|
{"wait": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
return resp["id"]
|
return discord.Message(resp)
|
||||||
|
|
||||||
def send_message(self, message: str, channel_id: str) -> None:
|
def send_message(self, message: str, channel_id: str) -> None:
|
||||||
self.send(
|
self.send(
|
||||||
|
@ -294,8 +244,8 @@ class Gateway(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
# 'body' being an empty dict breaks "GET" requests.
|
# 'body' being an empty dict breaks "GET" requests.
|
||||||
content = json.dumps(content) if content else None
|
payload = json.dumps(content) if content else None
|
||||||
|
|
||||||
return self.http.request(
|
return self.http.request(
|
||||||
method, endpoint, body=content, headers=headers
|
method, endpoint, body=payload, headers=headers
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,21 +5,19 @@ import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from typing import Dict, Tuple, Union
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import matrix
|
import matrix
|
||||||
from appservice import AppService
|
from appservice import AppService
|
||||||
|
from cache import Cache
|
||||||
from db import DataBase
|
from db import DataBase
|
||||||
from errors import RequestError
|
from errors import RequestError
|
||||||
from gateway import Gateway
|
from gateway import Gateway
|
||||||
from misc import dict_cls, except_deleted, hash_str
|
from misc import dict_cls, except_deleted, hash_str
|
||||||
|
|
||||||
# TODO should this be cleared periodically ?
|
|
||||||
message_cache: Dict[str, Union[discord.Webhook, str]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
class MatrixClient(AppService):
|
class MatrixClient(AppService):
|
||||||
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
|
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
|
||||||
|
@ -27,9 +25,11 @@ class MatrixClient(AppService):
|
||||||
|
|
||||||
self.db = DataBase(config["database"])
|
self.db = DataBase(config["database"])
|
||||||
self.discord = DiscordClient(self, config, http)
|
self.discord = DiscordClient(self, config, http)
|
||||||
self.emote_cache: Dict[str, str] = {}
|
|
||||||
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
|
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
|
||||||
|
|
||||||
|
for k in ("m_emotes", "m_members", "m_messages"):
|
||||||
|
Cache.cache[k] = {}
|
||||||
|
|
||||||
def handle_bridge(self, message: matrix.Event) -> None:
|
def handle_bridge(self, message: matrix.Event) -> None:
|
||||||
# Ignore events that aren't for us.
|
# Ignore events that aren't for us.
|
||||||
if message.sender.split(":")[
|
if message.sender.split(":")[
|
||||||
|
@ -37,6 +37,7 @@ class MatrixClient(AppService):
|
||||||
] != self.server_name or not message.body.startswith("!bridge"):
|
] != self.server_name or not message.body.startswith("!bridge"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Get the channel ID.
|
||||||
try:
|
try:
|
||||||
channel = message.body.split()[1]
|
channel = message.body.split()[1]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
|
@ -46,7 +47,7 @@ class MatrixClient(AppService):
|
||||||
try:
|
try:
|
||||||
channel = self.discord.get_channel(channel)
|
channel = self.discord.get_channel(channel)
|
||||||
except RequestError as e:
|
except RequestError as e:
|
||||||
# The channel can be invalid or we may not have permission.
|
# The channel can be invalid or we may not have permissions.
|
||||||
self.logger.warning(f"Failed to fetch channel {channel}: {e}")
|
self.logger.warning(f"Failed to fetch channel {channel}: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -61,7 +62,15 @@ class MatrixClient(AppService):
|
||||||
self.create_room(channel, message.sender)
|
self.create_room(channel, message.sender)
|
||||||
|
|
||||||
def on_member(self, event: matrix.Event) -> None:
|
def on_member(self, event: matrix.Event) -> None:
|
||||||
# Ignore events that aren't for us.
|
with Cache.lock:
|
||||||
|
# Just lazily clear the whole member cache on
|
||||||
|
# membership update events.
|
||||||
|
if event.room_id in Cache.cache["m_members"]:
|
||||||
|
self.logger.info(
|
||||||
|
f"Clearing member cache for room '{event.room_id}'."
|
||||||
|
)
|
||||||
|
del Cache.cache["m_members"][event.room_id]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
event.sender.split(":")[-1] != self.server_name
|
event.sender.split(":")[-1] != self.server_name
|
||||||
or event.state_key != self.user_id
|
or event.state_key != self.user_id
|
||||||
|
@ -70,7 +79,7 @@ class MatrixClient(AppService):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Join the direct message room.
|
# Join the direct message room.
|
||||||
self.logger.info(f"Joining direct message room {event.room_id}.")
|
self.logger.info(f"Joining direct message room '{event.room_id}'.")
|
||||||
self.join_room(event.room_id)
|
self.join_room(event.room_id)
|
||||||
|
|
||||||
def on_message(self, message: matrix.Event) -> None:
|
def on_message(self, message: matrix.Event) -> None:
|
||||||
|
@ -88,51 +97,78 @@ class MatrixClient(AppService):
|
||||||
if not channel_id:
|
if not channel_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
webhook = self.discord.get_webhook(channel_id, "matrix_bridge")
|
author = self.get_members(message.room_id)[message.sender]
|
||||||
|
|
||||||
|
if not author.display_name:
|
||||||
|
author.display_name = message.sender
|
||||||
|
|
||||||
|
webhook = self.discord.get_webhook(
|
||||||
|
channel_id, self.discord.webhook_name
|
||||||
|
)
|
||||||
|
|
||||||
if message.relates_to and message.reltype == "m.replace":
|
if message.relates_to and message.reltype == "m.replace":
|
||||||
relation = message_cache.get(message.relates_to)
|
with Cache.lock:
|
||||||
|
message_id = Cache.cache["m_messages"].get(message.relates_to)
|
||||||
|
|
||||||
if not message.new_body or not relation:
|
if not message_id or not message.new_body:
|
||||||
return
|
return
|
||||||
|
|
||||||
message.new_body = self.process_message(
|
message.new_body = self.process_message(message)
|
||||||
channel_id, message.new_body
|
|
||||||
)
|
|
||||||
|
|
||||||
except_deleted(self.discord.edit_webhook)(
|
except_deleted(self.discord.edit_webhook)(
|
||||||
message.new_body, relation["message_id"], webhook
|
message.new_body, message_id, webhook
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
message.body = (
|
message.body = (
|
||||||
f"`{message.body}`: {self.mxc_url(message.attachment)}"
|
f"`{message.body}`: {self.mxc_url(message.attachment)}"
|
||||||
if message.attachment
|
if message.attachment
|
||||||
else self.process_message(channel_id, message.body)
|
else self.process_message(message)
|
||||||
)
|
)
|
||||||
|
|
||||||
message_cache[message.event_id] = {
|
message_id = self.discord.send_webhook(
|
||||||
"message_id": self.discord.send_webhook(
|
webhook,
|
||||||
webhook,
|
self.mxc_url(author.avatar_url),
|
||||||
avatar_url=message.author.avatar_url,
|
message.body,
|
||||||
content=message.body,
|
author.display_name,
|
||||||
username=message.author.displayname,
|
).id
|
||||||
),
|
|
||||||
"webhook": webhook,
|
|
||||||
}
|
|
||||||
|
|
||||||
@except_deleted
|
with Cache.lock:
|
||||||
def on_redaction(self, event: dict) -> None:
|
Cache.cache["m_messages"][message.id] = message_id
|
||||||
redacts = event["redacts"]
|
|
||||||
|
|
||||||
event = message_cache.get(redacts)
|
def on_redaction(self, event: matrix.Event) -> None:
|
||||||
|
with Cache.lock:
|
||||||
|
message_id = Cache.cache["m_messages"].get(event.redacts)
|
||||||
|
|
||||||
if not event:
|
if not message_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.discord.delete_webhook(event["message_id"], event["webhook"])
|
webhook = self.discord.get_webhook(
|
||||||
|
self.db.get_channel(event.room_id), self.discord.webhook_name
|
||||||
|
)
|
||||||
|
|
||||||
message_cache.pop(redacts)
|
except_deleted(self.discord.delete_webhook)(message_id, webhook)
|
||||||
|
|
||||||
|
with Cache.lock:
|
||||||
|
del Cache.cache["m_messages"][event.redacts]
|
||||||
|
|
||||||
|
def get_members(self, room_id: str) -> Dict[str, matrix.User]:
|
||||||
|
with Cache.lock:
|
||||||
|
cached = Cache.cache["m_members"].get(room_id)
|
||||||
|
|
||||||
|
if cached:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
resp = self.send("GET", f"/rooms/{room_id}/joined_members")
|
||||||
|
|
||||||
|
joined = resp["joined"]
|
||||||
|
|
||||||
|
for k, v in joined.items():
|
||||||
|
joined[k] = dict_cls(v, matrix.User)
|
||||||
|
|
||||||
|
with Cache.lock:
|
||||||
|
Cache.cache["m_members"][room_id] = joined
|
||||||
|
|
||||||
|
return joined
|
||||||
|
|
||||||
def create_room(self, channel: discord.Channel, sender: str) -> None:
|
def create_room(self, channel: discord.Channel, sender: str) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -164,7 +200,11 @@ class MatrixClient(AppService):
|
||||||
self.db.add_room(resp["room_id"], channel.id)
|
self.db.add_room(resp["room_id"], channel.id)
|
||||||
|
|
||||||
def create_message_event(
|
def create_message_event(
|
||||||
self, message: str, emotes: dict, edit: str = "", reply: str = ""
|
self,
|
||||||
|
message: str,
|
||||||
|
emotes: dict,
|
||||||
|
edit: str = "",
|
||||||
|
reference: discord.MessageReference = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
content = {
|
content = {
|
||||||
"body": message,
|
"body": message,
|
||||||
|
@ -173,20 +213,39 @@ class MatrixClient(AppService):
|
||||||
"formatted_body": self.get_fmt(message, emotes),
|
"formatted_body": self.get_fmt(message, emotes),
|
||||||
}
|
}
|
||||||
|
|
||||||
event = message_cache.get(reply)
|
if reference:
|
||||||
|
# Reply to a Discord message.
|
||||||
|
with Cache.lock:
|
||||||
|
event_id = Cache.cache["d_messages"].get(reference.message_id)
|
||||||
|
|
||||||
if event:
|
# Reply to a Matrix message. (maybe)
|
||||||
content = {
|
if not event_id:
|
||||||
**content,
|
with Cache.lock:
|
||||||
"m.relates_to": {
|
event_id = [
|
||||||
"m.in_reply_to": {"event_id": event["event_id"]}
|
k
|
||||||
},
|
for k, v in Cache.cache["m_messages"].items()
|
||||||
"formatted_body": f"""<mx-reply><blockquote>\
|
if v == reference.message_id
|
||||||
<a href='https://matrix.to/#/{event["room_id"]}/{event["event_id"]}'>\
|
]
|
||||||
In reply to</a><a href='https://matrix.to/#/{event["mxid"]}'>\
|
event_id = next(iter(event_id), "")
|
||||||
{event["mxid"]}</a><br>{event["body"]}</blockquote></mx-reply>\
|
|
||||||
|
if reference and event_id:
|
||||||
|
event = except_deleted(self.get_event)(
|
||||||
|
event_id,
|
||||||
|
self.get_room_id(self.discord.matrixify(reference.channel_id)),
|
||||||
|
)
|
||||||
|
if event:
|
||||||
|
content = {
|
||||||
|
**content,
|
||||||
|
"body": (
|
||||||
|
f"> <{event.sender}> {event.body}\n{content['body']}"
|
||||||
|
),
|
||||||
|
"m.relates_to": {"m.in_reply_to": {"event_id": event.id}},
|
||||||
|
"formatted_body": f"""<mx-reply><blockquote>\
|
||||||
|
<a href='https://matrix.to/#/{event.room_id}/{event.id}'>\
|
||||||
|
In reply to</a><a href='https://matrix.to/#/{event.sender}'>\
|
||||||
|
{event.sender}</a><br>{event.formatted_body}</blockquote></mx-reply>\
|
||||||
{content["formatted_body"]}""",
|
{content["formatted_body"]}""",
|
||||||
}
|
}
|
||||||
|
|
||||||
if edit:
|
if edit:
|
||||||
content = {
|
content = {
|
||||||
|
@ -227,63 +286,67 @@ In reply to</a><a href='https://matrix.to/#/{event["mxid"]}'>\
|
||||||
for emote in emotes
|
for emote in emotes
|
||||||
]
|
]
|
||||||
|
|
||||||
[thread.start() for thread in upload_threads]
|
# Acquire the lock before starting the threads to avoid resource
|
||||||
[thread.join() for thread in upload_threads]
|
# contention by tens of threads at once.
|
||||||
|
with Cache.lock:
|
||||||
|
for thread in upload_threads:
|
||||||
|
thread.start()
|
||||||
|
for thread in upload_threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
for emote in emotes:
|
with Cache.lock:
|
||||||
emote_ = self.emote_cache.get(emote)
|
for emote in emotes:
|
||||||
|
emote_ = Cache.cache["m_emotes"].get(emote)
|
||||||
|
|
||||||
if emote_:
|
if emote_:
|
||||||
emote = f":{emote}:"
|
emote = f":{emote}:"
|
||||||
message = message.replace(
|
message = message.replace(
|
||||||
emote,
|
emote,
|
||||||
f"""<img alt=\"{emote}\" title=\"{emote}\" \
|
f"""<img alt=\"{emote}\" title=\"{emote}\" \
|
||||||
height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
|
height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
|
||||||
)
|
)
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def process_message(self, channel_id: str, message: str) -> str:
|
def process_message(self, event: matrix.Event) -> str:
|
||||||
|
message = event.new_body if event.new_body else event.body
|
||||||
|
|
||||||
message = message[:2000] # Discord limit.
|
message = message[:2000] # Discord limit.
|
||||||
|
|
||||||
|
id_regex = f"[0-9]{{{discord.ID_LEN}}}"
|
||||||
|
|
||||||
emotes = re.findall(r":(\w*):", message)
|
emotes = re.findall(r":(\w*):", message)
|
||||||
mentions = re.findall(r"(@(\w*))", message)
|
mentions = re.findall(
|
||||||
|
f"@{self.format}{id_regex}:{re.escape(self.server_name)}",
|
||||||
|
event.formatted_body,
|
||||||
|
)
|
||||||
|
|
||||||
# Remove the puppet user's username from replies.
|
with Cache.lock:
|
||||||
message = re.sub(f"<@{self.format}.+?>", "", message)
|
for emote in set(emotes):
|
||||||
|
emote_ = Cache.cache["d_emotes"].get(emote)
|
||||||
added_emotes = []
|
|
||||||
for emote in emotes:
|
|
||||||
# Don't replace emote names with IDs multiple times.
|
|
||||||
if emote not in added_emotes:
|
|
||||||
added_emotes.append(emote)
|
|
||||||
emote_ = self.discord.emote_cache.get(emote)
|
|
||||||
if emote_:
|
if emote_:
|
||||||
message = message.replace(f":{emote}:", emote_)
|
message = message.replace(f":{emote}:", emote_)
|
||||||
|
|
||||||
# Don't unnecessarily fetch the channel.
|
for mention in set(mentions):
|
||||||
if mentions:
|
username = self.db.fetch_user(mention).get("username")
|
||||||
guild_id = self.discord.get_channel(channel_id).guild_id
|
if username:
|
||||||
|
match = re.search(id_regex, mention)
|
||||||
|
|
||||||
# TODO this can block for too long if a long list is to be fetched.
|
if match:
|
||||||
for mention in mentions:
|
# Replace the 'mention' so that the user is tagged
|
||||||
if not mention[1]:
|
# in the case of replies aswell.
|
||||||
continue
|
# '> <@_discord_1234:localhost> Message'
|
||||||
|
for replace in (mention, username):
|
||||||
try:
|
message = message.replace(
|
||||||
member = self.discord.query_member(guild_id, mention[1])
|
replace, f"<@{match.group()}>"
|
||||||
except (asyncio.TimeoutError, RuntimeError):
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
if member:
|
|
||||||
message = message.replace(mention[0], member.mention)
|
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def upload_emote(self, emote_name: str, emote_id: str) -> None:
|
def upload_emote(self, emote_name: str, emote_id: str) -> None:
|
||||||
# There won't be a race condition here, since only a unique
|
# There won't be a race condition here, since only a unique
|
||||||
# set of emotes are uploaded at a time.
|
# set of emotes are uploaded at a time.
|
||||||
if emote_name in self.emote_cache:
|
if emote_name in Cache.cache["m_emotes"]:
|
||||||
return
|
return
|
||||||
|
|
||||||
emote_url = f"{discord.CDN_URL}/emojis/{emote_id}"
|
emote_url = f"{discord.CDN_URL}/emojis/{emote_id}"
|
||||||
|
@ -291,7 +354,8 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
|
||||||
# We don't want the message to be dropped entirely if an emote
|
# We don't want the message to be dropped entirely if an emote
|
||||||
# fails to upload for some reason.
|
# fails to upload for some reason.
|
||||||
try:
|
try:
|
||||||
self.emote_cache[emote_name] = self.upload(emote_url)
|
# TODO This is not thread safe, but we're protected by the GIL.
|
||||||
|
Cache.cache["m_emotes"][emote_name] = self.upload(emote_url)
|
||||||
except RequestError as e:
|
except RequestError as e:
|
||||||
self.logger.warning(f"Failed to upload emote {emote_id}: {e}")
|
self.logger.warning(f"Failed to upload emote {emote_id}: {e}")
|
||||||
|
|
||||||
|
@ -340,66 +404,19 @@ class DiscordClient(Gateway):
|
||||||
super().__init__(http, config["discord_token"])
|
super().__init__(http, config["discord_token"])
|
||||||
|
|
||||||
self.app = appservice
|
self.app = appservice
|
||||||
self.emote_cache: Dict[str, str] = {}
|
self.webhook_name = "matrix_bridge"
|
||||||
self.webhook_cache: Dict[str, discord.Webhook] = {}
|
|
||||||
|
|
||||||
async def sync(self) -> None:
|
for k in ("d_emotes", "d_messages", "d_webhooks"):
|
||||||
"""
|
Cache.cache[k] = {}
|
||||||
Periodically compare the usernames and avatar URLs with Discord
|
|
||||||
and update if they differ. Also synchronise emotes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO use websocket events and requests.
|
|
||||||
|
|
||||||
def sync_emotes(guilds: set):
|
|
||||||
emotes = []
|
|
||||||
|
|
||||||
for guild in guilds:
|
|
||||||
[emotes.append(emote) for emote in (self.get_emotes(guild))]
|
|
||||||
|
|
||||||
self.emote_cache.clear() # Clear deleted/renamed emotes.
|
|
||||||
|
|
||||||
for emote in emotes:
|
|
||||||
self.emote_cache[f"{emote.name}"] = (
|
|
||||||
f"<{'a' if emote.animated else ''}:"
|
|
||||||
f"{emote.name}:{emote.id}>"
|
|
||||||
)
|
|
||||||
|
|
||||||
def sync_users(guilds: set):
|
|
||||||
for guild in guilds:
|
|
||||||
[
|
|
||||||
self.sync_profile(user, self.matrixify(user.id, user=True))
|
|
||||||
for user in self.get_members(guild)
|
|
||||||
]
|
|
||||||
|
|
||||||
while True:
|
|
||||||
guilds = set() # Avoid duplicates.
|
|
||||||
|
|
||||||
try:
|
|
||||||
for channel in self.app.db.list_channels():
|
|
||||||
guilds.add(self.get_channel(channel).guild_id)
|
|
||||||
|
|
||||||
sync_emotes(guilds)
|
|
||||||
sync_users(guilds)
|
|
||||||
# Don't let the background task die.
|
|
||||||
except RequestError:
|
|
||||||
self.logger.exception(
|
|
||||||
"Ignoring exception during background sync:"
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.sleep(120) # Check every 2 minutes.
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
asyncio.ensure_future(self.sync())
|
|
||||||
|
|
||||||
await self.run()
|
|
||||||
|
|
||||||
def to_return(self, message: discord.Message) -> bool:
|
def to_return(self, message: discord.Message) -> bool:
|
||||||
|
with Cache.lock:
|
||||||
|
hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()]
|
||||||
|
|
||||||
return (
|
return (
|
||||||
message.channel_id not in self.app.db.list_channels()
|
message.channel_id not in self.app.db.list_channels()
|
||||||
or not message.author # Embeds can be weird sometimes.
|
or not message.author # Embeds can be weird sometimes.
|
||||||
or message.webhook_id
|
or message.webhook_id in hook_ids
|
||||||
in [hook.id for hook in self.webhook_cache.values()]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def matrixify(self, id: str, user: bool = False) -> str:
|
def matrixify(self, id: str, user: bool = False) -> str:
|
||||||
|
@ -408,11 +425,13 @@ class DiscordClient(Gateway):
|
||||||
f"{self.app.server_name}"
|
f"{self.app.server_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def sync_profile(self, user: discord.User, mxid: str) -> None:
|
def sync_profile(self, user: discord.User) -> None:
|
||||||
"""
|
"""
|
||||||
Sync the avatar and username for a puppeted user.
|
Sync the avatar and username for a puppeted user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
mxid = self.matrixify(user.id, user=True)
|
||||||
|
|
||||||
profile = self.app.db.fetch_user(mxid)
|
profile = self.app.db.fetch_user(mxid)
|
||||||
|
|
||||||
# User doesn't exist.
|
# User doesn't exist.
|
||||||
|
@ -422,10 +441,10 @@ class DiscordClient(Gateway):
|
||||||
username = f"{user.username}#{user.discriminator}"
|
username = f"{user.username}#{user.discriminator}"
|
||||||
|
|
||||||
if user.avatar_url != profile["avatar_url"]:
|
if user.avatar_url != profile["avatar_url"]:
|
||||||
self.logger.info(f"Updating avatar for Discord user {user.id}")
|
self.logger.info(f"Updating avatar for Discord user '{user.id}'")
|
||||||
self.app.set_avatar(user.avatar_url, mxid)
|
self.app.set_avatar(user.avatar_url, mxid)
|
||||||
if username != profile["username"]:
|
if username != profile["username"]:
|
||||||
self.logger.info(f"Updating username for Discord user {user.id}")
|
self.logger.info(f"Updating username for Discord user '{user.id}'")
|
||||||
self.app.set_nick(username, mxid)
|
self.app.set_nick(username, mxid)
|
||||||
|
|
||||||
def wrap(self, message: discord.Message) -> Tuple[str, str]:
|
def wrap(self, message: discord.Message) -> Tuple[str, str]:
|
||||||
|
@ -457,62 +476,95 @@ class DiscordClient(Gateway):
|
||||||
self.app.set_avatar(message.author.avatar_url, mxid)
|
self.app.set_avatar(message.author.avatar_url, mxid)
|
||||||
|
|
||||||
if mxid not in self.app.get_members(room_id):
|
if mxid not in self.app.get_members(room_id):
|
||||||
self.logger.info(f"Inviting user {mxid} to room {room_id}.")
|
self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.")
|
||||||
|
|
||||||
self.app.send_invite(room_id, mxid)
|
self.app.send_invite(room_id, mxid)
|
||||||
self.app.join_room(room_id, mxid)
|
self.app.join_room(room_id, mxid)
|
||||||
|
|
||||||
if message.webhook_id:
|
if message.webhook_id:
|
||||||
# Sync webhooks here as they can't be accessed like guild members.
|
# Sync webhooks here as they can't be accessed like guild members.
|
||||||
self.sync_profile(message.author, mxid)
|
self.sync_profile(message.author)
|
||||||
|
|
||||||
return mxid, room_id
|
return mxid, room_id
|
||||||
|
|
||||||
|
def cache_emotes(self, emotes: List[discord.Emote]):
|
||||||
|
# TODO maybe "namespace" emotes by guild in the cache ?
|
||||||
|
with Cache.lock:
|
||||||
|
for emote in emotes:
|
||||||
|
Cache.cache["d_emotes"][emote.name] = (
|
||||||
|
f"<{'a' if emote.animated else ''}:"
|
||||||
|
f"{emote.name}:{emote.id}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_guild_create(self, guild: discord.Guild) -> None:
|
||||||
|
for member in guild.members:
|
||||||
|
self.sync_profile(member)
|
||||||
|
|
||||||
|
self.cache_emotes(guild.emojis)
|
||||||
|
|
||||||
|
def on_guild_emojis_update(
|
||||||
|
self, update: discord.GuildEmojisUpdate
|
||||||
|
) -> None:
|
||||||
|
self.cache_emotes(update.emojis)
|
||||||
|
|
||||||
|
def on_guild_member_update(
|
||||||
|
self, update: discord.GuildMemberUpdate
|
||||||
|
) -> None:
|
||||||
|
self.sync_profile(update.user)
|
||||||
|
|
||||||
def on_message_create(self, message: discord.Message) -> None:
|
def on_message_create(self, message: discord.Message) -> None:
|
||||||
if self.to_return(message):
|
if self.to_return(message):
|
||||||
return
|
return
|
||||||
|
|
||||||
mxid, room_id = self.wrap(message)
|
mxid, room_id = self.wrap(message)
|
||||||
|
|
||||||
content, emotes = self.process_message(message)
|
content_, emotes = self.process_message(message)
|
||||||
|
|
||||||
content = self.app.create_message_event(
|
content = self.app.create_message_event(
|
||||||
content, emotes, reply=message.reference
|
content_, emotes, reference=message.reference
|
||||||
)
|
)
|
||||||
|
|
||||||
message_cache[message.id] = {
|
with Cache.lock:
|
||||||
"body": content["body"],
|
Cache.cache["d_messages"][message.id] = self.app.send_message(
|
||||||
"event_id": self.app.send_message(room_id, content, mxid),
|
room_id, content, mxid
|
||||||
"mxid": mxid,
|
)
|
||||||
"room_id": room_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
def on_message_delete(self, message: discord.DeletedMessage) -> None:
|
def on_message_delete(self, message: discord.Message) -> None:
|
||||||
event = message_cache.get(message.id)
|
with Cache.lock:
|
||||||
|
event_id = Cache.cache["d_messages"].get(message.id)
|
||||||
|
|
||||||
if not event:
|
if not event_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.app.redact(event["event_id"], event["room_id"], event["mxid"])
|
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
|
||||||
|
event = except_deleted(self.app.get_event)(event_id, room_id)
|
||||||
|
|
||||||
message_cache.pop(message.id)
|
if event:
|
||||||
|
self.app.redact(event.id, event.room_id, event.sender)
|
||||||
|
|
||||||
|
with Cache.lock:
|
||||||
|
del Cache.cache["d_messages"][message.id]
|
||||||
|
|
||||||
def on_message_update(self, message: discord.Message) -> None:
|
def on_message_update(self, message: discord.Message) -> None:
|
||||||
if self.to_return(message):
|
if self.to_return(message):
|
||||||
return
|
return
|
||||||
|
|
||||||
event = message_cache.get(message.id)
|
with Cache.lock:
|
||||||
|
event_id = Cache.cache["d_messages"].get(message.id)
|
||||||
|
|
||||||
if not event:
|
if not event_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
content, emotes = self.process_message(message)
|
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
|
||||||
|
mxid = self.matrixify(message.author.id, user=True)
|
||||||
|
|
||||||
|
content_, emotes = self.process_message(message)
|
||||||
|
|
||||||
content = self.app.create_message_event(
|
content = self.app.create_message_event(
|
||||||
content, emotes, edit=event["event_id"]
|
content_, emotes, edit=event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
self.app.send_message(event["room_id"], content, event["mxid"])
|
self.app.send_message(room_id, content, mxid)
|
||||||
|
|
||||||
def on_typing_start(self, typing: discord.Typing) -> None:
|
def on_typing_start(self, typing: discord.Typing) -> None:
|
||||||
if typing.channel_id not in self.app.db.list_channels():
|
if typing.channel_id not in self.app.db.list_channels():
|
||||||
|
@ -533,7 +585,8 @@ class DiscordClient(Gateway):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Check the cache first.
|
# Check the cache first.
|
||||||
webhook = self.webhook_cache.get(channel_id)
|
with Cache.lock:
|
||||||
|
webhook = Cache.cache["d_webhooks"].get(channel_id)
|
||||||
|
|
||||||
if webhook:
|
if webhook:
|
||||||
return webhook
|
return webhook
|
||||||
|
@ -551,24 +604,26 @@ class DiscordClient(Gateway):
|
||||||
if not webhook:
|
if not webhook:
|
||||||
webhook = self.create_webhook(channel_id, name)
|
webhook = self.create_webhook(channel_id, name)
|
||||||
|
|
||||||
self.webhook_cache[channel_id] = webhook
|
with Cache.lock:
|
||||||
|
Cache.cache["d_webhooks"][channel_id] = webhook
|
||||||
|
|
||||||
return webhook
|
return webhook
|
||||||
|
|
||||||
def process_message(self, message: discord.Message) -> Tuple[str, str]:
|
def process_message(self, message: discord.Message) -> Tuple[str, Dict]:
|
||||||
content = message.content
|
content = message.content
|
||||||
emotes = {}
|
emotes = {}
|
||||||
regex = r"<a?:(\w+):(\d+)>"
|
regex = r"<a?:(\w+):(\d+)>"
|
||||||
|
|
||||||
# Mentions can either be in the form of `<@1234>` or `<@!1234>`.
|
# Mentions can either be in the form of `<@1234>` or `<@!1234>`.
|
||||||
for char in ("", "!"):
|
for member in message.mentions:
|
||||||
for member in message.mentions:
|
for char in ("", "!"):
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
f"<@{char}{member.id}>", f"@{member.username}"
|
f"<@{char}{member.id}>", f"@{member.username}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# `except_deleted` for invalid channels.
|
# `except_deleted` for invalid channels.
|
||||||
for channel in re.findall(r"<#([0-9]+)>", content):
|
# TODO can this block for too long ?
|
||||||
|
for channel in re.findall(r"<#([0-9]{{{discord.ID_LEN}}})>", content):
|
||||||
channel_ = except_deleted(self.get_channel)(channel)
|
channel_ = except_deleted(self.get_channel)(channel)
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
f"<#{channel}>",
|
f"<#{channel}>",
|
||||||
|
@ -613,6 +668,16 @@ def config_gen(basedir: str, config_file: str) -> dict:
|
||||||
return json.loads(f.read())
|
return json.loads(f.read())
|
||||||
|
|
||||||
|
|
||||||
|
def excepthook(exc_type, exc_value, exc_traceback):
|
||||||
|
if issubclass(exc_type, KeyboardInterrupt):
|
||||||
|
sys.__excepthook__(exc_type, exc_value, exc_traceback)
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.critical(
|
||||||
|
"Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
try:
|
try:
|
||||||
basedir = sys.argv[1]
|
basedir = sys.argv[1]
|
||||||
|
@ -634,9 +699,9 @@ def main() -> None:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
http = urllib3.PoolManager(maxsize=10)
|
sys.excepthook = excepthook
|
||||||
|
|
||||||
app = MatrixClient(config, http)
|
app = MatrixClient(config, urllib3.PoolManager(maxsize=10))
|
||||||
|
|
||||||
# Start the bottle app in a separate thread.
|
# Start the bottle app in a separate thread.
|
||||||
app_thread = threading.Thread(
|
app_thread = threading.Thread(
|
||||||
|
@ -645,7 +710,7 @@ def main() -> None:
|
||||||
app_thread.start()
|
app_thread.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
asyncio.run(app.discord.start())
|
asyncio.run(app.discord.run())
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
|
|
@ -2,20 +2,21 @@ from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class User(object):
|
class User:
|
||||||
avatar_url: str = ""
|
avatar_url: str = ""
|
||||||
displayname: str = ""
|
display_name: str = ""
|
||||||
|
|
||||||
|
|
||||||
class Event(object):
|
class Event:
|
||||||
def __init__(self, event: dict):
|
def __init__(self, event: dict):
|
||||||
content = event["content"]
|
content = event.get("content", {})
|
||||||
|
|
||||||
self.attachment = content.get("url")
|
self.attachment = content.get("url")
|
||||||
self.author = event["author"]
|
|
||||||
self.body = content.get("body", "").strip()
|
self.body = content.get("body", "").strip()
|
||||||
self.event_id = event["event_id"]
|
self.formatted_body = content.get("formatted_body", "")
|
||||||
|
self.id = event["event_id"]
|
||||||
self.is_direct = content.get("is_direct", False)
|
self.is_direct = content.get("is_direct", False)
|
||||||
|
self.redacts = event.get("redacts", "")
|
||||||
self.room_id = event["room_id"]
|
self.room_id = event["room_id"]
|
||||||
self.sender = event["sender"]
|
self.sender = event["sender"]
|
||||||
self.state_key = event.get("state_key", "")
|
self.state_key = event.get("state_key", "")
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import fields
|
from dataclasses import fields
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
@ -8,13 +7,13 @@ import urllib3
|
||||||
from errors import RequestError
|
from errors import RequestError
|
||||||
|
|
||||||
|
|
||||||
def dict_cls(dict_var: dict, cls: Any) -> Any:
|
def dict_cls(d: dict, cls: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
Create a dataclass from a dictionary.
|
Create a dataclass from a dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
field_names = set(f.name for f in fields(cls))
|
field_names = set(f.name for f in fields(cls))
|
||||||
filtered_dict = {k: v for k, v in dict_var.items() if k in field_names}
|
filtered_dict = {k: v for k, v in d.items() if k in field_names}
|
||||||
|
|
||||||
return cls(**filtered_dict)
|
return cls(**filtered_dict)
|
||||||
|
|
||||||
|
@ -34,22 +33,6 @@ def log_except(fn):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def wrap_async(fn):
|
|
||||||
"""
|
|
||||||
Call an asynchronous function from a synchronous one.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrapper(self, *args, **kwargs):
|
|
||||||
if not self.loop:
|
|
||||||
raise RuntimeError("loop is None.")
|
|
||||||
|
|
||||||
return asyncio.run_coroutine_threadsafe(
|
|
||||||
fn(self, *args, **kwargs), loop=self.loop
|
|
||||||
).result()
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def request(fn):
|
def request(fn):
|
||||||
"""
|
"""
|
||||||
Either return json data or raise a `RequestError` if the request was
|
Either return json data or raise a `RequestError` if the request was
|
||||||
|
@ -75,8 +58,7 @@ def request(fn):
|
||||||
|
|
||||||
def except_deleted(fn):
|
def except_deleted(fn):
|
||||||
"""
|
"""
|
||||||
Ignore the `RequestError` on 404s, the message might have been
|
Ignore the `RequestError` on 404s, the content might have been removed.
|
||||||
deleted by someone else already.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
bottle==0.12.19
|
bottle
|
||||||
urllib3==1.26.3
|
urllib3
|
||||||
websockets==8.1
|
websockets
|
||||||
|
|
Loading…
Reference in a new issue