chore: retab, delete discord file

This commit is contained in:
əlemi 2024-01-29 03:29:16 +01:00
parent 461127a4ac
commit 947217abe3
Signed by: alemi
GPG key ID: A4895B84D311642C
9 changed files with 1106 additions and 1324 deletions

View file

@ -13,191 +13,191 @@ from misc import log_except, request
class AppService(bottle.Bottle): class AppService(bottle.Bottle):
def __init__(self, config: dict, http: urllib3.PoolManager) -> None: def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
super(AppService, self).__init__() super(AppService, self).__init__()
self.as_token = config["as_token"] self.as_token = config["as_token"]
self.hs_token = config["hs_token"] self.hs_token = config["hs_token"]
self.base_url = config["homeserver"] self.base_url = config["homeserver"]
self.server_name = config["server_name"] self.server_name = config["server_name"]
self.user_id = f"@{config['user_id']}:{self.server_name}" self.user_id = f"@{config['user_id']}:{self.server_name}"
self.http = http self.http = http
self.logger = logging.getLogger("appservice") self.logger = logging.getLogger("appservice")
# Map events to functions. # Map events to functions.
self.mapping = { self.mapping = {
"m.room.member": "on_member", "m.room.member": "on_member",
"m.room.message": "on_message", "m.room.message": "on_message",
"m.room.redaction": "on_redaction", "m.room.redaction": "on_redaction",
} }
# Add route for bottle. # Add route for bottle.
self.route( self.route(
"/transactions/<transaction>", "/transactions/<transaction>",
callback=self.receive_event, callback=self.receive_event,
method="PUT", method="PUT",
) )
Cache.cache["m_rooms"] = {} Cache.cache["m_rooms"] = {}
def handle_event(self, event: dict) -> None: def handle_event(self, event: dict) -> None:
event_type = event.get("type") event_type = event.get("type")
if event_type in ( if event_type in (
"m.room.member", "m.room.member",
"m.room.message", "m.room.message",
"m.room.redaction", "m.room.redaction",
): ):
obj = matrix.Event(event) obj = matrix.Event(event)
else: else:
self.logger.info(f"Unknown event type: {event_type}") self.logger.info(f"Unknown event type: {event_type}")
return return
func = getattr(self, self.mapping[event_type], None) func = getattr(self, self.mapping[event_type], None)
if not func: if not func:
self.logger.warning( self.logger.warning(
f"Function '{func}' not defined, ignoring event." f"Function '{func}' not defined, ignoring event."
) )
return return
# We don't catch exceptions here as the homeserver will re-send us # We don't catch exceptions here as the homeserver will re-send us
# the event in case of a failure. # the event in case of a failure.
func(obj) func(obj)
@log_except @log_except
def receive_event(self, transaction: str) -> dict: def receive_event(self, transaction: str) -> dict:
""" """
Verify the homeserver's token and handle events. Verify the homeserver's token and handle events.
""" """
hs_token = bottle.request.query.getone("access_token") hs_token = bottle.request.query.getone("access_token")
if not hs_token: if not hs_token:
bottle.response.status = 401 bottle.response.status = 401
return {"errcode": "APPSERVICE_UNAUTHORIZED"} return {"errcode": "APPSERVICE_UNAUTHORIZED"}
if hs_token != self.hs_token: if hs_token != self.hs_token:
bottle.response.status = 403 bottle.response.status = 403
return {"errcode": "APPSERVICE_FORBIDDEN"} return {"errcode": "APPSERVICE_FORBIDDEN"}
events = bottle.request.json.get("events") events = bottle.request.json.get("events")
for event in events: for event in events:
self.handle_event(event) self.handle_event(event)
return {} return {}
def mxc_url(self, mxc: str) -> str: def mxc_url(self, mxc: str) -> str:
try: try:
homeserver, media_id = mxc.replace("mxc://", "").split("/") homeserver, media_id = mxc.replace("mxc://", "").split("/")
except ValueError: except ValueError:
return "" return ""
return ( return (
f"https://{self.server_name}/_matrix/media/r0/download/" f"https://{self.server_name}/_matrix/media/r0/download/"
f"{homeserver}/{media_id}" f"{homeserver}/{media_id}"
) )
def join_room(self, room_id: str, mxid: str = "") -> None: def join_room(self, room_id: str, mxid: str = "") -> None:
self.send( self.send(
"POST", "POST",
f"/join/{room_id}", f"/join/{room_id}",
params={"user_id": mxid} if mxid else {}, params={"user_id": mxid} if mxid else {},
) )
def redact(self, event_id: str, room_id: str, mxid: str = "") -> None: def redact(self, event_id: str, room_id: str, mxid: str = "") -> None:
self.send( self.send(
"PUT", "PUT",
f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}", f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}",
params={"user_id": mxid} if mxid else {}, params={"user_id": mxid} if mxid else {},
) )
def get_room_id(self, alias: str) -> str: def get_room_id(self, alias: str) -> str:
with Cache.lock: with Cache.lock:
room = Cache.cache["m_rooms"].get(alias) room = Cache.cache["m_rooms"].get(alias)
if room: if room:
return room return room
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}") resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
room_id = resp["room_id"] room_id = resp["room_id"]
with Cache.lock: with Cache.lock:
Cache.cache["m_rooms"][alias] = room_id Cache.cache["m_rooms"][alias] = room_id
return room_id return room_id
def get_event(self, event_id: str, room_id: str) -> matrix.Event: def get_event(self, event_id: str, room_id: str) -> matrix.Event:
resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}") resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}")
return matrix.Event(resp) return matrix.Event(resp)
def upload(self, url: str) -> str: def upload(self, url: str) -> str:
""" """
Upload a file to the homeserver and get the MXC url. Upload a file to the homeserver and get the MXC url.
""" """
resp = self.http.request("GET", url) resp = self.http.request("GET", url)
resp = self.send( resp = self.send(
"POST", "POST",
content=resp.data, content=resp.data,
content_type=resp.headers.get("Content-Type"), content_type=resp.headers.get("Content-Type"),
params={"filename": f"{uuid.uuid4()}"}, params={"filename": f"{uuid.uuid4()}"},
endpoint="/_matrix/media/r0/upload", endpoint="/_matrix/media/r0/upload",
) )
return resp["content_uri"] return resp["content_uri"]
def send_message( def send_message(
self, self,
room_id: str, room_id: str,
content: dict, content: dict,
mxid: str = "", mxid: str = "",
) -> str: ) -> str:
resp = self.send( resp = self.send(
"PUT", "PUT",
f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}", f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}",
content, content,
{"user_id": mxid} if mxid else {}, {"user_id": mxid} if mxid else {},
) )
return resp["event_id"] return resp["event_id"]
def send_typing( def send_typing(
self, room_id: str, mxid: str = "", timeout: int = 8000 self, room_id: str, mxid: str = "", timeout: int = 8000
) -> None: ) -> None:
self.send( self.send(
"PUT", "PUT",
f"/rooms/{room_id}/typing/{mxid}", f"/rooms/{room_id}/typing/{mxid}",
{"typing": True, "timeout": timeout}, {"typing": True, "timeout": timeout},
{"user_id": mxid} if mxid else {}, {"user_id": mxid} if mxid else {},
) )
def send_invite(self, room_id: str, mxid: str) -> None: def send_invite(self, room_id: str, mxid: str) -> None:
self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid}) self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid})
@request @request
def send( def send(
self, self,
method: str, method: str,
path: str = "", path: str = "",
content: Union[bytes, dict] = {}, content: Union[bytes, dict] = {},
params: dict = {}, params: dict = {},
content_type: str = "application/json", content_type: str = "application/json",
endpoint: str = "/_matrix/client/r0", endpoint: str = "/_matrix/client/r0",
) -> dict: ) -> dict:
headers = { headers = {
"Authorization": f"Bearer {self.as_token}", "Authorization": f"Bearer {self.as_token}",
"Content-Type": content_type, "Content-Type": content_type,
} }
payload = json.dumps(content) if isinstance(content, dict) else content payload = json.dumps(content) if isinstance(content, dict) else content
endpoint = ( endpoint = (
f"{self.base_url}{endpoint}{path}?" f"{self.base_url}{endpoint}{path}?"
f"{urllib.parse.urlencode(params)}" f"{urllib.parse.urlencode(params)}"
) )
return self.http.request( return self.http.request(
method, endpoint, body=payload, headers=headers method, endpoint, body=payload, headers=headers
) )

View file

@ -2,5 +2,5 @@ import threading
class Cache: class Cache:
cache = {} cache = {}
lock = threading.Lock() lock = threading.Lock()

View file

@ -5,116 +5,116 @@ from typing import List
class DataBase: class DataBase:
def __init__(self, db_file) -> None: def __init__(self, db_file) -> None:
self.create(db_file) self.create(db_file)
# The database is accessed via multiple threads. # The database is accessed via multiple threads.
self.lock = threading.Lock() self.lock = threading.Lock()
def create(self, db_file) -> None: def create(self, db_file) -> None:
""" """
Create a database with the relevant tables if it doesn't already exist. Create a database with the relevant tables if it doesn't already exist.
""" """
exists = os.path.exists(db_file) exists = os.path.exists(db_file)
self.conn = sqlite3.connect(db_file, check_same_thread=False) self.conn = sqlite3.connect(db_file, check_same_thread=False)
self.conn.row_factory = self.dict_factory self.conn.row_factory = self.dict_factory
self.cur = self.conn.cursor() self.cur = self.conn.cursor()
if exists: if exists:
return return
self.cur.execute( self.cur.execute(
"CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);" "CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);"
) )
self.cur.execute( self.cur.execute(
"CREATE TABLE users(mxid TEXT PRIMARY KEY, " "CREATE TABLE users(mxid TEXT PRIMARY KEY, "
"avatar_url TEXT, username TEXT);" "avatar_url TEXT, username TEXT);"
) )
self.conn.commit() self.conn.commit()
def dict_factory(self, cursor, row): def dict_factory(self, cursor, row):
""" """
https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory
""" """
d = {} d = {}
for idx, col in enumerate(cursor.description): for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx] d[col[0]] = row[idx]
return d return d
def add_room(self, room_id: str, channel_id: str) -> None: def add_room(self, room_id: str, channel_id: str) -> None:
""" """
Add a bridged room to the database. Add a bridged room to the database.
""" """
with self.lock: with self.lock:
self.cur.execute( self.cur.execute(
"INSERT INTO bridge (room_id, channel_id) VALUES (?, ?)", "INSERT INTO bridge (room_id, channel_id) VALUES (?, ?)",
[room_id, channel_id], [room_id, channel_id],
) )
self.conn.commit() self.conn.commit()
def add_user(self, mxid: str) -> None: def add_user(self, mxid: str) -> None:
with self.lock: with self.lock:
self.cur.execute("INSERT INTO users (mxid) VALUES (?)", [mxid]) self.cur.execute("INSERT INTO users (mxid) VALUES (?)", [mxid])
self.conn.commit() self.conn.commit()
def add_avatar(self, avatar_url: str, mxid: str) -> None: def add_avatar(self, avatar_url: str, mxid: str) -> None:
with self.lock: with self.lock:
self.cur.execute( self.cur.execute(
"UPDATE users SET avatar_url = (?) WHERE mxid = (?)", "UPDATE users SET avatar_url = (?) WHERE mxid = (?)",
[avatar_url, mxid], [avatar_url, mxid],
) )
self.conn.commit() self.conn.commit()
def add_username(self, username: str, mxid: str) -> None: def add_username(self, username: str, mxid: str) -> None:
with self.lock: with self.lock:
self.cur.execute( self.cur.execute(
"UPDATE users SET username = (?) WHERE mxid = (?)", "UPDATE users SET username = (?) WHERE mxid = (?)",
[username, mxid], [username, mxid],
) )
self.conn.commit() self.conn.commit()
def get_channel(self, room_id: str) -> str: def get_channel(self, room_id: str) -> str:
""" """
Get the corresponding channel ID for a given room ID. Get the corresponding channel ID for a given room ID.
""" """
with self.lock: with self.lock:
self.cur.execute( self.cur.execute(
"SELECT channel_id FROM bridge WHERE room_id = ?", [room_id] "SELECT channel_id FROM bridge WHERE room_id = ?", [room_id]
) )
room = self.cur.fetchone() room = self.cur.fetchone()
# Return an empty string if the channel is not bridged. # Return an empty string if the channel is not bridged.
return "" if not room else room["channel_id"] return "" if not room else room["channel_id"]
def list_channels(self) -> List[str]: def list_channels(self) -> List[str]:
""" """
Get a list of all the bridged channels. Get a list of all the bridged channels.
""" """
with self.lock: with self.lock:
self.cur.execute("SELECT channel_id FROM bridge") self.cur.execute("SELECT channel_id FROM bridge")
channels = self.cur.fetchall() channels = self.cur.fetchall()
return [channel["channel_id"] for channel in channels] return [channel["channel_id"] for channel in channels]
def fetch_user(self, mxid: str) -> dict: def fetch_user(self, mxid: str) -> dict:
""" """
Fetch the profile for a bridged user. Fetch the profile for a bridged user.
""" """
with self.lock: with self.lock:
self.cur.execute("SELECT * FROM users where mxid = ?", [mxid]) self.cur.execute("SELECT * FROM users where mxid = ?", [mxid])
user = self.cur.fetchone() user = self.cur.fetchone()
return {} if not user else user return {} if not user else user

View file

@ -1,218 +0,0 @@
from dataclasses import dataclass
from misc import dict_cls
CDN_URL = "https://cdn.discordapp.com"
MESSAGE_LIMIT = 2000
def bitmask(bit: int) -> int:
return 1 << bit
@dataclass
class Channel:
id: str
type: str
guild_id: str = ""
name: str = ""
topic: str = ""
@dataclass
class Emote:
animated: bool
id: str
name: str
@dataclass
class MessageReference:
message_id: str
channel_id: str
guild_id: str
@dataclass
class Sticker:
name: str
id: str
format_type: int
@dataclass
class Typing:
user_id: str
channel_id: str
@dataclass
class Webhook:
id: str
token: str
class User:
def __init__(self, user: dict) -> None:
self.discriminator = user["discriminator"]
self.id = user["id"]
self.mention = f"<@{self.id}>"
self.username = user["username"]
avatar = user["avatar"]
if not avatar:
# https://discord.com/developers/docs/reference#image-formatting
self.avatar_url = (
f"{CDN_URL}/embed/avatars/{int(self.discriminator) % 5}.png"
)
else:
ext = "gif" if avatar.startswith("a_") else "png"
self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}"
class Guild:
def __init__(self, guild: dict) -> None:
self.guild_id = guild["id"]
self.channels = [dict_cls(c, Channel) for c in guild["channels"]]
self.emojis = [dict_cls(e, Emote) for e in guild["emojis"]]
members = [member["user"] for member in guild["members"]]
self.members = [User(m) for m in members]
class GuildEmojisUpdate:
def __init__(self, update: dict) -> None:
self.guild_id = update["guild_id"]
self.emojis = [dict_cls(e, Emote) for e in update["emojis"]]
class GuildMembersChunk:
def __init__(self, chunk: dict) -> None:
self.chunk_index = chunk["chunk_index"]
self.chunk_count = chunk["chunk_count"]
self.guild_id = chunk["guild_id"]
self.members = [User(m) for m in chunk["members"]]
class GuildMemberUpdate:
def __init__(self, update: dict) -> None:
self.guild_id = update["guild_id"]
self.user = User(update["user"])
class Message:
def __init__(self, message: dict) -> None:
self.attachments = message.get("attachments", [])
self.channel_id = message["channel_id"]
self.content = message.get("content", "")
self.id = message["id"]
self.guild_id = message.get(
"guild_id", ""
) # Responses for sending webhook messages don't have guild_id
self.webhook_id = message.get("webhook_id", "")
self.application_id = message.get("application_id", "")
self.mentions = [
User(mention) for mention in message.get("mentions", [])
]
ref = message.get("referenced_message")
self.referenced_message = Message(ref) if ref else None
author = message.get("author")
self.author = User(author) if author else None
self.stickers = [
dict_cls(sticker, Sticker)
for sticker in message.get("sticker_items", [])
]
class ChannelType:
GUILD_TEXT = 0
DM = 1
GUILD_VOICE = 2
GROUP_DM = 3
GUILD_CATEGORY = 4
GUILD_NEWS = 5
GUILD_STORE = 6
class InteractionResponseType:
PONG = 0
ACKNOWLEDGE = 1
CHANNEL_MESSAGE = 2
CHANNEL_MESSAGE_WITH_SOURCE = 4
ACKNOWLEDGE_WITH_SOURCE = 5
class GatewayIntents:
GUILDS = bitmask(0)
GUILD_MEMBERS = bitmask(1)
GUILD_BANS = bitmask(2)
GUILD_EMOJIS = bitmask(3)
GUILD_INTEGRATIONS = bitmask(4)
GUILD_WEBHOOKS = bitmask(5)
GUILD_INVITES = bitmask(6)
GUILD_VOICE_STATES = bitmask(7)
GUILD_PRESENCES = bitmask(8)
GUILD_MESSAGES = bitmask(9)
GUILD_MESSAGE_REACTIONS = bitmask(10)
GUILD_MESSAGE_TYPING = bitmask(11)
DIRECT_MESSAGES = bitmask(12)
DIRECT_MESSAGE_REACTIONS = bitmask(13)
DIRECT_MESSAGE_TYPING = bitmask(14)
class GatewayOpCodes:
DISPATCH = 0
HEARTBEAT = 1
IDENTIFY = 2
PRESENCE_UPDATE = 3
VOICE_STATE_UPDATE = 4
RESUME = 6
RECONNECT = 7
REQUEST_GUILD_MEMBERS = 8
INVALID_SESSION = 9
HELLO = 10
HEARTBEAT_ACK = 11
class Payloads:
def __init__(self, token: str) -> None:
self.seq = self.session = None
self.token = token
def HEARTBEAT(self) -> dict:
return {"op": GatewayOpCodes.HEARTBEAT, "d": self.seq}
def IDENTIFY(self) -> dict:
return {
"op": GatewayOpCodes.IDENTIFY,
"d": {
"token": self.token,
"intents": GatewayIntents.GUILDS
| GatewayIntents.GUILD_EMOJIS
| GatewayIntents.GUILD_MEMBERS
| GatewayIntents.GUILD_MESSAGES
| GatewayIntents.GUILD_MESSAGE_TYPING
| GatewayIntents.GUILD_PRESENCES,
"properties": {
"$os": "discord",
"$browser": "Discord Client",
"$device": "discord",
},
},
}
def RESUME(self) -> dict:
return {
"op": GatewayOpCodes.RESUME,
"d": {
"token": self.token,
"session_id": self.session,
"seq": self.seq,
},
}

View file

@ -1,5 +1,5 @@
class RequestError(Exception): class RequestError(Exception):
def __init__(self, status: int, *args): def __init__(self, status: int, *args):
super().__init__(*args) super().__init__(*args)
self.status = status self.status = status

View file

@ -12,249 +12,249 @@ from misc import dict_cls, log_except, request
class Gateway: class Gateway:
def __init__(self, http: urllib3.PoolManager, token: str): def __init__(self, http: urllib3.PoolManager, token: str):
self.http = http self.http = http
self.token = token self.token = token
self.logger = logging.getLogger("discord") self.logger = logging.getLogger("discord")
self.Payloads = discord.Payloads(self.token) self.Payloads = discord.Payloads(self.token)
self.websocket = None self.websocket = None
@log_except @log_except
async def run(self) -> None: async def run(self) -> None:
self.heartbeat_task: asyncio.Future = None self.heartbeat_task: asyncio.Future = None
self.resume = False self.resume = False
gateway_url = self.get_gateway_url() gateway_url = self.get_gateway_url()
while True: while True:
try: try:
await self.gateway_handler(gateway_url) await self.gateway_handler(gateway_url)
except ( except (
websockets.ConnectionClosedError, websockets.ConnectionClosedError,
websockets.InvalidMessage, websockets.InvalidMessage,
): ):
self.logger.exception("Connection lost, reconnecting.") self.logger.exception("Connection lost, reconnecting.")
# Stop sending heartbeats until we reconnect. # Stop sending heartbeats until we reconnect.
if self.heartbeat_task and not self.heartbeat_task.cancelled(): if self.heartbeat_task and not self.heartbeat_task.cancelled():
self.heartbeat_task.cancel() self.heartbeat_task.cancel()
def get_gateway_url(self) -> str: def get_gateway_url(self) -> str:
resp = self.send("GET", "/gateway") resp = self.send("GET", "/gateway")
return resp["url"] return resp["url"]
async def heartbeat_handler(self, interval_ms: int) -> None: async def heartbeat_handler(self, interval_ms: int) -> None:
while True: while True:
await asyncio.sleep(interval_ms / 1000) await asyncio.sleep(interval_ms / 1000)
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT())) await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
async def handle_resp(self, data: dict) -> None: async def handle_resp(self, data: dict) -> None:
data_dict = data["d"] data_dict = data["d"]
opcode = data["op"] opcode = data["op"]
seq = data["s"] seq = data["s"]
if seq: if seq:
self.Payloads.seq = seq self.Payloads.seq = seq
if opcode == discord.GatewayOpCodes.DISPATCH: if opcode == discord.GatewayOpCodes.DISPATCH:
otype = data["t"] otype = data["t"]
if otype == "READY": if otype == "READY":
self.Payloads.session = data_dict["session_id"] self.Payloads.session = data_dict["session_id"]
self.logger.info("READY") self.logger.info("READY")
else: else:
self.handle_otype(data_dict, otype) self.handle_otype(data_dict, otype)
elif opcode == discord.GatewayOpCodes.HELLO: elif opcode == discord.GatewayOpCodes.HELLO:
heartbeat_interval = data_dict.get("heartbeat_interval") heartbeat_interval = data_dict.get("heartbeat_interval")
self.logger.info(f"Heartbeat Interval: {heartbeat_interval}") self.logger.info(f"Heartbeat Interval: {heartbeat_interval}")
# Send periodic hearbeats to gateway. # Send periodic hearbeats to gateway.
self.heartbeat_task = asyncio.ensure_future( self.heartbeat_task = asyncio.ensure_future(
self.heartbeat_handler(heartbeat_interval) self.heartbeat_handler(heartbeat_interval)
) )
await self.websocket.send( await self.websocket.send(
json.dumps( json.dumps(
self.Payloads.RESUME() self.Payloads.RESUME()
if self.resume if self.resume
else self.Payloads.IDENTIFY() else self.Payloads.IDENTIFY()
) )
) )
elif opcode == discord.GatewayOpCodes.RECONNECT: elif opcode == discord.GatewayOpCodes.RECONNECT:
self.logger.info("Received RECONNECT.") self.logger.info("Received RECONNECT.")
self.resume = True self.resume = True
await self.websocket.close() await self.websocket.close()
elif opcode == discord.GatewayOpCodes.INVALID_SESSION: elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
self.logger.info("Received INVALID_SESSION.") self.logger.info("Received INVALID_SESSION.")
self.resume = False self.resume = False
await self.websocket.close() await self.websocket.close()
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK: elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
# NOP # NOP
pass pass
else: else:
self.logger.info( self.logger.info(
"Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}" "Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}"
) )
def handle_otype(self, data: dict, otype: str) -> None: def handle_otype(self, data: dict, otype: str) -> None:
if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"): if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"):
obj = discord.Message(data) obj = discord.Message(data)
elif otype == "TYPING_START": elif otype == "TYPING_START":
obj = dict_cls(data, discord.Typing) obj = dict_cls(data, discord.Typing)
elif otype == "GUILD_CREATE": elif otype == "GUILD_CREATE":
obj = discord.Guild(data) obj = discord.Guild(data)
elif otype == "GUILD_MEMBER_UPDATE": elif otype == "GUILD_MEMBER_UPDATE":
obj = discord.GuildMemberUpdate(data) obj = discord.GuildMemberUpdate(data)
elif otype == "GUILD_EMOJIS_UPDATE": elif otype == "GUILD_EMOJIS_UPDATE":
obj = discord.GuildEmojisUpdate(data) obj = discord.GuildEmojisUpdate(data)
else: else:
return return
func = getattr(self, f"on_{otype.lower()}", None) func = getattr(self, f"on_{otype.lower()}", None)
if not func: if not func:
self.logger.warning( self.logger.warning(
f"Function '{func}' not defined, ignoring message." f"Function '{func}' not defined, ignoring message."
) )
return return
try: try:
func(obj) func(obj)
except Exception: except Exception:
self.logger.exception(f"Ignoring exception in '{func.__name__}':") self.logger.exception(f"Ignoring exception in '{func.__name__}':")
async def gateway_handler(self, gateway_url: str) -> None: async def gateway_handler(self, gateway_url: str) -> None:
async with websockets.connect( async with websockets.connect(
f"{gateway_url}/?v=8&encoding=json" f"{gateway_url}/?v=8&encoding=json"
) as websocket: ) as websocket:
self.websocket = websocket self.websocket = websocket
async for message in websocket: async for message in websocket:
await self.handle_resp(json.loads(message)) await self.handle_resp(json.loads(message))
def get_channel(self, channel_id: str) -> discord.Channel: def get_channel(self, channel_id: str) -> discord.Channel:
""" """
Get the channel for a given channel ID. Get the channel for a given channel ID.
""" """
resp = self.send("GET", f"/channels/{channel_id}") resp = self.send("GET", f"/channels/{channel_id}")
return dict_cls(resp, discord.Channel) return dict_cls(resp, discord.Channel)
def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]: def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]:
""" """
Get all channels for a given guild ID. Get all channels for a given guild ID.
""" """
resp = self.send("GET", f"/guilds/{guild_id}/channels") resp = self.send("GET", f"/guilds/{guild_id}/channels")
return { return {
channel["id"]: dict_cls(channel, discord.Channel) channel["id"]: dict_cls(channel, discord.Channel)
for channel in resp for channel in resp
} }
def get_emotes(self, guild_id: str) -> List[discord.Emote]: def get_emotes(self, guild_id: str) -> List[discord.Emote]:
""" """
Get all the emotes for a given guild. Get all the emotes for a given guild.
""" """
resp = self.send("GET", f"/guilds/{guild_id}/emojis") resp = self.send("GET", f"/guilds/{guild_id}/emojis")
return [dict_cls(emote, discord.Emote) for emote in resp] return [dict_cls(emote, discord.Emote) for emote in resp]
def get_members(self, guild_id: str) -> List[discord.User]: def get_members(self, guild_id: str) -> List[discord.User]:
""" """
Get all the members for a given guild. Get all the members for a given guild.
""" """
resp = self.send( resp = self.send(
"GET", f"/guilds/{guild_id}/members", params={"limit": 1000} "GET", f"/guilds/{guild_id}/members", params={"limit": 1000}
) )
return [discord.User(member["user"]) for member in resp] return [discord.User(member["user"]) for member in resp]
def create_webhook(self, channel_id: str, name: str) -> discord.Webhook: def create_webhook(self, channel_id: str, name: str) -> discord.Webhook:
""" """
Create a webhook with the specified name in a given channel. Create a webhook with the specified name in a given channel.
""" """
resp = self.send( resp = self.send(
"POST", f"/channels/{channel_id}/webhooks", {"name": name} "POST", f"/channels/{channel_id}/webhooks", {"name": name}
) )
return dict_cls(resp, discord.Webhook) return dict_cls(resp, discord.Webhook)
def edit_webhook( def edit_webhook(
self, content: str, message_id: str, webhook: discord.Webhook self, content: str, message_id: str, webhook: discord.Webhook
) -> None: ) -> None:
self.send( self.send(
"PATCH", "PATCH",
f"/webhooks/{webhook.id}/{webhook.token}/messages/" f"/webhooks/{webhook.id}/{webhook.token}/messages/"
f"{message_id}", f"{message_id}",
{"content": content}, {"content": content},
) )
def delete_webhook( def delete_webhook(
self, message_id: str, webhook: discord.Webhook self, message_id: str, webhook: discord.Webhook
) -> None: ) -> None:
self.send( self.send(
"DELETE", "DELETE",
f"/webhooks/{webhook.id}/{webhook.token}/messages/" f"/webhooks/{webhook.id}/{webhook.token}/messages/"
f"{message_id}", f"{message_id}",
) )
def send_webhook( def send_webhook(
self, self,
webhook: discord.Webhook, webhook: discord.Webhook,
avatar_url: str, avatar_url: str,
content: str, content: str,
username: str, username: str,
) -> discord.Message: ) -> discord.Message:
payload = { payload = {
"avatar_url": avatar_url, "avatar_url": avatar_url,
"content": content, "content": content,
"username": username, "username": username,
# Disable 'everyone' and 'role' mentions. # Disable 'everyone' and 'role' mentions.
"allowed_mentions": {"parse": ["users"]}, "allowed_mentions": {"parse": ["users"]},
} }
resp = self.send( resp = self.send(
"POST", "POST",
f"/webhooks/{webhook.id}/{webhook.token}", f"/webhooks/{webhook.id}/{webhook.token}",
payload, payload,
{"wait": True}, {"wait": True},
) )
return discord.Message(resp) return discord.Message(resp)
def send_message(self, message: str, channel_id: str) -> None: def send_message(self, message: str, channel_id: str) -> None:
self.send( self.send(
"POST", f"/channels/{channel_id}/messages", {"content": message} "POST", f"/channels/{channel_id}/messages", {"content": message}
) )
@request @request
def send( def send(
self, method: str, path: str, content: dict = {}, params: dict = {} self, method: str, path: str, content: dict = {}, params: dict = {}
) -> dict: ) -> dict:
endpoint = ( endpoint = (
f"https://discord.com/api/v8{path}?" f"https://discord.com/api/v8{path}?"
f"{urllib.parse.urlencode(params)}" f"{urllib.parse.urlencode(params)}"
) )
headers = { headers = {
"Authorization": f"Bot {self.token}", "Authorization": f"Bot {self.token}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
# 'body' being an empty dict breaks "GET" requests. # 'body' being an empty dict breaks "GET" requests.
payload = json.dumps(content) if content else None payload = json.dumps(content) if content else None
return self.http.request( return self.http.request(
method, endpoint, body=payload, headers=headers method, endpoint, body=payload, headers=headers
) )

File diff suppressed because it is too large Load diff

View file

@ -3,26 +3,26 @@ from dataclasses import dataclass
@dataclass @dataclass
class User: class User:
avatar_url: str = "" avatar_url: str = ""
display_name: str = "" display_name: str = ""
class Event: class Event:
def __init__(self, event: dict): def __init__(self, event: dict):
content = event.get("content", {}) content = event.get("content", {})
self.attachment = content.get("url") self.attachment = content.get("url")
self.body = content.get("body", "").strip() self.body = content.get("body", "").strip()
self.formatted_body = content.get("formatted_body", "") self.formatted_body = content.get("formatted_body", "")
self.id = event["event_id"] self.id = event["event_id"]
self.is_direct = content.get("is_direct", False) self.is_direct = content.get("is_direct", False)
self.redacts = event.get("redacts", "") self.redacts = event.get("redacts", "")
self.room_id = event["room_id"] self.room_id = event["room_id"]
self.sender = event["sender"] self.sender = event["sender"]
self.state_key = event.get("state_key", "") self.state_key = event.get("state_key", "")
rel = content.get("m.relates_to", {}) rel = content.get("m.relates_to", {})
self.relates_to = rel.get("event_id") self.relates_to = rel.get("event_id")
self.reltype = rel.get("rel_type") self.reltype = rel.get("rel_type")
self.new_body = content.get("m.new_content", {}).get("body", "") self.new_body = content.get("m.new_content", {}).get("body", "")

View file

@ -8,77 +8,77 @@ from errors import RequestError
def dict_cls(d: dict, cls: Any) -> Any: def dict_cls(d: dict, cls: Any) -> Any:
""" """
Create a dataclass from a dictionary. Create a dataclass from a dictionary.
""" """
field_names = set(f.name for f in fields(cls)) field_names = set(f.name for f in fields(cls))
filtered_dict = {k: v for k, v in d.items() if k in field_names} filtered_dict = {k: v for k, v in d.items() if k in field_names}
return cls(**filtered_dict) return cls(**filtered_dict)
def log_except(fn): def log_except(fn):
""" """
Log unhandled exceptions to a logger instead of `stderr`. Log unhandled exceptions to a logger instead of `stderr`.
""" """
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
try: try:
return fn(self, *args, **kwargs) return fn(self, *args, **kwargs)
except Exception: except Exception:
self.logger.exception(f"Exception in '{fn.__name__}':") self.logger.exception(f"Exception in '{fn.__name__}':")
raise raise
return wrapper return wrapper
def request(fn): def request(fn):
""" """
Either return json data or raise a `RequestError` if the request was Either return json data or raise a `RequestError` if the request was
unsuccessful. unsuccessful.
""" """
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
resp = fn(*args, **kwargs) resp = fn(*args, **kwargs)
except urllib3.exceptions.HTTPError as e: except urllib3.exceptions.HTTPError as e:
raise RequestError(None, f"Failed to connect: {e}") from None raise RequestError(None, f"Failed to connect: {e}") from None
if resp.status < 200 or resp.status >= 300: if resp.status < 200 or resp.status >= 300:
raise RequestError( raise RequestError(
resp.status, resp.status,
f"Failed to get response from '{resp.geturl()}':\n{resp.data}", f"Failed to get response from '{resp.geturl()}':\n{resp.data}",
) )
return {} if resp.status == 204 else json.loads(resp.data) return {} if resp.status == 204 else json.loads(resp.data)
return wrapper return wrapper
def except_deleted(fn): def except_deleted(fn):
""" """
Ignore the `RequestError` on 404s, the content might have been removed. Ignore the `RequestError` on 404s, the content might have been removed.
""" """
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
return fn(*args, **kwargs) return fn(*args, **kwargs)
except RequestError as e: except RequestError as e:
if e.status != 404: if e.status != 404:
raise raise
return wrapper return wrapper
def hash_str(string: str) -> int: def hash_str(string: str) -> int:
""" """
Create the hash for a string Create the hash for a string
""" """
hash = 5381 hash = 5381
for ch in string: for ch in string:
hash = ((hash << 5) + hash) + ord(ch) hash = ((hash << 5) + hash) + ord(ch)
return hash & 0xFFFFFFFF return hash & 0xFFFFFFFF