Skip to content

Commit

Permalink
feat: add raises=True to tests to support changes in default
Browse files Browse the repository at this point in the history
Signed-off-by: Guilhem Barthes <guilhem.barthes@owkin.com>
  • Loading branch information
guilhem-barthes committed Jun 27, 2023
1 parent e36d633 commit 2b77292
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 66 deletions.
18 changes: 11 additions & 7 deletions tests/test_data_samples_order.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
from substra.sdk.models import Status

import substratest as sbt
from substratest.factory import DEFAULT_DATA_SAMPLE_FILENAME
Expand Down Expand Up @@ -208,7 +207,8 @@ def test_task_data_samples_relative_order(factory, client, dataset, worker):
assert [
i.asset_key for i in traintask.inputs if i.identifier == InputIdentifiers.datasamples
] == dataset.train_data_sample_keys
client.wait_task(traintask.key)
# `raises = True`, will fail if task not successful
client.wait_task(traintask.key, raises=True)

predict_input_models = FLTaskInputGenerator.train_to_predict(traintask.key)
predicttask_spec = factory.create_predicttask(
Expand All @@ -224,7 +224,8 @@ def test_task_data_samples_relative_order(factory, client, dataset, worker):
testtask = client.add_task(testtask_spec)

# Assert order is correct in the metric. If not, wait_task() will fail.
client.wait_task(testtask.key)
# `raises = True`, will fail if task not successful
client.wait_task(testtask.key, raises=True)


def test_composite_traintask_data_samples_relative_order(factory, client, dataset, worker):
Expand Down Expand Up @@ -263,7 +264,8 @@ def test_composite_traintask_data_samples_relative_order(factory, client, datase
assert [
i.asset_key for i in composite_traintask.inputs if i.identifier == InputIdentifiers.datasamples
] == dataset.train_data_sample_keys
client.wait_task(composite_traintask.key)
# `raises = True`, will fail if task not successful
client.wait_task(composite_traintask.key, raises=True)

predict_input_models = FLTaskInputGenerator.composite_to_predict(composite_traintask.key)

Expand All @@ -280,7 +282,8 @@ def test_composite_traintask_data_samples_relative_order(factory, client, datase
testtask = client.add_task(testtask_spec)

# Assert order is correct in the metric. If not, _wait() will fail.
client.wait_task(testtask.key)
# `raises = True`, will fail if task not successful
client.wait_task(testtask.key, raises=True)


@pytest.mark.slow
Expand Down Expand Up @@ -352,5 +355,6 @@ def save_predictions(predictions, path):
worker=worker,
)
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key)
assert traintask.status == Status.done

# `raises = True`, will fail if task not successful
client.wait_task(traintask.key, raises=True)
3 changes: 2 additions & 1 deletion tests/test_docker_image_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ def test_base_substra_tools_image(factory, cfg, client, default_dataset, worker)
worker=worker,
)
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key)
# `raises = True`, will fail if task not successful
client.wait_task(traintask.key, raises=True)
98 changes: 48 additions & 50 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def get_traintask_spec() -> TaskSpec:

spec = get_traintask_spec()
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key)
assert traintask.status == Status.done
traintask = client.wait_task(traintask.key, raises=True)
assert traintask.error_type is None
assert traintask.metadata == {"foo": "bar"}
assert len(traintask.outputs) == 1
Expand All @@ -62,8 +61,8 @@ def get_traintask_spec() -> TaskSpec:
)

predicttask = client.add_task(spec)
predicttask = client.wait_task(predicttask.key)
assert predicttask.status == Status.done
# `raises = True`, will fail if task not successful
predicttask = client.wait_task(predicttask.key, raises=True)
assert predicttask.error_type is None

spec = factory.create_testtask(
Expand All @@ -72,8 +71,8 @@ def get_traintask_spec() -> TaskSpec:
worker=worker,
)
testtask = client.add_task(spec)
testtask = client.wait_task(testtask.key)
assert testtask.status == Status.done
# `raises = True`, will fail if task not successful
testtask = client.wait_task(testtask.key, raises=True)
assert testtask.error_type is None
performance = client.get_task_output_asset(testtask.key, OutputIdentifiers.performance)
assert performance.asset == pytest.approx(2)
Expand All @@ -87,8 +86,8 @@ def get_traintask_spec() -> TaskSpec:
worker=worker,
)
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key)
assert traintask.status == Status.done
# `raises = True`, will fail if task not successful
traintask = client.wait_task(traintask.key, raises=True)
assert testtask.error_type is None
assert traintask.metadata == {}

Expand Down Expand Up @@ -123,9 +122,8 @@ def test_federated_learning_workflow(factory, client, default_datasets, workers)
worker=workers[index],
)
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key)
traintask = client.wait_task(traintask.key, raises=True)
out_models = client.get_task_models(traintask.key)
assert traintask.status == Status.done
assert traintask.error_type is None
assert len(traintask.outputs) == 1
assert len(out_models) == 1
Expand Down Expand Up @@ -170,8 +168,8 @@ def test_tasks_execution_on_different_organizations(
worker=workers[1],
)
traintask = client_1.add_task(spec)
traintask = client_1.wait_task(traintask.key)
assert traintask.status == Status.done
# `raises = True`, will fail if task not successful
traintask = client_1.wait_task(traintask.key, raises=True)
assert traintask.error_type is None
assert len(traintask.outputs) == 1
# Raises an exception if the output asset have not been created
Expand All @@ -185,8 +183,8 @@ def test_tasks_execution_on_different_organizations(
worker=workers[0],
)
predicttask = client_1.add_task(spec)
predicttask = client_1.wait_task(predicttask.key)
assert predicttask.status == Status.done
# `raises = True`, will fail if task not successful
predicttask = client_1.wait_task(predicttask.key, raises=True)
assert predicttask.error_type is None
assert predicttask.worker == client_1.organization_id

Expand All @@ -196,8 +194,8 @@ def test_tasks_execution_on_different_organizations(
worker=workers[0],
)
testtask = client_1.add_task(spec)
testtask = client_1.wait_task(testtask.key)
assert testtask.status == Status.done
# `raises = True`, will fail if task not successful
testtask = client_1.wait_task(testtask.key, raises=True)
assert testtask.error_type is None
assert testtask.worker == client_1.organization_id
performance = client_2.get_task_output_asset(testtask.key, OutputIdentifiers.performance)
Expand Down Expand Up @@ -297,7 +295,8 @@ def score(inputs, outputs, task_properties):
worker=worker,
)
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key)
# `raises = True`, will fail if task not successful
traintask = client.wait_task(traintask.key, raises=True)

# add predicttask
spec = factory.create_predicttask(
Expand All @@ -306,7 +305,8 @@ def score(inputs, outputs, task_properties):
worker=worker,
)
predicttask = client.add_task(spec)
predicttask = client.wait_task(predicttask.key)
# `raises = True`, will fail if task not successful
predicttask = client.wait_task(predicttask.key, raises=True)

# add metric function
spec = factory.create_function(category=FunctionCategory.metric, py_script=custom_metric_script)
Expand All @@ -328,8 +328,8 @@ def score(inputs, outputs, task_properties):
)

testtask = client.add_task(spec)
testtask = client.wait_task(testtask.key)
assert testtask.status == Status.done
# `raises = True`, will fail if task not successful
testtask = client.wait_task(testtask.key, raises=True)
assert testtask.error_type is None
output_1 = client.get_task_output_asset(testtask.key, identifier_1)
output_2 = client.get_task_output_asset(testtask.key, identifier_2)
Expand Down Expand Up @@ -447,8 +447,8 @@ def test_composite_traintasks_execution(factory, client, default_dataset, defaul
worker=worker,
)
composite_traintask_1 = client.add_task(spec)
composite_traintask_1 = client.wait_task(composite_traintask_1.key)
assert composite_traintask_1.status == Status.done
# `raises = True`, will fail if task not successful
composite_traintask_1 = client.wait_task(composite_traintask_1.key, raises=True)
assert composite_traintask_1.error_type is None
assert len(composite_traintask_1.outputs) == 2

Expand All @@ -460,8 +460,8 @@ def test_composite_traintasks_execution(factory, client, default_dataset, defaul
worker=worker,
)
composite_traintask_2 = client.add_task(spec)
composite_traintask_2 = client.wait_task(composite_traintask_2.key)
assert composite_traintask_2.status == Status.done
# `raises = True`, will fail if task not successful
composite_traintask_2 = client.wait_task(composite_traintask_2.key, raises=True)
assert composite_traintask_2.error_type is None
assert len(composite_traintask_2.outputs) == 2

Expand All @@ -472,7 +472,8 @@ def test_composite_traintasks_execution(factory, client, default_dataset, defaul
worker=worker,
)
predicttask = client.add_task(spec)
predicttask = client.wait_task(predicttask.key)
# `raises = True`, will fail if task not successful
predicttask = client.wait_task(predicttask.key, raises=True)
assert predicttask.status == Status.done
assert predicttask.error_type is None

Expand All @@ -482,8 +483,7 @@ def test_composite_traintasks_execution(factory, client, default_dataset, defaul
worker=worker,
)
testtask = client.add_task(spec)
testtask = client.wait_task(testtask.key)
assert testtask.status == Status.done
testtask = client.wait_task(testtask.key, raises=True)
assert testtask.error_type is None
performance = client.get_task_output_asset(testtask.key, OutputIdentifiers.performance)
assert performance.asset == pytest.approx(32)
Expand Down Expand Up @@ -537,15 +537,17 @@ def test_aggregatetask(factory, client, default_metric, default_dataset, worker)
worker=worker,
)
predicttask = client.add_task(spec)
predicttask = client.wait_task(predicttask.key)
# `raises = True`, will fail if task not successful
client.wait_task(predicttask.key, raises=True)

spec = factory.create_testtask(
function=default_metric,
inputs=default_dataset.test_data_inputs + FLTaskInputGenerator.predict_to_test(predicttask.key),
worker=worker,
)
testtask = client.add_task(spec)
testtask = client.wait_task(testtask.key)
# `raises = True`, will fail if task not successful
client.wait_task(testtask.key, raises=True)


@pytest.mark.slow
Expand Down Expand Up @@ -591,8 +593,8 @@ def test_aggregatetask_chained(factory, client, default_dataset, worker):
)

aggregatetask_2 = client.add_task(spec)
aggregatetask_2 = client.wait_task(aggregatetask_2.key)
assert aggregatetask_2.status == Status.done
# `raises = True`, will fail if task not successful
client.wait_task(aggregatetask_2.key, raises=True)
assert aggregatetask_2.error_type is None
assert len([i for i in aggregatetask_2.inputs if i.identifier == InputIdentifiers.models]) == 1

Expand Down Expand Up @@ -640,7 +642,8 @@ def test_aggregatetask_traintask(factory, client, default_dataset, worker):
)

traintask_2 = client.add_task(spec)
traintask_2 = client.wait_task(traintask_2.key)
# `raises = True`, will fail if task not successful
traintask_2 = client.wait_task(traintask_2.key, raises=True)

assert traintask_2.status == Status.done
assert traintask_2.error_type is None
Expand Down Expand Up @@ -683,9 +686,8 @@ def test_composite_traintask_2_organizations_to_composite_traintask(factory, cli
worker=workers[0],
)
composite_traintask = clients[0].add_task(spec)
composite_traintask = clients[0].wait_task(composite_traintask.key)

assert composite_traintask.status == Status.done
# `raises = True`, will fail if task not successful
clients[0].wait_task(composite_traintask.key, raises=True)


@pytest.mark.slow
Expand Down Expand Up @@ -758,7 +760,7 @@ def test_aggregate_composite_traintasks(factory, network, clients, default_datas
)

t = clients[0].add_task(spec)
t = clients[0].wait_task(t.key)
clients[0].wait_task(t.key, raises=True)
composite_traintask_keys.append(t.key)

# create aggregate on its organization
Expand All @@ -768,7 +770,7 @@ def test_aggregate_composite_traintasks(factory, network, clients, default_datas
inputs=FLTaskInputGenerator.composites_to_aggregate(composite_traintask_keys),
)
aggregatetask = clients[0].add_task(spec)
aggregatetask = clients[0].wait_task(aggregatetask.key)
clients[0].wait_task(aggregatetask.key, raises=True)

# save state of round
previous_aggregatetask_key = aggregatetask.key
Expand All @@ -784,15 +786,15 @@ def test_aggregate_composite_traintasks(factory, network, clients, default_datas
worker=workers[index],
)
predicttask = clients[0].add_task(spec)
predicttask = clients[0].wait_task(predicttask.key)
clients[0].wait_task(predicttask.key, raises=True)

spec = factory.create_testtask(
function=metric,
inputs=dataset.test_data_inputs + FLTaskInputGenerator.predict_to_test(predicttask.key),
worker=workers[index],
)
testtask = clients[0].add_task(spec)
testtask = clients[0].wait_task(testtask.key)
clients[0].wait_task(testtask.key, raises=True)
# y_true: [20], y_pred: [52.0], result: 32.0
performance = clients[0].get_task_output_asset(testtask.key, OutputIdentifiers.performance)
assert performance.asset == pytest.approx(32 + index)
Expand All @@ -804,15 +806,15 @@ def test_aggregate_composite_traintasks(factory, network, clients, default_datas
worker=workers[0],
)
predicttask = clients[0].add_task(spec)
predicttask = clients[0].wait_task(predicttask.key)
clients[0].wait_task(predicttask.key, raises=True)

spec = factory.create_testtask(
function=default_metrics[0],
inputs=default_datasets[0].test_data_inputs + FLTaskInputGenerator.predict_to_test(predicttask.key),
worker=workers[0],
)
testtask = clients[0].add_task(spec)
testtask = clients[0].wait_task(testtask.key)
clients[0].wait_task(testtask.key, raises=True)
# y_true: [20], y_pred: [28.0], result: 8.0
performance = clients[0].get_task_output_asset(testtask.key, OutputIdentifiers.performance)
assert performance.asset == pytest.approx(8)
Expand Down Expand Up @@ -873,8 +875,7 @@ def test_use_data_sample_located_in_shared_path(factory, network, client, organi

spec = factory.create_traintask(function=function, inputs=dataset.train_data_inputs, worker=worker)
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key)
assert traintask.status == Status.done
traintask = client.wait_task(traintask.key, raises=True)
assert traintask.error_type is None

# Raises an exception if the output asset have not been created
Expand All @@ -885,16 +886,14 @@ def test_use_data_sample_located_in_shared_path(factory, network, client, organi
function=predict_function, traintask=traintask, dataset=dataset, data_samples=[data_sample_key], worker=worker
)
predicttask = client.add_task(spec)
predicttask = client.wait_task(predicttask.key)
assert predicttask.status == Status.done
predicttask = client.wait_task(predicttask.key, raises=True)
assert predicttask.error_type is None

spec = factory.create_testtask(
function=default_metric, predicttask=predicttask, dataset=dataset, data_samples=[data_sample_key], worker=worker
)
testtask = client.add_task(spec)
testtask = client.wait_task(testtask.key)
assert testtask.status == Status.done
testtask = client.wait_task(testtask.key, raises=True)
assert testtask.error_type is None
performance = client.get_task_output_asset(testtask.key, OutputIdentifiers.performance)
assert performance.asset == pytest.approx(2)
Expand Down Expand Up @@ -946,7 +945,7 @@ def save_predictions(predictions, path):
function = client.add_function(spec)
spec = factory.create_traintask(function=function, inputs=default_dataset.train_data_inputs, worker=worker)
traintask = client.add_task(spec)
client.wait_task(traintask.key)
client.wait_task(traintask.key, raises=True)


WRITE_TO_HOME_DIRECTORY_FUNCTION = f"""
Expand Down Expand Up @@ -997,7 +996,6 @@ def test_write_to_home_directory(factory, client, default_dataset, worker):
function = client.add_function(spec)
spec = factory.create_traintask(function=function, inputs=default_dataset.train_data_inputs, worker=worker)
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key)
traintask = client.wait_task(traintask.key, raises=True)

assert traintask.status == Status.done
assert traintask.error_type is None
4 changes: 2 additions & 2 deletions tests/test_execution_compute_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,8 +648,8 @@ def test_execution_compute_plan_canceled(factory, client, default_dataset, cfg,
# wait the first traintask to be executed to ensure that the compute plan is launched
# and tasks are scheduled in the celery workers
first_traintask = [t for t in client.list_compute_plan_tasks(cp.key) if t.rank == 0][0]
first_traintask = client.wait_task(first_traintask.key)
assert first_traintask.status == models.Status.done
# `raises = True`, will fail if task not successful
client.wait_task(first_traintask.key, raises=True)

client.cancel_compute_plan(cp.key)
# as cancel request do not directly update localrep we need to wait for the sync
Expand Down
Loading

0 comments on commit 2b77292

Please sign in to comment.