Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added impersonation_chain for dataflow operators #24046

Merged
merged 1 commit into from
Jun 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,16 +972,31 @@ def start_sql_job(
:param on_new_job_callback: Callback called when the job is known.
:return: the new job object
"""
gcp_options = [
f"--project={project_id}",
"--format=value(job.id)",
f"--job-name={job_name}",
f"--region={location}",
]

if self.impersonation_chain:
if isinstance(self.impersonation_chain, str):
impersonation_account = self.impersonation_chain
elif len(self.impersonation_chain) == 1:
impersonation_account = self.impersonation_chain[0]
else:
raise AirflowException(
"Chained list of accounts is not supported, please specify only one service account"
)
Comment on lines +988 to +990
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don’t support multiple accounts, why do we accept a list in the first place? Can we not simply accept the str case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently in the several operators we have the same implementation if we don't support multiple accounts and also to be consistent in the parameters for Google operators.

gcp_options.append(f"--impersonate-service-account={impersonation_account}")

cmd = [
"gcloud",
"dataflow",
"sql",
"query",
query,
f"--project={project_id}",
"--format=value(job.id)",
f"--job-name={job_name}",
f"--region={location}",
*gcp_options,
*(beam_options_to_args(options)),
]
self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd))
Expand Down
22 changes: 22 additions & 0 deletions airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,14 @@ class DataflowStartFlexTemplateOperator(BaseOperator):

If you in your pipeline do not call the wait_for_pipeline method, and pass wait_until_finish=False
to the operator, the second loop will check once is job not in terminal state and exit the loop.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields: Sequence[str] = ("body", "location", "project_id", "gcp_conn_id")
Expand All @@ -742,6 +750,7 @@ def __init__(
drain_pipeline: bool = False,
cancel_timeout: Optional[int] = 10 * 60,
wait_until_finished: Optional[bool] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
*args,
**kwargs,
) -> None:
Expand All @@ -756,6 +765,7 @@ def __init__(
self.wait_until_finished = wait_until_finished
self.job = None
self.hook: Optional[DataflowHook] = None
self.impersonation_chain = impersonation_chain

def execute(self, context: 'Context'):
self.hook = DataflowHook(
Expand All @@ -764,6 +774,7 @@ def execute(self, context: 'Context'):
drain_pipeline=self.drain_pipeline,
cancel_timeout=self.cancel_timeout,
wait_until_finished=self.wait_until_finished,
impersonation_chain=self.impersonation_chain,
)

def set_current_job(current_job):
Expand Down Expand Up @@ -821,6 +832,14 @@ class DataflowStartSqlJobOperator(BaseOperator):
:param drain_pipeline: Optional, set to True if want to stop streaming job by draining it
instead of canceling during killing task instance. See:
https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields: Sequence[str] = (
Expand All @@ -843,6 +862,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
drain_pipeline: bool = False,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
*args,
**kwargs,
) -> None:
Expand All @@ -855,6 +875,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.drain_pipeline = drain_pipeline
self.impersonation_chain = impersonation_chain
self.job = None
self.hook: Optional[DataflowHook] = None

Expand All @@ -863,6 +884,7 @@ def execute(self, context: 'Context'):
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
drain_pipeline=self.drain_pipeline,
impersonation_chain=self.impersonation_chain,
)

def set_current_job(current_job):
Expand Down
17 changes: 3 additions & 14 deletions tests/providers/google/cloud/hooks/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,10 @@ def test_fn(self, *args, **kwargs):
FixtureFallback().test_fn({'project': "TEST"}, "TEST2")


def mock_init(
self,
gcp_conn_id,
delegate_to=None,
impersonation_chain=None,
):
pass


class TestDataflowHook(unittest.TestCase):
def setUp(self):
with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init):
self.dataflow_hook = DataflowHook(gcp_conn_id='test')
self.dataflow_hook.beam_hook = MagicMock()
self.dataflow_hook = DataflowHook(gcp_conn_id='google_cloud_default')
self.dataflow_hook.beam_hook = MagicMock()

@mock.patch("airflow.providers.google.cloud.hooks.dataflow.DataflowHook._authorize")
@mock.patch("airflow.providers.google.cloud.hooks.dataflow.build")
Expand Down Expand Up @@ -792,8 +782,7 @@ def test_wait_for_done(self, mock_conn, mock_dataflowjob):

class TestDataflowTemplateHook(unittest.TestCase):
def setUp(self):
with mock.patch(BASE_STRING.format('GoogleBaseHook.__init__'), new=mock_init):
self.dataflow_hook = DataflowHook(gcp_conn_id='test')
self.dataflow_hook = DataflowHook(gcp_conn_id='google_cloud_default')

@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'), return_value=MOCK_UUID)
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
Expand Down
13 changes: 12 additions & 1 deletion tests/providers/google/cloud/operators/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,14 @@ def test_execute(self, mock_dataflow):
location=TEST_LOCATION,
)
start_flex_template.execute(mock.MagicMock())
mock_dataflow.assert_called_once_with(
gcp_conn_id='google_cloud_default',
delegate_to=None,
drain_pipeline=False,
cancel_timeout=600,
wait_until_finished=None,
impersonation_chain=None,
)
mock_dataflow.return_value.start_flex_template.assert_called_once_with(
body={"launchParameter": TEST_FLEX_PARAMETERS},
location=TEST_LOCATION,
Expand Down Expand Up @@ -533,7 +541,10 @@ def test_execute(self, mock_hook):

start_sql.execute(mock.MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id='google_cloud_default', delegate_to=None, drain_pipeline=False
gcp_conn_id='google_cloud_default',
delegate_to=None,
drain_pipeline=False,
impersonation_chain=None,
)
mock_hook.return_value.start_sql_job.assert_called_once_with(
job_name=TEST_SQL_JOB_NAME,
Expand Down