diff --git a/cobald_tests/daemon/test_service.py b/cobald_tests/daemon/test_service.py index c1bc3833..c9642f7d 100644 --- a/cobald_tests/daemon/test_service.py +++ b/cobald_tests/daemon/test_service.py @@ -5,6 +5,8 @@ import asyncio import contextlib import logging +import signal +import os import pytest @@ -25,6 +27,7 @@ def accept(payload: ServiceRunner, name=None): ) thread.start() if not payload.running.wait(1): + payload.shutdown() raise RuntimeError("%s failed to start" % payload) try: yield @@ -33,24 +36,30 @@ def accept(payload: ServiceRunner, name=None): thread.join() -class TestServiceRunner(object): - def test_no_tainting(self): - """Assert that no payloads may be scheduled before starting""" +def sync_raise(what): + logging.info(f"raising {what}") + raise what - def payload(): - return - runner = ServiceRunner() - runner._meta_runner.register_payload(payload, flavour=threading) - with pytest.raises(RuntimeError): - runner.accept() +async def async_raise(what): + sync_raise(what) + + +def sync_raise_signal(what): + logging.info(f"signal {what}") + os.kill(os.getpid(), what) + + +async def async_raise_signal(what): + sync_raise_signal(what) + +class TestServiceRunner(object): def test_unique_reaper(self): """Assert that no two runners may fetch services""" with accept(ServiceRunner(accept_delay=0.1), name="outer"): with pytest.raises(RuntimeError): - with accept(ServiceRunner(accept_delay=0.1), name="inner"): - pass + ServiceRunner(accept_delay=0.1).accept() def test_service(self): """Test running service classes automatically""" @@ -133,3 +142,36 @@ async def co_pingpong(what=default): break else: assert len(reply_store) == 9 + + @pytest.mark.parametrize( + "flavour, do_sleep, do_raise", + ( + (asyncio, asyncio.sleep, async_raise), + (trio, trio.sleep, async_raise), + (threading, time.sleep, sync_raise), + ), + ) + def test_error_reporting(self, flavour, do_sleep, do_raise): + """Test that fatal errors do not pass silently""" + # errors should fail the entire runtime + runner = ServiceRunner(accept_delay=0.1) + runner.adopt(do_sleep, 5, flavour=flavour) + runner.adopt(do_raise, LookupError, flavour=flavour) + with pytest.raises(RuntimeError): + runner.accept() + + @pytest.mark.parametrize( + "flavour, do_sleep, do_raise", + ( + (asyncio, asyncio.sleep, async_raise_signal), + (trio, trio.sleep, async_raise_signal), + (threading, time.sleep, sync_raise_signal), + ), + ) + def test_interrupt(self, flavour, do_sleep, do_raise): + """Test that KeyboardInterrupt/^C is graceful shutdown""" + runner = ServiceRunner(accept_delay=0.1) + runner.adopt(do_sleep, 5, flavour=flavour) + # signal.SIGINT == KeyboardInterrupt + runner.adopt(do_raise, signal.SIGINT, flavour=flavour) + runner.accept() diff --git a/cobald_tests/utility/concurrent/test_meta_runner.py b/cobald_tests/utility/concurrent/test_meta_runner.py index 85dbe392..c317000a 100644 --- a/cobald_tests/utility/concurrent/test_meta_runner.py +++ b/cobald_tests/utility/concurrent/test_meta_runner.py @@ -2,6 +2,7 @@ import pytest import time import asyncio +import contextlib import trio @@ -13,36 +14,22 @@ class TerminateRunner(Exception): pass -def run_in_thread(payload, name, daemon=True): - thread = threading.Thread(target=payload, name=name, daemon=daemon) +@contextlib.contextmanager +def threaded_run(name=None): + runner = MetaRunner() + thread = threading.Thread(target=runner.run, name=name or str(runner), daemon=True) thread.start() - time.sleep(0.0) + if not runner.running.wait(1): + runner.stop() + raise RuntimeError("%s failed to start" % runner) + try: + yield runner + finally: + runner.stop() + thread.join() class TestMetaRunner(object): - def test_bool_payloads(self): - def subroutine(): - time.sleep(0.5) - - async def a_coroutine(): - await asyncio.sleep(0.5) - - async def t_coroutine(): - await trio.sleep(0.5) - - for flavour, payload in ( - (threading, subroutine), - (asyncio, a_coroutine), - (trio, t_coroutine), - ): - runner = MetaRunner() - assert not bool(runner) - runner.register_payload(payload, flavour=flavour) - assert bool(runner) - run_in_thread(runner.run, name="test_bool_payloads %s" % flavour) - assert bool(runner) - runner.stop() - @pytest.mark.parametrize("flavour", (threading,)) def test_run_subroutine(self, flavour): """Test executing a subroutine""" @@ -53,11 +40,11 @@ def with_return(): def with_raise(): raise KeyError("expected exception") - runner = MetaRunner() - result = runner.run_payload(with_return, flavour=flavour) - assert result == with_return() - with pytest.raises(KeyError): - runner.run_payload(with_raise, flavour=flavour) + with threaded_run("test_run_subroutine") as runner: + result = runner.run_payload(with_return, flavour=flavour) + assert result == with_return() + with pytest.raises(KeyError): + runner.run_payload(with_raise, flavour=flavour) @pytest.mark.parametrize("flavour", (asyncio, trio)) def test_run_coroutine(self, flavour): @@ -69,13 +56,11 @@ async def with_return(): async def with_raise(): raise KeyError("expected exception") - runner = MetaRunner() - run_in_thread(runner.run, name="test_run_coroutine %s" % flavour) - result = runner.run_payload(with_return, flavour=flavour) - assert result == trio.run(with_return) - with pytest.raises(KeyError): - runner.run_payload(with_raise, flavour=flavour) - runner.stop() + with threaded_run("test_run_coroutine") as runner: + result = runner.run_payload(with_return, flavour=flavour) + assert result == trio.run(with_return) + with pytest.raises(KeyError): + runner.run_payload(with_raise, flavour=flavour) @pytest.mark.parametrize("flavour", (threading,)) def test_return_subroutine(self, flavour): @@ -151,7 +136,6 @@ async def loop(): await flavour.sleep(0) runner = MetaRunner() - runner.register_payload(noop, loop, flavour=flavour) runner.register_payload(abort, flavour=flavour) with pytest.raises(RuntimeError) as exc: diff --git a/src/cobald/daemon/runners/_compat.py b/src/cobald/daemon/runners/_compat.py new file mode 100644 index 00000000..0243c25a --- /dev/null +++ b/src/cobald/daemon/runners/_compat.py @@ -0,0 +1,51 @@ +import sys +import asyncio +import inspect + + +if sys.version_info >= (3, 7): + asyncio_run = asyncio.run +else: + # almost literal backport of asyncio.run + def asyncio_run(main, *, debug=None): + assert inspect.iscoroutine(main) + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + if debug is not None: + loop.set_debug(debug) + return loop.run_until_complete(main) + finally: + try: + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + loop.close() + + def _cancel_all_tasks(loop): + to_cancel = asyncio.Task.all_tasks(loop) + if not to_cancel: + return + for task in to_cancel: + task.cancel() + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + + +if sys.version_info >= (3, 7): + asyncio_current_task = asyncio.current_task +else: + + def asyncio_current_task() -> asyncio.Task: + return asyncio.Task.current_task() diff --git a/src/cobald/daemon/runners/async_tools.py b/src/cobald/daemon/runners/async_tools.py deleted file mode 100644 index e2f29f2b..00000000 --- a/src/cobald/daemon/runners/async_tools.py +++ /dev/null @@ -1,41 +0,0 @@ -import threading -from .base_runner import OrphanedReturn - - -async def raise_return(payload): - """Wrapper to raise exception on unhandled return values""" - value = await payload() - if value is not None: - raise OrphanedReturn(payload, value) - - -class AsyncExecution(object): - def __init__(self, payload): - self.payload = payload - self._result = None - self._done = threading.Event() - self._done.clear() - - # explicit coroutine for libraries that type check - async def coroutine(self): - await self - - def __await__(self): - try: - value = yield from self.payload().__await__() - except Exception as err: - self._result = None, err - else: - self._result = value, None - self._done.set() - - def wait(self): - self._done.wait() - value, exception = self._result - if exception is None: - return value - else: - raise exception - - def __repr__(self): - return "%s(%s)" % (self.__class__.__name__, self.payload) diff --git a/src/cobald/daemon/runners/asyncio_runner.py b/src/cobald/daemon/runners/asyncio_runner.py index f71697bc..a63e000b 100644 --- a/src/cobald/daemon/runners/asyncio_runner.py +++ b/src/cobald/daemon/runners/asyncio_runner.py @@ -1,81 +1,69 @@ +from typing import Callable, Awaitable, Coroutine import asyncio -from functools import partial -from .base_runner import BaseRunner -from .async_tools import raise_return, AsyncExecution +from .base_runner import BaseRunner, OrphanedReturn +from ._compat import asyncio_current_task class AsyncioRunner(BaseRunner): - """Runner for coroutines with :py:mod:`asyncio`""" + """ + Runner for coroutines with :py:mod:`asyncio` + + All active payloads are actively cancelled when the runner is closed. + """ flavour = asyncio - def __init__(self): - super().__init__() - self.event_loop = asyncio.new_event_loop() + # This runner directly uses asyncio.Task to run payloads. + # To detect errors, each payload is wrapped; errors and unexpected return values + # are pushed to a queue from which the main task re-raises. + # Tasks are registered in a container to allow cancelling them. The payload wrapper + # takes care of adding/removing tasks. + def __init__(self, asyncio_loop: asyncio.AbstractEventLoop): + super().__init__(asyncio_loop) self._tasks = set() + self._payload_failure = asyncio_loop.create_future() - def register_payload(self, payload): - super().register_payload(partial(raise_return, payload)) + def register_payload(self, payload: Callable[[], Awaitable]): + self.asyncio_loop.call_soon_threadsafe(self._setup_payload, payload) - def run_payload(self, payload): - execution = AsyncExecution(payload) - super().register_payload(execution.coroutine) - return execution.wait() + def run_payload(self, payload: Callable[[], Coroutine]): + future = asyncio.run_coroutine_threadsafe(payload(), self.asyncio_loop) + return future.result() - def _run(self): - asyncio.set_event_loop(self.event_loop) - self.event_loop.run_until_complete(self._run_payloads()) + def _setup_payload(self, payload: Callable[[], Awaitable]): + task = self.asyncio_loop.create_task(self._monitor_payload(payload)) + self._tasks.add(task) - async def _run_payloads(self): - """Async component of _run""" - delay = 0.0 + async def _monitor_payload(self, payload: Callable[[], Awaitable]): try: - while self.running.is_set(): - await self._start_payloads() - await self._reap_payloads() - await asyncio.sleep(delay) - delay = min(delay + 0.1, 1.0) - except Exception: - await self._cancel_payloads() + result = await payload() + except (asyncio.CancelledError, KeyboardInterrupt): raise + except BaseException as e: + failure = e + else: + if result is None: + return + failure = OrphanedReturn(payload, result) + finally: + self._tasks.discard(asyncio_current_task()) + if not self._payload_failure.done(): + self._payload_failure.set_exception(failure) - async def _start_payloads(self): - """Start all queued payloads""" - with self._lock: - for coroutine in self._payloads: - task = self.event_loop.create_task(coroutine()) - self._tasks.add(task) - self._payloads.clear() - await asyncio.sleep(0) - - async def _reap_payloads(self): - """Clean up all finished payloads""" - for task in self._tasks.copy(): - if task.done(): - self._tasks.remove(task) - if task.exception() is not None: - raise task.exception() - await asyncio.sleep(0) - - async def _cancel_payloads(self): - """Cancel all remaining payloads""" - for task in self._tasks: - task.cancel() - await asyncio.sleep(0) - for task in self._tasks: - while not task.done(): - await asyncio.sleep(0.1) - task.cancel() + async def manage_payloads(self): + await self._payload_failure - def stop(self): - if not self.running.wait(0.2): + async def aclose(self): + if self._stopped.is_set() and not self._tasks: return - self._logger.debug("runner disabled: %s", self) - with self._lock: - self.running.clear() - for task in self._tasks: - task.cancel() - self._stopped.wait() - self.event_loop.stop() - self.event_loop.close() + # let the manage task wake up and exit + if not self._payload_failure.done(): + self._payload_failure.set_result(None) + while self._tasks: + for task in self._tasks.copy(): + if task.done(): + self._tasks.discard(task) + else: + task.cancel() + await asyncio.sleep(0.1) diff --git a/src/cobald/daemon/runners/asyncio_watcher.py b/src/cobald/daemon/runners/asyncio_watcher.py deleted file mode 100644 index f859198f..00000000 --- a/src/cobald/daemon/runners/asyncio_watcher.py +++ /dev/null @@ -1,41 +0,0 @@ -import asyncio -import threading -import sys - -from .base_runner import BaseRunner -from .thread_runner import CapturingThread - - -async def awaitable_runner(runner: BaseRunner): - """Execute a runner without blocking the event loop""" - runner_thread = CapturingThread(target=runner.run) - runner_thread.start() - delay = 0.0 - while not runner_thread.join(timeout=0): - await asyncio.sleep(delay) - delay = min(delay + 0.1, 1.0) - - -def asyncio_main_run(root_runner: BaseRunner): - """ - Create an ``asyncio`` event loop running in the main thread and watching runners - - Using ``asyncio`` to handle subprocesses requires a specific loop type - to run in the main thread. - This function sets up and runs the correct loop in a portable way. - In addition, it runs a single :py:class:`~.BaseRunner` until completion - or failure. - - .. seealso:: The `issue #8 `_ - for details. - """ - assert ( - threading.current_thread() == threading.main_thread() - ), "only main thread can accept asyncio subprocesses" - if sys.platform == "win32": - event_loop = asyncio.ProactorEventLoop() - asyncio.set_event_loop(event_loop) - else: - event_loop = asyncio.get_event_loop() - asyncio.get_child_watcher().attach_loop(event_loop) - event_loop.run_until_complete(awaitable_runner(root_runner)) diff --git a/src/cobald/daemon/runners/base_runner.py b/src/cobald/daemon/runners/base_runner.py index 2bff4a88..825df75f 100644 --- a/src/cobald/daemon/runners/base_runner.py +++ b/src/cobald/daemon/runners/base_runner.py @@ -1,86 +1,107 @@ +from typing import Any +from abc import abstractmethod, ABCMeta +import asyncio import logging import threading -from typing import Any -from cobald.daemon.debug import NameRepr +from ..debug import NameRepr -class BaseRunner(object): +class BaseRunner(metaclass=ABCMeta): + """Concurrency backend on top of `asyncio`""" + flavour = None # type: Any - def __init__(self): + def __init__(self, asyncio_loop: asyncio.AbstractEventLoop): + self.asyncio_loop = asyncio_loop self._logger = logging.getLogger( "cobald.runtime.runner.%s" % NameRepr(self.flavour) ) - self._payloads = [] - self._lock = threading.Lock() - #: signal that runner should keep in running - self.running = threading.Event() - #: signal that runner has stopped self._stopped = threading.Event() - self.running.clear() self._stopped.set() - def __bool__(self): - with self._lock: - return bool(self._payloads) or self.running.is_set() - + @abstractmethod def register_payload(self, payload): """ - Register ``payload`` for asynchronous execution + Register ``payload`` for background execution in a threadsafe manner This runs ``payload`` as an orphaned background task as soon as possible. It is an error for ``payload`` to return or raise anything without handling it. """ - with self._lock: - self._payloads.append(payload) + raise NotImplementedError + @abstractmethod def run_payload(self, payload): """ - Register ``payload`` for synchronous execution + Execute ``payload`` and return its result in a threadsafe manner This runs ``payload`` as soon as possible, blocking until completion. Should ``payload`` return or raise anything, it is propagated to the caller. """ raise NotImplementedError - def run(self): + async def ready(self): + """Wait until the runner is ready to accept payloads""" + assert ( + not self._stopped.is_set() + ), "runner must be .run before waiting until it is ready" + # Most runners are ready when instantiated, simply queueing payloads + # until they get a chance to run them. Only override this method when + # the runner has to do some `async` setup before being ready. + + async def run(self): """ - Execute all current and future payloads + Execute all current and future payloads in an `asyncio` coroutine + + This method will continuously execute payloads sent to the runner. + It only returns when :py:meth:`stop` is called + or if any orphaned payload returns or raises. + In the latter case, :py:exc:`~.OrphanedReturn` or the raised exception + is re-raised by this method. - Blocks and executes payloads until :py:meth:`stop` is called. - It is an error for any orphaned payload to return or raise. + Implementations should override :py:meth:`~.manage_payloads` + to customize their specific parts. """ self._logger.info("runner started: %s", self) + self._stopped.clear() try: - with self._lock: - assert not self.running.is_set() and self._stopped.is_set(), ( - "cannot re-run: %s" % self - ) - self.running.set() - self._stopped.clear() - self._run() - except Exception: + await self.manage_payloads() + except asyncio.CancelledError: + self._logger.info("runner cancelled: %s", self) + raise + except BaseException: self._logger.exception("runner aborted: %s", self) raise else: self._logger.info("runner stopped: %s", self) finally: - with self._lock: - self.running.clear() - self._stopped.set() + self._stopped.set() + + @abstractmethod + async def manage_payloads(self): + """ + Implementation of managing payloads when :py:meth:`~.run` + + This method must continuously execute payloads sent to the runner. + It may only return when :py:meth:`stop` is called + or if any orphaned payload return or raise. + In the latter case, :py:exc:`~.OrphanedReturn` or the raised exception + must re-raised by this method. + """ + raise NotImplementedError - def _run(self): + @abstractmethod + async def aclose(self): + """Shut down this runner""" raise NotImplementedError def stop(self): - """Stop execution of all current and future payloads""" - if not self.running.wait(0.2): + """Stop execution of all current and future payloads and block until success""" + if self._stopped.is_set(): return - self._logger.debug("runner disabled: %s", self) - with self._lock: - self.running.clear() - self._stopped.wait() + # the loop exists independently of all runners, we can use it to shut down + closed = asyncio.run_coroutine_threadsafe(self.aclose(), self.asyncio_loop) + closed.result() class OrphanedReturn(Exception): diff --git a/src/cobald/daemon/runners/meta_runner.py b/src/cobald/daemon/runners/meta_runner.py index f6627c75..7ebad5bc 100644 --- a/src/cobald/daemon/runners/meta_runner.py +++ b/src/cobald/daemon/runners/meta_runner.py @@ -1,17 +1,17 @@ +from typing import Dict, List, Any import logging import threading -import trio - +import warnings +import asyncio from types import ModuleType from .base_runner import BaseRunner from .trio_runner import TrioRunner from .asyncio_runner import AsyncioRunner from .thread_runner import ThreadRunner -from .asyncio_watcher import asyncio_main_run - +from ._compat import asyncio_run -from cobald.daemon.debug import NameRepr +from ..debug import NameRepr class MetaRunner(object): @@ -23,97 +23,110 @@ class MetaRunner(object): def __init__(self): self._logger = logging.getLogger("cobald.runtime.runner.meta") - self.runners = { - runner.flavour: runner() for runner in self.runner_types - } # type: dict[ModuleType, BaseRunner] - self._lock = threading.Lock() + self._runners: Dict[ModuleType, BaseRunner] = {} + # queue to store payloads submitted before the runner is started + self._runner_queues: Dict[ModuleType, Any] = {} self.running = threading.Event() - self.running.clear() - def __bool__(self): - return any(bool(runner) for runner in self.runners.values()) + @property + def runners(self): + warnings.warn( + DeprecationWarning( + "Accessing 'MetaRunner.runners' directly is deprecated. " + "Use register_payload or run_payload with the correct flavour instead." + ) + ) + return self._runners def register_payload(self, *payloads, flavour: ModuleType): - """Queue one or more payload for execution after its runner is started""" - for payload in payloads: - self._logger.debug( - "registering payload %s (%s)", NameRepr(payload), NameRepr(flavour) - ) - self.runners[flavour].register_payload(payload) + """Queue one or more payloads for execution after its runner is started""" + try: + runner = self._runners[flavour] + except KeyError: + if self.running.is_set(): + raise RuntimeError(f"unknown runner {NameRepr(flavour)}") from None + self._runner_queues.setdefault(flavour, []).extend(payloads) + else: + for payload in payloads: + self._logger.debug( + "registering payload %s (%s)", NameRepr(payload), NameRepr(flavour) + ) + runner.register_payload(payload) def run_payload(self, payload, *, flavour: ModuleType): - """Execute one payload after its runner is started and return its output""" - return self.runners[flavour].run_payload(payload) + """ + Execute one payload and return its output + + This method will block until the payload is completed. + It is an error to call it during initialisation before the runners are started. + """ + return self._runners[flavour].run_payload(payload) def run(self): """Run all runners, blocking until completion or error""" self._logger.info("starting all runners") try: - with self._lock: - assert not self.running.set(), "cannot re-run: %s" % self - self.running.set() - thread_runner = self.runners[threading] - for runner in self.runners.values(): - if runner is not thread_runner: - thread_runner.register_payload(runner.run) - if threading.current_thread() == threading.main_thread(): - asyncio_main_run(root_runner=thread_runner) - else: - thread_runner.run() + asyncio_run(self._manage_runners()) except KeyboardInterrupt: self._logger.info("runner interrupted") except Exception as err: self._logger.exception("runner terminated: %s", err) - raise RuntimeError from err + raise RuntimeError("background task failed") from err finally: - self._stop_runners() self._logger.info("stopped all runners") - self.running.clear() def stop(self): """Stop all runners""" - self._stop_runners() - - def _stop_runners(self): - for runner in self.runners.values(): - if runner.flavour == threading: - continue + self._logger.debug("stop all runners") + for runner in self._runners.values(): runner.stop() - self.runners[threading].stop() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - import time - import asyncio - - runner = MetaRunner() - async def trio_sleeper(): - for i in range(3): - print("trio\t", i) - await trio.sleep(0.1) - - runner.register_payload(trio_sleeper, flavour=trio) - - async def asyncio_sleeper(): - for i in range(3): - print("asyncio\t", i) - await asyncio.sleep(0.1) - - runner.register_payload(asyncio_sleeper, flavour=asyncio) - - def thread_sleeper(): - for i in range(3): - print("thread\t", i) - time.sleep(0.1) - - runner.register_payload(thread_sleeper, flavour=threading) - - async def teardown(): - await trio.sleep(5) - raise SystemExit("Abort from trio runner") - - runner.register_payload(teardown, flavour=trio) + async def _manage_runners(self): + """Manage all runners inside the current `asyncio` event loop""" + runner_tasks = await self._launch_runners() + self.running.set() + try: + # wait for all runners to either stop gracefully or propagate errors + # we only unqueue payloads *while* watching runners as payloads could + # cause the runners to fail – we need to stop unqueueing them as well. + await asyncio.gather(*runner_tasks, self._unqueue_payloads()) + except KeyboardInterrupt: + # KeyboardInterrupt in a runner task immediately kills the event loop. + # When we get resurrected, the exception has already been handled! + # Just clean up... + await asyncio.shield(self._aclose_runners(runner_tasks)) + except BaseException: + await asyncio.shield(self._aclose_runners(runner_tasks)) + raise + finally: + self.running.clear() - runner.run() + async def _launch_runners(self) -> List[asyncio.Task]: + """Launch all runners inside the current `asyncio` event loop""" + asyncio_loop = asyncio.get_event_loop() + self._runners = {} + runner_tasks = [] + for runner_type in self.runner_types: + runner = self._runners[runner_type.flavour] = runner_type(asyncio_loop) + runner_tasks.append(asyncio_loop.create_task(runner.run())) + for runner in self._runners.values(): + await runner.ready() + return runner_tasks + + async def _unqueue_payloads(self) -> None: + """Register payloads once runners are started""" + # Unqueue when we are running so that payloads do not get requeued. + # This also provides checking that the queued flavours correspond to a runner. + assert self.running.is_set(), "runners must be launched before unqueueing" + # runners are started, so re-registering payloads does not queue them again + for flavour, queue in self._runner_queues.items(): + self.register_payload(*queue, flavour=flavour) + queue.clear() + self._runner_queues.clear() + + async def _aclose_runners(self, runner_tasks): + for runner in self._runners.values(): + await runner.aclose() + # wait until runners are closed + await asyncio.gather(*runner_tasks, return_exceptions=True) + self._runners.clear() diff --git a/src/cobald/daemon/runners/service.py b/src/cobald/daemon/runners/service.py index fa10ecc0..eeccd4d7 100644 --- a/src/cobald/daemon/runners/service.py +++ b/src/cobald/daemon/runners/service.py @@ -110,6 +110,16 @@ def __new_service__(cls, *args, **kwargs): class ServiceRunner(object): """ Runner for coroutines, subroutines and services + + The service runner prevents silent failures by tracking concurrent tasks + and therefore provides safer concurrency. + If any task fails with an exception or provides + unexpected output values, this is registered as an error; the runner will + gracefully shut down all tasks in this case. + + To provide ``async`` concurrency, the runner also manages common + ``async`` event loops and tracks them for failures as well. As a result, + ``async`` code should usually use the "current" event loop directly. """ def __init__(self, accept_delay: float = 1): @@ -117,6 +127,7 @@ def __init__(self, accept_delay: float = 1): self._meta_runner = MetaRunner() self._must_shutdown = False self._is_shutdown = threading.Event() + self._is_shutdown.set() self.running = threading.Event() self.accept_delay = accept_delay @@ -150,8 +161,6 @@ def accept(self): Since services are globally defined, only one :py:class:`ServiceRunner` may :py:meth:`accept` payloads at any time. """ - if self._meta_runner: - raise RuntimeError("payloads scheduled for %s before being started" % self) self._must_shutdown = False self._logger.info("%s starting", self.__class__.__name__) # force collecting objects so that defunct, @@ -177,7 +186,9 @@ async def _accept_services(self): self._adopt_services() await trio.sleep(delay) delay = min(delay + increase, max_delay) - except Exception: + except trio.Cancelled: + self._logger.info("%s cancelled", self.__class__.__name__) + except BaseException: self._logger.exception("%s aborted", self.__class__.__name__) raise else: diff --git a/src/cobald/daemon/runners/thread_runner.py b/src/cobald/daemon/runners/thread_runner.py index df5b6641..1cf3dc84 100644 --- a/src/cobald/daemon/runners/thread_runner.py +++ b/src/cobald/daemon/runners/thread_runner.py @@ -1,81 +1,58 @@ import threading -import time +import asyncio -from ..debug import NameRepr from .base_runner import BaseRunner, OrphanedReturn -class CapturingThread(threading.Thread): - """ - Daemonic threads that capture any return value or exception from their ``target`` +class ThreadRunner(BaseRunner): """ + Runner for subroutines with :py:mod:`threading` - def __init__(self, **kwargs): - super().__init__(**kwargs, daemon=True) - self._exception = None - self._name = str(NameRepr(self._target)) - - def join(self, timeout=None): - super().join(timeout=timeout) - if self._started.is_set() and not self.is_alive(): - if self._exception is not None: - raise self._exception - return not self.is_alive() - - def run(self): - """Modified ``run`` that captures return value and exceptions from ``target``""" - try: - if self._target: - return_value = self._target(*self._args, **self._kwargs) - if return_value is not None: - self._exception = OrphanedReturn(self, return_value) - except BaseException as err: - self._exception = err - finally: - # Avoid a refcycle if the thread is running a function with - # an argument that has a member that points to the thread. - del self._target, self._args, self._kwargs - - -class ThreadRunner(BaseRunner): - """Runner for subroutines with :py:mod:`threading`""" + Active payloads are *not* cancelled when the runner is closed. + Only program termination forcefully cancels leftover payloads. + """ flavour = threading - def __init__(self): - super().__init__() - self._threads = set() + # This runner directly uses threading.Thread to run payloads. + # To detect errors, each payload is wrapped; errors and unexpected return values + # are pushed to a queue from which the main task re-raises. + def __init__(self, asyncio_loop: asyncio.AbstractEventLoop): + super().__init__(asyncio_loop) + self._payload_failure = asyncio_loop.create_future() + + def register_payload(self, payload): + thread = threading.Thread( + target=self._monitor_payload, args=(payload,), daemon=True + ) + thread.start() def run_payload(self, payload): - # - run_payload has to block until payload is done - # instead of running payload in a thread and blocking this one, - # we just block this thread by running payload directly + # The method has to block until payload is done. + # Instead of running payload in a thread and blocking this one, + # this thread is blocked by running the payload directly. return payload() - def _run(self): - delay = 0.0 - while self.running.is_set(): - self._start_payloads() - self._reap_payloads() - time.sleep(delay) - delay = min(delay + 0.1, 1.0) - - def _start_payloads(self): - """Start all queued payloads""" - with self._lock: - payloads = self._payloads.copy() - self._payloads.clear() - for subroutine in payloads: - thread = CapturingThread(target=subroutine) - thread.start() - self._threads.add(thread) - self._logger.debug("booted thread %s", thread) - time.sleep(0) - - def _reap_payloads(self): - """Clean up all finished payloads""" - for thread in self._threads.copy(): - # CapturingThread.join will throw - if thread.join(timeout=0): - self._threads.remove(thread) - self._logger.debug("reaped thread %s", thread) + def _monitor_payload(self, payload): + try: + result = payload() + except BaseException as e: + failure = e + else: + if result is None: + return + failure = OrphanedReturn(payload, result) + self.asyncio_loop.call_soon_threadsafe(self._set_failure, failure) + + def _set_failure(self, failure: BaseException): + if not self._payload_failure.done(): + self._payload_failure.set_exception(failure) + + async def manage_payloads(self): + await self._payload_failure + + async def aclose(self): + if self._stopped.is_set(): + return + if not self._payload_failure.done(): + self._payload_failure.set_result(None) diff --git a/src/cobald/daemon/runners/trio_runner.py b/src/cobald/daemon/runners/trio_runner.py index 6d6c11df..b5d41587 100644 --- a/src/cobald/daemon/runners/trio_runner.py +++ b/src/cobald/daemon/runners/trio_runner.py @@ -1,47 +1,96 @@ -import trio +from typing import Optional, Callable, Awaitable, Coroutine +import asyncio from functools import partial +import trio -from .base_runner import BaseRunner -from .async_tools import raise_return, AsyncExecution +from .base_runner import BaseRunner, OrphanedReturn class TrioRunner(BaseRunner): - """Runner for coroutines with :py:mod:`trio`""" + """ + Runner for coroutines with :py:mod:`trio` + + All active payloads are actively cancelled when the runner is closed. + """ flavour = trio - def __init__(self): - self._nursery = None - super().__init__() + # This runner uses a trio loop in a separate thread to run payloads. + # Tracking payloads and errors is handled by a trio nursery. A queue ("channel") + # is used to move payloads into the trio loop. + # Since the trio loop runs in its own thread, all public methods have to move + # payloads/tasks into that thread. + def __init__(self, asyncio_loop: asyncio.AbstractEventLoop): + super().__init__(asyncio_loop) + self._ready = asyncio.Event() + self._trio_token: Optional[trio.lowlevel.TrioToken] = None + self._submit_tasks: Optional[trio.MemorySendChannel] = None + + def register_payload(self, payload: Callable[[], Awaitable]): + assert self._trio_token is not None and self._submit_tasks is not None + try: + trio.from_thread.run( + self._submit_tasks.send, payload, trio_token=self._trio_token + ) + except (trio.RunFinishedError, trio.Cancelled): + self._logger.warning(f"discarding payload {payload} during shutdown") + return - def register_payload(self, payload): - super().register_payload(partial(raise_return, payload)) + def run_payload(self, payload: Callable[[], Coroutine]): + assert self._trio_token is not None and self._submit_tasks is not None + return trio.from_thread.run(payload, trio_token=self._trio_token) - def run_payload(self, payload): - execution = AsyncExecution(payload) - super().register_payload(execution.coroutine) - return execution.wait() + async def ready(self): + await self._ready.wait() - def _run(self): - return trio.run(self._await_all) + async def manage_payloads(self): + try: + await self.asyncio_loop.run_in_executor(None, self._run_trio_blocking) + except asyncio.CancelledError: + await self.aclose() + raise - async def _await_all(self): - """Async component of _run""" - delay = 0.0 - # we run a top-level nursery that automatically reaps/cancels for us + def _run_trio_blocking(self): + return trio.run(self._manage_payloads_trio) + + async def _manage_payloads_trio(self): + self._trio_token = trio.lowlevel.current_trio_token() + # buffer of 256 is somewhat arbitrary but should be large enough to rarely stall + # and small enough to smooth out explosive backlog. + self._submit_tasks, receive_tasks = trio.open_memory_channel(256) + self.asyncio_loop.call_soon_threadsafe(self._ready.set) async with trio.open_nursery() as nursery: - while self.running.is_set(): - await self._start_payloads(nursery=nursery) - await trio.sleep(delay) - delay = min(delay + 0.1, 1.0) - # cancel the scope to cancel all payloads + async for task in receive_tasks: + nursery.start_soon(self._monitor_payload, task) + # shutting down: cancel the scope to cancel all payloads nursery.cancel_scope.cancel() - async def _start_payloads(self, nursery): - """Start all queued payloads""" - with self._lock: - for coroutine in self._payloads: - nursery.start_soon(coroutine) - self._payloads.clear() - await trio.sleep(0) + async def _monitor_payload(self, payload: Callable[[], Awaitable]): + """Wrapper for awaitables and to raise exception on unhandled return values""" + value = await payload() + if value is not None: + raise OrphanedReturn(payload, value) + + async def _aclose_trio(self): + # suppress trio cancellation to avoid raising an error in aclose + try: + await self._submit_tasks.aclose() + except trio.Cancelled: + pass + + async def aclose(self): + if self._stopped.is_set(): + return + # Trio only allows us an *synchronously blocking* call it from other threads. + # Use an executor thread to make that *asynchronously* blocking for asyncio. + try: + await self.asyncio_loop.run_in_executor( + None, + partial( + trio.from_thread.run, self._aclose_trio, trio_token=self._trio_token + ), + ) + except (trio.RunFinishedError, trio.Cancelled): + # trio already finished in its own thread + return