diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index e7f7af5e236a5..2ee6d180de840 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -469,11 +469,7 @@ def execute(self, context: 'Context'): process_line_callback=process_line_callback, ) if dataflow_job_name and self.dataflow_config.location: - multiple_jobs = ( - self.dataflow_config.multiple_jobs - if self.dataflow_config.multiple_jobs - else False - ) + multiple_jobs = self.dataflow_config.multiple_jobs or False DataflowJobLink.persist( self, context, diff --git a/airflow/providers/apache/cassandra/hooks/cassandra.py b/airflow/providers/apache/cassandra/hooks/cassandra.py index 3d250741d2fc8..71c360ec472b8 100644 --- a/airflow/providers/apache/cassandra/hooks/cassandra.py +++ b/airflow/providers/apache/cassandra/hooks/cassandra.py @@ -169,9 +169,8 @@ def get_lb_policy(policy_name: str, policy_args: Dict[str, Any]) -> Policy: child_policy_args = policy_args.get('child_load_balancing_policy_args', {}) if child_policy_name not in allowed_child_policies: return TokenAwarePolicy(RoundRobinPolicy()) - else: - child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args) - return TokenAwarePolicy(child_policy) + child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args) + return TokenAwarePolicy(child_policy) # Fallback to default RoundRobinPolicy return RoundRobinPolicy() @@ -200,7 +199,7 @@ def record_exists(self, table: str, keys: Dict[str, str]) -> bool: keyspace = self.keyspace if '.' in table: keyspace, table = table.split('.', 1) - ks_str = " AND ".join(f"{key}=%({key})s" for key in keys.keys()) + ks_str = " AND ".join(f"{key}=%({key})s" for key in keys) query = f"SELECT * FROM {keyspace}.{table} WHERE {ks_str}" try: result = self.get_conn().execute(query, keys) diff --git a/airflow/providers/apache/drill/hooks/drill.py b/airflow/providers/apache/drill/hooks/drill.py index a15658e9e38ce..5baf1f9ecb47b 100644 --- a/airflow/providers/apache/drill/hooks/drill.py +++ b/airflow/providers/apache/drill/hooks/drill.py @@ -69,7 +69,7 @@ def get_uri(self) -> str: host = conn_md.host if conn_md.port is not None: host += f':{conn_md.port}' - conn_type = 'drill' if not conn_md.conn_type else conn_md.conn_type + conn_type = conn_md.conn_type or 'drill' dialect_driver = conn_md.extra_dejson.get('dialect_driver', 'drill+sadrill') storage_plugin = conn_md.extra_dejson.get('storage_plugin', 'dfs') return f'{conn_type}://{host}/{storage_plugin}?dialect_driver={dialect_driver}' diff --git a/airflow/providers/apache/druid/hooks/druid.py b/airflow/providers/apache/druid/hooks/druid.py index 671c914be604f..a10519eea476f 100644 --- a/airflow/providers/apache/druid/hooks/druid.py +++ b/airflow/providers/apache/druid/hooks/druid.py @@ -63,7 +63,7 @@ def get_conn_url(self) -> str: conn = self.get_connection(self.druid_ingest_conn_id) host = conn.host port = conn.port - conn_type = 'http' if not conn.conn_type else conn.conn_type + conn_type = conn.conn_type or 'http' endpoint = conn.extra_dejson.get('endpoint', '') return f"{conn_type}://{host}:{port}/{endpoint}" @@ -163,7 +163,7 @@ def get_uri(self) -> str: host = conn.host if conn.port is not None: host += f':{conn.port}' - conn_type = 'druid' if not conn.conn_type else conn.conn_type + conn_type = conn.conn_type or 'druid' endpoint = conn.extra_dejson.get('endpoint', 'druid/v2/sql') return f'{conn_type}://{host}/{endpoint}' diff --git a/airflow/providers/apache/hdfs/hooks/webhdfs.py b/airflow/providers/apache/hdfs/hooks/webhdfs.py index a32206ba9bef8..c8e1caa9bfa67 100644 --- a/airflow/providers/apache/hdfs/hooks/webhdfs.py +++ b/airflow/providers/apache/hdfs/hooks/webhdfs.py @@ -98,12 +98,9 @@ def _get_client(self, namenode: str, port: int, login: str, extra_dejson: dict) session.verify = extra_dejson.get('verify', True) if _kerberos_security_mode: - client = KerberosClient(connection_str, session=session) - else: - proxy_user = self.proxy_user or login - client = InsecureClient(connection_str, user=proxy_user, session=session) - - return client + return KerberosClient(connection_str, session=session) + proxy_user = self.proxy_user or login + return InsecureClient(connection_str, user=proxy_user, session=session) def check_for_path(self, hdfs_path: str) -> bool: """ diff --git a/airflow/providers/apache/hive/operators/hive_stats.py b/airflow/providers/apache/hive/operators/hive_stats.py index a1b5539622321..7cf2002a418f7 100644 --- a/airflow/providers/apache/hive/operators/hive_stats.py +++ b/airflow/providers/apache/hive/operators/hive_stats.py @@ -100,7 +100,7 @@ def get_default_exprs(self, col: str, col_type: str) -> Dict[Any, Any]: if col in self.excluded_columns: return {} exp = {(col, 'non_null'): f"COUNT({col})"} - if col_type in ['double', 'int', 'bigint', 'float']: + if col_type in {'double', 'int', 'bigint', 'float'}: exp[(col, 'sum')] = f'SUM({col})' exp[(col, 'min')] = f'MIN({col})' exp[(col, 'max')] = f'MAX({col})' @@ -108,7 +108,7 @@ def get_default_exprs(self, col: str, col_type: str) -> Dict[Any, Any]: elif col_type == 'boolean': exp[(col, 'true')] = f'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)' exp[(col, 'false')] = f'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)' - elif col_type in ['string']: + elif col_type == 'string': exp[(col, 'len')] = f'SUM(CAST(LENGTH({col}) AS BIGINT))' exp[(col, 'approx_distinct')] = f'APPROX_DISTINCT({col})' @@ -130,7 +130,7 @@ def execute(self, context: "Context") -> None: exprs.update(assign_exprs) exprs.update(self.extra_exprs) exprs = OrderedDict(exprs) - exprs_str = ",\n ".join(v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items()) + exprs_str = ",\n ".join(f"{v} AS {k[0]}__{k[1]}" for k, v in exprs.items()) where_clause_ = [f"{k} = '{v}'" for k, v in self.partition.items()] where_clause = " AND\n ".join(where_clause_) diff --git a/airflow/providers/apache/hive/sensors/named_hive_partition.py b/airflow/providers/apache/hive/sensors/named_hive_partition.py index 9535bcdab0219..61c902e95db25 100644 --- a/airflow/providers/apache/hive/sensors/named_hive_partition.py +++ b/airflow/providers/apache/hive/sensors/named_hive_partition.py @@ -76,7 +76,7 @@ def parse_partition_name(partition: str) -> Tuple[Any, ...]: schema, table_partition = first_split second_split = table_partition.split('/', 1) if len(second_split) == 1: - raise ValueError('Could not parse ' + partition + 'into table, partition') + raise ValueError(f'Could not parse {partition}into table, partition') else: table, partition = second_split return schema, table, partition diff --git a/airflow/providers/apache/hive/transfers/mssql_to_hive.py b/airflow/providers/apache/hive/transfers/mssql_to_hive.py index 912c2a58a36bd..b1b6e0b3c4907 100644 --- a/airflow/providers/apache/hive/transfers/mssql_to_hive.py +++ b/airflow/providers/apache/hive/transfers/mssql_to_hive.py @@ -115,9 +115,7 @@ def execute(self, context: "Context"): with NamedTemporaryFile("w") as tmp_file: csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding='utf-8') field_dict = OrderedDict() - col_count = 0 - for field in cursor.description: - col_count += 1 + for col_count, field in enumerate(cursor.description, start=1): col_position = f"Column{col_count}" field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor) diff --git a/airflow/providers/apache/hive/transfers/s3_to_hive.py b/airflow/providers/apache/hive/transfers/s3_to_hive.py index cc189303e0584..9f3827c56f2a5 100644 --- a/airflow/providers/apache/hive/transfers/s3_to_hive.py +++ b/airflow/providers/apache/hive/transfers/s3_to_hive.py @@ -142,11 +142,11 @@ def execute(self, context: 'Context'): if not s3_hook.check_for_wildcard_key(self.s3_key): raise AirflowException(f"No key matches {self.s3_key}") s3_key_object = s3_hook.get_wildcard_key(self.s3_key) - else: - if not s3_hook.check_for_key(self.s3_key): - raise AirflowException(f"The key {self.s3_key} does not exists") + elif s3_hook.check_for_key(self.s3_key): s3_key_object = s3_hook.get_key(self.s3_key) + else: + raise AirflowException(f"The key {self.s3_key} does not exists") _, file_ext = os.path.splitext(s3_key_object.key) if self.select_expression and self.input_compressed and file_ext.lower() != '.gz': raise AirflowException("GZIP is the only compression format Amazon S3 Select supports") @@ -227,8 +227,7 @@ def execute(self, context: 'Context'): def _get_top_row_as_list(self, file_name): with open(file_name) as file: header_line = file.readline().strip() - header_list = header_line.split(self.delimiter) - return header_list + return header_line.split(self.delimiter) def _match_headers(self, header_list): if not header_list: diff --git a/airflow/providers/apache/hive/transfers/vertica_to_hive.py b/airflow/providers/apache/hive/transfers/vertica_to_hive.py index 7e53638b0c70a..e3f432d9ad70d 100644 --- a/airflow/providers/apache/hive/transfers/vertica_to_hive.py +++ b/airflow/providers/apache/hive/transfers/vertica_to_hive.py @@ -117,9 +117,7 @@ def execute(self, context: 'Context'): with NamedTemporaryFile("w") as f: csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8') field_dict = OrderedDict() - col_count = 0 - for field in cursor.description: - col_count += 1 + for col_count, field in enumerate(cursor.description, start=1): col_position = f"Column{col_count}" field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1]) csv_writer.writerows(cursor.iterate()) diff --git a/airflow/providers/apache/kylin/hooks/kylin.py b/airflow/providers/apache/kylin/hooks/kylin.py index 032b15c7e5bbf..53a8bf5d909c8 100644 --- a/airflow/providers/apache/kylin/hooks/kylin.py +++ b/airflow/providers/apache/kylin/hooks/kylin.py @@ -48,16 +48,15 @@ def get_conn(self): conn = self.get_connection(self.kylin_conn_id) if self.dsn: return kylinpy.create_kylin(self.dsn) - else: - self.project = self.project if self.project else conn.schema - return kylinpy.Kylin( - conn.host, - username=conn.login, - password=conn.password, - port=conn.port, - project=self.project, - **conn.extra_dejson, - ) + self.project = self.project or conn.schema + return kylinpy.Kylin( + conn.host, + username=conn.login, + password=conn.password, + port=conn.port, + project=self.project, + **conn.extra_dejson, + ) def cube_run(self, datasource_name, op, **op_args): """ @@ -70,8 +69,7 @@ def cube_run(self, datasource_name, op, **op_args): """ cube_source = self.get_conn().get_datasource(datasource_name) try: - response = cube_source.invoke_command(op, **op_args) - return response + return cube_source.invoke_command(op, **op_args) except exceptions.KylinError as err: raise AirflowException(f"Cube operation {op} error , Message: {err}") diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py index 55ddce0bcc0cc..4943b37adcb7c 100644 --- a/airflow/providers/apache/pinot/hooks/pinot.py +++ b/airflow/providers/apache/pinot/hooks/pinot.py @@ -203,9 +203,7 @@ def run_cli(self, cmd: List[str], verbose: bool = True) -> str: :param cmd: List of command going to be run by pinot-admin.sh script :param verbose: """ - command = [self.cmd_path] - command.extend(cmd) - + command = [self.cmd_path, *cmd] env = None if self.pinot_admin_system_exit: env = os.environ.copy() @@ -273,7 +271,7 @@ def get_uri(self) -> str: host = conn.host if conn.port is not None: host += f':{conn.port}' - conn_type = 'http' if not conn.conn_type else conn.conn_type + conn_type = conn.conn_type or 'http' endpoint = conn.extra_dejson.get('endpoint', 'query/sql') return f'{conn_type}://{host}/{endpoint}' diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc.py b/airflow/providers/apache/spark/hooks/spark_jdbc.py index df9d715be0dcc..e95fae3dbbabc 100644 --- a/airflow/providers/apache/spark/hooks/spark_jdbc.py +++ b/airflow/providers/apache/spark/hooks/spark_jdbc.py @@ -220,7 +220,7 @@ def _build_jdbc_application_arguments(self, jdbc_conn: Dict[str, Any]) -> Any: def submit_jdbc_job(self) -> None: """Submit Spark JDBC job""" self._application_args = self._build_jdbc_application_arguments(self._jdbc_connection) - self.submit(application=os.path.dirname(os.path.abspath(__file__)) + "/spark_jdbc_script.py") + self.submit(application=f"{os.path.dirname(os.path.abspath(__file__))}/spark_jdbc_script.py") def get_conn(self) -> Any: pass diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py index 0f5dc2f7307cc..fd23e96958c54 100644 --- a/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/airflow/providers/apache/spark/hooks/spark_submit.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. # +import contextlib import os import re import subprocess @@ -28,10 +29,8 @@ from airflow.security.kerberos import renew_from_kt from airflow.utils.log.logging_mixin import LoggingMixin -try: +with contextlib.suppress(ImportError, NameError): from airflow.kubernetes import kube_client -except (ImportError, NameError): - pass class SparkSubmitHook(BaseHook, LoggingMixin): @@ -355,9 +354,7 @@ def _build_track_driver_status_command(self) -> List[str]: self.log.info(connection_cmd) # The driver id so we can poll for its status - if self._driver_id: - pass - else: + if not self._driver_id: raise AirflowException( "Invalid status: attempted to poll driver status but no driver id is known. Giving up." ) @@ -607,17 +604,14 @@ def on_kill(self) -> None: """Kill Spark submit command""" self.log.debug("Kill Command is being called") - if self._should_track_driver_status: - if self._driver_id: - self.log.info('Killing driver %s on cluster', self._driver_id) + if self._should_track_driver_status and self._driver_id: + self.log.info('Killing driver %s on cluster', self._driver_id) - kill_cmd = self._build_spark_driver_kill_command() - with subprocess.Popen( - kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) as driver_kill: - self.log.info( - "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() - ) + kill_cmd = self._build_spark_driver_kill_command() + with subprocess.Popen(kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as driver_kill: + self.log.info( + "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() + ) if self._submit_sp and self._submit_sp.poll() is None: self.log.info('Sending kill signal to %s', self._connection['spark_binary']) diff --git a/tests/providers/apache/hive/transfers/test_mssql_to_hive.py b/tests/providers/apache/hive/transfers/test_mssql_to_hive.py index 01973ae3b754a..cfab662f35635 100644 --- a/tests/providers/apache/hive/transfers/test_mssql_to_hive.py +++ b/tests/providers/apache/hive/transfers/test_mssql_to_hive.py @@ -113,9 +113,7 @@ def test_execute_empty_description_field(self, mock_hive_hook, mock_mssql_hook, mssql_to_hive_transfer.execute(context={}) field_dict = OrderedDict() - col_count = 0 - for field in mock_mssql_hook_cursor.return_value.description: - col_count += 1 + for col_count, field in enumerate(mock_mssql_hook_cursor.return_value.description, start=1): col_position = f"Column{col_count}" field_dict[col_position] = mssql_to_hive_transfer.type_map(field[1]) mock_hive_hook.return_value.load_file.assert_called_once_with( diff --git a/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py b/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py index 18bb75f192a3a..c0ac0130ccab6 100644 --- a/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py +++ b/tests/providers/apache/spark/hooks/test_spark_jdbc_script.py @@ -164,7 +164,8 @@ def test_spark_write_to_jdbc(self, mock_writer_save): # Given arguments = _parse_arguments(self.jdbc_arguments) spark_session = _create_spark_session(arguments) - spark_session.sql("CREATE TABLE IF NOT EXISTS " + arguments.metastore_table + " (key INT)") + spark_session.sql(f"CREATE TABLE IF NOT EXISTS {arguments.metastore_table} (key INT)") + # When spark_write_to_jdbc( @@ -191,7 +192,7 @@ def test_spark_read_from_jdbc(self, mock_reader_load): # Given arguments = _parse_arguments(self.jdbc_arguments) spark_session = _create_spark_session(arguments) - spark_session.sql("CREATE TABLE IF NOT EXISTS " + arguments.metastore_table + " (key INT)") + spark_session.sql(f"CREATE TABLE IF NOT EXISTS {arguments.metastore_table} (key INT)") # When spark_read_from_jdbc(