diff --git a/main.py b/main.py index 024a35f..8414a41 100644 --- a/main.py +++ b/main.py @@ -29,25 +29,24 @@ def config_gen(config_file): config = config_gen("config.json") -logging.basicConfig(level=logging.INFO) -matrix_logger = logging.getLogger("matrix_logger") - -message_store = {} -channel_store = {} +message_store, channel_store = {}, {} class MatrixClient(nio.AsyncClient): async def create(self, discord_client): + self.logger = logging.getLogger("matrix_logger") + password = config["password"] timeout = 30000 - matrix_logger.info(await self.login(password)) + self.logger.info(await self.login(password)) - matrix_logger.info("Doing initial sync.") + self.logger.info("Doing initial sync.") await self.sync(timeout) # Set up event callbacks - callbacks = Callbacks(self, self.process_message) + callbacks = Callbacks(self) + self.add_event_callback( callbacks.message_callback, (nio.RoomMessageText, nio.RoomMessageMedia, @@ -64,7 +63,7 @@ class MatrixClient(nio.AsyncClient): await discord_client.wait_until_ready() - matrix_logger.info("Syncing forever.") + self.logger.info("Syncing forever.") await self.sync_forever(timeout=timeout) await self.close() @@ -141,29 +140,7 @@ class MatrixClient(nio.AsyncClient): message_store[event_id] = hook message_store[hook.id] = event_id except discord.errors.HTTPException as e: - matrix_logger.warning(f"Failed to send message {event_id}: {e}") - - async def process_message(self, message, channel_id): - mentions = re.findall(r"(^|\s)(@(\w*))", message) - emotes = re.findall(r":(\w*):", message) - - guild = channel_store[channel_id].guild - - added_emotes = [] - for emote in emotes: - if emote not in added_emotes: - added_emotes.append(emote) - emote_ = discord.utils.get(guild.emojis, name=emote) - if emote_: - message = message.replace(f":{emote}:", str(emote_)) - - for mention in mentions: - if mention[2] != "": - member = await guild.query_members(query=mention[2]) - if member: - message = message.replace(mention[1], member[0].mention) - - return message + self.logger.warning(f"Failed to send message {event_id}: {e}") class DiscordClient(discord.Client): @@ -171,7 +148,8 @@ class DiscordClient(discord.Client): super().__init__(*args, **kwargs) self.matrix_client = MatrixClient( - config["homeserver"], config["username"]) + config["homeserver"], config["username"] + ) self.bg_task = self.loop.create_task(self.matrix_client.create(self)) @@ -248,9 +226,8 @@ class DiscordClient(discord.Client): class Callbacks(object): - def __init__(self, client, process_message): - self.client = client - self.process_message = process_message + def __init__(self, matrix_client): + self.matrix_client = matrix_client def get_channel(self, room): channel_id = next( @@ -263,7 +240,7 @@ class Callbacks(object): async def message_callback(self, room, event): # Ignore messages from ourselves or other rooms if room.room_id not in config["bridge"].values() or \ - event.sender == self.client.user: + event.sender == self.matrix_client.user: return message = event.body @@ -292,7 +269,7 @@ class Callbacks(object): try: await webhook_message.edit(content=edited_content) except discord.errors.NotFound as e: - matrix_logger.warning( + self.matrix_client.logger.warning( f"Failed to edit message {edited_event}: {e}" ) @@ -331,14 +308,14 @@ class Callbacks(object): avatar = f"{url}/{homeserver}/{avatar}" break - await self.client.webhook_send( + await self.matrix_client.webhook_send( author, avatar, message, event.event_id, channel_id ) async def redaction_callback(self, room, event): # Ignore messages from ourselves or other rooms if room.room_id not in config["bridge"].values() or \ - event.sender == self.client.user: + event.sender == self.matrix_client.user: return # Redact webhook message @@ -346,7 +323,7 @@ class Callbacks(object): message = message_store[event.redacts] await message.delete() except discord.errors.NotFound as e: - matrix_logger.warning( + self.matrix_client.logger.warning( f"Failed to delete message {event.event_id}: {e}" ) except KeyError: @@ -360,7 +337,7 @@ class Callbacks(object): if room.typing_users: # Ignore events from ourselves if len(room.typing_users) == 1 \ - and room.typing_users[0] == self.client.user: + and room.typing_users[0] == self.matrix_client.user: return channel_id = self.get_channel(room) @@ -369,16 +346,40 @@ class Callbacks(object): async with channel_store[channel_id].typing(): pass + async def process_message(self, message, channel_id): + mentions = re.findall(r"(^|\s)(@(\w*))", message) + emotes = re.findall(r":(\w*):", message) + + guild = channel_store[channel_id].guild + + added_emotes = [] + for emote in emotes: + if emote not in added_emotes: + added_emotes.append(emote) + emote_ = discord.utils.get(guild.emojis, name=emote) + if emote_: + message = message.replace(f":{emote}:", str(emote_)) + + for mention in mentions: + if mention[2] != "": + member = await guild.query_members(query=mention[2]) + if member: + message = message.replace(mention[1], member[0].mention) + + return message + def main(): + logging.basicConfig(level=logging.INFO) + intents = discord.Intents.default() intents.members = True allowed_mentions = discord.AllowedMentions(everyone=False, roles=False) - DiscordClient(intents=intents, allowed_mentions=allowed_mentions).run( - config["token"] - ) + DiscordClient( + intents=intents, allowed_mentions=allowed_mentions + ).run(config["token"]) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index 1adeabd..0b402a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -git+git://github.com/rapptz/discord.py@44dc7a8 +git+git://github.com/rapptz/discord.py@22cb4ef matrix-nio==0.15.2