121 lines
3.3 KiB
Python
121 lines
3.3 KiB
Python
import os
|
|
import sqlite3
|
|
import threading
|
|
from typing import List
|
|
|
|
|
|
class DataBase(object):
|
|
def __init__(self, db_file) -> None:
|
|
self.create(db_file)
|
|
|
|
# The database is accessed via both the threads.
|
|
self.lock = threading.Lock()
|
|
|
|
def create(self, db_file) -> None:
|
|
"""
|
|
Create a database with the relevant tables if it doesn't already exist.
|
|
"""
|
|
|
|
exists = os.path.exists(db_file)
|
|
|
|
self.conn = sqlite3.connect(db_file, check_same_thread=False)
|
|
self.conn.row_factory = self.dict_factory
|
|
|
|
self.cur = self.conn.cursor()
|
|
|
|
if exists:
|
|
return
|
|
|
|
self.cur.execute(
|
|
"CREATE TABLE bridge(room_id TEXT PRIMARY KEY, channel_id TEXT);"
|
|
)
|
|
|
|
self.cur.execute(
|
|
"CREATE TABLE users(mxid TEXT PRIMARY KEY, "
|
|
"avatar_url TEXT, username TEXT);"
|
|
)
|
|
|
|
self.conn.commit()
|
|
|
|
def dict_factory(self, cursor, row):
|
|
"""
|
|
https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.row_factory
|
|
"""
|
|
|
|
d = {}
|
|
for idx, col in enumerate(cursor.description):
|
|
d[col[0]] = row[idx]
|
|
return d
|
|
|
|
def add_room(self, room_id: str, channel_id: str) -> None:
|
|
"""
|
|
Add a bridged room to the database.
|
|
"""
|
|
|
|
with self.lock:
|
|
self.cur.execute(
|
|
"INSERT INTO bridge (room_id, channel_id) "
|
|
f"VALUES ('{room_id}', '{channel_id}')"
|
|
)
|
|
self.conn.commit()
|
|
|
|
def add_user(self, mxid: str) -> None:
|
|
with self.lock:
|
|
self.cur.execute(f"INSERT INTO users (mxid) VALUES ('{mxid}')")
|
|
self.conn.commit()
|
|
|
|
def add_avatar(self, avatar_url: str, mxid: str) -> None:
|
|
with self.lock:
|
|
self.cur.execute(
|
|
f"UPDATE users SET avatar_url = '{avatar_url}'"
|
|
f"WHERE mxid = '{mxid}'"
|
|
)
|
|
self.conn.commit()
|
|
|
|
def add_username(self, username: str, mxid: str) -> None:
|
|
with self.lock:
|
|
self.cur.execute(
|
|
f"UPDATE users SET username = '{username}'"
|
|
f"WHERE mxid = '{mxid}'"
|
|
)
|
|
self.conn.commit()
|
|
|
|
def get_channel(self, room_id: str) -> str:
|
|
"""
|
|
Get the corresponding channel ID for a given room ID.
|
|
"""
|
|
|
|
with self.lock:
|
|
self.cur.execute(
|
|
"SELECT channel_id FROM bridge WHERE room_id = ?", [room_id]
|
|
)
|
|
|
|
room = self.cur.fetchone()
|
|
|
|
# Return an empty string if nothing is bridged.
|
|
return "" if not room else room["channel_id"]
|
|
|
|
def list_channels(self) -> List[str]:
|
|
"""
|
|
Get a list of all the bridged channels.
|
|
"""
|
|
|
|
with self.lock:
|
|
self.cur.execute("SELECT channel_id FROM bridge")
|
|
|
|
channels = self.cur.fetchall()
|
|
|
|
return [channel["channel_id"] for channel in channels]
|
|
|
|
def fetch_user(self, mxid: str) -> dict:
|
|
"""
|
|
Fetch the profile for a bridged user.
|
|
"""
|
|
|
|
with self.lock:
|
|
self.cur.execute("SELECT * FROM users")
|
|
users = self.cur.fetchall()
|
|
|
|
user = [user for user in users if user["mxid"] == mxid]
|
|
|
|
return user[0] if user else {}
|