appservice: support messages from other webhooks

This commit is contained in:
git-bruh 2021-04-18 18:14:18 +05:30
parent 7e454a14a5
commit 78d85a2374
No known key found for this signature in database
GPG key ID: E1475C50075ADCE6
3 changed files with 60 additions and 58 deletions

View file

@ -107,26 +107,15 @@ class DataBase(object):
return [channel["channel_id"] for channel in channels] return [channel["channel_id"] for channel in channels]
def list_users(self) -> List[dict]: def fetch_user(self, mxid: str) -> dict:
""" """
Get a dictionary of all the puppeted users. Fetch the profile for a bridged user.
""" """
with self.lock: with self.lock:
self.cur.execute("SELECT * FROM users") self.cur.execute("SELECT * FROM users")
users = self.cur.fetchall() users = self.cur.fetchall()
return users user = [user for user in users if user["mxid"] == mxid]
def query_user(self, mxid: str) -> bool: return user[0] if user else {}
"""
Check whether a puppet user has already been created for a given mxid.
"""
with self.lock:
self.cur.execute("SELECT mxid FROM users")
users = self.cur.fetchall()
return next((True for user in users if user["mxid"] == mxid), False)

View file

@ -15,7 +15,7 @@ from appservice import AppService
from db import DataBase from db import DataBase
from errors import RequestError from errors import RequestError
from gateway import Gateway from gateway import Gateway
from misc import dict_cls, except_deleted from misc import dict_cls, except_deleted, hash_str
# TODO should this be cleared periodically ? # TODO should this be cleared periodically ?
message_cache: Dict[str, Union[discord.Webhook, str]] = {} message_cache: Dict[str, Union[discord.Webhook, str]] = {}
@ -261,6 +261,7 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
if mentions: if mentions:
guild_id = self.discord.get_channel(channel_id).guild_id guild_id = self.discord.get_channel(channel_id).guild_id
# TODO this can block for too long if a long list is to be fetched.
for mention in mentions: for mention in mentions:
if not mention[1]: if not mention[1]:
continue continue
@ -344,15 +345,15 @@ class DiscordClient(Gateway):
and update if they differ. Also synchronise emotes. and update if they differ. Also synchronise emotes.
""" """
# TODO use websocket events and requests.
def sync_emotes(guilds: set): def sync_emotes(guilds: set):
# We could store the emotes once and update according
# to gateway events but we're too lazy for that.
emotes = [] emotes = []
for guild in guilds: for guild in guilds:
[emotes.append(emote) for emote in (self.get_emotes(guild))] [emotes.append(emote) for emote in (self.get_emotes(guild))]
self.emote_cache.clear() # Clears deleted/renamed emotes. self.emote_cache.clear() # Clear deleted/renamed emotes.
for emote in emotes: for emote in emotes:
self.emote_cache[f"{emote.name}"] = ( self.emote_cache[f"{emote.name}"] = (
@ -361,44 +362,13 @@ class DiscordClient(Gateway):
) )
def sync_users(guilds: set): def sync_users(guilds: set):
# TODO use websockets for this, using IDs from database.
users = [] users = []
for guild in guilds: for guild in guilds:
[users.append(member) for member in self.get_members(guild)] [
self.sync_profile(user, self.matrixify(user.id, user=True))
db_users = self.app.db.list_users() for user in self.get_members(guild)
]
# Convert a list of dicts:
# [ { "avatar_url": ... } ]
# to a dict that is indexable by Discord IDs:
# { "discord_id": { "avatar_url": ... } }
users_ = {}
for user in db_users:
users_[user["mxid"].split("_")[-1].split(":")[0]] = {**user}
for user in users:
user_ = users_.get(user.id)
if not user_:
continue
mxid = user_["mxid"]
username = f"{user.username}#{user.discriminator}"
if user.avatar_url != user_["avatar_url"]:
self.logger.info(
f"Updating avatar for Discord user {user.id}."
)
self.app.set_avatar(user.avatar_url, mxid)
if username != user_["username"]:
self.logger.info(
f"Updating username for Discord user {user.id}."
)
self.app.set_nick(username, mxid)
while True: while True:
guilds = set() # Avoid duplicates. guilds = set() # Avoid duplicates.
@ -426,8 +396,9 @@ class DiscordClient(Gateway):
return ( return (
message.channel_id not in self.app.db.list_channels() message.channel_id not in self.app.db.list_channels()
or not message.content or not message.content
or not message.author or not message.author # Embeds can be weird sometimes.
or message.author.discriminator == "0000" or message.webhook_id
in [hook.id for hook in self.webhook_cache.values()]
) )
def matrixify(self, id: str, user: bool = False) -> str: def matrixify(self, id: str, user: bool = False) -> str:
@ -436,16 +407,40 @@ class DiscordClient(Gateway):
f"{self.app.server_name}" f"{self.app.server_name}"
) )
def sync_profile(self, user: discord.User, mxid: str) -> None:
"""
Sync the avatar and username for a puppeted user.
"""
profile = self.app.db.fetch_user(mxid)
# User doesn't exist.
if not profile:
return
username = f"{user.username}#{user.discriminator}"
if user.avatar_url != profile["avatar_url"]:
self.logger.info(f"Updating avatar for Discord user {user.id}")
self.app.set_avatar(user.avatar_url, mxid)
if username != profile["username"]:
self.logger.info(f"Updating username for Discord user {user.id}")
self.app.set_nick(username, mxid)
def wrap(self, message: discord.Message) -> Tuple[str, str]: def wrap(self, message: discord.Message) -> Tuple[str, str]:
""" """
Get the room ID and the puppet's mxid for a given channel ID and a Get the room ID and the puppet's mxid for a given channel ID and a
Discord user. Discord user.
""" """
if message.webhook_id:
hashed = hash_str(message.author.username)
message.author.id = str(int(message.author.id) + hashed)
mxid = self.matrixify(message.author.id, user=True) mxid = self.matrixify(message.author.id, user=True)
room_id = self.app.get_room_id(self.matrixify(message.channel_id)) room_id = self.app.get_room_id(self.matrixify(message.channel_id))
if not self.app.db.query_user(mxid): if not self.app.db.fetch_user(mxid):
self.logger.info( self.logger.info(
f"Creating dummy user for Discord user {message.author.id}." f"Creating dummy user for Discord user {message.author.id}."
) )
@ -466,6 +461,10 @@ class DiscordClient(Gateway):
self.app.send_invite(room_id, mxid) self.app.send_invite(room_id, mxid)
self.app.join_room(room_id, mxid) self.app.join_room(room_id, mxid)
if message.webhook_id:
# Sync webhooks here as they can't be accessed like guild members.
self.sync_profile(message.author, mxid)
return mxid, room_id return mxid, room_id
def on_message_create(self, message: discord.Message) -> None: def on_message_create(self, message: discord.Message) -> None:
@ -497,7 +496,7 @@ class DiscordClient(Gateway):
message_cache.pop(message.id) message_cache.pop(message.id)
def on_message_update(self, message: dict) -> None: def on_message_update(self, message: discord.Message) -> None:
if self.to_return(message): if self.to_return(message):
return return

View file

@ -87,3 +87,17 @@ def except_deleted(fn):
raise raise
return wrapper return wrapper
def hash_str(string: str) -> int:
"""
Create the hash for a string (poorly).
"""
hashed = 0
results = map(ord, string)
for result in results:
hashed += result
return hashed