Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Refactoring and Cleaning Apache Providers #24219

Merged
merged 15 commits into from
Jun 6, 2022
6 changes: 1 addition & 5 deletions airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,11 +469,7 @@ def execute(self, context: 'Context'):
process_line_callback=process_line_callback,
)
if dataflow_job_name and self.dataflow_config.location:
multiple_jobs = (
self.dataflow_config.multiple_jobs
if self.dataflow_config.multiple_jobs
else False
)
multiple_jobs = self.dataflow_config.multiple_jobs or False
DataflowJobLink.persist(
self,
context,
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/apache/cassandra/hooks/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,8 @@ def get_lb_policy(policy_name: str, policy_args: Dict[str, Any]) -> Policy:
child_policy_args = policy_args.get('child_load_balancing_policy_args', {})
if child_policy_name not in allowed_child_policies:
return TokenAwarePolicy(RoundRobinPolicy())
else:
child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args)
return TokenAwarePolicy(child_policy)
child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args)
return TokenAwarePolicy(child_policy)

# Fallback to default RoundRobinPolicy
return RoundRobinPolicy()
Expand Down Expand Up @@ -200,7 +199,7 @@ def record_exists(self, table: str, keys: Dict[str, str]) -> bool:
keyspace = self.keyspace
if '.' in table:
keyspace, table = table.split('.', 1)
ks_str = " AND ".join(f"{key}=%({key})s" for key in keys.keys())
ks_str = " AND ".join(f"{key}=%({key})s" for key in keys)
query = f"SELECT * FROM {keyspace}.{table} WHERE {ks_str}"
try:
result = self.get_conn().execute(query, keys)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/drill/hooks/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_uri(self) -> str:
host = conn_md.host
if conn_md.port is not None:
host += f':{conn_md.port}'
conn_type = 'drill' if not conn_md.conn_type else conn_md.conn_type
conn_type = conn_md.conn_type or 'drill'
dialect_driver = conn_md.extra_dejson.get('dialect_driver', 'drill+sadrill')
storage_plugin = conn_md.extra_dejson.get('storage_plugin', 'dfs')
return f'{conn_type}://{host}/{storage_plugin}?dialect_driver={dialect_driver}'
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/apache/druid/hooks/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_conn_url(self) -> str:
conn = self.get_connection(self.druid_ingest_conn_id)
host = conn.host
port = conn.port
conn_type = 'http' if not conn.conn_type else conn.conn_type
conn_type = conn.conn_type or 'http'
endpoint = conn.extra_dejson.get('endpoint', '')
return f"{conn_type}://{host}:{port}/{endpoint}"

Expand Down Expand Up @@ -163,7 +163,7 @@ def get_uri(self) -> str:
host = conn.host
if conn.port is not None:
host += f':{conn.port}'
conn_type = 'druid' if not conn.conn_type else conn.conn_type
conn_type = conn.conn_type or 'druid'
endpoint = conn.extra_dejson.get('endpoint', 'druid/v2/sql')
return f'{conn_type}://{host}/{endpoint}'

Expand Down
9 changes: 3 additions & 6 deletions airflow/providers/apache/hdfs/hooks/webhdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,9 @@ def _get_client(self, namenode: str, port: int, login: str, extra_dejson: dict)
session.verify = extra_dejson.get('verify', True)

if _kerberos_security_mode:
client = KerberosClient(connection_str, session=session)
else:
proxy_user = self.proxy_user or login
client = InsecureClient(connection_str, user=proxy_user, session=session)

return client
return KerberosClient(connection_str, session=session)
proxy_user = self.proxy_user or login
return InsecureClient(connection_str, user=proxy_user, session=session)

def check_for_path(self, hdfs_path: str) -> bool:
"""
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/apache/hive/operators/hive_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ def get_default_exprs(self, col: str, col_type: str) -> Dict[Any, Any]:
if col in self.excluded_columns:
return {}
exp = {(col, 'non_null'): f"COUNT({col})"}
if col_type in ['double', 'int', 'bigint', 'float']:
if col_type in {'double', 'int', 'bigint', 'float'}:
potiuk marked this conversation as resolved.
Show resolved Hide resolved
exp[(col, 'sum')] = f'SUM({col})'
exp[(col, 'min')] = f'MIN({col})'
exp[(col, 'max')] = f'MAX({col})'
exp[(col, 'avg')] = f'AVG({col})'
elif col_type == 'boolean':
exp[(col, 'true')] = f'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)'
exp[(col, 'false')] = f'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)'
elif col_type in ['string']:
elif col_type == 'string':
exp[(col, 'len')] = f'SUM(CAST(LENGTH({col}) AS BIGINT))'
exp[(col, 'approx_distinct')] = f'APPROX_DISTINCT({col})'

Expand All @@ -130,7 +130,7 @@ def execute(self, context: "Context") -> None:
exprs.update(assign_exprs)
exprs.update(self.extra_exprs)
exprs = OrderedDict(exprs)
exprs_str = ",\n ".join(v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items())
exprs_str = ",\n ".join(f"{v} AS {k[0]}__{k[1]}" for k, v in exprs.items())

where_clause_ = [f"{k} = '{v}'" for k, v in self.partition.items()]
where_clause = " AND\n ".join(where_clause_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def parse_partition_name(partition: str) -> Tuple[Any, ...]:
schema, table_partition = first_split
second_split = table_partition.split('/', 1)
if len(second_split) == 1:
raise ValueError('Could not parse ' + partition + 'into table, partition')
raise ValueError(f'Could not parse {partition}into table, partition')
else:
table, partition = second_split
return schema, table, partition
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/apache/hive/transfers/mssql_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def execute(self, context: "Context"):
with NamedTemporaryFile("w") as tmp_file:
csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding='utf-8')
field_dict = OrderedDict()
col_count = 0
for field in cursor.description:
col_count += 1
for col_count, field in enumerate(cursor.description, start=1):
col_position = f"Column{col_count}"
field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor)
Expand Down
9 changes: 4 additions & 5 deletions airflow/providers/apache/hive/transfers/s3_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def execute(self, context: 'Context'):
if not s3_hook.check_for_wildcard_key(self.s3_key):
raise AirflowException(f"No key matches {self.s3_key}")
s3_key_object = s3_hook.get_wildcard_key(self.s3_key)
else:
if not s3_hook.check_for_key(self.s3_key):
raise AirflowException(f"The key {self.s3_key} does not exists")
elif s3_hook.check_for_key(self.s3_key):
s3_key_object = s3_hook.get_key(self.s3_key)

else:
raise AirflowException(f"The key {self.s3_key} does not exists")
_, file_ext = os.path.splitext(s3_key_object.key)
if self.select_expression and self.input_compressed and file_ext.lower() != '.gz':
raise AirflowException("GZIP is the only compression format Amazon S3 Select supports")
Expand Down Expand Up @@ -227,8 +227,7 @@ def execute(self, context: 'Context'):
def _get_top_row_as_list(self, file_name):
with open(file_name) as file:
header_line = file.readline().strip()
header_list = header_line.split(self.delimiter)
return header_list
return header_line.split(self.delimiter)

def _match_headers(self, header_list):
if not header_list:
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/apache/hive/transfers/vertica_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ def execute(self, context: 'Context'):
with NamedTemporaryFile("w") as f:
csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8')
field_dict = OrderedDict()
col_count = 0
for field in cursor.description:
col_count += 1
for col_count, field in enumerate(cursor.description, start=1):
col_position = f"Column{col_count}"
field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor.iterate())
Expand Down
22 changes: 10 additions & 12 deletions airflow/providers/apache/kylin/hooks/kylin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,15 @@ def get_conn(self):
conn = self.get_connection(self.kylin_conn_id)
if self.dsn:
return kylinpy.create_kylin(self.dsn)
else:
self.project = self.project if self.project else conn.schema
return kylinpy.Kylin(
conn.host,
username=conn.login,
password=conn.password,
port=conn.port,
project=self.project,
**conn.extra_dejson,
)
self.project = self.project or conn.schema
return kylinpy.Kylin(
conn.host,
username=conn.login,
password=conn.password,
port=conn.port,
project=self.project,
**conn.extra_dejson,
)

def cube_run(self, datasource_name, op, **op_args):
"""
Expand All @@ -70,8 +69,7 @@ def cube_run(self, datasource_name, op, **op_args):
"""
cube_source = self.get_conn().get_datasource(datasource_name)
try:
response = cube_source.invoke_command(op, **op_args)
return response
return cube_source.invoke_command(op, **op_args)
except exceptions.KylinError as err:
raise AirflowException(f"Cube operation {op} error , Message: {err}")

Expand Down
6 changes: 2 additions & 4 deletions airflow/providers/apache/pinot/hooks/pinot.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ def run_cli(self, cmd: List[str], verbose: bool = True) -> str:
:param cmd: List of command going to be run by pinot-admin.sh script
:param verbose:
"""
command = [self.cmd_path]
command.extend(cmd)

command = [self.cmd_path, *cmd]
env = None
if self.pinot_admin_system_exit:
env = os.environ.copy()
Expand Down Expand Up @@ -273,7 +271,7 @@ def get_uri(self) -> str:
host = conn.host
if conn.port is not None:
host += f':{conn.port}'
conn_type = 'http' if not conn.conn_type else conn.conn_type
conn_type = conn.conn_type or 'http'
endpoint = conn.extra_dejson.get('endpoint', 'query/sql')
return f'{conn_type}://{host}/{endpoint}'

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/spark/hooks/spark_jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _build_jdbc_application_arguments(self, jdbc_conn: Dict[str, Any]) -> Any:
def submit_jdbc_job(self) -> None:
"""Submit Spark JDBC job"""
self._application_args = self._build_jdbc_application_arguments(self._jdbc_connection)
self.submit(application=os.path.dirname(os.path.abspath(__file__)) + "/spark_jdbc_script.py")
self.submit(application=f"{os.path.dirname(os.path.abspath(__file__))}/spark_jdbc_script.py")

def get_conn(self) -> Any:
pass
26 changes: 10 additions & 16 deletions airflow/providers/apache/spark/hooks/spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
#
import contextlib
import os
import re
import subprocess
Expand All @@ -28,10 +29,8 @@
from airflow.security.kerberos import renew_from_kt
from airflow.utils.log.logging_mixin import LoggingMixin

try:
with contextlib.suppress(ImportError, NameError):
from airflow.kubernetes import kube_client
except (ImportError, NameError):
pass


class SparkSubmitHook(BaseHook, LoggingMixin):
Expand Down Expand Up @@ -355,9 +354,7 @@ def _build_track_driver_status_command(self) -> List[str]:
self.log.info(connection_cmd)

# The driver id so we can poll for its status
if self._driver_id:
pass
else:
if not self._driver_id:
raise AirflowException(
"Invalid status: attempted to poll driver status but no driver id is known. Giving up."
)
Expand Down Expand Up @@ -607,17 +604,14 @@ def on_kill(self) -> None:
"""Kill Spark submit command"""
self.log.debug("Kill Command is being called")

if self._should_track_driver_status:
if self._driver_id:
self.log.info('Killing driver %s on cluster', self._driver_id)
if self._should_track_driver_status and self._driver_id:
self.log.info('Killing driver %s on cluster', self._driver_id)

kill_cmd = self._build_spark_driver_kill_command()
with subprocess.Popen(
kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
) as driver_kill:
self.log.info(
"Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait()
)
kill_cmd = self._build_spark_driver_kill_command()
with subprocess.Popen(kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as driver_kill:
self.log.info(
"Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait()
)

if self._submit_sp and self._submit_sp.poll() is None:
self.log.info('Sending kill signal to %s', self._connection['spark_binary'])
Expand Down
4 changes: 1 addition & 3 deletions tests/providers/apache/hive/transfers/test_mssql_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def test_execute_empty_description_field(self, mock_hive_hook, mock_mssql_hook,
mssql_to_hive_transfer.execute(context={})

field_dict = OrderedDict()
col_count = 0
for field in mock_mssql_hook_cursor.return_value.description:
col_count += 1
for col_count, field in enumerate(mock_mssql_hook_cursor.return_value.description, start=1):
col_position = f"Column{col_count}"
field_dict[col_position] = mssql_to_hive_transfer.type_map(field[1])
mock_hive_hook.return_value.load_file.assert_called_once_with(
Expand Down
5 changes: 3 additions & 2 deletions tests/providers/apache/spark/hooks/test_spark_jdbc_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def test_spark_write_to_jdbc(self, mock_writer_save):
# Given
arguments = _parse_arguments(self.jdbc_arguments)
spark_session = _create_spark_session(arguments)
spark_session.sql("CREATE TABLE IF NOT EXISTS " + arguments.metastore_table + " (key INT)")
spark_session.sql(f"CREATE TABLE IF NOT EXISTS {arguments.metastore_table} (key INT)")

# When

spark_write_to_jdbc(
Expand All @@ -191,7 +192,7 @@ def test_spark_read_from_jdbc(self, mock_reader_load):
# Given
arguments = _parse_arguments(self.jdbc_arguments)
spark_session = _create_spark_session(arguments)
spark_session.sql("CREATE TABLE IF NOT EXISTS " + arguments.metastore_table + " (key INT)")
spark_session.sql(f"CREATE TABLE IF NOT EXISTS {arguments.metastore_table} (key INT)")

# When
spark_read_from_jdbc(
Expand Down