diff --git a/src/aioappsrv/app.py b/src/aioappsrv/app.py index 9ecc4a3..12c0bc7 100644 --- a/src/aioappsrv/app.py +++ b/src/aioappsrv/app.py @@ -7,19 +7,21 @@ from typing import Callable, Awaitable from aiohttp import ClientSession, web from .matrix import Event, EventType -from .utils import mx_message +from .utils import mx_message, fmt_mxid class AppService: _site: web.TCPSite _client: ClientSession _app: web.Application _callbacks: dict[EventType, dict[str, Callable]] + _CLIENT_API_ROOT: str = "/_matrix/client/r0" + _MEDIA_API_ROOT: str = "/_matrix/media/r0" as_token: str hs_token: str - base_url: str + homeserver: str user_id: str - server_name: str + use_http: bool logger: logging.Logger @@ -27,9 +29,9 @@ class AppService: self, as_token: str, hs_token: str, - base_url: str, + homeserver: str, user_id: str, - server_name: str, + use_http: bool = False, logger: logging.Logger | None = None, ): self._app = web.Application() @@ -38,9 +40,8 @@ class AppService: self.as_token = as_token self.hs_token = hs_token - self.base_url = base_url + self.homeserver = homeserver self.user_id = user_id - self.server_name = server_name self.logger = logger if logger is not None else logging.getLogger(__file__) @@ -76,62 +77,114 @@ class AppService: return func return wrapper + @property + def client_api(self) -> str: + if self.use_http: + return "http://" + self.homeserver + self._CLIENT_API_ROOT + else: + return "https://" + self.homeserver + self._CLIENT_API_ROOT + + @property + def media_api(self) -> str: + if self.use_http: + return "http://" + self.homeserver + self._MEDIA_API_ROOT + else: + return "https://" + self.homeserver + self._MEDIA_API_ROOT + + @property + def api_headers(self) -> dict[str, str]: + return { + "Authorization": f"Bearer {self.as_token}", + "Content-Type": "application/json", + } + + async def register_mxid(self, mxid: str) -> None: + bare_mxid = fmt_mxid(mxid, full=False) + async with self._client.request( + method="POST", + url=f"{self.client_api}/register", + headers=self.api_headers, + json={ + "type":"m.login.application_service", + "username": bare_mxid + }) as res: + res.raise_for_status() + doc = await res.json() + self.logger.debug("registered mxid %s", bare_mxid) + return doc["user_id"] + + async def set_avatar(self, mxid: str, avatar_url: str) -> None: + bare_mxid = fmt_mxid(mxid, full=False) + async with self._client.get(avatar_url) as res: + res.raise_for_status() + async with self._client.request( + method="POST", + url=f"{self.media_api}/upload", + headers={ + "Authorization": f"Bearer {self.as_token}", + "Content-Type": res.content_type, + }, + chunked=res.content.read(), + params={"filename":str(uuid.uuid4())}, + ) as res: + res.raise_for_status() + doc = await res.json() + avatar_uri = doc["content_uri"] + async with self._client.request( + method="PUT", + url=f"{self.client_api}/profile/{bare_mxid}/avatar_url", + headers=self.api_headers, + json={"avatar_url": avatar_uri}, + params={"user_id": mxid}, + ) as res: + res.raise_for_status() + self.logger.debug("updated avatar of %s to %s", mxid, avatar_url) + + async def set_nick(self, mxid: str, nick: str) -> None: + bare_mxid = fmt_mxid(mxid, full=False) + async with self._client.request( + method="PUT", + url=f"{self.client_api}/profile/{bare_mxid}/displayname", + headers=self.api_headers, + json={"displayname": nick}, + params={"user_id": mxid}, + ) as res: + res.raise_for_status() + self.logger.debug("updated nick of %s to %s", mxid, nick) + async def invite_to_room(self, room: str, mxid: str) -> None: async with self._client.request( method="POST", - url=f"{self.base_url}/_matrix/client/r0/rooms/{room}/invite", - headers={ - "Authorization": f"Bearer {self.as_token}", - "Content-Type": "application/json", - }, + url=f"{self.client_api}/rooms/{room}/invite", + headers=self.api_headers, params={"user_id": mxid}, ) as res: - if res.ok: - self.logger.debug("inviting to room %s with %s : %s", room, mxid, await res.json()) - else: - self.logger.error("failed inviting to room: %s", await res.text()) - res.raise_for_status() + res.raise_for_status() + self.logger.debug("invited %s to room %s : %s", mxid, room, await res.json()) async def join_room(self, room: str, mxid: str | None = None): async with self._client.request( method="POST", - url=f"{self.base_url}/_matrix/client/r0/join/{room}", - headers={ - "Authorization": f"Bearer {self.as_token}", - "Content-Type": "application/json", - }, + url=f"{self.client_api}/join/{room}", + headers=self.api_headers, params={"user_id": mxid} if mxid else {}, ) as res: - if res.ok: - self.logger.debug("joined room %s with %s : %s", room, mxid, await res.json()) - else: - self.logger.error("failed joining room: %s", await res.text()) - res.raise_for_status() - - async def leave_room(self): - raise NotImplementedError + res.raise_for_status() + self.logger.debug("joined room %s with %s : %s", room, mxid, await res.json()) async def send_message(self, room: str, text: str, mxid: str | None = None) -> str: async with self._client.request( method="PUT", - url=f"{self.base_url}/_matrix/client/r0/rooms/{room}/send/m.room.message/{uuid.uuid4()}", - headers={ - "Authorization": f"Bearer {self.as_token}", - "Content-Type": "application/json", - }, + url=f"{self.client_api}/rooms/{room}/send/m.room.message/{uuid.uuid4()}", + headers=self.api_headers, params={"user_id": mxid} if mxid else {}, json=mx_message(text), ) as res: - if res.ok: - doc = await res.json() - self.logger.debug("sent message %s to %s as %s : %s", text, room, mxid, doc) - return doc["event_id"] - else: - text = await res.text() - self.logger.error("failed sending message: %s", text) - res.raise_for_status() - return "" + res.raise_for_status() + doc = await res.json() + self.logger.debug("sent message %s to %s as %s : %s", text, room, mxid, doc) + return doc["event_id"] async def redact_message(self): raise NotImplementedError diff --git a/src/aioappsrv/utils.py b/src/aioappsrv/utils.py index f8d0cbd..9965b36 100644 --- a/src/aioappsrv/utils.py +++ b/src/aioappsrv/utils.py @@ -2,6 +2,8 @@ import asyncio from io import StringIO from html.parser import HTMLParser +DEFAULT_HOMESERVER = "" # TODO get from env? idk + # thanks [stackoverflow](https://stackoverflow.com/questions/753052/strip-html-from-strings-in-python) class MLStripper(HTMLParser): def __init__(self): @@ -34,3 +36,12 @@ def mx_message(text: str) -> dict: "format": "org.matrix.custom.html", "formatted_body": text, } + +def fmt_mxid(txt:str, full:bool = True, homeserver:str = DEFAULT_HOMESERVER) -> str: + bare_id = txt.split(":", 1)[0] + if bare_id.startswith("@"): + bare_id = bare_id[1:] + if full: + return f"@{bare_id}:{homeserver}" + else: + return bare_id