Skip to content

Commit

Permalink
Merge pull request #137 from maxfischer2781/maintenance/test-egroup
Browse files Browse the repository at this point in the history
handle exception groups in tests
  • Loading branch information
maxfischer2781 authored Apr 16, 2024
2 parents 464f02e + 65e0342 commit c46246e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
17 changes: 13 additions & 4 deletions cobald_tests/utility/concurrent/test_meta_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Iterator
import threading
import pytest
import time
Expand All @@ -6,6 +7,7 @@
import gc

import trio
from exceptiongroup import ExceptionGroup

from cobald.daemon.runners.base_runner import OrphanedReturn
from cobald.daemon.runners.meta_runner import MetaRunner
Expand All @@ -15,8 +17,15 @@ class TerminateRunner(Exception):
pass


def unwrap_cause(exc: BaseException) -> BaseException:
if isinstance(exc.__cause__, ExceptionGroup):
assert len(exc.__cause__.exceptions) == 1
return exc.__cause__.exceptions[0]
return exc.__cause__


@contextlib.contextmanager
def threaded_run(name):
def threaded_run(name: str) -> Iterator[MetaRunner]:
gc.collect()
runner = MetaRunner()
thread = threading.Thread(target=runner.run, name=name, daemon=True)
Expand Down Expand Up @@ -90,7 +99,7 @@ async def with_return():
runner.register_payload(with_return, flavour=flavour)
with pytest.raises(RuntimeError) as exc:
runner.run()
assert isinstance(exc.value.__cause__, OrphanedReturn)
assert isinstance(unwrap_cause(exc.value), OrphanedReturn)

@pytest.mark.parametrize("flavour", (threading,))
def test_abort_subroutine(self, flavour):
Expand Down Expand Up @@ -130,7 +139,7 @@ async def abort():
runner.register_payload(abort, flavour=flavour)
with pytest.raises(RuntimeError) as exc:
runner.run()
assert isinstance(exc.value.__cause__, TerminateRunner)
assert isinstance(unwrap_cause(exc.value), TerminateRunner)

async def noop():
return
Expand All @@ -144,4 +153,4 @@ async def loop():
runner.register_payload(abort, flavour=flavour)
with pytest.raises(RuntimeError) as exc:
runner.run()
assert isinstance(exc.value.__cause__, TerminateRunner)
assert isinstance(unwrap_cause(exc.value), TerminateRunner)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
with open(os.path.join(repo_base_dir, "README.rst"), "r") as README:
long_description = README.read()

TESTS_REQUIRE = ["pytest>=4.3.0", "pytest-timeout"]
TESTS_REQUIRE = ["pytest>=4.3.0", "pytest-timeout", "exceptiongroup"]

if __name__ == "__main__":
setup(
Expand Down

0 comments on commit c46246e

Please sign in to comment.