From 78d85a237410578066dae78b729937a53a726682 Mon Sep 17 00:00:00 2001 From: git-bruh Date: Sun, 18 Apr 2021 18:14:18 +0530 Subject: [PATCH] appservice: support messages from other webhooks --- appservice/db.py | 19 +++-------- appservice/main.py | 85 +++++++++++++++++++++++----------------------- appservice/misc.py | 14 ++++++++ 3 files changed, 60 insertions(+), 58 deletions(-) diff --git a/appservice/db.py b/appservice/db.py index 1598260..fcb5033 100644 --- a/appservice/db.py +++ b/appservice/db.py @@ -107,26 +107,15 @@ class DataBase(object): return [channel["channel_id"] for channel in channels] - def list_users(self) -> List[dict]: + def fetch_user(self, mxid: str) -> dict: """ - Get a dictionary of all the puppeted users. + Fetch the profile for a bridged user. """ with self.lock: self.cur.execute("SELECT * FROM users") - users = self.cur.fetchall() - return users + user = [user for user in users if user["mxid"] == mxid] - def query_user(self, mxid: str) -> bool: - """ - Check whether a puppet user has already been created for a given mxid. - """ - - with self.lock: - self.cur.execute("SELECT mxid FROM users") - - users = self.cur.fetchall() - - return next((True for user in users if user["mxid"] == mxid), False) + return user[0] if user else {} diff --git a/appservice/main.py b/appservice/main.py index 8818c45..574287e 100644 --- a/appservice/main.py +++ b/appservice/main.py @@ -15,7 +15,7 @@ from appservice import AppService from db import DataBase from errors import RequestError from gateway import Gateway -from misc import dict_cls, except_deleted +from misc import dict_cls, except_deleted, hash_str # TODO should this be cleared periodically ? message_cache: Dict[str, Union[discord.Webhook, str]] = {} @@ -261,6 +261,7 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""", if mentions: guild_id = self.discord.get_channel(channel_id).guild_id + # TODO this can block for too long if a long list is to be fetched. for mention in mentions: if not mention[1]: continue @@ -344,15 +345,15 @@ class DiscordClient(Gateway): and update if they differ. Also synchronise emotes. """ + # TODO use websocket events and requests. + def sync_emotes(guilds: set): - # We could store the emotes once and update according - # to gateway events but we're too lazy for that. emotes = [] for guild in guilds: [emotes.append(emote) for emote in (self.get_emotes(guild))] - self.emote_cache.clear() # Clears deleted/renamed emotes. + self.emote_cache.clear() # Clear deleted/renamed emotes. for emote in emotes: self.emote_cache[f"{emote.name}"] = ( @@ -361,44 +362,13 @@ class DiscordClient(Gateway): ) def sync_users(guilds: set): - # TODO use websockets for this, using IDs from database. - users = [] for guild in guilds: - [users.append(member) for member in self.get_members(guild)] - - db_users = self.app.db.list_users() - - # Convert a list of dicts: - # [ { "avatar_url": ... } ] - # to a dict that is indexable by Discord IDs: - # { "discord_id": { "avatar_url": ... } } - users_ = {} - - for user in db_users: - users_[user["mxid"].split("_")[-1].split(":")[0]] = {**user} - - for user in users: - user_ = users_.get(user.id) - - if not user_: - continue - - mxid = user_["mxid"] - username = f"{user.username}#{user.discriminator}" - - if user.avatar_url != user_["avatar_url"]: - self.logger.info( - f"Updating avatar for Discord user {user.id}." - ) - self.app.set_avatar(user.avatar_url, mxid) - - if username != user_["username"]: - self.logger.info( - f"Updating username for Discord user {user.id}." - ) - self.app.set_nick(username, mxid) + [ + self.sync_profile(user, self.matrixify(user.id, user=True)) + for user in self.get_members(guild) + ] while True: guilds = set() # Avoid duplicates. @@ -426,8 +396,9 @@ class DiscordClient(Gateway): return ( message.channel_id not in self.app.db.list_channels() or not message.content - or not message.author - or message.author.discriminator == "0000" + or not message.author # Embeds can be weird sometimes. + or message.webhook_id + in [hook.id for hook in self.webhook_cache.values()] ) def matrixify(self, id: str, user: bool = False) -> str: @@ -436,16 +407,40 @@ class DiscordClient(Gateway): f"{self.app.server_name}" ) + def sync_profile(self, user: discord.User, mxid: str) -> None: + """ + Sync the avatar and username for a puppeted user. + """ + + profile = self.app.db.fetch_user(mxid) + + # User doesn't exist. + if not profile: + return + + username = f"{user.username}#{user.discriminator}" + + if user.avatar_url != profile["avatar_url"]: + self.logger.info(f"Updating avatar for Discord user {user.id}") + self.app.set_avatar(user.avatar_url, mxid) + if username != profile["username"]: + self.logger.info(f"Updating username for Discord user {user.id}") + self.app.set_nick(username, mxid) + def wrap(self, message: discord.Message) -> Tuple[str, str]: """ Get the room ID and the puppet's mxid for a given channel ID and a Discord user. """ + if message.webhook_id: + hashed = hash_str(message.author.username) + message.author.id = str(int(message.author.id) + hashed) + mxid = self.matrixify(message.author.id, user=True) room_id = self.app.get_room_id(self.matrixify(message.channel_id)) - if not self.app.db.query_user(mxid): + if not self.app.db.fetch_user(mxid): self.logger.info( f"Creating dummy user for Discord user {message.author.id}." ) @@ -466,6 +461,10 @@ class DiscordClient(Gateway): self.app.send_invite(room_id, mxid) self.app.join_room(room_id, mxid) + if message.webhook_id: + # Sync webhooks here as they can't be accessed like guild members. + self.sync_profile(message.author, mxid) + return mxid, room_id def on_message_create(self, message: discord.Message) -> None: @@ -497,7 +496,7 @@ class DiscordClient(Gateway): message_cache.pop(message.id) - def on_message_update(self, message: dict) -> None: + def on_message_update(self, message: discord.Message) -> None: if self.to_return(message): return diff --git a/appservice/misc.py b/appservice/misc.py index 45f7c67..3befa4a 100644 --- a/appservice/misc.py +++ b/appservice/misc.py @@ -87,3 +87,17 @@ def except_deleted(fn): raise return wrapper + + +def hash_str(string: str) -> int: + """ + Create the hash for a string (poorly). + """ + + hashed = 0 + results = map(ord, string) + + for result in results: + hashed += result + + return hashed