Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #7299: dbt retry #7763

Merged
merged 15 commits into from
Jun 5, 2023
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230602-083302.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: dbt retry
time: 2023-06-02T08:33:02.410456-07:00
custom:
Author: stu-k aranke
Issue: "7299"
5 changes: 4 additions & 1 deletion core/dbt/cli/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def add_fn(x):

spinal_cased = k.replace("_", "-")

if v in (None, False):
if k == "macro" and command == CliCommand.RUN_OPERATION:
add_fn(v)
elif v in (None, False):
add_fn(f"--no-{spinal_cased}")
elif v is True:
add_fn(f"--{spinal_cased}")
Expand Down Expand Up @@ -384,6 +386,7 @@ def command_args(command: CliCommand) -> ArgsList:
CliCommand.SNAPSHOT: cli.snapshot,
CliCommand.SOURCE_FRESHNESS: cli.freshness,
CliCommand.TEST: cli.test,
CliCommand.RETRY: cli.retry,
}
click_cmd: Optional[ClickCommand] = CMD_DICT.get(command, None)
if click_cmd is None:
Expand Down
32 changes: 32 additions & 0 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from dbt.task.generate import GenerateTask
from dbt.task.init import InitTask
from dbt.task.list import ListTask
from dbt.task.retry import RetryTask
from dbt.task.run import RunTask
from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
Expand Down Expand Up @@ -576,6 +577,36 @@ def run(ctx, **kwargs):
return results, success


# dbt run
dbeatty10 marked this conversation as resolved.
Show resolved Hide resolved
@cli.command("retry")
@click.pass_context
@p.project_dir
@p.profiles_dir
@p.vars
aranke marked this conversation as resolved.
Show resolved Hide resolved
@p.profile
@p.target
@p.state
@p.threads
@p.fail_fast
aranke marked this conversation as resolved.
Show resolved Hide resolved
@requires.postflight
@requires.preflight
@requires.profile
@requires.project
@requires.runtime_config
@requires.manifest
def retry(ctx, **kwargs):
"""Retry the nodes that failed in the previous run."""
task = RetryTask(
ctx.obj["flags"],
ctx.obj["runtime_config"],
ctx.obj["manifest"],
)

results = task.run()
success = task.interpret_results(results)
return results, success


# dbt run operation
@cli.command("run-operation")
@click.pass_context
Expand All @@ -586,6 +617,7 @@ def run(ctx, **kwargs):
@p.project_dir
@p.target
@p.target_path
@p.threads
@p.vars
@requires.postflight
@requires.preflight
Expand Down
1 change: 1 addition & 0 deletions core/dbt/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Command(Enum):
SNAPSHOT = "snapshot"
SOURCE_FRESHNESS = "freshness"
TEST = "test"
RETRY = "retry"

@classmethod
def from_str(cls, s: str) -> "Command":
Expand Down
113 changes: 113 additions & 0 deletions core/dbt/task/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from pathlib import Path

from dbt.cli.flags import Flags
from dbt.cli.types import Command as CliCommand
from dbt.config import RuntimeConfig
from dbt.contracts.results import NodeStatus
from dbt.contracts.state import PreviousState
from dbt.exceptions import DbtRuntimeError
from dbt.graph import GraphQueue
from dbt.task.base import ConfiguredTask
from dbt.task.build import BuildTask
from dbt.task.compile import CompileTask
from dbt.task.generate import GenerateTask
from dbt.task.run import RunTask
from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask

RETRYABLE_STATUSES = {NodeStatus.Error, NodeStatus.Fail, NodeStatus.Skipped, NodeStatus.RuntimeErr}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that @MichelleArk noticed while investigating #7744 (fix in #7767): If a model is cancelled because of --fail-fast, its downstream dependencies will not be recorded as "skipped" in run_results.json, and so it won't be retried by dbt retry.

(Repro case: model_a -> model_b (with error) -> model_c. dbt run -x && dbt retry. model_c never runs.)

Would it make sense to address that here, or to open as a new/separate issue?

Copy link
Contributor

@ChenyuLInx ChenyuLInx Jun 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have two options:

  1. record that in run_results.json.
  2. in retry do something to select the down stream models.

I prefer the first one and if that's the solution we pick we should do it in a separate PR.


TASK_DICT = {
"build": BuildTask,
"compile": CompileTask,
"generate": GenerateTask,
"seed": SeedTask,
"snapshot": SnapshotTask,
"test": TestTask,
"run": RunTask,
"run-operation": RunOperationTask,
}

CMD_DICT = {
aranke marked this conversation as resolved.
Show resolved Hide resolved
"build": CliCommand.BUILD,
"compile": CliCommand.COMPILE,
"generate": CliCommand.DOCS_GENERATE,
"seed": CliCommand.SEED,
"snapshot": CliCommand.SNAPSHOT,
"test": CliCommand.TEST,
"run": CliCommand.RUN,
"run-operation": CliCommand.RUN_OPERATION,
}


class RetryTask(ConfiguredTask):
def __init__(self, args, config, manifest):
super().__init__(args, config, manifest)

state_path = self.args.state or self.config.target_path

if self.args.warn_error:
RETRYABLE_STATUSES.add(NodeStatus.Warn)
Comment on lines +51 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is:

  • If the previous command passed --warn-error, then we will rerun tests or freshness checks that warned
  • Otherwise, we will not rerun those warn-status tests/checks

I buy that - let's be sure to document it.


self.previous_state = PreviousState(
aranke marked this conversation as resolved.
Show resolved Hide resolved
state_path=Path(state_path),
target_path=Path(self.config.target_path),
project_root=Path(self.config.project_root),
)

if not self.previous_state.results:
raise DbtRuntimeError(
f"Could not find previous run in '{state_path}' target directory"
)

self.previous_args = self.previous_state.results.args
self.previous_command_name = self.previous_args.get("which")
self.task_class = TASK_DICT.get(self.previous_command_name)

def run(self):
unique_ids = set(
aranke marked this conversation as resolved.
Show resolved Hide resolved
[
result.unique_id
for result in self.previous_state.results.results
if result.status in RETRYABLE_STATUSES
]
)

cli_command = CMD_DICT.get(self.previous_command_name)

# Remove these args when their default values are present, otherwise they'll raise an exception
args_to_remove = {
"show": lambda x: True,
"resource_types": lambda x: x == [],
"warn_error_options": lambda x: x == {"exclude": [], "include": []},
}

for k, v in args_to_remove.items():
if k in self.previous_args and v(self.previous_args[k]):
del self.previous_args[k]

retry_flags = Flags.from_dict(cli_command, self.previous_args)
retry_config = RuntimeConfig.from_args(args=retry_flags)

class TaskWrapper(self.task_class):
def get_graph_queue(self):
new_graph = self.graph.get_subset_graph(unique_ids)
return GraphQueue(
new_graph.graph,
self.manifest,
unique_ids,
)

task = TaskWrapper(
retry_flags,
retry_config,
self.manifest,
)

return_value = task.run()
return return_value

def interpret_results(self, *args, **kwargs):
return self.task_class.interpret_results(*args, **kwargs)
9 changes: 6 additions & 3 deletions core/dbt/task/run_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def _run_unsafe(self) -> agate.Table:
def run(self) -> RunResultsArtifact:
start = datetime.utcnow()
self.compile_manifest()

success = True

try:
self._run_unsafe()
except dbt.exceptions.Exception as exc:
Expand All @@ -59,8 +62,7 @@ def run(self) -> RunResultsArtifact:
fire_event(RunningOperationUncaughtError(exc=str(exc)))
fire_event(LogDebugStackTrace(exc_info=traceback.format_exc()))
success = False
else:
success = True

end = datetime.utcnow()

package_name, macro_name = self._get_macro_parts()
Expand Down Expand Up @@ -108,5 +110,6 @@ def run(self) -> RunResultsArtifact:

return results

def interpret_results(self, results):
@classmethod
aranke marked this conversation as resolved.
Show resolved Hide resolved
def interpret_results(cls, results):
return results.results[0].status == RunStatus.Success
47 changes: 47 additions & 0 deletions tests/functional/retry/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
models__sample_model = """select 1 as id, baz as foo"""
models__second_model = """select 1 as id, 2 as bar"""

models__union_model = """
select foo + bar as sum3 from {{ ref('sample_model') }}
left join {{ ref('second_model') }} on sample_model.id = second_model.id
"""

schema_yml = """
models:
- name: sample_model
columns:
- name: foo
tests:
- accepted_values:
values: [3]
quote: false
config:
severity: warn
- name: second_model
columns:
- name: bar
tests:
- accepted_values:
values: [3]
quote: false
config:
severity: warn
- name: union_model
columns:
- name: sum3
tests:
- accepted_values:
values: [3]
quote: false
"""

macros__alter_timezone_sql = """
{% macro alter_timezone(timezone='America/Los_Angeles') %}
{% set sql %}
SET TimeZone='{{ timezone }}';
{% endset %}

{% do run_query(sql) %}
{% do log("Timezone set to: " + timezone, info=True) %}
{% endmacro %}
"""
Loading