feat: add register, avatar and nick apis

also refactored a little requests
This commit is contained in:
əlemi 2024-01-29 22:01:41 +01:00
parent 075bc52cbe
commit 22d72d631e
Signed by: alemi
GPG key ID: A4895B84D311642C
2 changed files with 108 additions and 44 deletions

View file

@ -7,19 +7,21 @@ from typing import Callable, Awaitable
from aiohttp import ClientSession, web
from .matrix import Event, EventType
from .utils import mx_message
from .utils import mx_message, fmt_mxid
class AppService:
_site: web.TCPSite
_client: ClientSession
_app: web.Application
_callbacks: dict[EventType, dict[str, Callable]]
_CLIENT_API_ROOT: str = "/_matrix/client/r0"
_MEDIA_API_ROOT: str = "/_matrix/media/r0"
as_token: str
hs_token: str
base_url: str
homeserver: str
user_id: str
server_name: str
use_http: bool
logger: logging.Logger
@ -27,9 +29,9 @@ class AppService:
self,
as_token: str,
hs_token: str,
base_url: str,
homeserver: str,
user_id: str,
server_name: str,
use_http: bool = False,
logger: logging.Logger | None = None,
):
self._app = web.Application()
@ -38,9 +40,8 @@ class AppService:
self.as_token = as_token
self.hs_token = hs_token
self.base_url = base_url
self.homeserver = homeserver
self.user_id = user_id
self.server_name = server_name
self.logger = logger if logger is not None else logging.getLogger(__file__)
@ -76,62 +77,114 @@ class AppService:
return func
return wrapper
@property
def client_api(self) -> str:
if self.use_http:
return "http://" + self.homeserver + self._CLIENT_API_ROOT
else:
return "https://" + self.homeserver + self._CLIENT_API_ROOT
@property
def media_api(self) -> str:
if self.use_http:
return "http://" + self.homeserver + self._MEDIA_API_ROOT
else:
return "https://" + self.homeserver + self._MEDIA_API_ROOT
@property
def api_headers(self) -> dict[str, str]:
return {
"Authorization": f"Bearer {self.as_token}",
"Content-Type": "application/json",
}
async def register_mxid(self, mxid: str) -> None:
bare_mxid = fmt_mxid(mxid, full=False)
async with self._client.request(
method="POST",
url=f"{self.client_api}/register",
headers=self.api_headers,
json={
"type":"m.login.application_service",
"username": bare_mxid
}) as res:
res.raise_for_status()
doc = await res.json()
self.logger.debug("registered mxid %s", bare_mxid)
return doc["user_id"]
async def set_avatar(self, mxid: str, avatar_url: str) -> None:
bare_mxid = fmt_mxid(mxid, full=False)
async with self._client.get(avatar_url) as res:
res.raise_for_status()
async with self._client.request(
method="POST",
url=f"{self.media_api}/upload",
headers={
"Authorization": f"Bearer {self.as_token}",
"Content-Type": res.content_type,
},
chunked=res.content.read(),
params={"filename":str(uuid.uuid4())},
) as res:
res.raise_for_status()
doc = await res.json()
avatar_uri = doc["content_uri"]
async with self._client.request(
method="PUT",
url=f"{self.client_api}/profile/{bare_mxid}/avatar_url",
headers=self.api_headers,
json={"avatar_url": avatar_uri},
params={"user_id": mxid},
) as res:
res.raise_for_status()
self.logger.debug("updated avatar of %s to %s", mxid, avatar_url)
async def set_nick(self, mxid: str, nick: str) -> None:
bare_mxid = fmt_mxid(mxid, full=False)
async with self._client.request(
method="PUT",
url=f"{self.client_api}/profile/{bare_mxid}/displayname",
headers=self.api_headers,
json={"displayname": nick},
params={"user_id": mxid},
) as res:
res.raise_for_status()
self.logger.debug("updated nick of %s to %s", mxid, nick)
async def invite_to_room(self, room: str, mxid: str) -> None:
async with self._client.request(
method="POST",
url=f"{self.base_url}/_matrix/client/r0/rooms/{room}/invite",
headers={
"Authorization": f"Bearer {self.as_token}",
"Content-Type": "application/json",
},
url=f"{self.client_api}/rooms/{room}/invite",
headers=self.api_headers,
params={"user_id": mxid},
) as res:
if res.ok:
self.logger.debug("inviting to room %s with %s : %s", room, mxid, await res.json())
else:
self.logger.error("failed inviting to room: %s", await res.text())
res.raise_for_status()
self.logger.debug("invited %s to room %s : %s", mxid, room, await res.json())
async def join_room(self, room: str, mxid: str | None = None):
async with self._client.request(
method="POST",
url=f"{self.base_url}/_matrix/client/r0/join/{room}",
headers={
"Authorization": f"Bearer {self.as_token}",
"Content-Type": "application/json",
},
url=f"{self.client_api}/join/{room}",
headers=self.api_headers,
params={"user_id": mxid} if mxid else {},
) as res:
if res.ok:
self.logger.debug("joined room %s with %s : %s", room, mxid, await res.json())
else:
self.logger.error("failed joining room: %s", await res.text())
res.raise_for_status()
async def leave_room(self):
raise NotImplementedError
self.logger.debug("joined room %s with %s : %s", room, mxid, await res.json())
async def send_message(self, room: str, text: str, mxid: str | None = None) -> str:
async with self._client.request(
method="PUT",
url=f"{self.base_url}/_matrix/client/r0/rooms/{room}/send/m.room.message/{uuid.uuid4()}",
headers={
"Authorization": f"Bearer {self.as_token}",
"Content-Type": "application/json",
},
url=f"{self.client_api}/rooms/{room}/send/m.room.message/{uuid.uuid4()}",
headers=self.api_headers,
params={"user_id": mxid} if mxid else {},
json=mx_message(text),
) as res:
if res.ok:
res.raise_for_status()
doc = await res.json()
self.logger.debug("sent message %s to %s as %s : %s", text, room, mxid, doc)
return doc["event_id"]
else:
text = await res.text()
self.logger.error("failed sending message: %s", text)
res.raise_for_status()
return ""
async def redact_message(self):
raise NotImplementedError

View file

@ -2,6 +2,8 @@ import asyncio
from io import StringIO
from html.parser import HTMLParser
DEFAULT_HOMESERVER = "" # TODO get from env? idk
# thanks [stackoverflow](https://stackoverflow.com/questions/753052/strip-html-from-strings-in-python)
class MLStripper(HTMLParser):
def __init__(self):
@ -34,3 +36,12 @@ def mx_message(text: str) -> dict:
"format": "org.matrix.custom.html",
"formatted_body": text,
}
def fmt_mxid(txt:str, full:bool = True, homeserver:str = DEFAULT_HOMESERVER) -> str:
bare_id = txt.split(":", 1)[0]
if bare_id.startswith("@"):
bare_id = bare_id[1:]
if full:
return f"@{bare_id}:{homeserver}"
else:
return bare_id