diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 369316c642..fd8fcc7d9c 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -186,7 +186,11 @@ from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType from flytekit.types import directory, file, schema -from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetType +from flytekit.types.structured.structured_dataset import ( + StructuredDataset, + StructuredDatasetFormat, + StructuredDatasetType, +) __version__ = "0.0.0+develop" diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 53e541e6da..a36e50b976 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -105,13 +105,17 @@ def all(self) -> DF: if self._dataframe_type is None: raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.") ctx = FlyteContextManager.current_context() - return FLYTE_DATASET_TRANSFORMER.open_as(ctx, self.literal, self._dataframe_type) + return FLYTE_DATASET_TRANSFORMER.open_as( + ctx, self.literal, self._dataframe_type, updated_metadata=self.metadata + ) def iter(self) -> Generator[DF, None, None]: if self._dataframe_type is None: raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.") ctx = FlyteContextManager.current_context() - return FLYTE_DATASET_TRANSFORMER.iter_as(ctx, self.literal, self._dataframe_type) + return FLYTE_DATASET_TRANSFORMER.iter_as( + ctx, self.literal, self._dataframe_type, updated_metadata=self.metadata + ) def extract_cols_and_format( @@ -405,8 +409,7 @@ def to_literal( ) -> Literal: # Make a copy in case we need to hand off to encoders, since we can't be sure of mutations. # Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema - if get_origin(python_type) is Annotated: - python_type = get_args(python_type)[0] + python_type, *attrs = extract_cols_and_format(python_type) # In case it's a FlyteSchema sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, None)) @@ -574,55 +577,82 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ... else: return self.open_as(ctx, sd_literal, df_type=expected_python_type) + # Start handling for StructuredDataset scalars, first look at the columns + incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns + + # If the incoming literal, also doesn't have columns, then we just have an empty list, so initialize here + final_dataset_columns = [] + # If the current running task's input does not have columns defined, or has an empty list of columns + if column_dict is None or len(column_dict) == 0: + # but if it does, then we just copy it over + if incoming_columns is not None and incoming_columns != []: + for c in incoming_columns: + final_dataset_columns.append(c) + # If the current running task's input does have columns defined + else: + final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict) + + new_sdt = StructuredDatasetType( + columns=final_dataset_columns, + format=lv.scalar.structured_dataset.metadata.structured_dataset_type.format, + external_schema_type=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_type, + external_schema_bytes=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_bytes, + ) + metad = StructuredDatasetMetadata(structured_dataset_type=new_sdt) + # A StructuredDataset type, for example # t1(input_a: StructuredDataset) # or # t1(input_a: Annotated[StructuredDataset, my_cols]) if issubclass(expected_python_type, StructuredDataset): - incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns - - # If the incoming literal, also doesn't have columns, then we just have an empty list, so initialize here - final_dataset_columns = [] - # If the current running task's input does not have columns defined, or has an empty list of columns - if column_dict is None or len(column_dict) == 0: - # but if it does, then we just copy it over - if incoming_columns is not None and incoming_columns != []: - for c in incoming_columns: - final_dataset_columns.append(c) - # If the current running task's input does have columns defined - else: - final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict) - - new_sdt = StructuredDatasetType( - columns=final_dataset_columns, - format=lv.scalar.structured_dataset.metadata.structured_dataset_type.format, - external_schema_type=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_type, - external_schema_bytes=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_bytes, - ) sd = expected_python_type( dataframe=None, # Note here that the type being passed in - metadata=StructuredDatasetMetadata(structured_dataset_type=new_sdt), + metadata=metad, ) sd._literal_sd = lv.scalar.structured_dataset return sd # If the requested type was not a StructuredDataset, then it means it was a plain dataframe type, which means # we should do the opening/downloading and whatever else it might entail right now. No iteration option here. - return self.open_as(ctx, lv.scalar.structured_dataset, df_type=expected_python_type) + return self.open_as(ctx, lv.scalar.structured_dataset, df_type=expected_python_type, updated_metadata=metad) + + def open_as( + self, + ctx: FlyteContext, + sd: literals.StructuredDataset, + df_type: Type[DF], + updated_metadata: Optional[StructuredDatasetMetadata] = None, + ) -> DF: + """ - def open_as(self, ctx: FlyteContext, sd: literals.StructuredDataset, df_type: Type[DF]) -> DF: + :param ctx: + :param sd: + :param df_type: + :param meta: New metadata type, since it might be different from the metadata in the literal. + :return: + """ protocol = protocol_prefix(sd.uri) decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format) + # todo: revisit this, we probably should add a new field to the decoder interface + if updated_metadata: + sd._metadata = updated_metadata result = decoder.decode(ctx, sd) if isinstance(result, types.GeneratorType): raise ValueError(f"Decoder {decoder} returned iterator {result} but whole value requested from {sd}") return result def iter_as( - self, ctx: FlyteContext, sd: literals.StructuredDataset, df_type: Type[DF] + self, + ctx: FlyteContext, + sd: literals.StructuredDataset, + df_type: Type[DF], + updated_metadata: Optional[StructuredDatasetMetadata] = None, ) -> Generator[DF, None, None]: protocol = protocol_prefix(sd.uri) decoder = self.DECODERS[df_type][protocol][sd.metadata.structured_dataset_type.format] + # todo: revisit this, should we add a new field to the decoder interface + if updated_metadata: + sd._metadata = updated_metadata result = decoder.decode(ctx, sd) if not isinstance(result, types.GeneratorType): raise ValueError(f"Decoder {decoder} didn't return iterator {result} but should have from {sd}") diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index 59b20cc3b6..4e4309e292 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -49,7 +49,7 @@ def test_protocol(): def generate_pandas() -> pd.DataFrame: - return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) def test_types_pandas(): @@ -265,5 +265,60 @@ def test_convert_schema_type_to_structured_dataset_type(): convert_schema_type_to_structured_dataset_type(int) -def test_to_python_value(): - ... +def test_to_python_value_with_incoming_columns(): + # make a literal with a type that has two columns + original_type = Annotated[pd.DataFrame, kwtypes(name=str, age=int)] + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(original_type) + df = generate_pandas() + lit = FLYTE_DATASET_TRANSFORMER.to_literal(ctx, df, python_type=original_type, expected=lt) + assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type.columns) == 2 + + # declare a new type that only has one column + # get the dataframe, make sure it has the column that was asked for. + subset_sd_type = Annotated[StructuredDataset, kwtypes(age=int)] + sd = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, subset_sd_type) + assert sd.metadata.structured_dataset_type.columns[0].name == "age" + sub_df = sd.open(pd.DataFrame).all() + assert sub_df.shape[1] == 1 + + # check when columns are not specified, should pull both and add column information. + sd = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, StructuredDataset) + assert sd.metadata.structured_dataset_type.columns[0].name == "age" + + # should also work if subset type is just an annotated pd.DataFrame + subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)] + sub_df = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, subset_pd_type) + assert sub_df.shape[1] == 1 + + +def test_to_python_value_without_incoming_columns(): + # make a literal with a type with no columns + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(pd.DataFrame) + df = generate_pandas() + lit = FLYTE_DATASET_TRANSFORMER.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt) + assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type.columns) == 0 + + # declare a new type that only has one column + # get the dataframe, make sure it has the column that was asked for. + subset_sd_type = Annotated[StructuredDataset, kwtypes(age=int)] + sd = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, subset_sd_type) + assert sd.metadata.structured_dataset_type.columns[0].name == "age" + sub_df = sd.open(pd.DataFrame).all() + assert sub_df.shape[1] == 1 + + # check when columns are not specified, should pull both and add column information. + # todo: see the todos in the open_as, and iter_as functions in StructuredDatasetTransformerEngine + # we have to recreate the literal because the test case above filled in the metadata + lit = FLYTE_DATASET_TRANSFORMER.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt) + sd = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, StructuredDataset) + assert sd.metadata.structured_dataset_type.columns == [] + sub_df = sd.open(pd.DataFrame).all() + assert sub_df.shape[1] == 2 + + # should also work if subset type is just an annotated pd.DataFrame + lit = FLYTE_DATASET_TRANSFORMER.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt) + subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)] + sub_df = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, subset_pd_type) + assert sub_df.shape[1] == 1