diff --git a/src/treepuncher/__main__.py b/src/treepuncher/__main__.py index f99c980..7008b0f 100644 --- a/src/treepuncher/__main__.py +++ b/src/treepuncher/__main__.py @@ -10,8 +10,8 @@ import inspect from pathlib import Path from importlib import import_module import traceback -from typing import List, Type, Set, get_type_hints -from dataclasses import dataclass, MISSING, fields +from typing import Type, Set, get_type_hints +from dataclasses import MISSING, fields from setproctitle import setproctitle diff --git a/src/treepuncher/addon.py b/src/treepuncher/addon.py index e0ae08d..2f90eb7 100644 --- a/src/treepuncher/addon.py +++ b/src/treepuncher/addon.py @@ -4,6 +4,8 @@ import logging from typing import TYPE_CHECKING, Dict, Any, Optional, Union, List, Callable, get_type_hints, get_args, get_origin from dataclasses import dataclass, MISSING, fields +from treepuncher.storage import AddonStorage + from .scaffold import ConfigObject if TYPE_CHECKING: @@ -45,6 +47,7 @@ def parse_with_hint(val:str, hint:Any) -> Any: class Addon: name: str config: ConfigObject + storage: AddonStorage logger: logging.Logger _client: 'Treepuncher' @@ -81,6 +84,7 @@ class Addon: else: # not really necessary since it's a dataclass but whatever opts[field.name] = default self.config = self.Options(**opts) + self.storage = client.storage.addon_storage(self.name) self.logger = self._client.logger.getChild(self.name) self.register() diff --git a/src/treepuncher/storage.py b/src/treepuncher/storage.py index 142052e..51e348c 100644 --- a/src/treepuncher/storage.py +++ b/src/treepuncher/storage.py @@ -20,11 +20,31 @@ class AuthenticatorState: token : Dict[str, Any] legacy : bool = False +class AddonStorage: + db: sqlite3.Connection + name: str + + def __init__(self, db:sqlite3.Connection, name:str): + self.db = db + self.name = name + self.db.cursor().execute('CREATE TABLE IF NOT EXISTS documents (name TEXT PRIMARY KEY, value TEXT)') + self.db.commit() + + # fstrings in queries are evil but if you go to this length to fuck up you kinda deserve it :) + def get(self, key:str) -> Optional[Any]: + res = self.db.cursor().execute(f"SELECT * FROM documents_{self.name} WHERE name = ?", (key,)).fetchall() + return json.loads(res[0][1]) + + def put(self, key:str, val:Any) -> None: + cur = self.db.cursor() + cur.execute("DELETE FROM documents WHERE name = ?", (key,)) + cur.execute(f"INSERT INTO documents_{self.name} VALUES (?, ?)", (key, json.dumps(val, default=str),)) + self.db.commit() + class Storage: name : str db : sqlite3.Connection - def __init__(self, name:str): self.name = name init = not os.path.isfile(name) @@ -57,6 +77,9 @@ class Storage: cur.execute('INSERT INTO authenticator VALUES (?, ?, ?)', (state.date.strftime(__DATE_FORMAT__), json.dumps(state.token), state.legacy)) self.db.commit() + def addon_storage(self, name:str) -> AddonStorage: + return AddonStorage(self.db, name) + def system(self) -> Optional[SystemState]: cur = self.db.cursor() val = cur.execute('SELECT * FROM system').fetchall() @@ -78,10 +101,9 @@ class Storage: token=json.loads(val[0][1]), legacy=val[0][2] or False ) - + def get(self, key:str) -> Optional[Any]: - cur = self.db.cursor() - val = cur.execute("SELECT * FROM documents WHERE name = ?", (key,)).fetchall() + val = self.db.cursor().execute("SELECT * FROM documents WHERE name = ?", (key,)).fetchall() return json.loads(val[0][1]) if val else None def put(self, key:str, val:Any) -> None: @@ -90,3 +112,4 @@ class Storage: cur.execute("INSERT INTO documents VALUES (?, ?)", (key, json.dumps(val, default=str))) self.db.commit() +