diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index bc398ab7a7..1a6e56cd90 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -1,5 +1,5 @@ -import typing from datetime import datetime as _datetime +from typing import Optional import pytz as _pytz from flyteidl.core import literals_pb2 as _literals_pb2 @@ -549,7 +549,7 @@ def from_flyte_idl(cls, pb2_object): class StructuredDatasetMetadata(_common.FlyteIdlEntity): - def __init__(self, structured_dataset_type: StructuredDatasetType = None): + def __init__(self, structured_dataset_type: Optional[StructuredDatasetType] = None): self._structured_dataset_type = structured_dataset_type @property @@ -571,7 +571,7 @@ def from_flyte_idl(cls, pb2_object: _literals_pb2.StructuredDatasetMetadata) -> class StructuredDataset(_common.FlyteIdlEntity): - def __init__(self, uri: str, metadata: typing.Optional[StructuredDatasetMetadata] = None): + def __init__(self, uri: str, metadata: Optional[StructuredDatasetMetadata] = None): """ A strongly typed schema that defines the interface of data retrieved from the underlying storage medium. """ diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 49b2f13ed9..ff9d692cec 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -55,14 +55,13 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> pd.DataFrame: path = flyte_value.uri local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(path, local_dir, is_multipart=True) - if flyte_value.metadata.structured_dataset_type.columns: - columns = [] - for c in flyte_value.metadata.structured_dataset_type.columns: - columns.append(c.name) + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] return pd.read_parquet(local_dir, columns=columns) return pd.read_parquet(local_dir) @@ -94,14 +93,13 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> pa.Table: path = flyte_value.uri local_dir = ctx.file_access.get_random_local_directory() ctx.file_access.get_data(path, local_dir, is_multipart=True) - if flyte_value.metadata.structured_dataset_type.columns: - columns = [] - for c in flyte_value.metadata.structured_dataset_type.columns: - columns.append(c.name) + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] return pq.read_table(local_dir, columns=columns) return pq.read_table(local_dir) diff --git a/flytekit/types/structured/bigquery.py b/flytekit/types/structured/bigquery.py index 923ea06e9e..aa0ef42f6b 100644 --- a/flytekit/types/structured/bigquery.py +++ b/flytekit/types/structured/bigquery.py @@ -11,7 +11,6 @@ from flytekit.models.types import StructuredDatasetType from flytekit.types.structured.structured_dataset import ( BIGQUERY, - DF, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -29,7 +28,9 @@ def _write_to_bq(structured_dataset: StructuredDataset): client.load_table_from_dataframe(df, table_id) -def _read_from_bq(flyte_value: literals.StructuredDataset) -> pd.DataFrame: +def _read_from_bq( + flyte_value: literals.StructuredDataset, current_task_metadata: StructuredDatasetMetadata +) -> pd.DataFrame: path = flyte_value.uri _, project_id, dataset_id, table_id = re.split("\\.|://|:", path) client = bigquery_storage.BigQueryReadClient() @@ -37,10 +38,8 @@ def _read_from_bq(flyte_value: literals.StructuredDataset) -> pd.DataFrame: parent = "projects/{}".format(project_id) read_options = None - if flyte_value.metadata.structured_dataset_type.columns: - columns = [] - for c in flyte_value.metadata.structured_dataset_type.columns: - columns.append(c.name) + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] read_options = types.ReadSession.TableReadOptions(selected_fields=columns) requested_session = types.ReadSession(table=table, data_format=types.DataFormat.ARROW, read_options=read_options) @@ -78,8 +77,9 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, - ) -> typing.Union[DF, typing.Generator[DF, None, None]]: - return _read_from_bq(flyte_value) + current_task_metadata: StructuredDatasetMetadata, + ) -> pd.DataFrame: + return _read_from_bq(flyte_value, current_task_metadata) class ArrowToBQEncodingHandlers(StructuredDatasetEncoder): @@ -106,7 +106,8 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, - ) -> typing.Union[DF, typing.Generator[DF, None, None]]: + current_task_metadata: StructuredDatasetMetadata, + ) -> pa.Table: return pa.Table.from_pandas(_read_from_bq(flyte_value)) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 819bc012cc..fa57d7b5de 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -105,9 +105,7 @@ def all(self) -> DF: if self._dataframe_type is None: raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.") ctx = FlyteContextManager.current_context() - return flyte_dataset_transformer.open_as( - ctx, self.literal, self._dataframe_type, updated_metadata=self.metadata - ) + return flyte_dataset_transformer.open_as(ctx, self.literal, self._dataframe_type, self.metadata) def iter(self) -> Generator[DF, None, None]: if self._dataframe_type is None: @@ -261,14 +259,17 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> Union[DF, Generator[DF, None, None]]: """ This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal value into a Python instance. - :param ctx: + :param ctx: A FlyteContext, useful in accessing the filesystem and other attributes :param flyte_value: This will be a Flyte IDL StructuredDataset Literal - do not confuse this with the StructuredDataset class defined also in this module. + :param current_task_metadata: Metadata object containing the type (and columns if any) for the currently + executing task. This type may have more or less information than the type information bundled inside the incoming flyte_value. :return: This function can either return an instance of the dataframe that this decoder handles, or an iterator of those dataframes. """ @@ -585,7 +586,7 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ... sd._literal_sd = sd_literal return sd else: - return self.open_as(ctx, sd_literal, df_type=expected_python_type) + return self.open_as(ctx, sd_literal, expected_python_type, metad) # Start handling for StructuredDataset scalars, first look at the columns incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns @@ -596,8 +597,7 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ... if column_dict is None or len(column_dict) == 0: # but if it does, then we just copy it over if incoming_columns is not None and incoming_columns != []: - for c in incoming_columns: - final_dataset_columns.append(c) + final_dataset_columns = incoming_columns.copy() # If the current running task's input does have columns defined else: final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict) @@ -631,22 +631,18 @@ def open_as( ctx: FlyteContext, sd: literals.StructuredDataset, df_type: Type[DF], - updated_metadata: Optional[StructuredDatasetMetadata] = None, + updated_metadata: StructuredDatasetMetadata, ) -> DF: """ - - :param ctx: + :param ctx: A FlyteContext, useful in accessing the filesystem and other attributes :param sd: :param df_type: - :param meta: New metadata type, since it might be different from the metadata in the literal. - :return: + :param updated_metadata: New metadata type, since it might be different from the metadata in the literal. + :return: dataframe. It could be pandas dataframe or arrow table, etc. """ protocol = protocol_prefix(sd.uri) decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format) - # todo: revisit this, we probably should add a new field to the decoder interface - if updated_metadata: - sd._metadata = updated_metadata - result = decoder.decode(ctx, sd) + result = decoder.decode(ctx, sd, updated_metadata) if isinstance(result, types.GeneratorType): raise ValueError(f"Decoder {decoder} returned iterator {result} but whole value requested from {sd}") return result @@ -656,14 +652,11 @@ def iter_as( ctx: FlyteContext, sd: literals.StructuredDataset, df_type: Type[DF], - updated_metadata: Optional[StructuredDatasetMetadata] = None, + updated_metadata: StructuredDatasetMetadata, ) -> Generator[DF, None, None]: protocol = protocol_prefix(sd.uri) decoder = self.DECODERS[df_type][protocol][sd.metadata.structured_dataset_type.format] - # todo: revisit this, should we add a new field to the decoder interface - if updated_metadata: - sd._metadata = updated_metadata - result = decoder.decode(ctx, sd) + result = decoder.decode(ctx, sd, updated_metadata) if not isinstance(result, types.GeneratorType): raise ValueError(f"Decoder {decoder} didn't return iterator {result} but should have from {sd}") return result diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py index c17e2fc8bd..d47318666f 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/arrow.py @@ -47,6 +47,7 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> pa.Table: uri = flyte_value.uri if not ctx.file_access.is_remote(uri): @@ -54,10 +55,8 @@ def decode( _, path = split_protocol(uri) columns = None - if flyte_value.metadata.structured_dataset_type.columns: - columns = [] - for c in flyte_value.metadata.structured_dataset_type.columns: - columns.append(c.name) + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] try: fs = FSSpecPersistence.get_filesystem(uri) return pq.read_table(path, filesystem=fs, columns=columns) diff --git a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py index 52bcc4522a..07b58d243a 100644 --- a/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py +++ b/plugins/flytekit-data-fsspec/flytekitplugins/fsspec/pandas.py @@ -58,14 +58,13 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> pd.DataFrame: uri = flyte_value.uri columns = None kwargs = get_storage_options(uri) - if flyte_value.metadata.structured_dataset_type.columns: - columns = [] - for c in flyte_value.metadata.structured_dataset_type.columns: - columns.append(c.name) + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] try: return pd.read_parquet(uri, columns=columns, storage_options=kwargs) except NoCredentialsError: diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 2466d3fc13..cd451fa080 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -39,8 +39,12 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> DataFrame: user_ctx = FlyteContext.current_context().user_space_params + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns) return user_ctx.spark_session.read.parquet(flyte_value.uri) diff --git a/plugins/flytekit-spark/tests/test_wf.py b/plugins/flytekit-spark/tests/test_wf.py index a0a624fec7..8c42a6162f 100644 --- a/plugins/flytekit-spark/tests/test_wf.py +++ b/plugins/flytekit-spark/tests/test_wf.py @@ -6,6 +6,11 @@ from flytekit import kwtypes, task, workflow from flytekit.types.schema import FlyteSchema +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated + def test_wf1_with_spark(): @task(task_config=Spark()) @@ -53,27 +58,6 @@ def my_wf() -> my_schema: assert df2 is not None -def test_ddwf1_with_spark(): - @task(task_config=Spark()) - def my_spark(a: int) -> (int, str): - session = flytekit.current_context().spark_session - assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" - return a + 2, "world" - - @task - def t2(a: str, b: str) -> str: - return b + a - - @workflow - def my_wf(a: int, b: str) -> (int, str): - x, y = my_spark(a=a) - d = t2(a=y, b=b) - return x, d - - x = my_wf(a=5, b="hello ") - assert x == (7, "hello world") - - def test_fs_sd_compatibility(): my_schema = FlyteSchema[kwtypes(name=str, age=int)] @@ -108,7 +92,6 @@ def test_spark_dataframe_return(): def my_spark(a: int) -> my_schema: session = flytekit.current_context().spark_session df = session.createDataFrame([("Alice", a)], my_schema.column_names()) - print(type(df)) return df @workflow @@ -120,3 +103,19 @@ def my_wf(a: int) -> my_schema: df2 = reader.all() result_df = df2.reset_index(drop=True) == pd.DataFrame(data={"name": ["Alice"], "age": [5]}).reset_index(drop=True) assert result_df.all().all() + + +def test_read_spark_subset_columns(): + @task + def t1() -> pd.DataFrame: + return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + + @task(task_config=Spark()) + def t2(df: Annotated[pyspark.sql.DataFrame, kwtypes(Name=str)]) -> int: + return len(df.columns) + + @workflow() + def wf() -> int: + return t2(df=t1()) + + assert wf() == 1 diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index a7ef1ea953..8efaecfcc9 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -221,6 +221,7 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> typing.Union[typing.Generator[pd.DataFrame, None, None]]: yield pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) @@ -241,8 +242,9 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> pd.DataFrame: - pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) StructuredDatasetTransformerEngine.register( MockPandasDecodingHandlers(pd.DataFrame, "tmpfs"), default_for_type=False, override=True @@ -288,7 +290,7 @@ def test_to_python_value_with_incoming_columns(): # check when columns are not specified, should pull both and add column information. sd = fdt.to_python_value(ctx, lit, StructuredDataset) - assert sd.metadata.structured_dataset_type.columns[0].name == "age" + assert len(sd.metadata.structured_dataset_type.columns) == 2 # should also work if subset type is just an annotated pd.DataFrame subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)] diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index 0d755ab78b..ada7483a0f 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -6,6 +6,7 @@ from flytekit.core import context_manager from flytekit.core.base_task import kwtypes +from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType from flytekit.types.structured import basic_dfs from flytekit.types.structured.structured_dataset import ( @@ -26,12 +27,11 @@ def test_pandas(): decoder = basic_dfs.ParquetToPandasDecodingHandler("/") ctx = context_manager.FlyteContextManager.current_context() - sd = StructuredDataset( - dataframe=df, - ) - sd_lit = encoder.encode(ctx, sd, StructuredDatasetType(format="parquet")) + sd = StructuredDataset(dataframe=df) + sd_type = StructuredDatasetType(format="parquet") + sd_lit = encoder.encode(ctx, sd, sd_type) - df2 = decoder.decode(ctx, sd_lit) + df2 = decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) assert df.equals(df2) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py index d911f971e4..0fc0bd976c 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset_workflow.py @@ -54,6 +54,7 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> pd.DataFrame: return pd_df @@ -86,6 +87,7 @@ def decode( self, ctx: FlyteContext, flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, ) -> typing.Union[DF, typing.Generator[DF, None, None]]: path = flyte_value.uri local_dir = ctx.file_access.get_random_local_directory()