chore: retab, delete discord file
This commit is contained in:
parent
461127a4ac
commit
947217abe3
9 changed files with 1106 additions and 1324 deletions
|
@ -13,191 +13,191 @@ from misc import log_except, request
|
||||||
|
|
||||||
|
|
||||||
class AppService(bottle.Bottle):
|
class AppService(bottle.Bottle):
|
||||||
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
|
def __init__(self, config: dict, http: urllib3.PoolManager) -> None:
|
||||||
super(AppService, self).__init__()
|
super(AppService, self).__init__()
|
||||||
|
|
||||||
self.as_token = config["as_token"]
|
self.as_token = config["as_token"]
|
||||||
self.hs_token = config["hs_token"]
|
self.hs_token = config["hs_token"]
|
||||||
self.base_url = config["homeserver"]
|
self.base_url = config["homeserver"]
|
||||||
self.server_name = config["server_name"]
|
self.server_name = config["server_name"]
|
||||||
self.user_id = f"@{config['user_id']}:{self.server_name}"
|
self.user_id = f"@{config['user_id']}:{self.server_name}"
|
||||||
self.http = http
|
self.http = http
|
||||||
self.logger = logging.getLogger("appservice")
|
self.logger = logging.getLogger("appservice")
|
||||||
|
|
||||||
# Map events to functions.
|
# Map events to functions.
|
||||||
self.mapping = {
|
self.mapping = {
|
||||||
"m.room.member": "on_member",
|
"m.room.member": "on_member",
|
||||||
"m.room.message": "on_message",
|
"m.room.message": "on_message",
|
||||||
"m.room.redaction": "on_redaction",
|
"m.room.redaction": "on_redaction",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add route for bottle.
|
# Add route for bottle.
|
||||||
self.route(
|
self.route(
|
||||||
"/transactions/<transaction>",
|
"/transactions/<transaction>",
|
||||||
callback=self.receive_event,
|
callback=self.receive_event,
|
||||||
method="PUT",
|
method="PUT",
|
||||||
)
|
)
|
||||||
|
|
||||||
Cache.cache["m_rooms"] = {}
|
Cache.cache["m_rooms"] = {}
|
||||||
|
|
||||||
def handle_event(self, event: dict) -> None:
|
def handle_event(self, event: dict) -> None:
|
||||||
event_type = event.get("type")
|
event_type = event.get("type")
|
||||||
|
|
||||||
if event_type in (
|
if event_type in (
|
||||||
"m.room.member",
|
"m.room.member",
|
||||||
"m.room.message",
|
"m.room.message",
|
||||||
"m.room.redaction",
|
"m.room.redaction",
|
||||||
):
|
):
|
||||||
obj = matrix.Event(event)
|
obj = matrix.Event(event)
|
||||||
else:
|
else:
|
||||||
self.logger.info(f"Unknown event type: {event_type}")
|
self.logger.info(f"Unknown event type: {event_type}")
|
||||||
return
|
return
|
||||||
|
|
||||||
func = getattr(self, self.mapping[event_type], None)
|
func = getattr(self, self.mapping[event_type], None)
|
||||||
|
|
||||||
if not func:
|
if not func:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"Function '{func}' not defined, ignoring event."
|
f"Function '{func}' not defined, ignoring event."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# We don't catch exceptions here as the homeserver will re-send us
|
# We don't catch exceptions here as the homeserver will re-send us
|
||||||
# the event in case of a failure.
|
# the event in case of a failure.
|
||||||
func(obj)
|
func(obj)
|
||||||
|
|
||||||
@log_except
|
@log_except
|
||||||
def receive_event(self, transaction: str) -> dict:
|
def receive_event(self, transaction: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Verify the homeserver's token and handle events.
|
Verify the homeserver's token and handle events.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hs_token = bottle.request.query.getone("access_token")
|
hs_token = bottle.request.query.getone("access_token")
|
||||||
|
|
||||||
if not hs_token:
|
if not hs_token:
|
||||||
bottle.response.status = 401
|
bottle.response.status = 401
|
||||||
return {"errcode": "APPSERVICE_UNAUTHORIZED"}
|
return {"errcode": "APPSERVICE_UNAUTHORIZED"}
|
||||||
|
|
||||||
if hs_token != self.hs_token:
|
if hs_token != self.hs_token:
|
||||||
bottle.response.status = 403
|
bottle.response.status = 403
|
||||||
return {"errcode": "APPSERVICE_FORBIDDEN"}
|
return {"errcode": "APPSERVICE_FORBIDDEN"}
|
||||||
|
|
||||||
events = bottle.request.json.get("events")
|
events = bottle.request.json.get("events")
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
self.handle_event(event)
|
self.handle_event(event)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def mxc_url(self, mxc: str) -> str:
|
def mxc_url(self, mxc: str) -> str:
|
||||||
try:
|
try:
|
||||||
homeserver, media_id = mxc.replace("mxc://", "").split("/")
|
homeserver, media_id = mxc.replace("mxc://", "").split("/")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return (
|
return (
|
||||||
f"https://{self.server_name}/_matrix/media/r0/download/"
|
f"https://{self.server_name}/_matrix/media/r0/download/"
|
||||||
f"{homeserver}/{media_id}"
|
f"{homeserver}/{media_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def join_room(self, room_id: str, mxid: str = "") -> None:
|
def join_room(self, room_id: str, mxid: str = "") -> None:
|
||||||
self.send(
|
self.send(
|
||||||
"POST",
|
"POST",
|
||||||
f"/join/{room_id}",
|
f"/join/{room_id}",
|
||||||
params={"user_id": mxid} if mxid else {},
|
params={"user_id": mxid} if mxid else {},
|
||||||
)
|
)
|
||||||
|
|
||||||
def redact(self, event_id: str, room_id: str, mxid: str = "") -> None:
|
def redact(self, event_id: str, room_id: str, mxid: str = "") -> None:
|
||||||
self.send(
|
self.send(
|
||||||
"PUT",
|
"PUT",
|
||||||
f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}",
|
f"/rooms/{room_id}/redact/{event_id}/{uuid.uuid4()}",
|
||||||
params={"user_id": mxid} if mxid else {},
|
params={"user_id": mxid} if mxid else {},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_room_id(self, alias: str) -> str:
|
def get_room_id(self, alias: str) -> str:
|
||||||
with Cache.lock:
|
with Cache.lock:
|
||||||
room = Cache.cache["m_rooms"].get(alias)
|
room = Cache.cache["m_rooms"].get(alias)
|
||||||
if room:
|
if room:
|
||||||
return room
|
return room
|
||||||
|
|
||||||
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
|
resp = self.send("GET", f"/directory/room/{urllib.parse.quote(alias)}")
|
||||||
|
|
||||||
room_id = resp["room_id"]
|
room_id = resp["room_id"]
|
||||||
|
|
||||||
with Cache.lock:
|
with Cache.lock:
|
||||||
Cache.cache["m_rooms"][alias] = room_id
|
Cache.cache["m_rooms"][alias] = room_id
|
||||||
|
|
||||||
return room_id
|
return room_id
|
||||||
|
|
||||||
def get_event(self, event_id: str, room_id: str) -> matrix.Event:
|
def get_event(self, event_id: str, room_id: str) -> matrix.Event:
|
||||||
resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}")
|
resp = self.send("GET", f"/rooms/{room_id}/event/{event_id}")
|
||||||
|
|
||||||
return matrix.Event(resp)
|
return matrix.Event(resp)
|
||||||
|
|
||||||
def upload(self, url: str) -> str:
|
def upload(self, url: str) -> str:
|
||||||
"""
|
"""
|
||||||
Upload a file to the homeserver and get the MXC url.
|
Upload a file to the homeserver and get the MXC url.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
resp = self.http.request("GET", url)
|
resp = self.http.request("GET", url)
|
||||||
|
|
||||||
resp = self.send(
|
resp = self.send(
|
||||||
"POST",
|
"POST",
|
||||||
content=resp.data,
|
content=resp.data,
|
||||||
content_type=resp.headers.get("Content-Type"),
|
content_type=resp.headers.get("Content-Type"),
|
||||||
params={"filename": f"{uuid.uuid4()}"},
|
params={"filename": f"{uuid.uuid4()}"},
|
||||||
endpoint="/_matrix/media/r0/upload",
|
endpoint="/_matrix/media/r0/upload",
|
||||||
)
|
)
|
||||||
|
|
||||||
return resp["content_uri"]
|
return resp["content_uri"]
|
||||||
|
|
||||||
def send_message(
|
def send_message(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
content: dict,
|
content: dict,
|
||||||
mxid: str = "",
|
mxid: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
resp = self.send(
|
resp = self.send(
|
||||||
"PUT",
|
"PUT",
|
||||||
f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}",
|
f"/rooms/{room_id}/send/m.room.message/{uuid.uuid4()}",
|
||||||
content,
|
content,
|
||||||
{"user_id": mxid} if mxid else {},
|
{"user_id": mxid} if mxid else {},
|
||||||
)
|
)
|
||||||
|
|
||||||
return resp["event_id"]
|
return resp["event_id"]
|
||||||
|
|
||||||
def send_typing(
|
def send_typing(
|
||||||
self, room_id: str, mxid: str = "", timeout: int = 8000
|
self, room_id: str, mxid: str = "", timeout: int = 8000
|
||||||
) -> None:
|
) -> None:
|
||||||
self.send(
|
self.send(
|
||||||
"PUT",
|
"PUT",
|
||||||
f"/rooms/{room_id}/typing/{mxid}",
|
f"/rooms/{room_id}/typing/{mxid}",
|
||||||
{"typing": True, "timeout": timeout},
|
{"typing": True, "timeout": timeout},
|
||||||
{"user_id": mxid} if mxid else {},
|
{"user_id": mxid} if mxid else {},
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_invite(self, room_id: str, mxid: str) -> None:
|
def send_invite(self, room_id: str, mxid: str) -> None:
|
||||||
self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid})
|
self.send("POST", f"/rooms/{room_id}/invite", {"user_id": mxid})
|
||||||
|
|
||||||
@request
|
@request
|
||||||
def send(
|
def send(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
path: str = "",
|
path: str = "",
|
||||||
content: Union[bytes, dict] = {},
|
content: Union[bytes, dict] = {},
|
||||||
params: dict = {},
|
params: dict = {},
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
endpoint: str = "/_matrix/client/r0",
|
endpoint: str = "/_matrix/client/r0",
|
||||||
) -> dict:
|
) -> dict:
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.as_token}",
|
"Authorization": f"Bearer {self.as_token}",
|
||||||
"Content-Type": content_type,
|
"Content-Type": content_type,
|
||||||
}
|
}
|
||||||
payload = json.dumps(content) if isinstance(content, dict) else content
|
payload = json.dumps(content) if isinstance(content, dict) else content
|
||||||
endpoint = (
|
endpoint = (
|
||||||
f"{self.base_url}{endpoint}{path}?"
|
f"{self.base_url}{endpoint}{path}?"
|
||||||
f"{urllib.parse.urlencode(params)}"
|
f"{urllib.parse.urlencode(params)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.http.request(
|
return self.http.request(
|
||||||
method, endpoint, body=payload, headers=headers
|
method, endpoint, body=payload, headers=headers
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,5 +2,5 @@ import threading
|
||||||
|
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
cache = {}
|
cache = {}
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
|
|
|
@ -5,116 +5,116 @@ from typing import List
|
||||||
|
|
||||||
|
|
||||||
class DataBase:
|
class DataBase:
|
||||||
def __init__(self, db_file) -> None:
|
def __init__(self, db_file) -> None:
|
||||||
self.create(db_file)
|
self.create(db_file)
|
||||||
|
|
||||||
# The database is accessed via multiple threads.
|
# The database is accessed via multiple threads.
|
||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
def create(self, db_file) -> None:
|
def create(self, db_file) -> None:
|
||||||
"""
|
"""
|
||||||
Create a database with the relevant tables if it doesn't already exist.
|
Create a database with the relevant tables if it doesn't already exist.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
exists = os.path.exists(db_file)
|
exists = os.path.exists(db_file)
|
||||||
|
|
||||||
self.conn = sqlite3.connect(db_file, check_same_thread=False)
|
self.conn = sqlite3.connect(db_file, check_same_thread=False)
|
||||||
self.conn.row_factory = self.dict_factory
|
self.conn.row_factory = self.dict_factory
|
||||||
|
|
||||||
self.cur = self.conn.cursor()
|
self.cur = self.conn.cursor()
|
||||||
|
|
||||||
if exists:
|
if exists:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
"CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);"
|
"CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
"CREATE TABLE users(mxid TEXT PRIMARY KEY, "
|
"CREATE TABLE users(mxid TEXT PRIMARY KEY, "
|
||||||
"avatar_url TEXT, username TEXT);"
|
"avatar_url TEXT, username TEXT);"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def dict_factory(self, cursor, row):
|
def dict_factory(self, cursor, row):
|
||||||
"""
|
"""
|
||||||
https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory
|
https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
d = {}
|
d = {}
|
||||||
for idx, col in enumerate(cursor.description):
|
for idx, col in enumerate(cursor.description):
|
||||||
d[col[0]] = row[idx]
|
d[col[0]] = row[idx]
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def add_room(self, room_id: str, channel_id: str) -> None:
|
def add_room(self, room_id: str, channel_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Add a bridged room to the database.
|
Add a bridged room to the database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
"INSERT INTO bridge (room_id, channel_id) VALUES (?, ?)",
|
"INSERT INTO bridge (room_id, channel_id) VALUES (?, ?)",
|
||||||
[room_id, channel_id],
|
[room_id, channel_id],
|
||||||
)
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def add_user(self, mxid: str) -> None:
|
def add_user(self, mxid: str) -> None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.cur.execute("INSERT INTO users (mxid) VALUES (?)", [mxid])
|
self.cur.execute("INSERT INTO users (mxid) VALUES (?)", [mxid])
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def add_avatar(self, avatar_url: str, mxid: str) -> None:
|
def add_avatar(self, avatar_url: str, mxid: str) -> None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
"UPDATE users SET avatar_url = (?) WHERE mxid = (?)",
|
"UPDATE users SET avatar_url = (?) WHERE mxid = (?)",
|
||||||
[avatar_url, mxid],
|
[avatar_url, mxid],
|
||||||
)
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def add_username(self, username: str, mxid: str) -> None:
|
def add_username(self, username: str, mxid: str) -> None:
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
"UPDATE users SET username = (?) WHERE mxid = (?)",
|
"UPDATE users SET username = (?) WHERE mxid = (?)",
|
||||||
[username, mxid],
|
[username, mxid],
|
||||||
)
|
)
|
||||||
self.conn.commit()
|
self.conn.commit()
|
||||||
|
|
||||||
def get_channel(self, room_id: str) -> str:
|
def get_channel(self, room_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Get the corresponding channel ID for a given room ID.
|
Get the corresponding channel ID for a given room ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.cur.execute(
|
self.cur.execute(
|
||||||
"SELECT channel_id FROM bridge WHERE room_id = ?", [room_id]
|
"SELECT channel_id FROM bridge WHERE room_id = ?", [room_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
room = self.cur.fetchone()
|
room = self.cur.fetchone()
|
||||||
|
|
||||||
# Return an empty string if the channel is not bridged.
|
# Return an empty string if the channel is not bridged.
|
||||||
return "" if not room else room["channel_id"]
|
return "" if not room else room["channel_id"]
|
||||||
|
|
||||||
def list_channels(self) -> List[str]:
|
def list_channels(self) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get a list of all the bridged channels.
|
Get a list of all the bridged channels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.cur.execute("SELECT channel_id FROM bridge")
|
self.cur.execute("SELECT channel_id FROM bridge")
|
||||||
|
|
||||||
channels = self.cur.fetchall()
|
channels = self.cur.fetchall()
|
||||||
|
|
||||||
return [channel["channel_id"] for channel in channels]
|
return [channel["channel_id"] for channel in channels]
|
||||||
|
|
||||||
def fetch_user(self, mxid: str) -> dict:
|
def fetch_user(self, mxid: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Fetch the profile for a bridged user.
|
Fetch the profile for a bridged user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.cur.execute("SELECT * FROM users where mxid = ?", [mxid])
|
self.cur.execute("SELECT * FROM users where mxid = ?", [mxid])
|
||||||
|
|
||||||
user = self.cur.fetchone()
|
user = self.cur.fetchone()
|
||||||
|
|
||||||
return {} if not user else user
|
return {} if not user else user
|
||||||
|
|
|
@ -1,218 +0,0 @@
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from misc import dict_cls
|
|
||||||
|
|
||||||
CDN_URL = "https://cdn.discordapp.com"
|
|
||||||
MESSAGE_LIMIT = 2000
|
|
||||||
|
|
||||||
|
|
||||||
def bitmask(bit: int) -> int:
|
|
||||||
return 1 << bit
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Channel:
|
|
||||||
id: str
|
|
||||||
type: str
|
|
||||||
guild_id: str = ""
|
|
||||||
name: str = ""
|
|
||||||
topic: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Emote:
|
|
||||||
animated: bool
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MessageReference:
|
|
||||||
message_id: str
|
|
||||||
channel_id: str
|
|
||||||
guild_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Sticker:
|
|
||||||
name: str
|
|
||||||
id: str
|
|
||||||
format_type: int
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Typing:
|
|
||||||
user_id: str
|
|
||||||
channel_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Webhook:
|
|
||||||
id: str
|
|
||||||
token: str
|
|
||||||
|
|
||||||
|
|
||||||
class User:
|
|
||||||
def __init__(self, user: dict) -> None:
|
|
||||||
self.discriminator = user["discriminator"]
|
|
||||||
self.id = user["id"]
|
|
||||||
self.mention = f"<@{self.id}>"
|
|
||||||
self.username = user["username"]
|
|
||||||
|
|
||||||
avatar = user["avatar"]
|
|
||||||
|
|
||||||
if not avatar:
|
|
||||||
# https://discord.com/developers/docs/reference#image-formatting
|
|
||||||
self.avatar_url = (
|
|
||||||
f"{CDN_URL}/embed/avatars/{int(self.discriminator) % 5}.png"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ext = "gif" if avatar.startswith("a_") else "png"
|
|
||||||
self.avatar_url = f"{CDN_URL}/avatars/{self.id}/{avatar}.{ext}"
|
|
||||||
|
|
||||||
|
|
||||||
class Guild:
|
|
||||||
def __init__(self, guild: dict) -> None:
|
|
||||||
self.guild_id = guild["id"]
|
|
||||||
self.channels = [dict_cls(c, Channel) for c in guild["channels"]]
|
|
||||||
self.emojis = [dict_cls(e, Emote) for e in guild["emojis"]]
|
|
||||||
members = [member["user"] for member in guild["members"]]
|
|
||||||
self.members = [User(m) for m in members]
|
|
||||||
|
|
||||||
|
|
||||||
class GuildEmojisUpdate:
|
|
||||||
def __init__(self, update: dict) -> None:
|
|
||||||
self.guild_id = update["guild_id"]
|
|
||||||
self.emojis = [dict_cls(e, Emote) for e in update["emojis"]]
|
|
||||||
|
|
||||||
|
|
||||||
class GuildMembersChunk:
|
|
||||||
def __init__(self, chunk: dict) -> None:
|
|
||||||
self.chunk_index = chunk["chunk_index"]
|
|
||||||
self.chunk_count = chunk["chunk_count"]
|
|
||||||
self.guild_id = chunk["guild_id"]
|
|
||||||
self.members = [User(m) for m in chunk["members"]]
|
|
||||||
|
|
||||||
|
|
||||||
class GuildMemberUpdate:
|
|
||||||
def __init__(self, update: dict) -> None:
|
|
||||||
self.guild_id = update["guild_id"]
|
|
||||||
self.user = User(update["user"])
|
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
|
||||||
def __init__(self, message: dict) -> None:
|
|
||||||
self.attachments = message.get("attachments", [])
|
|
||||||
self.channel_id = message["channel_id"]
|
|
||||||
self.content = message.get("content", "")
|
|
||||||
self.id = message["id"]
|
|
||||||
self.guild_id = message.get(
|
|
||||||
"guild_id", ""
|
|
||||||
) # Responses for sending webhook messages don't have guild_id
|
|
||||||
self.webhook_id = message.get("webhook_id", "")
|
|
||||||
self.application_id = message.get("application_id", "")
|
|
||||||
|
|
||||||
self.mentions = [
|
|
||||||
User(mention) for mention in message.get("mentions", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
ref = message.get("referenced_message")
|
|
||||||
|
|
||||||
self.referenced_message = Message(ref) if ref else None
|
|
||||||
|
|
||||||
author = message.get("author")
|
|
||||||
|
|
||||||
self.author = User(author) if author else None
|
|
||||||
|
|
||||||
self.stickers = [
|
|
||||||
dict_cls(sticker, Sticker)
|
|
||||||
for sticker in message.get("sticker_items", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelType:
|
|
||||||
GUILD_TEXT = 0
|
|
||||||
DM = 1
|
|
||||||
GUILD_VOICE = 2
|
|
||||||
GROUP_DM = 3
|
|
||||||
GUILD_CATEGORY = 4
|
|
||||||
GUILD_NEWS = 5
|
|
||||||
GUILD_STORE = 6
|
|
||||||
|
|
||||||
|
|
||||||
class InteractionResponseType:
|
|
||||||
PONG = 0
|
|
||||||
ACKNOWLEDGE = 1
|
|
||||||
CHANNEL_MESSAGE = 2
|
|
||||||
CHANNEL_MESSAGE_WITH_SOURCE = 4
|
|
||||||
ACKNOWLEDGE_WITH_SOURCE = 5
|
|
||||||
|
|
||||||
|
|
||||||
class GatewayIntents:
|
|
||||||
GUILDS = bitmask(0)
|
|
||||||
GUILD_MEMBERS = bitmask(1)
|
|
||||||
GUILD_BANS = bitmask(2)
|
|
||||||
GUILD_EMOJIS = bitmask(3)
|
|
||||||
GUILD_INTEGRATIONS = bitmask(4)
|
|
||||||
GUILD_WEBHOOKS = bitmask(5)
|
|
||||||
GUILD_INVITES = bitmask(6)
|
|
||||||
GUILD_VOICE_STATES = bitmask(7)
|
|
||||||
GUILD_PRESENCES = bitmask(8)
|
|
||||||
GUILD_MESSAGES = bitmask(9)
|
|
||||||
GUILD_MESSAGE_REACTIONS = bitmask(10)
|
|
||||||
GUILD_MESSAGE_TYPING = bitmask(11)
|
|
||||||
DIRECT_MESSAGES = bitmask(12)
|
|
||||||
DIRECT_MESSAGE_REACTIONS = bitmask(13)
|
|
||||||
DIRECT_MESSAGE_TYPING = bitmask(14)
|
|
||||||
|
|
||||||
|
|
||||||
class GatewayOpCodes:
|
|
||||||
DISPATCH = 0
|
|
||||||
HEARTBEAT = 1
|
|
||||||
IDENTIFY = 2
|
|
||||||
PRESENCE_UPDATE = 3
|
|
||||||
VOICE_STATE_UPDATE = 4
|
|
||||||
RESUME = 6
|
|
||||||
RECONNECT = 7
|
|
||||||
REQUEST_GUILD_MEMBERS = 8
|
|
||||||
INVALID_SESSION = 9
|
|
||||||
HELLO = 10
|
|
||||||
HEARTBEAT_ACK = 11
|
|
||||||
|
|
||||||
|
|
||||||
class Payloads:
|
|
||||||
def __init__(self, token: str) -> None:
|
|
||||||
self.seq = self.session = None
|
|
||||||
self.token = token
|
|
||||||
|
|
||||||
def HEARTBEAT(self) -> dict:
|
|
||||||
return {"op": GatewayOpCodes.HEARTBEAT, "d": self.seq}
|
|
||||||
|
|
||||||
def IDENTIFY(self) -> dict:
|
|
||||||
return {
|
|
||||||
"op": GatewayOpCodes.IDENTIFY,
|
|
||||||
"d": {
|
|
||||||
"token": self.token,
|
|
||||||
"intents": GatewayIntents.GUILDS
|
|
||||||
| GatewayIntents.GUILD_EMOJIS
|
|
||||||
| GatewayIntents.GUILD_MEMBERS
|
|
||||||
| GatewayIntents.GUILD_MESSAGES
|
|
||||||
| GatewayIntents.GUILD_MESSAGE_TYPING
|
|
||||||
| GatewayIntents.GUILD_PRESENCES,
|
|
||||||
"properties": {
|
|
||||||
"$os": "discord",
|
|
||||||
"$browser": "Discord Client",
|
|
||||||
"$device": "discord",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def RESUME(self) -> dict:
|
|
||||||
return {
|
|
||||||
"op": GatewayOpCodes.RESUME,
|
|
||||||
"d": {
|
|
||||||
"token": self.token,
|
|
||||||
"session_id": self.session,
|
|
||||||
"seq": self.seq,
|
|
||||||
},
|
|
||||||
}
|
|
|
@ -1,5 +1,5 @@
|
||||||
class RequestError(Exception):
|
class RequestError(Exception):
|
||||||
def __init__(self, status: int, *args):
|
def __init__(self, status: int, *args):
|
||||||
super().__init__(*args)
|
super().__init__(*args)
|
||||||
|
|
||||||
self.status = status
|
self.status = status
|
||||||
|
|
|
@ -12,249 +12,249 @@ from misc import dict_cls, log_except, request
|
||||||
|
|
||||||
|
|
||||||
class Gateway:
|
class Gateway:
|
||||||
def __init__(self, http: urllib3.PoolManager, token: str):
|
def __init__(self, http: urllib3.PoolManager, token: str):
|
||||||
self.http = http
|
self.http = http
|
||||||
self.token = token
|
self.token = token
|
||||||
self.logger = logging.getLogger("discord")
|
self.logger = logging.getLogger("discord")
|
||||||
self.Payloads = discord.Payloads(self.token)
|
self.Payloads = discord.Payloads(self.token)
|
||||||
self.websocket = None
|
self.websocket = None
|
||||||
|
|
||||||
@log_except
|
@log_except
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
self.heartbeat_task: asyncio.Future = None
|
self.heartbeat_task: asyncio.Future = None
|
||||||
self.resume = False
|
self.resume = False
|
||||||
|
|
||||||
gateway_url = self.get_gateway_url()
|
gateway_url = self.get_gateway_url()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
await self.gateway_handler(gateway_url)
|
await self.gateway_handler(gateway_url)
|
||||||
except (
|
except (
|
||||||
websockets.ConnectionClosedError,
|
websockets.ConnectionClosedError,
|
||||||
websockets.InvalidMessage,
|
websockets.InvalidMessage,
|
||||||
):
|
):
|
||||||
self.logger.exception("Connection lost, reconnecting.")
|
self.logger.exception("Connection lost, reconnecting.")
|
||||||
|
|
||||||
# Stop sending heartbeats until we reconnect.
|
# Stop sending heartbeats until we reconnect.
|
||||||
if self.heartbeat_task and not self.heartbeat_task.cancelled():
|
if self.heartbeat_task and not self.heartbeat_task.cancelled():
|
||||||
self.heartbeat_task.cancel()
|
self.heartbeat_task.cancel()
|
||||||
|
|
||||||
def get_gateway_url(self) -> str:
|
def get_gateway_url(self) -> str:
|
||||||
resp = self.send("GET", "/gateway")
|
resp = self.send("GET", "/gateway")
|
||||||
|
|
||||||
return resp["url"]
|
return resp["url"]
|
||||||
|
|
||||||
async def heartbeat_handler(self, interval_ms: int) -> None:
|
async def heartbeat_handler(self, interval_ms: int) -> None:
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(interval_ms / 1000)
|
await asyncio.sleep(interval_ms / 1000)
|
||||||
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
|
await self.websocket.send(json.dumps(self.Payloads.HEARTBEAT()))
|
||||||
|
|
||||||
async def handle_resp(self, data: dict) -> None:
|
async def handle_resp(self, data: dict) -> None:
|
||||||
data_dict = data["d"]
|
data_dict = data["d"]
|
||||||
|
|
||||||
opcode = data["op"]
|
opcode = data["op"]
|
||||||
|
|
||||||
seq = data["s"]
|
seq = data["s"]
|
||||||
|
|
||||||
if seq:
|
if seq:
|
||||||
self.Payloads.seq = seq
|
self.Payloads.seq = seq
|
||||||
|
|
||||||
if opcode == discord.GatewayOpCodes.DISPATCH:
|
if opcode == discord.GatewayOpCodes.DISPATCH:
|
||||||
otype = data["t"]
|
otype = data["t"]
|
||||||
|
|
||||||
if otype == "READY":
|
if otype == "READY":
|
||||||
self.Payloads.session = data_dict["session_id"]
|
self.Payloads.session = data_dict["session_id"]
|
||||||
|
|
||||||
self.logger.info("READY")
|
self.logger.info("READY")
|
||||||
else:
|
else:
|
||||||
self.handle_otype(data_dict, otype)
|
self.handle_otype(data_dict, otype)
|
||||||
elif opcode == discord.GatewayOpCodes.HELLO:
|
elif opcode == discord.GatewayOpCodes.HELLO:
|
||||||
heartbeat_interval = data_dict.get("heartbeat_interval")
|
heartbeat_interval = data_dict.get("heartbeat_interval")
|
||||||
|
|
||||||
self.logger.info(f"Heartbeat Interval: {heartbeat_interval}")
|
self.logger.info(f"Heartbeat Interval: {heartbeat_interval}")
|
||||||
|
|
||||||
# Send periodic hearbeats to gateway.
|
# Send periodic hearbeats to gateway.
|
||||||
self.heartbeat_task = asyncio.ensure_future(
|
self.heartbeat_task = asyncio.ensure_future(
|
||||||
self.heartbeat_handler(heartbeat_interval)
|
self.heartbeat_handler(heartbeat_interval)
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.websocket.send(
|
await self.websocket.send(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
self.Payloads.RESUME()
|
self.Payloads.RESUME()
|
||||||
if self.resume
|
if self.resume
|
||||||
else self.Payloads.IDENTIFY()
|
else self.Payloads.IDENTIFY()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif opcode == discord.GatewayOpCodes.RECONNECT:
|
elif opcode == discord.GatewayOpCodes.RECONNECT:
|
||||||
self.logger.info("Received RECONNECT.")
|
self.logger.info("Received RECONNECT.")
|
||||||
|
|
||||||
self.resume = True
|
self.resume = True
|
||||||
await self.websocket.close()
|
await self.websocket.close()
|
||||||
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
|
elif opcode == discord.GatewayOpCodes.INVALID_SESSION:
|
||||||
self.logger.info("Received INVALID_SESSION.")
|
self.logger.info("Received INVALID_SESSION.")
|
||||||
|
|
||||||
self.resume = False
|
self.resume = False
|
||||||
await self.websocket.close()
|
await self.websocket.close()
|
||||||
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
|
elif opcode == discord.GatewayOpCodes.HEARTBEAT_ACK:
|
||||||
# NOP
|
# NOP
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}"
|
"Unknown OP code: {opcode}\n{json.dumps(data, indent=4)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_otype(self, data: dict, otype: str) -> None:
|
def handle_otype(self, data: dict, otype: str) -> None:
|
||||||
if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"):
|
if otype in ("MESSAGE_CREATE", "MESSAGE_UPDATE", "MESSAGE_DELETE"):
|
||||||
obj = discord.Message(data)
|
obj = discord.Message(data)
|
||||||
elif otype == "TYPING_START":
|
elif otype == "TYPING_START":
|
||||||
obj = dict_cls(data, discord.Typing)
|
obj = dict_cls(data, discord.Typing)
|
||||||
elif otype == "GUILD_CREATE":
|
elif otype == "GUILD_CREATE":
|
||||||
obj = discord.Guild(data)
|
obj = discord.Guild(data)
|
||||||
elif otype == "GUILD_MEMBER_UPDATE":
|
elif otype == "GUILD_MEMBER_UPDATE":
|
||||||
obj = discord.GuildMemberUpdate(data)
|
obj = discord.GuildMemberUpdate(data)
|
||||||
elif otype == "GUILD_EMOJIS_UPDATE":
|
elif otype == "GUILD_EMOJIS_UPDATE":
|
||||||
obj = discord.GuildEmojisUpdate(data)
|
obj = discord.GuildEmojisUpdate(data)
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
func = getattr(self, f"on_{otype.lower()}", None)
|
func = getattr(self, f"on_{otype.lower()}", None)
|
||||||
|
|
||||||
if not func:
|
if not func:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
f"Function '{func}' not defined, ignoring message."
|
f"Function '{func}' not defined, ignoring message."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
func(obj)
|
func(obj)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.logger.exception(f"Ignoring exception in '{func.__name__}':")
|
self.logger.exception(f"Ignoring exception in '{func.__name__}':")
|
||||||
|
|
||||||
async def gateway_handler(self, gateway_url: str) -> None:
|
async def gateway_handler(self, gateway_url: str) -> None:
|
||||||
async with websockets.connect(
|
async with websockets.connect(
|
||||||
f"{gateway_url}/?v=8&encoding=json"
|
f"{gateway_url}/?v=8&encoding=json"
|
||||||
) as websocket:
|
) as websocket:
|
||||||
self.websocket = websocket
|
self.websocket = websocket
|
||||||
|
|
||||||
async for message in websocket:
|
async for message in websocket:
|
||||||
await self.handle_resp(json.loads(message))
|
await self.handle_resp(json.loads(message))
|
||||||
|
|
||||||
def get_channel(self, channel_id: str) -> discord.Channel:
|
def get_channel(self, channel_id: str) -> discord.Channel:
|
||||||
"""
|
"""
|
||||||
Get the channel for a given channel ID.
|
Get the channel for a given channel ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
resp = self.send("GET", f"/channels/{channel_id}")
|
resp = self.send("GET", f"/channels/{channel_id}")
|
||||||
|
|
||||||
return dict_cls(resp, discord.Channel)
|
return dict_cls(resp, discord.Channel)
|
||||||
|
|
||||||
def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]:
|
def get_channels(self, guild_id: str) -> Dict[str, discord.Channel]:
|
||||||
"""
|
"""
|
||||||
Get all channels for a given guild ID.
|
Get all channels for a given guild ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
resp = self.send("GET", f"/guilds/{guild_id}/channels")
|
resp = self.send("GET", f"/guilds/{guild_id}/channels")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
channel["id"]: dict_cls(channel, discord.Channel)
|
channel["id"]: dict_cls(channel, discord.Channel)
|
||||||
for channel in resp
|
for channel in resp
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_emotes(self, guild_id: str) -> List[discord.Emote]:
|
def get_emotes(self, guild_id: str) -> List[discord.Emote]:
|
||||||
"""
|
"""
|
||||||
Get all the emotes for a given guild.
|
Get all the emotes for a given guild.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
resp = self.send("GET", f"/guilds/{guild_id}/emojis")
|
resp = self.send("GET", f"/guilds/{guild_id}/emojis")
|
||||||
|
|
||||||
return [dict_cls(emote, discord.Emote) for emote in resp]
|
return [dict_cls(emote, discord.Emote) for emote in resp]
|
||||||
|
|
||||||
def get_members(self, guild_id: str) -> List[discord.User]:
|
def get_members(self, guild_id: str) -> List[discord.User]:
|
||||||
"""
|
"""
|
||||||
Get all the members for a given guild.
|
Get all the members for a given guild.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
resp = self.send(
|
resp = self.send(
|
||||||
"GET", f"/guilds/{guild_id}/members", params={"limit": 1000}
|
"GET", f"/guilds/{guild_id}/members", params={"limit": 1000}
|
||||||
)
|
)
|
||||||
|
|
||||||
return [discord.User(member["user"]) for member in resp]
|
return [discord.User(member["user"]) for member in resp]
|
||||||
|
|
||||||
def create_webhook(self, channel_id: str, name: str) -> discord.Webhook:
|
def create_webhook(self, channel_id: str, name: str) -> discord.Webhook:
|
||||||
"""
|
"""
|
||||||
Create a webhook with the specified name in a given channel.
|
Create a webhook with the specified name in a given channel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
resp = self.send(
|
resp = self.send(
|
||||||
"POST", f"/channels/{channel_id}/webhooks", {"name": name}
|
"POST", f"/channels/{channel_id}/webhooks", {"name": name}
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict_cls(resp, discord.Webhook)
|
return dict_cls(resp, discord.Webhook)
|
||||||
|
|
||||||
def edit_webhook(
|
def edit_webhook(
|
||||||
self, content: str, message_id: str, webhook: discord.Webhook
|
self, content: str, message_id: str, webhook: discord.Webhook
|
||||||
) -> None:
|
) -> None:
|
||||||
self.send(
|
self.send(
|
||||||
"PATCH",
|
"PATCH",
|
||||||
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
|
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
|
||||||
f"{message_id}",
|
f"{message_id}",
|
||||||
{"content": content},
|
{"content": content},
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_webhook(
|
def delete_webhook(
|
||||||
self, message_id: str, webhook: discord.Webhook
|
self, message_id: str, webhook: discord.Webhook
|
||||||
) -> None:
|
) -> None:
|
||||||
self.send(
|
self.send(
|
||||||
"DELETE",
|
"DELETE",
|
||||||
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
|
f"/webhooks/{webhook.id}/{webhook.token}/messages/"
|
||||||
f"{message_id}",
|
f"{message_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_webhook(
|
def send_webhook(
|
||||||
self,
|
self,
|
||||||
webhook: discord.Webhook,
|
webhook: discord.Webhook,
|
||||||
avatar_url: str,
|
avatar_url: str,
|
||||||
content: str,
|
content: str,
|
||||||
username: str,
|
username: str,
|
||||||
) -> discord.Message:
|
) -> discord.Message:
|
||||||
payload = {
|
payload = {
|
||||||
"avatar_url": avatar_url,
|
"avatar_url": avatar_url,
|
||||||
"content": content,
|
"content": content,
|
||||||
"username": username,
|
"username": username,
|
||||||
# Disable 'everyone' and 'role' mentions.
|
# Disable 'everyone' and 'role' mentions.
|
||||||
"allowed_mentions": {"parse": ["users"]},
|
"allowed_mentions": {"parse": ["users"]},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = self.send(
|
resp = self.send(
|
||||||
"POST",
|
"POST",
|
||||||
f"/webhooks/{webhook.id}/{webhook.token}",
|
f"/webhooks/{webhook.id}/{webhook.token}",
|
||||||
payload,
|
payload,
|
||||||
{"wait": True},
|
{"wait": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
return discord.Message(resp)
|
return discord.Message(resp)
|
||||||
|
|
||||||
def send_message(self, message: str, channel_id: str) -> None:
|
def send_message(self, message: str, channel_id: str) -> None:
|
||||||
self.send(
|
self.send(
|
||||||
"POST", f"/channels/{channel_id}/messages", {"content": message}
|
"POST", f"/channels/{channel_id}/messages", {"content": message}
|
||||||
)
|
)
|
||||||
|
|
||||||
@request
|
@request
|
||||||
def send(
|
def send(
|
||||||
self, method: str, path: str, content: dict = {}, params: dict = {}
|
self, method: str, path: str, content: dict = {}, params: dict = {}
|
||||||
) -> dict:
|
) -> dict:
|
||||||
endpoint = (
|
endpoint = (
|
||||||
f"https://discord.com/api/v8{path}?"
|
f"https://discord.com/api/v8{path}?"
|
||||||
f"{urllib.parse.urlencode(params)}"
|
f"{urllib.parse.urlencode(params)}"
|
||||||
)
|
)
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bot {self.token}",
|
"Authorization": f"Bot {self.token}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 'body' being an empty dict breaks "GET" requests.
|
# 'body' being an empty dict breaks "GET" requests.
|
||||||
payload = json.dumps(content) if content else None
|
payload = json.dumps(content) if content else None
|
||||||
|
|
||||||
return self.http.request(
|
return self.http.request(
|
||||||
method, endpoint, body=payload, headers=headers
|
method, endpoint, body=payload, headers=headers
|
||||||
)
|
)
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -3,26 +3,26 @@ from dataclasses import dataclass
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class User:
|
class User:
|
||||||
avatar_url: str = ""
|
avatar_url: str = ""
|
||||||
display_name: str = ""
|
display_name: str = ""
|
||||||
|
|
||||||
|
|
||||||
class Event:
|
class Event:
|
||||||
def __init__(self, event: dict):
|
def __init__(self, event: dict):
|
||||||
content = event.get("content", {})
|
content = event.get("content", {})
|
||||||
|
|
||||||
self.attachment = content.get("url")
|
self.attachment = content.get("url")
|
||||||
self.body = content.get("body", "").strip()
|
self.body = content.get("body", "").strip()
|
||||||
self.formatted_body = content.get("formatted_body", "")
|
self.formatted_body = content.get("formatted_body", "")
|
||||||
self.id = event["event_id"]
|
self.id = event["event_id"]
|
||||||
self.is_direct = content.get("is_direct", False)
|
self.is_direct = content.get("is_direct", False)
|
||||||
self.redacts = event.get("redacts", "")
|
self.redacts = event.get("redacts", "")
|
||||||
self.room_id = event["room_id"]
|
self.room_id = event["room_id"]
|
||||||
self.sender = event["sender"]
|
self.sender = event["sender"]
|
||||||
self.state_key = event.get("state_key", "")
|
self.state_key = event.get("state_key", "")
|
||||||
|
|
||||||
rel = content.get("m.relates_to", {})
|
rel = content.get("m.relates_to", {})
|
||||||
|
|
||||||
self.relates_to = rel.get("event_id")
|
self.relates_to = rel.get("event_id")
|
||||||
self.reltype = rel.get("rel_type")
|
self.reltype = rel.get("rel_type")
|
||||||
self.new_body = content.get("m.new_content", {}).get("body", "")
|
self.new_body = content.get("m.new_content", {}).get("body", "")
|
||||||
|
|
|
@ -8,77 +8,77 @@ from errors import RequestError
|
||||||
|
|
||||||
|
|
||||||
def dict_cls(d: dict, cls: Any) -> Any:
|
def dict_cls(d: dict, cls: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
Create a dataclass from a dictionary.
|
Create a dataclass from a dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
field_names = set(f.name for f in fields(cls))
|
field_names = set(f.name for f in fields(cls))
|
||||||
filtered_dict = {k: v for k, v in d.items() if k in field_names}
|
filtered_dict = {k: v for k, v in d.items() if k in field_names}
|
||||||
|
|
||||||
return cls(**filtered_dict)
|
return cls(**filtered_dict)
|
||||||
|
|
||||||
|
|
||||||
def log_except(fn):
|
def log_except(fn):
|
||||||
"""
|
"""
|
||||||
Log unhandled exceptions to a logger instead of `stderr`.
|
Log unhandled exceptions to a logger instead of `stderr`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return fn(self, *args, **kwargs)
|
return fn(self, *args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.logger.exception(f"Exception in '{fn.__name__}':")
|
self.logger.exception(f"Exception in '{fn.__name__}':")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def request(fn):
|
def request(fn):
|
||||||
"""
|
"""
|
||||||
Either return json data or raise a `RequestError` if the request was
|
Either return json data or raise a `RequestError` if the request was
|
||||||
unsuccessful.
|
unsuccessful.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
resp = fn(*args, **kwargs)
|
resp = fn(*args, **kwargs)
|
||||||
except urllib3.exceptions.HTTPError as e:
|
except urllib3.exceptions.HTTPError as e:
|
||||||
raise RequestError(None, f"Failed to connect: {e}") from None
|
raise RequestError(None, f"Failed to connect: {e}") from None
|
||||||
|
|
||||||
if resp.status < 200 or resp.status >= 300:
|
if resp.status < 200 or resp.status >= 300:
|
||||||
raise RequestError(
|
raise RequestError(
|
||||||
resp.status,
|
resp.status,
|
||||||
f"Failed to get response from '{resp.geturl()}':\n{resp.data}",
|
f"Failed to get response from '{resp.geturl()}':\n{resp.data}",
|
||||||
)
|
)
|
||||||
|
|
||||||
return {} if resp.status == 204 else json.loads(resp.data)
|
return {} if resp.status == 204 else json.loads(resp.data)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def except_deleted(fn):
|
def except_deleted(fn):
|
||||||
"""
|
"""
|
||||||
Ignore the `RequestError` on 404s, the content might have been removed.
|
Ignore the `RequestError` on 404s, the content might have been removed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
except RequestError as e:
|
except RequestError as e:
|
||||||
if e.status != 404:
|
if e.status != 404:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def hash_str(string: str) -> int:
|
def hash_str(string: str) -> int:
|
||||||
"""
|
"""
|
||||||
Create the hash for a string
|
Create the hash for a string
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hash = 5381
|
hash = 5381
|
||||||
|
|
||||||
for ch in string:
|
for ch in string:
|
||||||
hash = ((hash << 5) + hash) + ord(ch)
|
hash = ((hash << 5) + hash) + ord(ch)
|
||||||
|
|
||||||
return hash & 0xFFFFFFFF
|
return hash & 0xFFFFFFFF
|
||||||
|
|
Loading…
Reference in a new issue