diff --git a/src/aioappsrv/appservice.py b/src/aioappsrv/appservice.py index 99265a1..8dec182 100644 --- a/src/aioappsrv/appservice.py +++ b/src/aioappsrv/appservice.py @@ -13,191 +13,191 @@ from misc import log_except, request class AppService(bottle.Bottle): - def __init__(self, config: dict, http: urllib3.PoolManager) -> None: - super(AppService, self).__init__() + def __init__(self, config: dict, http: urllib3.PoolManager) -> None: + super(AppService, self).__init__() - self.as_token = config["as_token"] - self.hs_token = config["hs_token"] - self.base_url = config["homeserver"] - self.server_name = config["server_name"] - self.user_id = f"@{config['user_id']}:{self.server_name}" - self.http = http - self.logger = logging.getLogger("appservice") + self.as_token = config["as_token"] + self.hs_token = config["hs_token"] + self.base_url = config["homeserver"] + self.server_name = config["server_name"] + self.user_id = f"@{config['user_id']}:{self.server_name}" + self.http = http + self.logger = logging.getLogger("appservice") - # Map events to functions. - self.mapping = { - "m.room.member": "on_member", - "m.room.message": "on_message", - "m.room.redaction": "on_redaction", - } + # Map events to functions. + self.mapping = { + "m.room.member": "on_member", + "m.room.message": "on_message", + "m.room.redaction": "on_redaction", + } - # Add route for bottle. - self.route( - "/transactions/", - callback=self.receive_event, - method="PUT", - ) + # Add route for bottle. + self.route( + "/transactions/", + callback=self.receive_event, + method="PUT", + ) - Cache.cache["m_rooms"] = {} + Cache.cache["m_rooms"] = {} - def handle_event(self, event: dict) -> None: - event_type = event.get("type") + def handle_event(self, event: dict) -> None: + event_type = event.get("type") - if event_type in ( - "m.room.member", - "m.room.message", - "m.room.redaction", - ): - obj = matrix.Event(event) - else: - self.logger.info(f"Unknown event type: {event_type}") - return + if event_type in ( + "m.room.member", + "m.room.message", + "m.room.redaction", + ): + obj = matrix.Event(event) + else: + self.logger.info(f"Unknown event type: {event_type}") + return - func = getattr(self, self.mapping[event_type], None) + func = getattr(self, self.mapping[event_type], None) - if not func: - self.logger.warning( - f"Function '{func}' not defined, ignoring event." - ) - return + if not func: + self.logger.warning( + f"Function '{func}' not defined, ignoring event." + ) + return - # We don't catch exceptions here as the homeserver will re-send us - # the event in case of a failure. - func(obj) + # We don't catch exceptions here as the homeserver will re-send us + # the event in case of a failure. + func(obj) - @log_except - def receive_event(self, transaction: str) -> dict: - """ - Verify the homeserver's token and handle events. - """ + @log_except + def receive_event(self, transaction: str) -> dict: + """ + Verify the homeserver's token and handle events. + """ - hs_token = bottle.request.query.getone("access_token") + hs_token = bottle.request.query.getone("access_token") - if not hs_token: - bottle.response.status = 401 - return {"errcode": "APPSERVICE_UNAUTHORIZED"} + if not hs_token: + bottle.response.status = 401 + return {"errcode": "APPSERVICE_UNAUTHORIZED"} - if hs_token != self.hs_token: - bottle.response.status = 403 - return {"errcode": "APPSERVICE_FORBIDDEN"} + if hs_token != self.hs_token: + bottle.response.status = 403 + return {"errcode": "APPSERVICE_FORBIDDEN"} - events = bottle.request.json.get("events") + events = bottle.request.json.get("events") - for event in events: - self.handle_event(event) + for event in events: + self.handle_event(event) - return {} + return {} - def mxc_url(self, mxc: str) -> str: - try: - homeserver, media_id = mxc.replace("mxc://", "").split("/") - except ValueError: - return "" + def mxc_url(self, mxc: str) -> str: + try: + homeserver, media_id = mxc.replace("mxc://", "").split("/") + except ValueError: + return "" - return ( - f"https://{self.server_name}/_matrix/media/r0/download/" - f"{homeserver}/{media_id}" - ) + return ( + f"https://{self.server_name}/_matrix/media/r0/download/" + f"{homeserver}/{media_id}" + ) - def join_room(self, room_id: str, mxid: str = "") -> None: - self.send( - "POST", - f"/join/{room_id}", - params={"user_id": mxid} if mxid else {}, - ) + def join_room(self, room_id: str, mxid: str = "") -> None: + self.send( + "POST", + f"/join/{room_id}", + params={"user_id": mxid} if mxid else {}, + ) - def redact(self, event_id: str, room_id: str, mxid: str = "") -> None: - self.send( - "PUT", - f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}", - params={"user_id": mxid} if mxid else {}, - ) + def redact(self, event_id: str, room_id: str, mxid: str = "") -> None: + self.send( + "PUT", + f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}", + params={"user_id": mxid} if mxid else {}, + ) - def get_room_id(self, alias: str) -> str: - with Cache.lock: - room = Cache.cache["m_rooms"].get(alias) - if room: - return room + def get_room_id(self, alias: str) -> str: + with Cache.lock: + room = Cache.cache["m_rooms"].get(alias) + if room: + return room - resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}") + resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}") - room_id = resp["room_id"] + room_id = resp["room_id"] - with Cache.lock: - Cache.cache["m_rooms"][alias] = room_id + with Cache.lock: + Cache.cache["m_rooms"][alias] = room_id - return room_id + return room_id - def get_event(self, event_id: str, room_id: str) -> matrix.Event: - resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}") + def get_event(self, event_id: str, room_id: str) -> matrix.Event: + resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}") - return matrix.Event(resp) + return matrix.Event(resp) - def upload(self, url: str) -> str: - """ - Upload a file to the homeserver and get the MXC url. - """ + def upload(self, url: str) -> str: + """ + Upload a file to the homeserver and get the MXC url. + """ - resp = self.http.request("GET", url) + resp = self.http.request("GET", url) - resp = self.send( - "POST", - content=resp.data, - content_type=resp.headers.get("Content-Type"), - params={"filename": f"{uuid.uuid4()}"}, - endpoint="/_matrix/media/r0/upload", - ) + resp = self.send( + "POST", + content=resp.data, + content_type=resp.headers.get("Content-Type"), + params={"filename": f"{uuid.uuid4()}"}, + endpoint="/_matrix/media/r0/upload", + ) - return resp["content_uri"] + return resp["content_uri"] - def send_message( - self, - room_id: str, - content: dict, - mxid: str = "", - ) -> str: - resp = self.send( - "PUT", - f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}", - content, - {"user_id": mxid} if mxid else {}, - ) + def send_message( + self, + room_id: str, + content: dict, + mxid: str = "", + ) -> str: + resp = self.send( + "PUT", + f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}", + content, + {"user_id": mxid} if mxid else {}, + ) - return resp["event_id"] + return resp["event_id"] - def send_typing( - self, room_id: str, mxid: str = "", timeout: int = 8000 - ) -> None: - self.send( - "PUT", - f"/rooms/{room_id}/typing/{mxid}", - {"typing": True, "timeout": timeout}, - {"user_id": mxid} if mxid else {}, - ) + def send_typing( + self, room_id: str, mxid: str = "", timeout: int = 8000 + ) -> None: + self.send( + "PUT", + f"/rooms/{room_id}/typing/{mxid}", + {"typing": True, "timeout": timeout}, + {"user_id": mxid} if mxid else {}, + ) - def send_invite(self, room_id: str, mxid: str) -> None: - self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid}) + def send_invite(self, room_id: str, mxid: str) -> None: + self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid}) - @request - def send( - self, - method: str, - path: str = "", - content: Union[bytes, dict] = {}, - params: dict = {}, - content_type: str = "application/json", - endpoint: str = "/_matrix/client/r0", - ) -> dict: - headers = { - "Authorization": f"Bearer {self.as_token}", - "Content-Type": content_type, - } - payload = json.dumps(content) if isinstance(content, dict) else content - endpoint = ( - f"{self.base_url}{endpoint}{path}?" - f"{urllib.parse.urlencode(params)}" - ) + @request + def send( + self, + method: str, + path: str = "", + content: Union[bytes, dict] = {}, + params: dict = {}, + content_type: str = "application/json", + endpoint: str = "/_matrix/client/r0", + ) -> dict: + headers = { + "Authorization": f"Bearer {self.as_token}", + "Content-Type": content_type, + } + payload = json.dumps(content) if isinstance(content, dict) else content + endpoint = ( + f"{self.base_url}{endpoint}{path}?" + f"{urllib.parse.urlencode(params)}" + ) - return self.http.request( - method, endpoint, body=payload, headers=headers - ) + return self.http.request( + method, endpoint, body=payload, headers=headers + ) diff --git a/src/aioappsrv/cache.py b/src/aioappsrv/cache.py index 8d14443..536887b 100644 --- a/src/aioappsrv/cache.py +++ b/src/aioappsrv/cache.py @@ -2,5 +2,5 @@ import threading class Cache: - cache = {} - lock = threading.Lock() + cache = {} + lock = threading.Lock() diff --git a/src/aioappsrv/db.py b/src/aioappsrv/db.py index 7e3c4d9..a699f61 100644 --- a/src/aioappsrv/db.py +++ b/src/aioappsrv/db.py @@ -5,116 +5,116 @@ from typing import List class DataBase: - def __init__(self, db_file) -> None: - self.create(db_file) + def __init__(self, db_file) -> None: + self.create(db_file) - # The database is accessed via multiple threads. - self.lock = threading.Lock() + # The database is accessed via multiple threads. + self.lock = threading.Lock() - def create(self, db_file) -> None: - """ - Create a database with the relevant tables if it doesn't already exist. - """ + def create(self, db_file) -> None: + """ + Create a database with the relevant tables if it doesn't already exist. + """ - exists = os.path.exists(db_file) + exists = os.path.exists(db_file) - self.conn = sqlite3.connect(db_file, check_same_thread=False) - self.conn.row_factory = self.dict_factory + self.conn = sqlite3.connect(db_file, check_same_thread=False) + self.conn.row_factory = self.dict_factory - self.cur = self.conn.cursor() + self.cur = self.conn.cursor() - if exists: - return + if exists: + return - self.cur.execute( - "CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);" - ) + self.cur.execute( + "CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);" + ) - self.cur.execute( - "CREATE TABLE users(mxid TEXT PRIMARY KEY, " - "avatar_url TEXT, username TEXT);" - ) + self.cur.execute( + "CREATE TABLE users(mxid TEXT PRIMARY KEY, " + "avatar_url TEXT, username TEXT);" + ) - self.conn.commit() + self.conn.commit() - def dict_factory(self, cursor, row): - """ - https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory - """ + def dict_factory(self, cursor, row): + """ + https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory + """ - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d + d = {} + for idx, col in enumerate(cursor.description): + d[col[0]] = row[idx] + return d - def add_room(self, room_id: str, channel_id: str) -> None: - """ - Add a bridged room to the database. - """ + def add_room(self, room_id: str, channel_id: str) -> None: + """ + Add a bridged room to the database. + """ - with self.lock: - self.cur.execute( - "INSERT INTO bridge (room_id, channel_id) VALUES (?, ?)", - [room_id, channel_id], - ) - self.conn.commit() + with self.lock: + self.cur.execute( + "INSERT INTO bridge (room_id, channel_id) VALUES (?, ?)", + [room_id, channel_id], + ) + self.conn.commit() - def add_user(self, mxid: str) -> None: - with self.lock: - self.cur.execute("INSERT INTO users (mxid) VALUES (?)", [mxid]) - self.conn.commit() + def add_user(self, mxid: str) -> None: + with self.lock: + self.cur.execute("INSERT INTO users (mxid) VALUES (?)", [mxid]) + self.conn.commit() - def add_avatar(self, avatar_url: str, mxid: str) -> None: - with self.lock: - self.cur.execute( - "UPDATE users SET avatar_url = (?) WHERE mxid = (?)", - [avatar_url, mxid], - ) - self.conn.commit() + def add_avatar(self, avatar_url: str, mxid: str) -> None: + with self.lock: + self.cur.execute( + "UPDATE users SET avatar_url = (?) WHERE mxid = (?)", + [avatar_url, mxid], + ) + self.conn.commit() - def add_username(self, username: str, mxid: str) -> None: - with self.lock: - self.cur.execute( - "UPDATE users SET username = (?) WHERE mxid = (?)", - [username, mxid], - ) - self.conn.commit() + def add_username(self, username: str, mxid: str) -> None: + with self.lock: + self.cur.execute( + "UPDATE users SET username = (?) WHERE mxid = (?)", + [username, mxid], + ) + self.conn.commit() - def get_channel(self, room_id: str) -> str: - """ - Get the corresponding channel ID for a given room ID. - """ + def get_channel(self, room_id: str) -> str: + """ + Get the corresponding channel ID for a given room ID. + """ - with self.lock: - self.cur.execute( - "SELECT channel_id FROM bridge WHERE room_id = ?", [room_id] - ) + with self.lock: + self.cur.execute( + "SELECT channel_id FROM bridge WHERE room_id = ?", [room_id] + ) - room = self.cur.fetchone() + room = self.cur.fetchone() - # Return an empty string if the channel is not bridged. - return "" if not room else room["channel_id"] + # Return an empty string if the channel is not bridged. + return "" if not room else room["channel_id"] - def list_channels(self) -> List[str]: - """ - Get a list of all the bridged channels. - """ + def list_channels(self) -> List[str]: + """ + Get a list of all the bridged channels. + """ - with self.lock: - self.cur.execute("SELECT channel_id FROM bridge") + with self.lock: + self.cur.execute("SELECT channel_id FROM bridge") - channels = self.cur.fetchall() + channels = self.cur.fetchall() - return [channel["channel_id"] for channel in channels] + return [channel["channel_id"] for channel in channels] - def fetch_user(self, mxid: str) -> dict: - """ - Fetch the profile for a bridged user. - """ + def fetch_user(self, mxid: str) -> dict: + """ + Fetch the profile for a bridged user. + """ - with self.lock: - self.cur.execute("SELECT * FROM users where mxid = ?", [mxid]) + with self.lock: + self.cur.execute("SELECT * FROM users where mxid = ?", [mxid]) - user = self.cur.fetchone() + user = self.cur.fetchone() - return {} if not user else user + return {} if not user else user diff --git a/src/aioappsrv/discord.py b/src/aioappsrv/discord.py deleted file mode 100644 index 441559a..0000000 --- a/src/aioappsrv/discord.py +++ /dev/null @@ -1,218 +0,0 @@ -from dataclasses import dataclass - -from misc import dict_cls - -CDN_URL = "https://cdn.discordapp.com" -MESSAGE_LIMIT = 2000 - - -def bitmask(bit: int) -> int: - return 1 << bit - - -@dataclass -class Channel: - id: str - type: str - guild_id: str = "" - name: str = "" - topic: str = "" - - -@dataclass -class Emote: - animated: bool - id: str - name: str - - -@dataclass -class MessageReference: - message_id: str - channel_id: str - guild_id: str - - -@dataclass -class Sticker: - name: str - id: str - format_type: int - - -@dataclass -class Typing: - user_id: str - channel_id: str - - -@dataclass -class Webhook: - id: str - token: str - - -class User: - def __init__(self, user: dict) -> None: - self.discriminator = user["discriminator"] - self.id = user["id"] - self.mention = f"<@{self.id}>" - self.username = user["username"] - - avatar = user["avatar"] - - if not avatar: - # https://discord.com/developers/docs/reference#image-formatting - self.avatar_url = ( - f"{CDN_URL}/embed/avatars/{int(self.discriminator) % 5}.png" - ) - else: - ext = "gif" if avatar.startswith("a_") else "png" - self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}" - - -class Guild: - def __init__(self, guild: dict) -> None: - self.guild_id = guild["id"] - self.channels = [dict_cls(c, Channel) for c in guild["channels"]] - self.emojis = [dict_cls(e, Emote) for e in guild["emojis"]] - members = [member["user"] for member in guild["members"]] - self.members = [User(m) for m in members] - - -class GuildEmojisUpdate: - def __init__(self, update: dict) -> None: - self.guild_id = update["guild_id"] - self.emojis = [dict_cls(e, Emote) for e in update["emojis"]] - - -class GuildMembersChunk: - def __init__(self, chunk: dict) -> None: - self.chunk_index = chunk["chunk_index"] - self.chunk_count = chunk["chunk_count"] - self.guild_id = chunk["guild_id"] - self.members = [User(m) for m in chunk["members"]] - - -class GuildMemberUpdate: - def __init__(self, update: dict) -> None: - self.guild_id = update["guild_id"] - self.user = User(update["user"]) - - -class Message: - def __init__(self, message: dict) -> None: - self.attachments = message.get("attachments", []) - self.channel_id = message["channel_id"] - self.content = message.get("content", "") - self.id = message["id"] - self.guild_id = message.get( - "guild_id", "" - ) # Responses for sending webhook messages don't have guild_id - self.webhook_id = message.get("webhook_id", "") - self.application_id = message.get("application_id", "") - - self.mentions = [ - User(mention) for mention in message.get("mentions", []) - ] - - ref = message.get("referenced_message") - - self.referenced_message = Message(ref) if ref else None - - author = message.get("author") - - self.author = User(author) if author else None - - self.stickers = [ - dict_cls(sticker, Sticker) - for sticker in message.get("sticker_items", []) - ] - - -class ChannelType: - GUILD_TEXT = 0 - DM = 1 - GUILD_VOICE = 2 - GROUP_DM = 3 - GUILD_CATEGORY = 4 - GUILD_NEWS = 5 - GUILD_STORE = 6 - - -class InteractionResponseType: - PONG = 0 - ACKNOWLEDGE = 1 - CHANNEL_MESSAGE = 2 - CHANNEL_MESSAGE_WITH_SOURCE = 4 - ACKNOWLEDGE_WITH_SOURCE = 5 - - -class GatewayIntents: - GUILDS = bitmask(0) - GUILD_MEMBERS = bitmask(1) - GUILD_BANS = bitmask(2) - GUILD_EMOJIS = bitmask(3) - GUILD_INTEGRATIONS = bitmask(4) - GUILD_WEBHOOKS = bitmask(5) - GUILD_INVITES = bitmask(6) - GUILD_VOICE_STATES = bitmask(7) - GUILD_PRESENCES = bitmask(8) - GUILD_MESSAGES = bitmask(9) - GUILD_MESSAGE_REACTIONS = bitmask(10) - GUILD_MESSAGE_TYPING = bitmask(11) - DIRECT_MESSAGES = bitmask(12) - DIRECT_MESSAGE_REACTIONS = bitmask(13) - DIRECT_MESSAGE_TYPING = bitmask(14) - - -class GatewayOpCodes: - DISPATCH = 0 - HEARTBEAT = 1 - IDENTIFY = 2 - PRESENCE_UPDATE = 3 - VOICE_STATE_UPDATE = 4 - RESUME = 6 - RECONNECT = 7 - REQUEST_GUILD_MEMBERS = 8 - INVALID_SESSION = 9 - HELLO = 10 - HEARTBEAT_ACK = 11 - - -class Payloads: - def __init__(self, token: str) -> None: - self.seq = self.session = None - self.token = token - - def HEARTBEAT(self) -> dict: - return {"op": GatewayOpCodes.HEARTBEAT, "d": self.seq} - - def IDENTIFY(self) -> dict: - return { - "op": GatewayOpCodes.IDENTIFY, - "d": { - "token": self.token, - "intents": GatewayIntents.GUILDS - | GatewayIntents.GUILD_EMOJIS - | GatewayIntents.GUILD_MEMBERS - | GatewayIntents.GUILD_MESSAGES - | GatewayIntents.GUILD_MESSAGE_TYPING - | GatewayIntents.GUILD_PRESENCES, - "properties": { - "$os": "discord", - "$browser": "Discord Client", - "$device": "discord", - }, - }, - } - - def RESUME(self) -> dict: - return { - "op": GatewayOpCodes.RESUME, - "d": { - "token": self.token, - "session_id": self.session, - "seq": self.seq, - }, - } diff --git a/src/aioappsrv/errors.py b/src/aioappsrv/errors.py index 4200f50..502b06a 100644 --- a/src/aioappsrv/errors.py +++ b/src/aioappsrv/errors.py @@ -1,5 +1,5 @@ class RequestError(Exception): - def __init__(self, status: int, *args): - super().__init__(*args) + def __init__(self, status: int, *args): + super().__init__(*args) - self.status = status + self.status = status diff --git a/src/aioappsrv/gateway.py b/src/aioappsrv/gateway.py index 8d53e36..16686d4 100644 --- a/src/aioappsrv/gateway.py +++ b/src/aioappsrv/gateway.py @@ -12,249 +12,249 @@ from misc import dict_cls, log_except, request class Gateway: - def __init__(self, http: urllib3.PoolManager, token: str): - self.http = http - self.token = token - self.logger = logging.getLogger("discord") - self.Payloads = discord.Payloads(self.token) - self.websocket = None + def __init__(self, http: urllib3.PoolManager, token: str): + self.http = http + self.token = token + self.logger = logging.getLogger("discord") + self.Payloads = discord.Payloads(self.token) + self.websocket = None - @log_except - async def run(self) -> None: - self.heartbeat_task: asyncio.Future = None - self.resume = False + @log_except + async def run(self) -> None: + self.heartbeat_task: asyncio.Future = None + self.resume = False - gateway_url = self.get_gateway_url() + gateway_url = self.get_gateway_url() - while True: - try: - await self.gateway_handler(gateway_url) - except ( - websockets.ConnectionClosedError, - websockets.InvalidMessage, - ): - self.logger.exception("Connection lost, reconnecting.") + while True: + try: + await self.gateway_handler(gateway_url) + except ( + websockets.ConnectionClosedError, + websockets.InvalidMessage, + ): + self.logger.exception("Connection lost, reconnecting.") - # Stop sending heartbeats until we reconnect. - if self.heartbeat_task and not self.heartbeat_task.cancelled(): - self.heartbeat_task.cancel() + # Stop sending heartbeats until we reconnect. + if self.heartbeat_task and not self.heartbeat_task.cancelled(): + self.heartbeat_task.cancel() - def get_gateway_url(self) -> str: - resp = self.send("GET", "/gateway") + def get_gateway_url(self) -> str: + resp = self.send("GET", "/gateway") - return resp["url"] + return resp["url"] - async def heartbeat_handler(self, interval_ms: int) -> None: - while True: - await asyncio.sleep(interval_ms / 1000) - await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT())) + async def heartbeat_handler(self, interval_ms: int) -> None: + while True: + await asyncio.sleep(interval_ms / 1000) + await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT())) - async def handle_resp(self, data: dict) -> None: - data_dict = data["d"] + async def handle_resp(self, data: dict) -> None: + data_dict = data["d"] - opcode = data["op"] + opcode = data["op"] - seq = data["s"] + seq = data["s"] - if seq: - self.Payloads.seq = seq + if seq: + self.Payloads.seq = seq - if opcode == discord.GatewayOpCodes.DISPATCH: - otype = data["t"] + if opcode == discord.GatewayOpCodes.DISPATCH: + otype = data["t"] - if otype == "READY": - self.Payloads.session = data_dict["session_id"] + if otype == "READY": + self.Payloads.session = data_dict["session_id"] - self.logger.info("READY") - else: - self.handle_otype(data_dict, otype) - elif opcode == discord.GatewayOpCodes.HELLO: - heartbeat_interval = data_dict.get("heartbeat_interval") + self.logger.info("READY") + else: + self.handle_otype(data_dict, otype) + elif opcode == discord.GatewayOpCodes.HELLO: + heartbeat_interval = data_dict.get("heartbeat_interval") - self.logger.info(f"Heartbeat Interval: {heartbeat_interval}") + self.logger.info(f"Heartbeat Interval: {heartbeat_interval}") - # Send periodic hearbeats to gateway. - self.heartbeat_task = asyncio.ensure_future( - self.heartbeat_handler(heartbeat_interval) - ) + # Send periodic hearbeats to gateway. + self.heartbeat_task = asyncio.ensure_future( + self.heartbeat_handler(heartbeat_interval) + ) - await self.websocket.send( - json.dumps( - self.Payloads.RESUME() - if self.resume - else self.Payloads.IDENTIFY() - ) - ) - elif opcode == discord.GatewayOpCodes.RECONNECT: - self.logger.info("Received RECONNECT.") + await self.websocket.send( + json.dumps( + self.Payloads.RESUME() + if self.resume + else self.Payloads.IDENTIFY() + ) + ) + elif opcode == discord.GatewayOpCodes.RECONNECT: + self.logger.info("Received RECONNECT.") - self.resume = True - await self.websocket.close() - elif opcode == discord.GatewayOpCodes.INVALID_SESSION: - self.logger.info("Received INVALID_SESSION.") + self.resume = True + await self.websocket.close() + elif opcode == discord.GatewayOpCodes.INVALID_SESSION: + self.logger.info("Received INVALID_SESSION.") - self.resume = False - await self.websocket.close() - elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK: - # NOP - pass - else: - self.logger.info( - "Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}" - ) + self.resume = False + await self.websocket.close() + elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK: + # NOP + pass + else: + self.logger.info( + "Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}" + ) - def handle_otype(self, data: dict, otype: str) -> None: - if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"): - obj = discord.Message(data) - elif otype == "TYPING_START": - obj = dict_cls(data, discord.Typing) - elif otype == "GUILD_CREATE": - obj = discord.Guild(data) - elif otype == "GUILD_MEMBER_UPDATE": - obj = discord.GuildMemberUpdate(data) - elif otype == "GUILD_EMOJIS_UPDATE": - obj = discord.GuildEmojisUpdate(data) - else: - return + def handle_otype(self, data: dict, otype: str) -> None: + if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"): + obj = discord.Message(data) + elif otype == "TYPING_START": + obj = dict_cls(data, discord.Typing) + elif otype == "GUILD_CREATE": + obj = discord.Guild(data) + elif otype == "GUILD_MEMBER_UPDATE": + obj = discord.GuildMemberUpdate(data) + elif otype == "GUILD_EMOJIS_UPDATE": + obj = discord.GuildEmojisUpdate(data) + else: + return - func = getattr(self, f"on_{otype.lower()}", None) + func = getattr(self, f"on_{otype.lower()}", None) - if not func: - self.logger.warning( - f"Function '{func}' not defined, ignoring message." - ) - return + if not func: + self.logger.warning( + f"Function '{func}' not defined, ignoring message." + ) + return - try: - func(obj) - except Exception: - self.logger.exception(f"Ignoring exception in '{func.__name__}':") + try: + func(obj) + except Exception: + self.logger.exception(f"Ignoring exception in '{func.__name__}':") - async def gateway_handler(self, gateway_url: str) -> None: - async with websockets.connect( - f"{gateway_url}/?v=8&encoding=json" - ) as websocket: - self.websocket = websocket + async def gateway_handler(self, gateway_url: str) -> None: + async with websockets.connect( + f"{gateway_url}/?v=8&encoding=json" + ) as websocket: + self.websocket = websocket - async for message in websocket: - await self.handle_resp(json.loads(message)) + async for message in websocket: + await self.handle_resp(json.loads(message)) - def get_channel(self, channel_id: str) -> discord.Channel: - """ - Get the channel for a given channel ID. - """ + def get_channel(self, channel_id: str) -> discord.Channel: + """ + Get the channel for a given channel ID. + """ - resp = self.send("GET", f"/channels/{channel_id}") + resp = self.send("GET", f"/channels/{channel_id}") - return dict_cls(resp, discord.Channel) + return dict_cls(resp, discord.Channel) - def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]: - """ - Get all channels for a given guild ID. - """ + def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]: + """ + Get all channels for a given guild ID. + """ - resp = self.send("GET", f"/guilds/{guild_id}/channels") + resp = self.send("GET", f"/guilds/{guild_id}/channels") - return { - channel["id"]: dict_cls(channel, discord.Channel) - for channel in resp - } + return { + channel["id"]: dict_cls(channel, discord.Channel) + for channel in resp + } - def get_emotes(self, guild_id: str) -> List[discord.Emote]: - """ - Get all the emotes for a given guild. - """ + def get_emotes(self, guild_id: str) -> List[discord.Emote]: + """ + Get all the emotes for a given guild. + """ - resp = self.send("GET", f"/guilds/{guild_id}/emojis") + resp = self.send("GET", f"/guilds/{guild_id}/emojis") - return [dict_cls(emote, discord.Emote) for emote in resp] + return [dict_cls(emote, discord.Emote) for emote in resp] - def get_members(self, guild_id: str) -> List[discord.User]: - """ - Get all the members for a given guild. - """ + def get_members(self, guild_id: str) -> List[discord.User]: + """ + Get all the members for a given guild. + """ - resp = self.send( - "GET", f"/guilds/{guild_id}/members", params={"limit": 1000} - ) + resp = self.send( + "GET", f"/guilds/{guild_id}/members", params={"limit": 1000} + ) - return [discord.User(member["user"]) for member in resp] + return [discord.User(member["user"]) for member in resp] - def create_webhook(self, channel_id: str, name: str) -> discord.Webhook: - """ - Create a webhook with the specified name in a given channel. - """ + def create_webhook(self, channel_id: str, name: str) -> discord.Webhook: + """ + Create a webhook with the specified name in a given channel. + """ - resp = self.send( - "POST", f"/channels/{channel_id}/webhooks", {"name": name} - ) + resp = self.send( + "POST", f"/channels/{channel_id}/webhooks", {"name": name} + ) - return dict_cls(resp, discord.Webhook) + return dict_cls(resp, discord.Webhook) - def edit_webhook( - self, content: str, message_id: str, webhook: discord.Webhook - ) -> None: - self.send( - "PATCH", - f"/webhooks/{webhook.id}/{webhook.token}/messages/" - f"{message_id}", - {"content": content}, - ) + def edit_webhook( + self, content: str, message_id: str, webhook: discord.Webhook + ) -> None: + self.send( + "PATCH", + f"/webhooks/{webhook.id}/{webhook.token}/messages/" + f"{message_id}", + {"content": content}, + ) - def delete_webhook( - self, message_id: str, webhook: discord.Webhook - ) -> None: - self.send( - "DELETE", - f"/webhooks/{webhook.id}/{webhook.token}/messages/" - f"{message_id}", - ) + def delete_webhook( + self, message_id: str, webhook: discord.Webhook + ) -> None: + self.send( + "DELETE", + f"/webhooks/{webhook.id}/{webhook.token}/messages/" + f"{message_id}", + ) - def send_webhook( - self, - webhook: discord.Webhook, - avatar_url: str, - content: str, - username: str, - ) -> discord.Message: - payload = { - "avatar_url": avatar_url, - "content": content, - "username": username, - # Disable 'everyone' and 'role' mentions. - "allowed_mentions": {"parse": ["users"]}, - } + def send_webhook( + self, + webhook: discord.Webhook, + avatar_url: str, + content: str, + username: str, + ) -> discord.Message: + payload = { + "avatar_url": avatar_url, + "content": content, + "username": username, + # Disable 'everyone' and 'role' mentions. + "allowed_mentions": {"parse": ["users"]}, + } - resp = self.send( - "POST", - f"/webhooks/{webhook.id}/{webhook.token}", - payload, - {"wait": True}, - ) + resp = self.send( + "POST", + f"/webhooks/{webhook.id}/{webhook.token}", + payload, + {"wait": True}, + ) - return discord.Message(resp) + return discord.Message(resp) - def send_message(self, message: str, channel_id: str) -> None: - self.send( - "POST", f"/channels/{channel_id}/messages", {"content": message} - ) + def send_message(self, message: str, channel_id: str) -> None: + self.send( + "POST", f"/channels/{channel_id}/messages", {"content": message} + ) - @request - def send( - self, method: str, path: str, content: dict = {}, params: dict = {} - ) -> dict: - endpoint = ( - f"https://discord.com/api/v8{path}?" - f"{urllib.parse.urlencode(params)}" - ) - headers = { - "Authorization": f"Bot {self.token}", - "Content-Type": "application/json", - } + @request + def send( + self, method: str, path: str, content: dict = {}, params: dict = {} + ) -> dict: + endpoint = ( + f"https://discord.com/api/v8{path}?" + f"{urllib.parse.urlencode(params)}" + ) + headers = { + "Authorization": f"Bot {self.token}", + "Content-Type": "application/json", + } - # 'body' being an empty dict breaks "GET" requests. - payload = json.dumps(content) if content else None + # 'body' being an empty dict breaks "GET" requests. + payload = json.dumps(content) if content else None - return self.http.request( - method, endpoint, body=payload, headers=headers - ) + return self.http.request( + method, endpoint, body=payload, headers=headers + ) diff --git a/src/aioappsrv/main.py b/src/aioappsrv/main.py index 4c635fe..fb9fe75 100644 --- a/src/aioappsrv/main.py +++ b/src/aioappsrv/main.py @@ -22,779 +22,779 @@ from misc import dict_cls, except_deleted, hash_str class MatrixClient(AppService): - def __init__(self, config: dict, http: urllib3.PoolManager) -> None: - super().__init__(config, http) + def __init__(self, config: dict, http: urllib3.PoolManager) -> None: + super().__init__(config, http) - self.db = DataBase(config["database"]) - self.discord = DiscordClient(self, config, http) - self.format = "_discord_" # "{@,#}_discord_1234:localhost" - self.id_regex = "[0-9]+" # Snowflakes may have variable length + self.db = DataBase(config["database"]) + self.discord = DiscordClient(self, config, http) + self.format = "_discord_" # "{@,#}_discord_1234:localhost" + self.id_regex = "[0-9]+" # Snowflakes may have variable length - # TODO Find a cleaner way to use these keys. - for k in ("m_emotes", "m_members", "m_messages"): - Cache.cache[k] = {} + # TODO Find a cleaner way to use these keys. + for k in ("m_emotes", "m_members", "m_messages"): + Cache.cache[k] = {} - def handle_bridge(self, message: matrix.Event) -> None: - # Ignore events that aren't for us. - if message.sender.split(":")[ - -1 - ] != self.server_name or not message.body.startswith("!bridge"): - return + def handle_bridge(self, message: matrix.Event) -> None: + # Ignore events that aren't for us. + if message.sender.split(":")[ + -1 + ] != self.server_name or not message.body.startswith("!bridge"): + return - # Get the channel ID. - try: - channel = message.body.split()[1] - except IndexError: - return + # Get the channel ID. + try: + channel = message.body.split()[1] + except IndexError: + return - # Check if the given channel is valid. - try: - channel = self.discord.get_channel(channel) - except RequestError as e: - # The channel can be invalid or we may not have permissions. - self.logger.warning(f"Failed to fetch channel {channel}: {e}") - return + # Check if the given channel is valid. + try: + channel = self.discord.get_channel(channel) + except RequestError as e: + # The channel can be invalid or we may not have permissions. + self.logger.warning(f"Failed to fetch channel {channel}: {e}") + return - if ( - channel.type != discord.ChannelType.GUILD_TEXT - or channel.id in self.db.list_channels() - ): - return + if ( + channel.type != discord.ChannelType.GUILD_TEXT + or channel.id in self.db.list_channels() + ): + return - self.logger.info(f"Creating bridged room for channel {channel.id}.") + self.logger.info(f"Creating bridged room for channel {channel.id}.") - self.create_room(channel, message.sender) + self.create_room(channel, message.sender) - def on_member(self, event: matrix.Event) -> None: - with Cache.lock: - # Just lazily clear the whole member cache on - # membership update events. - if event.room_id in Cache.cache["m_members"]: - self.logger.info( - f"Clearing member cache for room '{event.room_id}'." - ) - del Cache.cache["m_members"][event.room_id] + def on_member(self, event: matrix.Event) -> None: + with Cache.lock: + # Just lazily clear the whole member cache on + # membership update events. + if event.room_id in Cache.cache["m_members"]: + self.logger.info( + f"Clearing member cache for room '{event.room_id}'." + ) + del Cache.cache["m_members"][event.room_id] - if ( - event.sender.split(":")[-1] != self.server_name - or event.state_key != self.user_id - or not event.is_direct - ): - return + if ( + event.sender.split(":")[-1] != self.server_name + or event.state_key != self.user_id + or not event.is_direct + ): + return - # Join the direct message room. - self.logger.info(f"Joining direct message room '{event.room_id}'.") - self.join_room(event.room_id) + # Join the direct message room. + self.logger.info(f"Joining direct message room '{event.room_id}'.") + self.join_room(event.room_id) - def on_message(self, message: matrix.Event) -> None: - if ( - message.sender.startswith((f"@{self.format}", self.user_id)) - or not message.body - ): - return + def on_message(self, message: matrix.Event) -> None: + if ( + message.sender.startswith((f"@{self.format}", self.user_id)) + or not message.body + ): + return - # Handle bridging commands. - self.handle_bridge(message) + # Handle bridging commands. + self.handle_bridge(message) - channel_id = self.db.get_channel(message.room_id) + channel_id = self.db.get_channel(message.room_id) - if not channel_id: - return + if not channel_id: + return - author = self.get_members(message.room_id)[message.sender] + author = self.get_members(message.room_id)[message.sender] - webhook = self.discord.get_webhook( - channel_id, self.discord.webhook_name - ) + webhook = self.discord.get_webhook( + channel_id, self.discord.webhook_name + ) - if message.relates_to and message.reltype == "m.replace": - with Cache.lock: - message_id = Cache.cache["m_messages"].get(message.relates_to) + if message.relates_to and message.reltype == "m.replace": + with Cache.lock: + message_id = Cache.cache["m_messages"].get(message.relates_to) - # TODO validate if the original author sent the edit. + # TODO validate if the original author sent the edit. - if not message_id or not message.new_body: - return + if not message_id or not message.new_body: + return - message.new_body = self.process_message(message) + message.new_body = self.process_message(message) - except_deleted(self.discord.edit_webhook)( - message.new_body, message_id, webhook - ) - else: - message.body = ( - f"`{message.body}`: {self.mxc_url(message.attachment)}" - if message.attachment - else self.process_message(message) - ) + except_deleted(self.discord.edit_webhook)( + message.new_body, message_id, webhook + ) + else: + message.body = ( + f"`{message.body}`: {self.mxc_url(message.attachment)}" + if message.attachment + else self.process_message(message) + ) - message_id = self.discord.send_webhook( - webhook, - self.mxc_url(author.avatar_url) if author.avatar_url else None, - message.body, - author.display_name if author.display_name else message.sender, - ).id + message_id = self.discord.send_webhook( + webhook, + self.mxc_url(author.avatar_url) if author.avatar_url else None, + message.body, + author.display_name if author.display_name else message.sender, + ).id - with Cache.lock: - Cache.cache["m_messages"][message.id] = message_id + with Cache.lock: + Cache.cache["m_messages"][message.id] = message_id - def on_redaction(self, event: matrix.Event) -> None: - with Cache.lock: - message_id = Cache.cache["m_messages"].get(event.redacts) + def on_redaction(self, event: matrix.Event) -> None: + with Cache.lock: + message_id = Cache.cache["m_messages"].get(event.redacts) - if not message_id: - return + if not message_id: + return - webhook = self.discord.get_webhook( - self.db.get_channel(event.room_id), self.discord.webhook_name - ) + webhook = self.discord.get_webhook( + self.db.get_channel(event.room_id), self.discord.webhook_name + ) - except_deleted(self.discord.delete_webhook)(message_id, webhook) + except_deleted(self.discord.delete_webhook)(message_id, webhook) - with Cache.lock: - del Cache.cache["m_messages"][event.redacts] + with Cache.lock: + del Cache.cache["m_messages"][event.redacts] - def get_members(self, room_id: str) -> Dict[str, matrix.User]: - with Cache.lock: - cached = Cache.cache["m_members"].get(room_id) + def get_members(self, room_id: str) -> Dict[str, matrix.User]: + with Cache.lock: + cached = Cache.cache["m_members"].get(room_id) - if cached: - return cached + if cached: + return cached - resp = self.send("GET", f"/rooms/{room_id}/joined_members") + resp = self.send("GET", f"/rooms/{room_id}/joined_members") - joined = resp["joined"] + joined = resp["joined"] - for k, v in joined.items(): - joined[k] = dict_cls(v, matrix.User) + for k, v in joined.items(): + joined[k] = dict_cls(v, matrix.User) - with Cache.lock: - Cache.cache["m_members"][room_id] = joined + with Cache.lock: + Cache.cache["m_members"][room_id] = joined - return joined + return joined - def create_room(self, channel: discord.Channel, sender: str) -> None: - """ - Create a bridged room and invite the person who invoked the command. - """ + def create_room(self, channel: discord.Channel, sender: str) -> None: + """ + Create a bridged room and invite the person who invoked the command. + """ - content = { - "room_alias_name": f"{self.format}{channel.id}", - "name": channel.name, - "topic": channel.topic if channel.topic else channel.name, - "visibility": "private", - "invite": [sender], - "creation_content": {"m.federate": True}, - "initial_state": [ - { - "type": "m.room.join_rules", - "content": {"join_rule": "public"}, - }, - { - "type": "m.room.history_visibility", - "content": {"history_visibility": "shared"}, - }, - ], - "power_level_content_override": { - "users": {sender: 100, self.user_id: 100} - }, - } + content = { + "room_alias_name": f"{self.format}{channel.id}", + "name": channel.name, + "topic": channel.topic if channel.topic else channel.name, + "visibility": "private", + "invite": [sender], + "creation_content": {"m.federate": True}, + "initial_state": [ + { + "type": "m.room.join_rules", + "content": {"join_rule": "public"}, + }, + { + "type": "m.room.history_visibility", + "content": {"history_visibility": "shared"}, + }, + ], + "power_level_content_override": { + "users": {sender: 100, self.user_id: 100} + }, + } - resp = self.send("POST", "/createRoom", content) + resp = self.send("POST", "/createRoom", content) - self.db.add_room(resp["room_id"], channel.id) + self.db.add_room(resp["room_id"], channel.id) - def create_message_event( - self, - message: str, - emotes: dict, - edit: str = "", - reference: discord.Message = None, - ) -> dict: - content = { - "body": message, - "msgtype": "m.text", - } + def create_message_event( + self, + message: str, + emotes: dict, + edit: str = "", + reference: discord.Message = None, + ) -> dict: + content = { + "body": message, + "msgtype": "m.text", + } - fmt = self.get_fmt(message, emotes) + fmt = self.get_fmt(message, emotes) - if fmt != message: - content = { - **content, - "format": "org.matrix.custom.html", - "formatted_body": fmt, - } + if fmt != message: + content = { + **content, + "format": "org.matrix.custom.html", + "formatted_body": fmt, + } - ref_id = None + ref_id = None - if reference: - # Reply to a Discord message. - with Cache.lock: - ref_id = Cache.cache["d_messages"].get(reference.id) + if reference: + # Reply to a Discord message. + with Cache.lock: + ref_id = Cache.cache["d_messages"].get(reference.id) - # Reply to a Matrix message. (maybe) - if not ref_id: - with Cache.lock: - ref_id = [ - k - for k, v in Cache.cache["m_messages"].items() - if v == reference.id - ] - ref_id = next(iter(ref_id), "") + # Reply to a Matrix message. (maybe) + if not ref_id: + with Cache.lock: + ref_id = [ + k + for k, v in Cache.cache["m_messages"].items() + if v == reference.id + ] + ref_id = next(iter(ref_id), "") - if ref_id: - event = except_deleted(self.get_event)( - ref_id, - self.get_room_id(self.discord.matrixify(reference.channel_id)), - ) - if event: - # Content with the reply fallbacks stripped. - tmp = "" - # We don't want to strip lines starting with "> " after - # encountering a regular line, so we use this variable. - got_fallback = True - for line in event.body.split("\n"): - if not line.startswith("> "): - got_fallback = False - if not got_fallback: - tmp += line + if ref_id: + event = except_deleted(self.get_event)( + ref_id, + self.get_room_id(self.discord.matrixify(reference.channel_id)), + ) + if event: + # Content with the reply fallbacks stripped. + tmp = "" + # We don't want to strip lines starting with "> " after + # encountering a regular line, so we use this variable. + got_fallback = True + for line in event.body.split("\n"): + if not line.startswith("> "): + got_fallback = False + if not got_fallback: + tmp += line - event.body = tmp - event.formatted_body = ( - # re.DOTALL allows the match to span newlines. - re.sub( - "", - "", - event.formatted_body, - flags=re.DOTALL, - ) - if event.formatted_body - else event.body - ) + event.body = tmp + event.formatted_body = ( + # re.DOTALL allows the match to span newlines. + re.sub( + "", + "", + event.formatted_body, + flags=re.DOTALL, + ) + if event.formatted_body + else event.body + ) - content = { - **content, - "body": ( - f"> <{event.sender}> {event.body}\n{content['body']}" - ), - "m.relates_to": {"m.in_reply_to": {"event_id": event.id}}, - "format": "org.matrix.custom.html", - "formatted_body": f"""
\ + content = { + **content, + "body": ( + f"> <{event.sender}> {event.body}\n{content['body']}" + ), + "m.relates_to": {"m.in_reply_to": {"event_id": event.id}}, + "format": "org.matrix.custom.html", + "formatted_body": f"""
\ In reply to\ {event.sender}\
{event.formatted_body if event.formatted_body else event.body}\
\ {content.get("formatted_body", content['body'])}""", - } + } - if edit: - content = { - **content, - "body": f" * {content['body']}", - "formatted_body": f" * {content.get('formatted_body', content['body'])}", - "m.relates_to": {"event_id": edit, "rel_type": "m.replace"}, - "m.new_content": {**content}, - } + if edit: + content = { + **content, + "body": f" * {content['body']}", + "formatted_body": f" * {content.get('formatted_body', content['body'])}", + "m.relates_to": {"event_id": edit, "rel_type": "m.replace"}, + "m.new_content": {**content}, + } - return content + return content - def get_fmt(self, message: str, emotes: dict) -> str: - message = ( - markdown.markdown(message) - .replace("

", "") - .replace("

", "") - .replace("\n", "
") - ) + def get_fmt(self, message: str, emotes: dict) -> str: + message = ( + markdown.markdown(message) + .replace("

", "") + .replace("

", "") + .replace("\n", "
") + ) - # Upload emotes in multiple threads so that we don't - # block the Discord bot for too long. - upload_threads = [ - threading.Thread( - target=self.upload_emote, args=(emote, emotes[emote]) - ) - for emote in emotes - ] + # Upload emotes in multiple threads so that we don't + # block the Discord bot for too long. + upload_threads = [ + threading.Thread( + target=self.upload_emote, args=(emote, emotes[emote]) + ) + for emote in emotes + ] - # Acquire the lock before starting the threads to avoid resource - # contention by tens of threads at once. - with Cache.lock: - for thread in upload_threads: - thread.start() - for thread in upload_threads: - thread.join() + # Acquire the lock before starting the threads to avoid resource + # contention by tens of threads at once. + with Cache.lock: + for thread in upload_threads: + thread.start() + for thread in upload_threads: + thread.join() - with Cache.lock: - for emote in emotes: - emote_ = Cache.cache["m_emotes"].get(emote) + with Cache.lock: + for emote in emotes: + emote_ = Cache.cache["m_emotes"].get(emote) - if emote_: - emote = f":{emote}:" - message = message.replace( - emote, - f"""\"{emote}\"""", - ) + ) - return message + return message - def mention_regex(self, encode: bool, id_as_group: bool) -> str: - mention = "@" - colon = ":" - snowflake = self.id_regex + def mention_regex(self, encode: bool, id_as_group: bool) -> str: + mention = "@" + colon = ":" + snowflake = self.id_regex - if encode: - mention = urllib.parse.quote(mention) - colon = urllib.parse.quote(colon) + if encode: + mention = urllib.parse.quote(mention) + colon = urllib.parse.quote(colon) - if id_as_group: - snowflake = f"({snowflake})" + if id_as_group: + snowflake = f"({snowflake})" - hashed = f"(?:-{snowflake})?" + hashed = f"(?:-{snowflake})?" - return f"{mention}{self.format}{snowflake}{hashed}{colon}{re.escape(self.server_name)}" + return f"{mention}{self.format}{snowflake}{hashed}{colon}{re.escape(self.server_name)}" - def process_message(self, event: matrix.Event) -> str: - message = event.new_body if event.new_body else event.body + def process_message(self, event: matrix.Event) -> str: + message = event.new_body if event.new_body else event.body - emotes = re.findall(r":(\w*):", message) + emotes = re.findall(r":(\w*):", message) - mentions = list( - re.finditer( - self.mention_regex(encode=False, id_as_group=True), - event.formatted_body, - ) - ) - # For clients that properly encode mentions. - # 'https://matrix.to/#/%40_discord_...%3Adomain.tld' - mentions.extend( - re.finditer( - self.mention_regex(encode=True, id_as_group=True), - event.formatted_body, - ) - ) + mentions = list( + re.finditer( + self.mention_regex(encode=False, id_as_group=True), + event.formatted_body, + ) + ) + # For clients that properly encode mentions. + # 'https://matrix.to/#/%40_discord_...%3Adomain.tld' + mentions.extend( + re.finditer( + self.mention_regex(encode=True, id_as_group=True), + event.formatted_body, + ) + ) - with Cache.lock: - for emote in set(emotes): - emote_ = Cache.cache["d_emotes"].get(emote) - if emote_: - message = message.replace(f":{emote}:", emote_) + with Cache.lock: + for emote in set(emotes): + emote_ = Cache.cache["d_emotes"].get(emote) + if emote_: + message = message.replace(f":{emote}:", emote_) - for mention in set(mentions): - # Unquote just in-case we matched an encoded username. - username = self.db.fetch_user( - urllib.parse.unquote(mention.group(0)) - ).get("username") - if username: - if mention.group(2): - # Replace mention with plain text for hashed users (webhooks) - message = message.replace(mention.group(0), f"@{username}") - else: - # Replace the 'mention' so that the user is tagged - # in the case of replies aswell. - # '> <@_discord_1234:localhost> Message' - for replace in (mention.group(0), username): - message = message.replace( - replace, f"<@{mention.group(1)}>" - ) + for mention in set(mentions): + # Unquote just in-case we matched an encoded username. + username = self.db.fetch_user( + urllib.parse.unquote(mention.group(0)) + ).get("username") + if username: + if mention.group(2): + # Replace mention with plain text for hashed users (webhooks) + message = message.replace(mention.group(0), f"@{username}") + else: + # Replace the 'mention' so that the user is tagged + # in the case of replies aswell. + # '> <@_discord_1234:localhost> Message' + for replace in (mention.group(0), username): + message = message.replace( + replace, f"<@{mention.group(1)}>" + ) - # We trim the message later as emotes take up extra characters too. - return message[: discord.MESSAGE_LIMIT] + # We trim the message later as emotes take up extra characters too. + return message[: discord.MESSAGE_LIMIT] - def upload_emote(self, emote_name: str, emote_id: str) -> None: - # There won't be a race condition here, since only a unique - # set of emotes are uploaded at a time. - if emote_name in Cache.cache["m_emotes"]: - return + def upload_emote(self, emote_name: str, emote_id: str) -> None: + # There won't be a race condition here, since only a unique + # set of emotes are uploaded at a time. + if emote_name in Cache.cache["m_emotes"]: + return - emote_url = f"{discord.CDN_URL}/emojis/{emote_id}" + emote_url = f"{discord.CDN_URL}/emojis/{emote_id}" - # We don't want the message to be dropped entirely if an emote - # fails to upload for some reason. - try: - # TODO This is not thread safe, but we're protected by the GIL. - Cache.cache["m_emotes"][emote_name] = self.upload(emote_url) - except RequestError as e: - self.logger.warning(f"Failed to upload emote {emote_id}: {e}") + # We don't want the message to be dropped entirely if an emote + # fails to upload for some reason. + try: + # TODO This is not thread safe, but we're protected by the GIL. + Cache.cache["m_emotes"][emote_name] = self.upload(emote_url) + except RequestError as e: + self.logger.warning(f"Failed to upload emote {emote_id}: {e}") - def register(self, mxid: str) -> None: - """ - Register a dummy user on the homeserver. - """ + def register(self, mxid: str) -> None: + """ + Register a dummy user on the homeserver. + """ - content = { - "type": "m.login.application_service", - # "@test:localhost" -> "test" (Can't register with a full mxid.) - "username": mxid[1:].split(":")[0], - } + content = { + "type": "m.login.application_service", + # "@test:localhost" -> "test" (Can't register with a full mxid.) + "username": mxid[1:].split(":")[0], + } - resp = self.send("POST", "/register", content) + resp = self.send("POST", "/register", content) - self.db.add_user(resp["user_id"]) + self.db.add_user(resp["user_id"]) - def set_avatar(self, avatar_url: str, mxid: str) -> None: - avatar_uri = self.upload(avatar_url) + def set_avatar(self, avatar_url: str, mxid: str) -> None: + avatar_uri = self.upload(avatar_url) - self.send( - "PUT", - f"/profile/{mxid}/avatar_url", - {"avatar_url": avatar_uri}, - params={"user_id": mxid}, - ) + self.send( + "PUT", + f"/profile/{mxid}/avatar_url", + {"avatar_url": avatar_uri}, + params={"user_id": mxid}, + ) - self.db.add_avatar(avatar_url, mxid) + self.db.add_avatar(avatar_url, mxid) - def set_nick(self, username: str, mxid: str) -> None: - self.send( - "PUT", - f"/profile/{mxid}/displayname", - {"displayname": username}, - params={"user_id": mxid}, - ) + def set_nick(self, username: str, mxid: str) -> None: + self.send( + "PUT", + f"/profile/{mxid}/displayname", + {"displayname": username}, + params={"user_id": mxid}, + ) - self.db.add_username(username, mxid) + self.db.add_username(username, mxid) class DiscordClient(Gateway): - def __init__( - self, appservice: MatrixClient, config: dict, http: urllib3.PoolManager - ) -> None: - super().__init__(http, config["discord_token"]) + def __init__( + self, appservice: MatrixClient, config: dict, http: urllib3.PoolManager + ) -> None: + super().__init__(http, config["discord_token"]) - self.app = appservice - self.webhook_name = "matrix_bridge" + self.app = appservice + self.webhook_name = "matrix_bridge" - # TODO Find a cleaner way to use these keys. - for k in ("d_emotes", "d_messages", "d_webhooks"): - Cache.cache[k] = {} + # TODO Find a cleaner way to use these keys. + for k in ("d_emotes", "d_messages", "d_webhooks"): + Cache.cache[k] = {} - def to_return(self, message: discord.Message) -> bool: - with Cache.lock: - hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()] + def to_return(self, message: discord.Message) -> bool: + with Cache.lock: + hook_ids = [hook.id for hook in Cache.cache["d_webhooks"].values()] - return ( - message.channel_id not in self.app.db.list_channels() - or not message.author # Embeds can be weird sometimes. - or message.webhook_id in hook_ids - ) + return ( + message.channel_id not in self.app.db.list_channels() + or not message.author # Embeds can be weird sometimes. + or message.webhook_id in hook_ids + ) - def matrixify(self, id: str, user: bool = False, hashed: str = "") -> str: - return ( - f"{'@' if user else '#'}{self.app.format}" - f"{id}{'-' + hashed if hashed else ''}:" - f"{self.app.server_name}" - ) + def matrixify(self, id: str, user: bool = False, hashed: str = "") -> str: + return ( + f"{'@' if user else '#'}{self.app.format}" + f"{id}{'-' + hashed if hashed else ''}:" + f"{self.app.server_name}" + ) - def sync_profile(self, user: discord.User, hashed: str = "") -> None: - """ - Sync the avatar and username for a puppeted user. - """ + def sync_profile(self, user: discord.User, hashed: str = "") -> None: + """ + Sync the avatar and username for a puppeted user. + """ - mxid = self.matrixify(user.id, user=True, hashed=hashed) + mxid = self.matrixify(user.id, user=True, hashed=hashed) - profile = self.app.db.fetch_user(mxid) + profile = self.app.db.fetch_user(mxid) - # User doesn't exist. - if not profile: - return + # User doesn't exist. + if not profile: + return - username = f"{user.username}#{user.discriminator}" + username = f"{user.username}#{user.discriminator}" - if user.avatar_url != profile["avatar_url"]: - self.logger.info(f"Updating avatar for Discord user '{user.id}'") - self.app.set_avatar(user.avatar_url, mxid) - if username != profile["username"]: - self.logger.info(f"Updating username for Discord user '{user.id}'") - self.app.set_nick(username, mxid) + if user.avatar_url != profile["avatar_url"]: + self.logger.info(f"Updating avatar for Discord user '{user.id}'") + self.app.set_avatar(user.avatar_url, mxid) + if username != profile["username"]: + self.logger.info(f"Updating username for Discord user '{user.id}'") + self.app.set_nick(username, mxid) - def wrap(self, message: discord.Message) -> Tuple[str, str]: - """ - Get the room ID and the puppet's mxid for a given channel ID and a - Discord user. - """ + def wrap(self, message: discord.Message) -> Tuple[str, str]: + """ + Get the room ID and the puppet's mxid for a given channel ID and a + Discord user. + """ - hashed = "" - if message.webhook_id and message.webhook_id != message.application_id: - hashed = str(hash_str(message.author.username)) + hashed = "" + if message.webhook_id and message.webhook_id != message.application_id: + hashed = str(hash_str(message.author.username)) - mxid = self.matrixify(message.author.id, user=True, hashed=hashed) - room_id = self.app.get_room_id(self.matrixify(message.channel_id)) + mxid = self.matrixify(message.author.id, user=True, hashed=hashed) + room_id = self.app.get_room_id(self.matrixify(message.channel_id)) - if not self.app.db.fetch_user(mxid): - self.logger.info( - f"Creating dummy user for Discord user {message.author.id}." - ) - self.app.register(mxid) + if not self.app.db.fetch_user(mxid): + self.logger.info( + f"Creating dummy user for Discord user {message.author.id}." + ) + self.app.register(mxid) - self.app.set_nick( - f"{message.author.username}#" - f"{message.author.discriminator}", - mxid, - ) + self.app.set_nick( + f"{message.author.username}#" + f"{message.author.discriminator}", + mxid, + ) - if message.author.avatar_url: - self.app.set_avatar(message.author.avatar_url, mxid) + if message.author.avatar_url: + self.app.set_avatar(message.author.avatar_url, mxid) - if mxid not in self.app.get_members(room_id): - self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.") + if mxid not in self.app.get_members(room_id): + self.logger.info(f"Inviting user '{mxid}' to room '{room_id}'.") - self.app.send_invite(room_id, mxid) - self.app.join_room(room_id, mxid) + self.app.send_invite(room_id, mxid) + self.app.join_room(room_id, mxid) - if message.webhook_id: - # Sync webhooks here as they can't be accessed like guild members. - self.sync_profile(message.author, hashed=hashed) + if message.webhook_id: + # Sync webhooks here as they can't be accessed like guild members. + self.sync_profile(message.author, hashed=hashed) - return mxid, room_id + return mxid, room_id - def cache_emotes(self, emotes: List[discord.Emote]): - # TODO maybe "namespace" emotes by guild in the cache ? - with Cache.lock: - for emote in emotes: - Cache.cache["d_emotes"][emote.name] = ( - f"<{'a' if emote.animated else ''}:" - f"{emote.name}:{emote.id}>" - ) + def cache_emotes(self, emotes: List[discord.Emote]): + # TODO maybe "namespace" emotes by guild in the cache ? + with Cache.lock: + for emote in emotes: + Cache.cache["d_emotes"][emote.name] = ( + f"<{'a' if emote.animated else ''}:" + f"{emote.name}:{emote.id}>" + ) - def on_guild_create(self, guild: discord.Guild) -> None: - for member in guild.members: - self.sync_profile(member) + def on_guild_create(self, guild: discord.Guild) -> None: + for member in guild.members: + self.sync_profile(member) - self.cache_emotes(guild.emojis) + self.cache_emotes(guild.emojis) - def on_guild_emojis_update( - self, update: discord.GuildEmojisUpdate - ) -> None: - self.cache_emotes(update.emojis) + def on_guild_emojis_update( + self, update: discord.GuildEmojisUpdate + ) -> None: + self.cache_emotes(update.emojis) - def on_guild_member_update( - self, update: discord.GuildMemberUpdate - ) -> None: - self.sync_profile(update.user) + def on_guild_member_update( + self, update: discord.GuildMemberUpdate + ) -> None: + self.sync_profile(update.user) - def on_message_create(self, message: discord.Message) -> None: - if self.to_return(message): - return + def on_message_create(self, message: discord.Message) -> None: + if self.to_return(message): + return - mxid, room_id = self.wrap(message) + mxid, room_id = self.wrap(message) - content_, emotes = self.process_message(message) + content_, emotes = self.process_message(message) - content = self.app.create_message_event( - content_, emotes, reference=message.referenced_message - ) + content = self.app.create_message_event( + content_, emotes, reference=message.referenced_message + ) - with Cache.lock: - Cache.cache["d_messages"][message.id] = self.app.send_message( - room_id, content, mxid - ) + with Cache.lock: + Cache.cache["d_messages"][message.id] = self.app.send_message( + room_id, content, mxid + ) - def on_message_delete(self, message: discord.Message) -> None: - with Cache.lock: - event_id = Cache.cache["d_messages"].get(message.id) + def on_message_delete(self, message: discord.Message) -> None: + with Cache.lock: + event_id = Cache.cache["d_messages"].get(message.id) - if not event_id: - return + if not event_id: + return - room_id = self.app.get_room_id(self.matrixify(message.channel_id)) - event = except_deleted(self.app.get_event)(event_id, room_id) + room_id = self.app.get_room_id(self.matrixify(message.channel_id)) + event = except_deleted(self.app.get_event)(event_id, room_id) - if event: - self.app.redact(event.id, event.room_id, event.sender) + if event: + self.app.redact(event.id, event.room_id, event.sender) - with Cache.lock: - del Cache.cache["d_messages"][message.id] + with Cache.lock: + del Cache.cache["d_messages"][message.id] - def on_message_update(self, message: discord.Message) -> None: - if self.to_return(message): - return + def on_message_update(self, message: discord.Message) -> None: + if self.to_return(message): + return - with Cache.lock: - event_id = Cache.cache["d_messages"].get(message.id) + with Cache.lock: + event_id = Cache.cache["d_messages"].get(message.id) - if not event_id: - return + if not event_id: + return - room_id = self.app.get_room_id(self.matrixify(message.channel_id)) - mxid = self.matrixify(message.author.id, user=True) + room_id = self.app.get_room_id(self.matrixify(message.channel_id)) + mxid = self.matrixify(message.author.id, user=True) - # It is possible that a webhook edit's it's own old message - # after changing it's name, hence we generate a new mxid from - # the hashed username, but that mxid hasn't been registered before, - # so the request fails with: - # M_FORBIDDEN: Application service has not registered this user - if not self.app.db.fetch_user(mxid): - return + # It is possible that a webhook edit's it's own old message + # after changing it's name, hence we generate a new mxid from + # the hashed username, but that mxid hasn't been registered before, + # so the request fails with: + # M_FORBIDDEN: Application service has not registered this user + if not self.app.db.fetch_user(mxid): + return - content_, emotes = self.process_message(message) + content_, emotes = self.process_message(message) - content = self.app.create_message_event( - content_, emotes, edit=event_id - ) + content = self.app.create_message_event( + content_, emotes, edit=event_id + ) - self.app.send_message(room_id, content, mxid) + self.app.send_message(room_id, content, mxid) - def on_typing_start(self, typing: discord.Typing) -> None: - if typing.channel_id not in self.app.db.list_channels(): - return + def on_typing_start(self, typing: discord.Typing) -> None: + if typing.channel_id not in self.app.db.list_channels(): + return - mxid = self.matrixify(typing.user_id, user=True) - room_id = self.app.get_room_id(self.matrixify(typing.channel_id)) + mxid = self.matrixify(typing.user_id, user=True) + room_id = self.app.get_room_id(self.matrixify(typing.channel_id)) - if mxid not in self.app.get_members(room_id): - return + if mxid not in self.app.get_members(room_id): + return - self.app.send_typing(room_id, mxid) + self.app.send_typing(room_id, mxid) - def get_webhook(self, channel_id: str, name: str) -> discord.Webhook: - """ - Get the webhook object for the first webhook that matches the specified - name in a given channel, create the webhook if it doesn't exist. - """ + def get_webhook(self, channel_id: str, name: str) -> discord.Webhook: + """ + Get the webhook object for the first webhook that matches the specified + name in a given channel, create the webhook if it doesn't exist. + """ - # Check the cache first. - with Cache.lock: - webhook = Cache.cache["d_webhooks"].get(channel_id) + # Check the cache first. + with Cache.lock: + webhook = Cache.cache["d_webhooks"].get(channel_id) - if webhook: - return webhook + if webhook: + return webhook - webhooks = self.send("GET", f"/channels/{channel_id}/webhooks") - webhook = next( - ( - dict_cls(webhook, discord.Webhook) - for webhook in webhooks - if webhook["name"] == name - ), - None, - ) + webhooks = self.send("GET", f"/channels/{channel_id}/webhooks") + webhook = next( + ( + dict_cls(webhook, discord.Webhook) + for webhook in webhooks + if webhook["name"] == name + ), + None, + ) - if not webhook: - webhook = self.create_webhook(channel_id, name) + if not webhook: + webhook = self.create_webhook(channel_id, name) - with Cache.lock: - Cache.cache["d_webhooks"][channel_id] = webhook + with Cache.lock: + Cache.cache["d_webhooks"][channel_id] = webhook - return webhook + return webhook - def process_message(self, message: discord.Message) -> Tuple[str, Dict]: - content = message.content - emotes = {} - regex = r"" + def process_message(self, message: discord.Message) -> Tuple[str, Dict]: + content = message.content + emotes = {} + regex = r"" - # Mentions can either be in the form of `<@1234>` or `<@!1234>`. - for member in message.mentions: - for char in ("", "!"): - content = content.replace( - f"<@{char}{member.id}>", f"@{member.username}" - ) + # Mentions can either be in the form of `<@1234>` or `<@!1234>`. + for member in message.mentions: + for char in ("", "!"): + content = content.replace( + f"<@{char}{member.id}>", f"@{member.username}" + ) - # Replace channel IDs with names. - channels = re.findall("<#([0-9]+)>", content) - if channels: - if not message.guild_id: - self.logger.warning( - f"Message '{message.id}' in channel '{message.channel_id}' does not have a guild_id!" - ) - else: - discord_channels = self.get_channels(message.guild_id) - for channel in channels: - discord_channel = discord_channels.get(channel) - name = ( - discord_channel.name - if discord_channel - else "deleted-channel" - ) - content = content.replace(f"<#{channel}>", f"#{name}") + # Replace channel IDs with names. + channels = re.findall("<#([0-9]+)>", content) + if channels: + if not message.guild_id: + self.logger.warning( + f"Message '{message.id}' in channel '{message.channel_id}' does not have a guild_id!" + ) + else: + discord_channels = self.get_channels(message.guild_id) + for channel in channels: + discord_channel = discord_channels.get(channel) + name = ( + discord_channel.name + if discord_channel + else "deleted-channel" + ) + content = content.replace(f"<#{channel}>", f"#{name}") - # { "emote_name": "emote_id" } - for emote in re.findall(regex, content): - emotes[emote[0]] = emote[1] + # { "emote_name": "emote_id" } + for emote in re.findall(regex, content): + emotes[emote[0]] = emote[1] - # Replace emote IDs with names. - content = re.sub(regex, r":\g<1>:", content) + # Replace emote IDs with names. + content = re.sub(regex, r":\g<1>:", content) - # Append attachments to message. - for attachment in message.attachments: - content += f"\n{attachment['url']}" + # Append attachments to message. + for attachment in message.attachments: + content += f"\n{attachment['url']}" - # Append stickers to message. - for sticker in message.stickers: - if sticker.format_type != 3: # 3 == Lottie format. - content += f"\n{discord.CDN_URL}/stickers/{sticker.id}.png" + # Append stickers to message. + for sticker in message.stickers: + if sticker.format_type != 3: # 3 == Lottie format. + content += f"\n{discord.CDN_URL}/stickers/{sticker.id}.png" - return content, emotes + return content, emotes def config_gen(basedir: str, config_file: str) -> dict: - config_file = f"{basedir}/{config_file}" + config_file = f"{basedir}/{config_file}" - config_dict = { - "as_token": "my-secret-as-token", - "hs_token": "my-secret-hs-token", - "user_id": "appservice-discord", - "homeserver": "http://127.0.0.1:8008", - "server_name": "localhost", - "discord_token": "my-secret-discord-token", - "port": 5000, - "database": f"{basedir}/bridge.db", - } + config_dict = { + "as_token": "my-secret-as-token", + "hs_token": "my-secret-hs-token", + "user_id": "appservice-discord", + "homeserver": "http://127.0.0.1:8008", + "server_name": "localhost", + "discord_token": "my-secret-discord-token", + "port": 5000, + "database": f"{basedir}/bridge.db", + } - if not os.path.exists(config_file): - with open(config_file, "w") as f: - json.dump(config_dict, f, indent=4) - print(f"Configuration dumped to '{config_file}'") - sys.exit() + if not os.path.exists(config_file): + with open(config_file, "w") as f: + json.dump(config_dict, f, indent=4) + print(f"Configuration dumped to '{config_file}'") + sys.exit() - with open(config_file, "r") as f: - return json.loads(f.read()) + with open(config_file, "r") as f: + return json.loads(f.read()) def excepthook(exc_type, exc_value, exc_traceback): - if issubclass(exc_type, KeyboardInterrupt): - sys.__excepthook__(exc_type, exc_value, exc_traceback) - return + if issubclass(exc_type, KeyboardInterrupt): + sys.__excepthook__(exc_type, exc_value, exc_traceback) + return - logging.critical( - "Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback) - ) + logging.critical( + "Unknown exception:", exc_info=(exc_type, exc_value, exc_traceback) + ) def main() -> None: - try: - basedir = sys.argv[1] - if not os.path.exists(basedir): - print(f"Path '{basedir}' does not exist!") - sys.exit(1) - basedir = os.path.abspath(basedir) - except IndexError: - basedir = os.getcwd() + try: + basedir = sys.argv[1] + if not os.path.exists(basedir): + print(f"Path '{basedir}' does not exist!") + sys.exit(1) + basedir = os.path.abspath(basedir) + except IndexError: + basedir = os.getcwd() - config = config_gen(basedir, "appservice.json") + config = config_gen(basedir, "appservice.json") - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(name)s:%(levelname)s:%(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - handlers=[ - logging.FileHandler(f"{basedir}/appservice.log"), - ], - ) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(name)s:%(levelname)s:%(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[ + logging.FileHandler(f"{basedir}/appservice.log"), + ], + ) - sys.excepthook = excepthook + sys.excepthook = excepthook - app = MatrixClient(config, urllib3.PoolManager(maxsize=10)) + app = MatrixClient(config, urllib3.PoolManager(maxsize=10)) - # Start the bottle app in a separate thread. - app_thread = threading.Thread( - target=app.run, kwargs={"port": int(config["port"])}, daemon=True - ) - app_thread.start() + # Start the bottle app in a separate thread. + app_thread = threading.Thread( + target=app.run, kwargs={"port": int(config["port"])}, daemon=True + ) + app_thread.start() - try: - asyncio.run(app.discord.run()) - except KeyboardInterrupt: - sys.exit() + try: + asyncio.run(app.discord.run()) + except KeyboardInterrupt: + sys.exit() if __name__ == "__main__": - main() + main() diff --git a/src/aioappsrv/matrix.py b/src/aioappsrv/matrix.py index ab10124..3f3f66f 100644 --- a/src/aioappsrv/matrix.py +++ b/src/aioappsrv/matrix.py @@ -3,26 +3,26 @@ from dataclasses import dataclass @dataclass class User: - avatar_url: str = "" - display_name: str = "" + avatar_url: str = "" + display_name: str = "" class Event: - def __init__(self, event: dict): - content = event.get("content", {}) + def __init__(self, event: dict): + content = event.get("content", {}) - self.attachment = content.get("url") - self.body = content.get("body", "").strip() - self.formatted_body = content.get("formatted_body", "") - self.id = event["event_id"] - self.is_direct = content.get("is_direct", False) - self.redacts = event.get("redacts", "") - self.room_id = event["room_id"] - self.sender = event["sender"] - self.state_key = event.get("state_key", "") + self.attachment = content.get("url") + self.body = content.get("body", "").strip() + self.formatted_body = content.get("formatted_body", "") + self.id = event["event_id"] + self.is_direct = content.get("is_direct", False) + self.redacts = event.get("redacts", "") + self.room_id = event["room_id"] + self.sender = event["sender"] + self.state_key = event.get("state_key", "") - rel = content.get("m.relates_to", {}) + rel = content.get("m.relates_to", {}) - self.relates_to = rel.get("event_id") - self.reltype = rel.get("rel_type") - self.new_body = content.get("m.new_content", {}).get("body", "") + self.relates_to = rel.get("event_id") + self.reltype = rel.get("rel_type") + self.new_body = content.get("m.new_content", {}).get("body", "") diff --git a/src/aioappsrv/misc.py b/src/aioappsrv/misc.py index f2e500d..be77c9d 100644 --- a/src/aioappsrv/misc.py +++ b/src/aioappsrv/misc.py @@ -8,77 +8,77 @@ from errors import RequestError def dict_cls(d: dict, cls: Any) -> Any: - """ - Create a dataclass from a dictionary. - """ + """ + Create a dataclass from a dictionary. + """ - field_names = set(f.name for f in fields(cls)) - filtered_dict = {k: v for k, v in d.items() if k in field_names} + field_names = set(f.name for f in fields(cls)) + filtered_dict = {k: v for k, v in d.items() if k in field_names} - return cls(**filtered_dict) + return cls(**filtered_dict) def log_except(fn): - """ - Log unhandled exceptions to a logger instead of `stderr`. - """ + """ + Log unhandled exceptions to a logger instead of `stderr`. + """ - def wrapper(self, *args, **kwargs): - try: - return fn(self, *args, **kwargs) - except Exception: - self.logger.exception(f"Exception in '{fn.__name__}':") - raise + def wrapper(self, *args, **kwargs): + try: + return fn(self, *args, **kwargs) + except Exception: + self.logger.exception(f"Exception in '{fn.__name__}':") + raise - return wrapper + return wrapper def request(fn): - """ - Either return json data or raise a `RequestError` if the request was - unsuccessful. - """ + """ + Either return json data or raise a `RequestError` if the request was + unsuccessful. + """ - def wrapper(*args, **kwargs): - try: - resp = fn(*args, **kwargs) - except urllib3.exceptions.HTTPError as e: - raise RequestError(None, f"Failed to connect: {e}") from None + def wrapper(*args, **kwargs): + try: + resp = fn(*args, **kwargs) + except urllib3.exceptions.HTTPError as e: + raise RequestError(None, f"Failed to connect: {e}") from None - if resp.status < 200 or resp.status >= 300: - raise RequestError( - resp.status, - f"Failed to get response from '{resp.geturl()}':\n{resp.data}", - ) + if resp.status < 200 or resp.status >= 300: + raise RequestError( + resp.status, + f"Failed to get response from '{resp.geturl()}':\n{resp.data}", + ) - return {} if resp.status == 204 else json.loads(resp.data) + return {} if resp.status == 204 else json.loads(resp.data) - return wrapper + return wrapper def except_deleted(fn): - """ - Ignore the `RequestError` on 404s, the content might have been removed. - """ + """ + Ignore the `RequestError` on 404s, the content might have been removed. + """ - def wrapper(*args, **kwargs): - try: - return fn(*args, **kwargs) - except RequestError as e: - if e.status != 404: - raise + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except RequestError as e: + if e.status != 404: + raise - return wrapper + return wrapper def hash_str(string: str) -> int: - """ - Create the hash for a string - """ + """ + Create the hash for a string + """ - hash = 5381 + hash = 5381 - for ch in string: - hash = ((hash << 5) + hash) + ord(ch) + for ch in string: + hash = ((hash << 5) + hash) + ord(ch) - return hash & 0xFFFFFFFF + return hash & 0xFFFFFFFF