Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pr into #822] Final update to structured dataset column subsetting #828

Merged
merged 2 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,11 @@
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types import directory, file, schema
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetType
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetFormat,
StructuredDatasetType,
)

__version__ = "0.0.0+develop"

Expand Down
86 changes: 58 additions & 28 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,17 @@ def all(self) -> DF:
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
ctx = FlyteContextManager.current_context()
return FLYTE_DATASET_TRANSFORMER.open_as(ctx, self.literal, self._dataframe_type)
return FLYTE_DATASET_TRANSFORMER.open_as(
ctx, self.literal, self._dataframe_type, updated_metadata=self.metadata
)

def iter(self) -> Generator[DF, None, None]:
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
ctx = FlyteContextManager.current_context()
return FLYTE_DATASET_TRANSFORMER.iter_as(ctx, self.literal, self._dataframe_type)
return FLYTE_DATASET_TRANSFORMER.iter_as(
ctx, self.literal, self._dataframe_type, updated_metadata=self.metadata
)


def extract_cols_and_format(
Expand Down Expand Up @@ -405,8 +409,7 @@ def to_literal(
) -> Literal:
# Make a copy in case we need to hand off to encoders, since we can't be sure of mutations.
# Check first to see if it's even an SD type. For backwards compatibility, we may be getting a FlyteSchema
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
python_type, *attrs = extract_cols_and_format(python_type)
# In case it's a FlyteSchema
sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, None))

Expand Down Expand Up @@ -574,55 +577,82 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ...
else:
return self.open_as(ctx, sd_literal, df_type=expected_python_type)

# Start handling for StructuredDataset scalars, first look at the columns
incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns

# If the incoming literal, also doesn't have columns, then we just have an empty list, so initialize here
final_dataset_columns = []
# If the current running task's input does not have columns defined, or has an empty list of columns
if column_dict is None or len(column_dict) == 0:
# but if it does, then we just copy it over
if incoming_columns is not None and incoming_columns != []:
for c in incoming_columns:
final_dataset_columns.append(c)
# If the current running task's input does have columns defined
else:
final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict)

new_sdt = StructuredDatasetType(
columns=final_dataset_columns,
format=lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
external_schema_type=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_type,
external_schema_bytes=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_bytes,
)
metad = StructuredDatasetMetadata(structured_dataset_type=new_sdt)

# A StructuredDataset type, for example
# t1(input_a: StructuredDataset) # or
# t1(input_a: Annotated[StructuredDataset, my_cols])
if issubclass(expected_python_type, StructuredDataset):
incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns

# If the incoming literal, also doesn't have columns, then we just have an empty list, so initialize here
final_dataset_columns = []
# If the current running task's input does not have columns defined, or has an empty list of columns
if column_dict is None or len(column_dict) == 0:
# but if it does, then we just copy it over
if incoming_columns is not None and incoming_columns != []:
for c in incoming_columns:
final_dataset_columns.append(c)
# If the current running task's input does have columns defined
else:
final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict)

new_sdt = StructuredDatasetType(
columns=final_dataset_columns,
format=lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
external_schema_type=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_type,
external_schema_bytes=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_bytes,
)
sd = expected_python_type(
dataframe=None,
# Note here that the type being passed in
metadata=StructuredDatasetMetadata(structured_dataset_type=new_sdt),
metadata=metad,
)
sd._literal_sd = lv.scalar.structured_dataset
return sd

# If the requested type was not a StructuredDataset, then it means it was a plain dataframe type, which means
# we should do the opening/downloading and whatever else it might entail right now. No iteration option here.
return self.open_as(ctx, lv.scalar.structured_dataset, df_type=expected_python_type)
return self.open_as(ctx, lv.scalar.structured_dataset, df_type=expected_python_type, updated_metadata=metad)

def open_as(
self,
ctx: FlyteContext,
sd: literals.StructuredDataset,
df_type: Type[DF],
updated_metadata: Optional[StructuredDatasetMetadata] = None,
) -> DF:
"""

def open_as(self, ctx: FlyteContext, sd: literals.StructuredDataset, df_type: Type[DF]) -> DF:
:param ctx:
:param sd:
:param df_type:
:param meta: New metadata type, since it might be different from the metadata in the literal.
:return:
"""
protocol = protocol_prefix(sd.uri)
decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format)
# todo: revisit this, we probably should add a new field to the decoder interface
if updated_metadata:
sd._metadata = updated_metadata
result = decoder.decode(ctx, sd)
if isinstance(result, types.GeneratorType):
raise ValueError(f"Decoder {decoder} returned iterator {result} but whole value requested from {sd}")
return result

def iter_as(
self, ctx: FlyteContext, sd: literals.StructuredDataset, df_type: Type[DF]
self,
ctx: FlyteContext,
sd: literals.StructuredDataset,
df_type: Type[DF],
updated_metadata: Optional[StructuredDatasetMetadata] = None,
) -> Generator[DF, None, None]:
protocol = protocol_prefix(sd.uri)
decoder = self.DECODERS[df_type][protocol][sd.metadata.structured_dataset_type.format]
# todo: revisit this, should we add a new field to the decoder interface
if updated_metadata:
sd._metadata = updated_metadata
result = decoder.decode(ctx, sd)
if not isinstance(result, types.GeneratorType):
raise ValueError(f"Decoder {decoder} didn't return iterator {result} but should have from {sd}")
Expand Down
61 changes: 58 additions & 3 deletions tests/flytekit/unit/core/test_structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_protocol():


def generate_pandas() -> pd.DataFrame:
return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]})


def test_types_pandas():
Expand Down Expand Up @@ -265,5 +265,60 @@ def test_convert_schema_type_to_structured_dataset_type():
convert_schema_type_to_structured_dataset_type(int)


def test_to_python_value():
...
def test_to_python_value_with_incoming_columns():
# make a literal with a type that has two columns
original_type = Annotated[pd.DataFrame, kwtypes(name=str, age=int)]
ctx = FlyteContextManager.current_context()
lt = TypeEngine.to_literal_type(original_type)
df = generate_pandas()
lit = FLYTE_DATASET_TRANSFORMER.to_literal(ctx, df, python_type=original_type, expected=lt)
assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type.columns) == 2

# declare a new type that only has one column
# get the dataframe, make sure it has the column that was asked for.
subset_sd_type = Annotated[StructuredDataset, kwtypes(age=int)]
sd = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, subset_sd_type)
assert sd.metadata.structured_dataset_type.columns[0].name == "age"
sub_df = sd.open(pd.DataFrame).all()
assert sub_df.shape[1] == 1

# check when columns are not specified, should pull both and add column information.
sd = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, StructuredDataset)
assert sd.metadata.structured_dataset_type.columns[0].name == "age"

# should also work if subset type is just an annotated pd.DataFrame
subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)]
sub_df = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, subset_pd_type)
assert sub_df.shape[1] == 1


def test_to_python_value_without_incoming_columns():
# make a literal with a type with no columns
ctx = FlyteContextManager.current_context()
lt = TypeEngine.to_literal_type(pd.DataFrame)
df = generate_pandas()
lit = FLYTE_DATASET_TRANSFORMER.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt)
assert len(lit.scalar.structured_dataset.metadata.structured_dataset_type.columns) == 0

# declare a new type that only has one column
# get the dataframe, make sure it has the column that was asked for.
subset_sd_type = Annotated[StructuredDataset, kwtypes(age=int)]
sd = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, subset_sd_type)
assert sd.metadata.structured_dataset_type.columns[0].name == "age"
sub_df = sd.open(pd.DataFrame).all()
assert sub_df.shape[1] == 1

# check when columns are not specified, should pull both and add column information.
# todo: see the todos in the open_as, and iter_as functions in StructuredDatasetTransformerEngine
# we have to recreate the literal because the test case above filled in the metadata
lit = FLYTE_DATASET_TRANSFORMER.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt)
sd = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, StructuredDataset)
assert sd.metadata.structured_dataset_type.columns == []
sub_df = sd.open(pd.DataFrame).all()
assert sub_df.shape[1] == 2

# should also work if subset type is just an annotated pd.DataFrame
lit = FLYTE_DATASET_TRANSFORMER.to_literal(ctx, df, python_type=pd.DataFrame, expected=lt)
subset_pd_type = Annotated[pd.DataFrame, kwtypes(age=int)]
sub_df = FLYTE_DATASET_TRANSFORMER.to_python_value(ctx, lit, subset_pd_type)
assert sub_df.shape[1] == 1