feat: fixes in the runtime, added block_on

Former-commit-id: 282251232e15fbca4f7d6f591293cfc010bc63de
This commit is contained in:
cschen 2024-08-20 12:06:46 +02:00
parent 974afb98f1
commit 62ed439b41
6 changed files with 67 additions and 99 deletions

View file

@ -28,7 +28,7 @@ package_logger.propagate = False
logger = logging.getLogger(__name__)
TEXT_LISTENER = None
rt.dispatch(inner_logger.listen(), "codemp-logger")
# rt.dispatch(inner_logger.listen(), "codemp-logger")
# Initialisation and Deinitialisation
@ -48,8 +48,6 @@ def plugin_loaded():
def disconnect_client():
global TEXT_LISTENER
# rt.stop_all()
if TEXT_LISTENER is not None:
safe_listener_detach(TEXT_LISTENER)
@ -66,8 +64,6 @@ def plugin_unloaded():
disconnect_client()
rt.stop_loop()
# tm.release(False)
# Listeners
##############################################################################
@ -161,7 +157,7 @@ class CodempClientTextChangeListener(sublime_plugin.TextChangeListener):
vbuff = client.get_buffer(self.buffer.primary_view())
if vbuff is not None:
vbuff.send_buffer_change(changes)
rt.dispatch(vbuff.send_buffer_change(changes))
# Commands:
@ -235,7 +231,10 @@ async def JoinCommand(client: VirtualClient, workspace_id: str, buffer_id: str):
except Exception as e:
raise e
assert vws is not None
if vws is None:
logger.warning("The client returned a void workspace.")
return
vws.materialize()
if buffer_id != "":
@ -244,7 +243,6 @@ async def JoinCommand(client: VirtualClient, workspace_id: str, buffer_id: str):
class CodempJoinCommand(sublime_plugin.WindowCommand):
def run(self, workspace_id, buffer_id):
print(workspace_id, buffer_id)
if buffer_id == "* Don't Join Any":
buffer_id = ""
rt.dispatch(JoinCommand(client, workspace_id, buffer_id))

View file

@ -94,7 +94,7 @@ class VirtualBuffer:
logger.error(f"buffer worker '{self.codemp_id}' crashed:\n{e}")
raise
def send_buffer_change(self, changes):
async def send_buffer_change(self, changes):
# we do not do any index checking, and trust sublime with providing the correct
# sequential indexing, assuming the changes are applied in the order they are received.
for change in changes:
@ -104,7 +104,7 @@ class VirtualBuffer:
region.begin(), region.end(), change.str
)
)
self.buffctl.send(region.begin(), region.end(), change.str)
await self.buffctl.send(region.begin(), region.end(), change.str)
def send_cursor(self, vws): # pyright: ignore # noqa: F821
# TODO: only the last placed cursor/selection.

View file

@ -31,10 +31,10 @@ class VirtualClient:
)
return
id = self.handle.user_id()
id = self.handle.user_id() # pyright: ignore
logger.debug(f"Connected to '{host}' with user {user} and id: {id}")
async def join_workspace(
def join_workspace(
self,
workspace_id: str,
) -> VirtualWorkspace | None:
@ -43,7 +43,7 @@ class VirtualClient:
logger.info(f"Joining workspace: '{workspace_id}'")
try:
workspace = await self.handle.join_workspace(workspace_id)
workspace = self.handle.join_workspace(workspace_id)
except Exception as e:
logger.error(f"Could not join workspace '{workspace_id}'.\n\nerror: {e}")
sublime.error_message(f"Could not join workspace '{workspace_id}'")

View file

@ -18,7 +18,8 @@ class CodempLogger:
# initialize only once
self.internal_logger = PyLogger(self.level == logging.DEBUG)
except Exception:
pass
if self.internal_logger is None:
raise
async def listen(self):
if self.started:
@ -29,7 +30,11 @@ class CodempLogger:
assert self.internal_logger is not None
try:
while msg := await self.internal_logger.listen():
self.logger.log(self.level, msg)
if msg is not None:
logger.log(logging.DEBUG, msg)
else:
logger.log(logging.DEBUG, "logger sender dropped.")
break
except CancelledError:
self.logger.debug("inner logger stopped.")
self.started = False

View file

@ -1,4 +1,6 @@
from typing import Optional, Callable, Any
from asyncio.coroutines import functools
import sublime
import logging
import asyncio
@ -61,38 +63,71 @@ class Runtime:
self.thread = threading.Thread(
target=self.loop.run_forever, name="codemp-asyncio-loop"
)
logger.debug("spinning up even loop in its own thread.")
self.thread.start()
def __del__(self):
logger.debug("closing down the event loop")
for task in self.tasks:
logger.debug("closing down the event loop.")
for task in asyncio.all_tasks(self.loop):
task.cancel()
self.stop_loop()
try:
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
except Exception as e:
logger.error(f"Unexpected crash while shutting down event loop: {e}")
self.stop_loop()
self.thread.join()
def stop_loop(self):
logger.debug("stopping event loop.")
self.loop.call_soon_threadsafe(lambda: asyncio.get_running_loop().stop())
self.thread.join()
def run_blocking(self, fut, *args, **kwargs):
return self.loop.run_in_executor(None, fut, *args, **kwargs)
def dispatch(self, coro, name=None):
logging.debug("dispatching coroutine...")
"""
Dispatch a task on the event loop and returns the task itself.
Similar to `run_coroutine_threadsafe` but returns the
actual task running and not the result of the coroutine.
def make_task():
logging.debug("creating task on the loop.")
task = self.loop.create_task(coro)
task.set_name(name)
self.tasks.append(task)
`run_coroutine_threadsafe` returns a concurrent.futures.Future
which has a blocking .result so not really suited for long running
coroutines
"""
logger.debug("dispatching coroutine...")
self.loop.call_soon_threadsafe(make_task)
def make_task(fut):
logger.debug("creating task on the loop.")
try:
fut.set_result(self.loop.create_task(coro))
except Exception as e:
fut.set_exception(e)
# create the future to populate with the task
# we use the concurrent.futures.Future since it is thread safe
# and the .result() call is blocking.
fut = concurrent.futures.Future()
self.loop.call_soon_threadsafe(make_task, fut)
task = fut.result(None) # wait for the task to be created
task.set_name(name)
self.tasks.append(task) # save the reference
return task
def block_on(self, coro, timeout=None):
fut = asyncio.run_coroutine_threadsafe(coro, self.loop)
try:
return fut.result(timeout)
except asyncio.CancelledError:
logger.debug("future got cancelled.")
raise
except TimeoutError:
logger.debug("future took too long to finish.")
raise
except Exception as e:
raise e
def get_task(self, name) -> Optional[asyncio.Task]:
return next((t for t in self.tasks if t.get_name() == name), None)
@ -110,76 +145,6 @@ class Runtime:
return
# class TaskManager:
# def __init__(self):
# self.tasks = []
# self.runtime = rt
# self.exit_handler_id = None
# def acquire(self, exit_handler):
# if self.exit_handler_id is None:
# # don't allow multiple exit handlers
# self.exit_handler_id = self.runtime.acquire(exit_handler)
# return self.exit_handler_id
# def release(self, at_exit):
# self.runtime.release(at_exit=at_exit, exit_handler_id=self.exit_handler_id)
# self.exit_handler_id = None
# def dispatch(self, coro, name=None):
# self.runtime.dispatch(coro, self.store_named_lambda(name))
# def sync(self, coro):
# return self.runtime.sync(coro)
# def remove_stopped(self):
# self.tasks = list(filter(lambda T: not T.cancelled(), self.tasks))
# def store(self, task, name=None):
# if name is not None:
# task.set_name(name)
# self.tasks.append(task)
# self.remove_stopped()
# def store_named_lambda(self, name=None):
# def _store(task):
# self.store(task, name)
# return _store
# def get_task(self, name) -> Optional[asyncio.Task]:
# return next((t for t in self.tasks if t.get_name() == name), None)
# def get_task_idx(self, name) -> Optional[int]:
# return next(
# (i for (i, t) in enumerate(self.tasks) if t.get_name() == name), None
# )
# def pop_task(self, name) -> Optional[asyncio.Task]:
# idx = self.get_task_idx(name)
# if id is not None:
# return self.tasks.pop(idx)
# return None
# async def _stop(self, task):
# task.cancel() # cancelling a task, merely requests a cancellation.
# try:
# await task
# except asyncio.CancelledError:
# return
# def stop(self, name):
# t = self.get_task(name)
# if t is not None:
# self.runtime.dispatch(self._stop(t))
# def stop_all(self):
# for task in self.tasks:
# self.runtime.dispatch(self._stop(task))
# # singleton instance
# tm = TaskManager()
# store a global in the module so it acts as a singleton
# (modules are loaded only once)
rt = Runtime()

View file

@ -246,7 +246,7 @@ class VirtualWorkspace:
[reg],
flags=reg_flags,
scope=g.REGIONS_COLORS[user_hash % len(g.REGIONS_COLORS)],
annotations=[cursor_event.user],
annotations=[cursor_event.user], # pyright: ignore
annotation_color=g.PALETTE[user_hash % len(g.PALETTE)],
)