diff --git a/appservice/discord.py b/appservice/discord.py index 2404d27..441559a 100644 --- a/appservice/discord.py +++ b/appservice/discord.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from misc import dict_cls CDN_URL = "https://cdn.discordapp.com" -ID_LEN = 18 MESSAGE_LIMIT = 2000 @@ -107,6 +106,9 @@ class Message: self.channel_id = message["channel_id"] self.content = message.get("content", "") self.id = message["id"] + self.guild_id = message.get( + "guild_id", "" + ) # Responses for sending webhook messages don't have guild_id self.webhook_id = message.get("webhook_id", "") self.application_id = message.get("application_id", "") diff --git a/appservice/gateway.py b/appservice/gateway.py index 5610d48..8d53e36 100644 --- a/appservice/gateway.py +++ b/appservice/gateway.py @@ -2,7 +2,7 @@ import asyncio import json import logging import urllib.parse -from typing import List +from typing import Dict, List import urllib3 import websockets @@ -148,6 +148,18 @@ class Gateway: return dict_cls(resp, discord.Channel) + def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]: + """ + Get all channels for a given guild ID. + """ + + resp = self.send("GET", f"/guilds/{guild_id}/channels") + + return { + channel["id"]: dict_cls(channel, discord.Channel) + for channel in resp + } + def get_emotes(self, guild_id: str) -> List[discord.Emote]: """ Get all the emotes for a given guild. diff --git a/appservice/main.py b/appservice/main.py index b73a233..daabdfa 100644 --- a/appservice/main.py +++ b/appservice/main.py @@ -5,11 +5,11 @@ import os import re import sys import threading +import urllib.parse from typing import Dict, List, Tuple import markdown import urllib3 -import urllib.parse import discord import matrix @@ -28,7 +28,7 @@ class MatrixClient(AppService): self.db = DataBase(config["database"]) self.discord = DiscordClient(self, config, http) self.format = "_discord_" # "{@,#}_discord_1234:localhost" - self.id_regex = f"[0-9]{{{discord.ID_LEN}}}" + self.id_regex = "[0-9]+" # Snowflakes may have variable length # TODO Find a cleaner way to use these keys. for k in ("m_emotes", "m_members", "m_messages"): @@ -329,28 +329,40 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""", return message - def mention_regex(self, encode: bool) -> str: + def mention_regex(self, encode: bool, id_as_group: bool) -> str: mention = "@" colon = ":" + snowflake = self.id_regex if encode: mention = urllib.parse.quote(mention) colon = urllib.parse.quote(colon) - return f"{mention}{self.format}{self.id_regex}{colon}{re.escape(self.server_name)}" + if id_as_group: + snowflake = f"({snowflake})" + + hashed = f"(?:-{snowflake})?" + + return f"{mention}{self.format}{snowflake}{hashed}{colon}{re.escape(self.server_name)}" def process_message(self, event: matrix.Event) -> str: message = event.new_body if event.new_body else event.body emotes = re.findall(r":(\w*):", message) - mentions = re.findall( - self.mention_regex(encode=False), event.formatted_body + mentions = list( + re.finditer( + self.mention_regex(encode=False, id_as_group=True), + event.formatted_body, + ) ) # For clients that properly encode mentions. # 'https://matrix.to/#/%40_discord_...%3Adomain.tld' mentions.extend( - re.findall(self.mention_regex(encode=True), event.formatted_body) + re.finditer( + self.mention_regex(encode=True, id_as_group=True), + event.formatted_body, + ) ) with Cache.lock: @@ -361,19 +373,20 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""", for mention in set(mentions): # Unquote just in-case we matched an encoded username. - username = self.db.fetch_user(urllib.parse.unquote(mention)).get( - "username" - ) + username = self.db.fetch_user( + urllib.parse.unquote(mention.group(0)) + ).get("username") if username: - match = re.search(self.id_regex, mention) - - if match: + if mention.group(2): + # Replace mention with plain text for hashed users (webhooks) + message = message.replace(mention.group(0), f"@{username}") + else: # Replace the 'mention' so that the user is tagged # in the case of replies aswell. # '> <@_discord_1234:localhost> Message' - for replace in (mention, username): + for replace in (mention.group(0), username): message = message.replace( - replace, f"<@{match.group()}>" + replace, f"<@{mention.group(1)}>" ) # We trim the message later as emotes take up extra characters too. @@ -456,9 +469,10 @@ class DiscordClient(Gateway): or message.webhook_id in hook_ids ) - def matrixify(self, id: str, user: bool = False) -> str: + def matrixify(self, id: str, user: bool = False, hashed: str = "") -> str: return ( - f"{'@' if user else '#'}{self.app.format}{id}:" + f"{'@' if user else '#'}{self.app.format}" + f"{id}{'-' + hashed if hashed else ''}:" f"{self.app.server_name}" ) @@ -490,11 +504,11 @@ class DiscordClient(Gateway): Discord user. """ - if message.webhook_id and not message.application_id: - hashed = hash_str(message.author.username) - message.author.id = str(int(message.author.id) + hashed) + hashed = "" + if message.webhook_id and message.webhook_id != message.application_id: + hashed = str(hash_str(message.author.username)) - mxid = self.matrixify(message.author.id, user=True) + mxid = self.matrixify(message.author.id, user=True, hashed=hashed) room_id = self.app.get_room_id(self.matrixify(message.channel_id)) if not self.app.db.fetch_user(mxid): @@ -666,14 +680,23 @@ class DiscordClient(Gateway): f"<@{char}{member.id}>", f"@{member.username}" ) - # `except_deleted` for invalid channels. - # TODO can this block for too long ? - for channel in re.findall(f"<#([0-9]{{{discord.ID_LEN}}})>", content): - discord_channel = except_deleted(self.get_channel)(channel) - name = ( - discord_channel.name if discord_channel else "deleted-channel" - ) - content = content.replace(f"<#{channel}>", f"#{name}") + # Replace channel IDs with names. + channels = re.findall("<#([0-9]+)>", content) + if channels: + if not message.guild_id: + self.logger.warning( + f"Message '{message.id}' in channel '{message.channel_id}' does not have a guild_id!" + ) + else: + discord_channels = self.get_channels(message.guild_id) + for channel in channels: + discord_channel = discord_channels.get(channel) + name = ( + discord_channel.name + if discord_channel + else "deleted-channel" + ) + content = content.replace(f"<#{channel}>", f"#{name}") # { "emote_name": "emote_id" } for emote in re.findall(regex, content): diff --git a/appservice/misc.py b/appservice/misc.py index 7c69459..f2e500d 100644 --- a/appservice/misc.py +++ b/appservice/misc.py @@ -73,13 +73,12 @@ def except_deleted(fn): def hash_str(string: str) -> int: """ - Create the hash for a string (poorly). + Create the hash for a string """ - hashed = 0 - results = map(ord, string) + hash = 5381 - for result in results: - hashed += result + for ch in string: + hash = ((hash << 5) + hash) + ord(ch) - return hashed + return hash & 0xFFFFFFFF