Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add run-operation + snapshot to the RPC server (#1875) #1878

Merged
merged 3 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions core/dbt/contracts/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,15 @@ class GCParameters(RPCParameters):
will be applied to the task manager before GC starts. By default the
existing gc settings remain.
"""
task_ids: Optional[List[TaskID]]
before: Optional[datetime]
settings: Optional[GCSettings]
task_ids: Optional[List[TaskID]] = None
before: Optional[datetime] = None
settings: Optional[GCSettings] = None


@dataclass
class RPCRunOperationParameters(RPCParameters):
macro: str
args: Dict[str, Any] = field(default_factory=dict)


# Outputs
Expand Down Expand Up @@ -161,6 +167,11 @@ class ResultTable(JsonSchemaMixin):
rows: List[Any]


@dataclass
class RemoteRunOperationResult(RemoteResult):
success: bool


@dataclass
class RemoteRunResult(RemoteCompileResult):
table: ResultTable
Expand Down Expand Up @@ -431,6 +442,31 @@ def from_result(
)


@dataclass
class PollRunOperationCompleteResult(RemoteRunOperationResult, PollResult):
state: TaskHandlerState = field(
metadata=restrict_to(TaskHandlerState.Success,
TaskHandlerState.Failed),
)

@classmethod
def from_result(
cls: Type['PollRunOperationCompleteResult'],
base: RemoteRunOperationResult,
tags: TaskTags,
timing: TaskTiming,
) -> 'PollRunOperationCompleteResult':
return cls(
success=base.success,
logs=base.logs,
tags=tags,
state=timing.state,
start=timing.start,
end=timing.end,
elapsed=timing.elapsed,
)


@dataclass
class PollCatalogCompleteResult(RemoteCatalogResults, PollResult):
state: TaskHandlerState = field(
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _build_snapshot_subparser(subparsers, base_subparser):
'''
)
sub.set_defaults(cls=snapshot_task.SnapshotTask, which='snapshot',
rpc_method=None)
rpc_method='snapshot')
return sub


Expand Down Expand Up @@ -707,7 +707,7 @@ def _build_run_operation_subparser(subparsers, base_subparser):
'''
)
sub.set_defaults(cls=run_operation_task.RunOperationTask,
which='run-operation', rpc_method=None)
which='run-operation', rpc_method='run-operation')
return sub


Expand Down
5 changes: 5 additions & 0 deletions core/dbt/rpc/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RemoteCompileResult,
RemoteCatalogResults,
RemoteEmptyResult,
RemoteRunOperationResult,
PollParameters,
PollResult,
PollInProgressResult,
Expand All @@ -30,6 +31,7 @@
PollCompileCompleteResult,
PollCatalogCompleteResult,
PollRemoteEmptyCompleteResult,
PollRunOperationCompleteResult,
TaskHandlerState,
TaskTiming,
)
Expand Down Expand Up @@ -141,6 +143,7 @@ def poll_complete(
PollCompileCompleteResult,
PollCatalogCompleteResult,
PollRemoteEmptyCompleteResult,
PollRunOperationCompleteResult,
]]

if isinstance(result, RemoteExecutionResult):
Expand All @@ -154,6 +157,8 @@ def poll_complete(
cls = PollCatalogCompleteResult
elif isinstance(result, RemoteEmptyResult):
cls = PollRemoteEmptyCompleteResult
elif isinstance(result, RemoteRunOperationResult):
cls = PollRunOperationCompleteResult
else:
raise dbt.exceptions.InternalException(
'got invalid result in poll_complete: {}'.format(result)
Expand Down
41 changes: 41 additions & 0 deletions core/dbt/task/rpc/project_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
from dbt.contracts.rpc import (
RPCCompileParameters,
RPCDocsGenerateParameters,
RPCRunOperationParameters,
RPCSeedParameters,
RPCTestParameters,
RemoteCatalogResults,
RemoteExecutionResult,
RemoteRunOperationResult,
)
from dbt.rpc.method import (
Parameters,
)
from dbt.task.compile import CompileTask
from dbt.task.generate import GenerateTask
from dbt.task.run import RunTask
from dbt.task.run_operation import RunOperationTask
from dbt.task.seed import SeedTask
from dbt.task.snapshot import SnapshotTask
from dbt.task.test import TestTask

from .base import RPCTask
Expand Down Expand Up @@ -97,3 +101,40 @@ def get_catalog_results(
_compile_results=compile_results,
logs=[],
)


class RemoteRunOperationTask(
RPCTask[RPCRunOperationParameters],
HasCLI[RPCRunOperationParameters, RemoteRunOperationResult],
RunOperationTask,
):
METHOD_NAME = 'run-operation'

def set_args(self, params: RPCRunOperationParameters) -> None:
self.args.macro = params.macro
self.args.args = params.args

def _get_kwargs(self):
if isinstance(self.args.args, dict):
return self.args.args
else:
return RunOperationTask._get_kwargs(self)

def _runtime_initialize(self):
return RunOperationTask._runtime_initialize(self)

def handle_request(self) -> RemoteRunOperationResult:
success, _ = RunOperationTask.run(self)
result = RemoteRunOperationResult(logs=[], success=success)
return result

def interpret_results(self, results):
return results.success


class RemoteSnapshotTask(RPCCommandTask[RPCCompileParameters], SnapshotTask):
METHOD_NAME = 'snapshot'

def set_args(self, params: RPCCompileParameters) -> None:
self.args.models = self._listify(params.models)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the input to snapshot on the CLI is -s --select, not --models - can we update this accordingly?

self.args.exclude = self._listify(params.exclude)
4 changes: 4 additions & 0 deletions core/dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,10 @@ def default(self, obj):
return float(obj)
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
return obj.isoformat()
if hasattr(obj, 'to_dict'):
# if we have a to_dict we should try to serialize the result of
# that!
obj = obj.to_dict()
return super().default(obj)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pickle
import os


class TestRpcExecuteReturnsResults(DBTIntegrationTest):

@property
Expand All @@ -26,7 +27,7 @@ def test_pickle(self, agate_table):

pickle.dumps(table)

def test_file(self, filename):
def do_test_file(self, filename):
file_path = os.path.join("sql", filename)
with open(file_path) as fh:
query = fh.read()
Expand All @@ -39,12 +40,12 @@ def test_file(self, filename):

@use_profile('bigquery')
def test__bigquery_fetch_and_serialize(self):
self.test_file('bigquery.sql')
self.do_test_file('bigquery.sql')

@use_profile('snowflake')
def test__snowflake_fetch_and_serialize(self):
self.test_file('snowflake.sql')
self.do_test_file('snowflake.sql')

@use_profile('redshift')
def test__redshift_fetch_and_serialize(self):
self.test_file('redshift.sql')
self.do_test_file('redshift.sql')
52 changes: 43 additions & 9 deletions test/rpc/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,19 @@ def run(
method='run', params=params, request_id=request_id
)

def run_operation(
self,
macro: str,
args: Optional[Dict[str, Any]],
request_id: int = 1,
):
params = {'macro': macro}
if args is not None:
params['args'] = args
return self.request(
method='run-operation', params=params, request_id=request_id
)

def seed(self, show: bool = None, request_id: int = 1):
params = {}
if show is not None:
Expand All @@ -224,6 +237,21 @@ def seed(self, show: bool = None, request_id: int = 1):
method='seed', params=params, request_id=request_id
)

def snapshot(
self,
models: Optional[Union[str, List[str]]] = None,
exclude: Optional[Union[str, List[str]]] = None,
request_id: int = 1,
):
params = {}
if models is not None:
params['models'] = models
if exclude is not None:
params['exclude'] = exclude
return self.request(
method='snapshot', params=params, request_id=request_id
)

def test(
self,
models: Optional[Union[str, List[str]]] = None,
Expand All @@ -241,7 +269,7 @@ def test(
params['data'] = data
if schema is not None:
params['schema'] = schema
return self.requuest(
return self.request(
method='test', params=params, request_id=request_id
)

Expand Down Expand Up @@ -419,6 +447,7 @@ def __init__(
packages=None,
models=None,
macros=None,
snapshots=None,
):
self.project = {
'name': name,
Expand All @@ -430,6 +459,7 @@ def __init__(
self.packages = packages
self.models = models
self.macros = macros
self.snapshots = snapshots

def _write_recursive(self, path, inputs):
for name, value in inputs.items():
Expand Down Expand Up @@ -460,19 +490,22 @@ def write_config(self, project_dir, remove=False):
cfg.remove()
cfg.write(yaml.safe_dump(self.project))

def write_models(self, project_dir, remove=False):
def _write_values(self, project_dir, remove, name, value):
if remove:
project_dir.join('models').remove()
project_dir.join(name).remove()

if value is not None:
self._write_recursive(project_dir.mkdir(name), value)

if self.models is not None:
self._write_recursive(project_dir.mkdir('models'), self.models)

def write_models(self, project_dir, remove=False):
self._write_values(project_dir, remove, 'models', self.models)

def write_macros(self, project_dir, remove=False):
if remove:
project_dir.join('macros').remove()
self._write_values(project_dir, remove, 'macros', self.macros)

if self.macros is not None:
self._write_recursive(project_dir.mkdir('macros'), self.macros)
def write_snapshots(self, project_dir, remove=False):
self._write_values(project_dir, remove, 'snapshots', self.snapshots)

def write_to(self, project_dir, remove=False):
if remove:
Expand All @@ -482,6 +515,7 @@ def write_to(self, project_dir, remove=False):
self.write_config(project_dir)
self.write_models(project_dir)
self.write_macros(project_dir)
self.write_snapshots(project_dir)


class TestArgs:
Expand Down
Loading