Make Discord id length variable and fix webhook hash (#16)
* prevent duplicate user for interactions * fix for older discord accounts * check channel mentions against full channel list * Fix compatibility for Python 3.7 Replace dict with Dict object from typing module * remove scary hashing * always expect guild_id * change hash to djb2 * Revert "always expect guild_id" This reverts commit dbcb3d1b9c97f6ceda0cf982b4bd7228926112c3. * guild_id warning, don't group bot created webhooks * fmt Co-authored-by: Friskygote <7283122+Friskygote@users.noreply.github.com> Co-authored-by: Wolf Gupta <e817509a-8ee9-4332-b0ad-3a6bdf9ab63f@aleeas.com>
This commit is contained in:
parent
bae2716aef
commit
865c1eb9f4
4 changed files with 73 additions and 37 deletions
|
@ -3,7 +3,6 @@ from dataclasses import dataclass
|
||||||
from misc import dict_cls
|
from misc import dict_cls
|
||||||
|
|
||||||
CDN_URL = "https://cdn.discordapp.com"
|
CDN_URL = "https://cdn.discordapp.com"
|
||||||
ID_LEN = 18
|
|
||||||
MESSAGE_LIMIT = 2000
|
MESSAGE_LIMIT = 2000
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,6 +106,9 @@ class Message:
|
||||||
self.channel_id = message["channel_id"]
|
self.channel_id = message["channel_id"]
|
||||||
self.content = message.get("content", "")
|
self.content = message.get("content", "")
|
||||||
self.id = message["id"]
|
self.id = message["id"]
|
||||||
|
self.guild_id = message.get(
|
||||||
|
"guild_id", ""
|
||||||
|
) # Responses for sending webhook messages don't have guild_id
|
||||||
self.webhook_id = message.get("webhook_id", "")
|
self.webhook_id = message.get("webhook_id", "")
|
||||||
self.application_id = message.get("application_id", "")
|
self.application_id = message.get("application_id", "")
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
import websockets
|
import websockets
|
||||||
|
@ -148,6 +148,18 @@ class Gateway:
|
||||||
|
|
||||||
return dict_cls(resp, discord.Channel)
|
return dict_cls(resp, discord.Channel)
|
||||||
|
|
||||||
|
def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]:
|
||||||
|
"""
|
||||||
|
Get all channels for a given guild ID.
|
||||||
|
"""
|
||||||
|
|
||||||
|
resp = self.send("GET", f"/guilds/{guild_id}/channels")
|
||||||
|
|
||||||
|
return {
|
||||||
|
channel["id"]: dict_cls(channel, discord.Channel)
|
||||||
|
for channel in resp
|
||||||
|
}
|
||||||
|
|
||||||
def get_emotes(self, guild_id: str) -> List[discord.Emote]:
|
def get_emotes(self, guild_id: str) -> List[discord.Emote]:
|
||||||
"""
|
"""
|
||||||
Get all the emotes for a given guild.
|
Get all the emotes for a given guild.
|
||||||
|
|
|
@ -5,11 +5,11 @@ import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
import urllib.parse
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import markdown
|
import markdown
|
||||||
import urllib3
|
import urllib3
|
||||||
import urllib.parse
|
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import matrix
|
import matrix
|
||||||
|
@ -28,7 +28,7 @@ class MatrixClient(AppService):
|
||||||
self.db = DataBase(config["database"])
|
self.db = DataBase(config["database"])
|
||||||
self.discord = DiscordClient(self, config, http)
|
self.discord = DiscordClient(self, config, http)
|
||||||
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
|
self.format = "_discord_" # "{@,#}_discord_1234:localhost"
|
||||||
self.id_regex = f"[0-9]{{{discord.ID_LEN}}}"
|
self.id_regex = "[0-9]+" # Snowflakes may have variable length
|
||||||
|
|
||||||
# TODO Find a cleaner way to use these keys.
|
# TODO Find a cleaner way to use these keys.
|
||||||
for k in ("m_emotes", "m_members", "m_messages"):
|
for k in ("m_emotes", "m_members", "m_messages"):
|
||||||
|
@ -329,28 +329,40 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def mention_regex(self, encode: bool) -> str:
|
def mention_regex(self, encode: bool, id_as_group: bool) -> str:
|
||||||
mention = "@"
|
mention = "@"
|
||||||
colon = ":"
|
colon = ":"
|
||||||
|
snowflake = self.id_regex
|
||||||
|
|
||||||
if encode:
|
if encode:
|
||||||
mention = urllib.parse.quote(mention)
|
mention = urllib.parse.quote(mention)
|
||||||
colon = urllib.parse.quote(colon)
|
colon = urllib.parse.quote(colon)
|
||||||
|
|
||||||
return f"{mention}{self.format}{self.id_regex}{colon}{re.escape(self.server_name)}"
|
if id_as_group:
|
||||||
|
snowflake = f"({snowflake})"
|
||||||
|
|
||||||
|
hashed = f"(?:-{snowflake})?"
|
||||||
|
|
||||||
|
return f"{mention}{self.format}{snowflake}{hashed}{colon}{re.escape(self.server_name)}"
|
||||||
|
|
||||||
def process_message(self, event: matrix.Event) -> str:
|
def process_message(self, event: matrix.Event) -> str:
|
||||||
message = event.new_body if event.new_body else event.body
|
message = event.new_body if event.new_body else event.body
|
||||||
|
|
||||||
emotes = re.findall(r":(\w*):", message)
|
emotes = re.findall(r":(\w*):", message)
|
||||||
|
|
||||||
mentions = re.findall(
|
mentions = list(
|
||||||
self.mention_regex(encode=False), event.formatted_body
|
re.finditer(
|
||||||
|
self.mention_regex(encode=False, id_as_group=True),
|
||||||
|
event.formatted_body,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# For clients that properly encode mentions.
|
# For clients that properly encode mentions.
|
||||||
# 'https://matrix.to/#/%40_discord_...%3Adomain.tld'
|
# 'https://matrix.to/#/%40_discord_...%3Adomain.tld'
|
||||||
mentions.extend(
|
mentions.extend(
|
||||||
re.findall(self.mention_regex(encode=True), event.formatted_body)
|
re.finditer(
|
||||||
|
self.mention_regex(encode=True, id_as_group=True),
|
||||||
|
event.formatted_body,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
with Cache.lock:
|
with Cache.lock:
|
||||||
|
@ -361,19 +373,20 @@ height=\"32\" src=\"{emote_}\" data-mx-emoticon />""",
|
||||||
|
|
||||||
for mention in set(mentions):
|
for mention in set(mentions):
|
||||||
# Unquote just in-case we matched an encoded username.
|
# Unquote just in-case we matched an encoded username.
|
||||||
username = self.db.fetch_user(urllib.parse.unquote(mention)).get(
|
username = self.db.fetch_user(
|
||||||
"username"
|
urllib.parse.unquote(mention.group(0))
|
||||||
)
|
).get("username")
|
||||||
if username:
|
if username:
|
||||||
match = re.search(self.id_regex, mention)
|
if mention.group(2):
|
||||||
|
# Replace mention with plain text for hashed users (webhooks)
|
||||||
if match:
|
message = message.replace(mention.group(0), f"@{username}")
|
||||||
|
else:
|
||||||
# Replace the 'mention' so that the user is tagged
|
# Replace the 'mention' so that the user is tagged
|
||||||
# in the case of replies aswell.
|
# in the case of replies aswell.
|
||||||
# '> <@_discord_1234:localhost> Message'
|
# '> <@_discord_1234:localhost> Message'
|
||||||
for replace in (mention, username):
|
for replace in (mention.group(0), username):
|
||||||
message = message.replace(
|
message = message.replace(
|
||||||
replace, f"<@{match.group()}>"
|
replace, f"<@{mention.group(1)}>"
|
||||||
)
|
)
|
||||||
|
|
||||||
# We trim the message later as emotes take up extra characters too.
|
# We trim the message later as emotes take up extra characters too.
|
||||||
|
@ -456,9 +469,10 @@ class DiscordClient(Gateway):
|
||||||
or message.webhook_id in hook_ids
|
or message.webhook_id in hook_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
def matrixify(self, id: str, user: bool = False) -> str:
|
def matrixify(self, id: str, user: bool = False, hashed: str = "") -> str:
|
||||||
return (
|
return (
|
||||||
f"{'@' if user else '#'}{self.app.format}{id}:"
|
f"{'@' if user else '#'}{self.app.format}"
|
||||||
|
f"{id}{'-' + hashed if hashed else ''}:"
|
||||||
f"{self.app.server_name}"
|
f"{self.app.server_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -490,11 +504,11 @@ class DiscordClient(Gateway):
|
||||||
Discord user.
|
Discord user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if message.webhook_id and not message.application_id:
|
hashed = ""
|
||||||
hashed = hash_str(message.author.username)
|
if message.webhook_id and message.webhook_id != message.application_id:
|
||||||
message.author.id = str(int(message.author.id) + hashed)
|
hashed = str(hash_str(message.author.username))
|
||||||
|
|
||||||
mxid = self.matrixify(message.author.id, user=True)
|
mxid = self.matrixify(message.author.id, user=True, hashed=hashed)
|
||||||
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
|
room_id = self.app.get_room_id(self.matrixify(message.channel_id))
|
||||||
|
|
||||||
if not self.app.db.fetch_user(mxid):
|
if not self.app.db.fetch_user(mxid):
|
||||||
|
@ -666,12 +680,21 @@ class DiscordClient(Gateway):
|
||||||
f"<@{char}{member.id}>", f"@{member.username}"
|
f"<@{char}{member.id}>", f"@{member.username}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# `except_deleted` for invalid channels.
|
# Replace channel IDs with names.
|
||||||
# TODO can this block for too long ?
|
channels = re.findall("<#([0-9]+)>", content)
|
||||||
for channel in re.findall(f"<#([0-9]{{{discord.ID_LEN}}})>", content):
|
if channels:
|
||||||
discord_channel = except_deleted(self.get_channel)(channel)
|
if not message.guild_id:
|
||||||
|
self.logger.warning(
|
||||||
|
f"Message '{message.id}' in channel '{message.channel_id}' does not have a guild_id!"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
discord_channels = self.get_channels(message.guild_id)
|
||||||
|
for channel in channels:
|
||||||
|
discord_channel = discord_channels.get(channel)
|
||||||
name = (
|
name = (
|
||||||
discord_channel.name if discord_channel else "deleted-channel"
|
discord_channel.name
|
||||||
|
if discord_channel
|
||||||
|
else "deleted-channel"
|
||||||
)
|
)
|
||||||
content = content.replace(f"<#{channel}>", f"#{name}")
|
content = content.replace(f"<#{channel}>", f"#{name}")
|
||||||
|
|
||||||
|
|
|
@ -73,13 +73,12 @@ def except_deleted(fn):
|
||||||
|
|
||||||
def hash_str(string: str) -> int:
|
def hash_str(string: str) -> int:
|
||||||
"""
|
"""
|
||||||
Create the hash for a string (poorly).
|
Create the hash for a string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hashed = 0
|
hash = 5381
|
||||||
results = map(ord, string)
|
|
||||||
|
|
||||||
for result in results:
|
for ch in string:
|
||||||
hashed += result
|
hash = ((hash << 5) + hash) + ord(ch)
|
||||||
|
|
||||||
return hashed
|
return hash & 0xFFFFFFFF
|
||||||
|
|
Loading…
Reference in a new issue