Skip to content

Commit

Permalink
chore: rename raises in raise_on_failure
Browse files Browse the repository at this point in the history
Signed-off-by: Guilhem Barthes <guilhem.barthes@owkin.com>
  • Loading branch information
guilhem-barthes committed Jun 29, 2023
1 parent 2f5aff0 commit f05e9ad
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 50 deletions.
10 changes: 5 additions & 5 deletions tests/test_data_samples_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_task_data_samples_relative_order(factory, client, dataset, worker):
i.asset_key for i in traintask.inputs if i.identifier == InputIdentifiers.datasamples
] == dataset.train_data_sample_keys
# `raises = True`, will fail if task not successful
client.wait_task(traintask.key, raises=True)
client.wait_task(traintask.key, raise_on_failure=True)

predict_input_models = FLTaskInputGenerator.train_to_predict(traintask.key)
predicttask_spec = factory.create_predicttask(
Expand All @@ -225,7 +225,7 @@ def test_task_data_samples_relative_order(factory, client, dataset, worker):

# Assert order is correct in the metric. If not, wait_task() will fail.
# `raises = True`, will fail if task not successful
client.wait_task(testtask.key, raises=True)
client.wait_task(testtask.key, raise_on_failure=True)


def test_composite_traintask_data_samples_relative_order(factory, client, dataset, worker):
Expand Down Expand Up @@ -265,7 +265,7 @@ def test_composite_traintask_data_samples_relative_order(factory, client, datase
i.asset_key for i in composite_traintask.inputs if i.identifier == InputIdentifiers.datasamples
] == dataset.train_data_sample_keys
# `raises = True`, will fail if task not successful
client.wait_task(composite_traintask.key, raises=True)
client.wait_task(composite_traintask.key, raise_on_failure=True)

predict_input_models = FLTaskInputGenerator.composite_to_predict(composite_traintask.key)

Expand All @@ -283,7 +283,7 @@ def test_composite_traintask_data_samples_relative_order(factory, client, datase

# Assert order is correct in the metric. If not, _wait() will fail.
# `raises = True`, will fail if task not successful
client.wait_task(testtask.key, raises=True)
client.wait_task(testtask.key, raise_on_failure=True)


@pytest.mark.slow
Expand Down Expand Up @@ -357,4 +357,4 @@ def save_predictions(predictions, path):
traintask = client.add_task(spec)

# `raises = True`, will fail if task not successful
client.wait_task(traintask.key, raises=True)
client.wait_task(traintask.key, raise_on_failure=True)
2 changes: 1 addition & 1 deletion tests/test_docker_image_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def test_base_substra_tools_image(factory, cfg, client, default_dataset, worker)
)
traintask = client.add_task(spec)
# `raises = True`, will fail if task not successful
client.wait_task(traintask.key, raises=True)
client.wait_task(traintask.key, raise_on_failure=True)
70 changes: 35 additions & 35 deletions tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_traintask_spec() -> TaskSpec:

spec = get_traintask_spec()
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key, raises=True)
traintask = client.wait_task(traintask.key, raise_on_failure=True)
assert traintask.error_type is None
assert traintask.metadata == {"foo": "bar"}
assert len(traintask.outputs) == 1
Expand All @@ -62,7 +62,7 @@ def get_traintask_spec() -> TaskSpec:

predicttask = client.add_task(spec)
# `raises = True`, will fail if task not successful
predicttask = client.wait_task(predicttask.key, raises=True)
predicttask = client.wait_task(predicttask.key, raise_on_failure=True)
assert predicttask.error_type is None

spec = factory.create_testtask(
Expand All @@ -72,7 +72,7 @@ def get_traintask_spec() -> TaskSpec:
)
testtask = client.add_task(spec)
# `raises = True`, will fail if task not successful
testtask = client.wait_task(testtask.key, raises=True)
testtask = client.wait_task(testtask.key, raise_on_failure=True)
assert testtask.error_type is None
performance = client.get_task_output_asset(testtask.key, OutputIdentifiers.performance)
assert performance.asset == pytest.approx(2)
Expand All @@ -87,7 +87,7 @@ def get_traintask_spec() -> TaskSpec:
)
traintask = client.add_task(spec)
# `raises = True`, will fail if task not successful
traintask = client.wait_task(traintask.key, raises=True)
traintask = client.wait_task(traintask.key, raise_on_failure=True)
assert testtask.error_type is None
assert traintask.metadata == {}

Expand Down Expand Up @@ -122,7 +122,7 @@ def test_federated_learning_workflow(factory, client, default_datasets, workers)
worker=workers[index],
)
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key, raises=True)
traintask = client.wait_task(traintask.key, raise_on_failure=True)
out_models = client.get_task_models(traintask.key)
assert traintask.error_type is None
assert len(traintask.outputs) == 1
Expand Down Expand Up @@ -169,7 +169,7 @@ def test_tasks_execution_on_different_organizations(
)
traintask = client_1.add_task(spec)
# `raises = True`, will fail if task not successful
traintask = client_1.wait_task(traintask.key, raises=True)
traintask = client_1.wait_task(traintask.key, raise_on_failure=True)
assert traintask.error_type is None
assert len(traintask.outputs) == 1
# Raises an exception if the output asset have not been created
Expand All @@ -184,7 +184,7 @@ def test_tasks_execution_on_different_organizations(
)
predicttask = client_1.add_task(spec)
# `raises = True`, will fail if task not successful
predicttask = client_1.wait_task(predicttask.key, raises=True)
predicttask = client_1.wait_task(predicttask.key, raise_on_failure=True)
assert predicttask.error_type is None
assert predicttask.worker == client_1.organization_id

Expand All @@ -195,7 +195,7 @@ def test_tasks_execution_on_different_organizations(
)
testtask = client_1.add_task(spec)
# `raises = True`, will fail if task not successful
testtask = client_1.wait_task(testtask.key, raises=True)
testtask = client_1.wait_task(testtask.key, raise_on_failure=True)
assert testtask.error_type is None
assert testtask.worker == client_1.organization_id
performance = client_2.get_task_output_asset(testtask.key, OutputIdentifiers.performance)
Expand All @@ -221,7 +221,7 @@ def test_function_build_failure(factory, network, default_dataset_1, worker):
network.clients[0].add_task(spec)
else:
traintask = network.clients[0].add_task(spec)
traintask = network.clients[0].wait_task(traintask.key, raises=False)
traintask = network.clients[0].wait_task(traintask.key, raise_on_failure=False)

assert traintask.status == Status.failed
assert traintask.error_type == substra.sdk.models.TaskErrorType.build
Expand All @@ -248,7 +248,7 @@ def test_task_execution_failure(factory, network, default_dataset_1, worker):
network.clients[0].add_task(spec)
else:
traintask = network.clients[0].add_task(spec)
traintask = network.clients[0].wait_task(traintask.key, raises=False)
traintask = network.clients[0].wait_task(traintask.key, raise_on_failure=False)

assert traintask.status == Status.failed
assert traintask.error_type == substra.sdk.models.TaskErrorType.execution
Expand Down Expand Up @@ -296,7 +296,7 @@ def score(inputs, outputs, task_properties):
)
traintask = client.add_task(spec)
# `raises = True`, will fail if task not successful
traintask = client.wait_task(traintask.key, raises=True)
traintask = client.wait_task(traintask.key, raise_on_failure=True)

# add predicttask
spec = factory.create_predicttask(
Expand All @@ -306,7 +306,7 @@ def score(inputs, outputs, task_properties):
)
predicttask = client.add_task(spec)
# `raises = True`, will fail if task not successful
predicttask = client.wait_task(predicttask.key, raises=True)
predicttask = client.wait_task(predicttask.key, raise_on_failure=True)

# add metric function
spec = factory.create_function(category=FunctionCategory.metric, py_script=custom_metric_script)
Expand All @@ -329,7 +329,7 @@ def score(inputs, outputs, task_properties):

testtask = client.add_task(spec)
# `raises = True`, will fail if task not successful
testtask = client.wait_task(testtask.key, raises=True)
testtask = client.wait_task(testtask.key, raise_on_failure=True)
assert testtask.error_type is None
output_1 = client.get_task_output_asset(testtask.key, identifier_1)
output_2 = client.get_task_output_asset(testtask.key, identifier_2)
Expand Down Expand Up @@ -363,7 +363,7 @@ def test_composite_traintask_execution_failure(factory, client, default_dataset,
)
if client.backend_mode == substra.BackendType.REMOTE:
composite_traintask = client.add_task(spec)
composite_traintask = client.wait_task(composite_traintask.key, raises=False)
composite_traintask = client.wait_task(composite_traintask.key, raise_on_failure=False)

assert composite_traintask.status == Status.failed
assert composite_traintask.error_type == substra.sdk.models.TaskErrorType.execution
Expand Down Expand Up @@ -409,7 +409,7 @@ def test_aggregatetask_execution_failure(factory, client, default_dataset, worke

if client.backend_mode == substra.BackendType.REMOTE:
aggregatetask = client.add_task(spec)
aggregatetask = client.wait_task(aggregatetask.key, raises=False)
aggregatetask = client.wait_task(aggregatetask.key, raise_on_failure=False)

for composite_traintask_key in composite_traintask_keys:
composite_traintask = client.get_task(composite_traintask_key)
Expand Down Expand Up @@ -448,7 +448,7 @@ def test_composite_traintasks_execution(factory, client, default_dataset, defaul
)
composite_traintask_1 = client.add_task(spec)
# `raises = True`, will fail if task not successful
composite_traintask_1 = client.wait_task(composite_traintask_1.key, raises=True)
composite_traintask_1 = client.wait_task(composite_traintask_1.key, raise_on_failure=True)
assert composite_traintask_1.error_type is None
assert len(composite_traintask_1.outputs) == 2

Expand All @@ -461,7 +461,7 @@ def test_composite_traintasks_execution(factory, client, default_dataset, defaul
)
composite_traintask_2 = client.add_task(spec)
# `raises = True`, will fail if task not successful
composite_traintask_2 = client.wait_task(composite_traintask_2.key, raises=True)
composite_traintask_2 = client.wait_task(composite_traintask_2.key, raise_on_failure=True)
assert composite_traintask_2.error_type is None
assert len(composite_traintask_2.outputs) == 2

Expand All @@ -473,7 +473,7 @@ def test_composite_traintasks_execution(factory, client, default_dataset, defaul
)
predicttask = client.add_task(spec)
# `raises = True`, will fail if task not successful
predicttask = client.wait_task(predicttask.key, raises=True)
predicttask = client.wait_task(predicttask.key, raise_on_failure=True)
assert predicttask.status == Status.done
assert predicttask.error_type is None

Expand All @@ -483,7 +483,7 @@ def test_composite_traintasks_execution(factory, client, default_dataset, defaul
worker=worker,
)
testtask = client.add_task(spec)
testtask = client.wait_task(testtask.key, raises=True)
testtask = client.wait_task(testtask.key, raise_on_failure=True)
assert testtask.error_type is None
performance = client.get_task_output_asset(testtask.key, OutputIdentifiers.performance)
assert performance.asset == pytest.approx(32)
Expand Down Expand Up @@ -538,7 +538,7 @@ def test_aggregatetask(factory, client, default_metric, default_dataset, worker)
)
predicttask = client.add_task(spec)
# `raises = True`, will fail if task not successful
client.wait_task(predicttask.key, raises=True)
client.wait_task(predicttask.key, raise_on_failure=True)

spec = factory.create_testtask(
function=default_metric,
Expand All @@ -547,7 +547,7 @@ def test_aggregatetask(factory, client, default_metric, default_dataset, worker)
)
testtask = client.add_task(spec)
# `raises = True`, will fail if task not successful
client.wait_task(testtask.key, raises=True)
client.wait_task(testtask.key, raise_on_failure=True)


@pytest.mark.slow
Expand Down Expand Up @@ -594,7 +594,7 @@ def test_aggregatetask_chained(factory, client, default_dataset, worker):

aggregatetask_2 = client.add_task(spec)
# `raises = True`, will fail if task not successful
client.wait_task(aggregatetask_2.key, raises=True)
client.wait_task(aggregatetask_2.key, raise_on_failure=True)
assert aggregatetask_2.error_type is None
assert len([i for i in aggregatetask_2.inputs if i.identifier == InputIdentifiers.models]) == 1

Expand Down Expand Up @@ -643,7 +643,7 @@ def test_aggregatetask_traintask(factory, client, default_dataset, worker):

traintask_2 = client.add_task(spec)
# `raises = True`, will fail if task not successful
traintask_2 = client.wait_task(traintask_2.key, raises=True)
traintask_2 = client.wait_task(traintask_2.key, raise_on_failure=True)

assert traintask_2.status == Status.done
assert traintask_2.error_type is None
Expand Down Expand Up @@ -687,7 +687,7 @@ def test_composite_traintask_2_organizations_to_composite_traintask(factory, cli
)
composite_traintask = clients[0].add_task(spec)
# `raises = True`, will fail if task not successful
clients[0].wait_task(composite_traintask.key, raises=True)
clients[0].wait_task(composite_traintask.key, raise_on_failure=True)


@pytest.mark.slow
Expand Down Expand Up @@ -760,7 +760,7 @@ def test_aggregate_composite_traintasks(factory, network, clients, default_datas
)

t = clients[0].add_task(spec)
clients[0].wait_task(t.key, raises=True)
clients[0].wait_task(t.key, raise_on_failure=True)
composite_traintask_keys.append(t.key)

# create aggregate on its organization
Expand All @@ -770,7 +770,7 @@ def test_aggregate_composite_traintasks(factory, network, clients, default_datas
inputs=FLTaskInputGenerator.composites_to_aggregate(composite_traintask_keys),
)
aggregatetask = clients[0].add_task(spec)
clients[0].wait_task(aggregatetask.key, raises=True)
clients[0].wait_task(aggregatetask.key, raise_on_failure=True)

# save state of round
previous_aggregatetask_key = aggregatetask.key
Expand All @@ -786,15 +786,15 @@ def test_aggregate_composite_traintasks(factory, network, clients, default_datas
worker=workers[index],
)
predicttask = clients[0].add_task(spec)
clients[0].wait_task(predicttask.key, raises=True)
clients[0].wait_task(predicttask.key, raise_on_failure=True)

spec = factory.create_testtask(
function=metric,
inputs=dataset.test_data_inputs + FLTaskInputGenerator.predict_to_test(predicttask.key),
worker=workers[index],
)
testtask = clients[0].add_task(spec)
clients[0].wait_task(testtask.key, raises=True)
clients[0].wait_task(testtask.key, raise_on_failure=True)
# y_true: [20], y_pred: [52.0], result: 32.0
performance = clients[0].get_task_output_asset(testtask.key, OutputIdentifiers.performance)
assert performance.asset == pytest.approx(32 + index)
Expand All @@ -806,15 +806,15 @@ def test_aggregate_composite_traintasks(factory, network, clients, default_datas
worker=workers[0],
)
predicttask = clients[0].add_task(spec)
clients[0].wait_task(predicttask.key, raises=True)
clients[0].wait_task(predicttask.key, raise_on_failure=True)

spec = factory.create_testtask(
function=default_metrics[0],
inputs=default_datasets[0].test_data_inputs + FLTaskInputGenerator.predict_to_test(predicttask.key),
worker=workers[0],
)
testtask = clients[0].add_task(spec)
clients[0].wait_task(testtask.key, raises=True)
clients[0].wait_task(testtask.key, raise_on_failure=True)
# y_true: [20], y_pred: [28.0], result: 8.0
performance = clients[0].get_task_output_asset(testtask.key, OutputIdentifiers.performance)
assert performance.asset == pytest.approx(8)
Expand Down Expand Up @@ -875,7 +875,7 @@ def test_use_data_sample_located_in_shared_path(factory, network, client, organi

spec = factory.create_traintask(function=function, inputs=dataset.train_data_inputs, worker=worker)
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key, raises=True)
traintask = client.wait_task(traintask.key, raise_on_failure=True)
assert traintask.error_type is None

# Raises an exception if the output asset have not been created
Expand All @@ -886,14 +886,14 @@ def test_use_data_sample_located_in_shared_path(factory, network, client, organi
function=predict_function, traintask=traintask, dataset=dataset, data_samples=[data_sample_key], worker=worker
)
predicttask = client.add_task(spec)
predicttask = client.wait_task(predicttask.key, raises=True)
predicttask = client.wait_task(predicttask.key, raise_on_failure=True)
assert predicttask.error_type is None

spec = factory.create_testtask(
function=default_metric, predicttask=predicttask, dataset=dataset, data_samples=[data_sample_key], worker=worker
)
testtask = client.add_task(spec)
testtask = client.wait_task(testtask.key, raises=True)
testtask = client.wait_task(testtask.key, raise_on_failure=True)
assert testtask.error_type is None
performance = client.get_task_output_asset(testtask.key, OutputIdentifiers.performance)
assert performance.asset == pytest.approx(2)
Expand Down Expand Up @@ -945,7 +945,7 @@ def save_predictions(predictions, path):
function = client.add_function(spec)
spec = factory.create_traintask(function=function, inputs=default_dataset.train_data_inputs, worker=worker)
traintask = client.add_task(spec)
client.wait_task(traintask.key, raises=True)
client.wait_task(traintask.key, raise_on_failure=True)


WRITE_TO_HOME_DIRECTORY_FUNCTION = f"""
Expand Down Expand Up @@ -996,6 +996,6 @@ def test_write_to_home_directory(factory, client, default_dataset, worker):
function = client.add_function(spec)
spec = factory.create_traintask(function=function, inputs=default_dataset.train_data_inputs, worker=worker)
traintask = client.add_task(spec)
traintask = client.wait_task(traintask.key, raises=True)
traintask = client.wait_task(traintask.key, raise_on_failure=True)

assert traintask.error_type is None
6 changes: 3 additions & 3 deletions tests/test_execution_compute_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def test_compute_plan_single_client_failure(factory, client, default_dataset, de

# Submit compute plan and wait for it to complete
cp_added = client.add_compute_plan(cp_spec)
cp = client.wait_compute_plan(cp_added.key, raises=False)
cp = client.wait_compute_plan(cp_added.key, raise_on_failure=False)

assert cp.status == "PLAN_STATUS_FAILED"
assert cp.end_date is not None
Expand Down Expand Up @@ -649,11 +649,11 @@ def test_execution_compute_plan_canceled(factory, client, default_dataset, cfg,
# and tasks are scheduled in the celery workers
first_traintask = [t for t in client.list_compute_plan_tasks(cp.key) if t.rank == 0][0]
# `raises = True`, will fail if task not successful
client.wait_task(first_traintask.key, raises=True)
client.wait_task(first_traintask.key, raise_on_failure=True)

client.cancel_compute_plan(cp.key)
# as cancel request do not directly update localrep we need to wait for the sync
cp = client.wait_compute_plan(cp.key, raises=False, timeout=cfg.options.organization_sync_timeout)
cp = client.wait_compute_plan(cp.key, raise_on_failure=False, timeout=cfg.options.organization_sync_timeout)
assert cp.status == models.ComputePlanStatus.canceled
assert cp.end_date is not None
assert cp.duration is not None
Expand Down
Loading

0 comments on commit f05e9ad

Please sign in to comment.