Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support LargeList in flatten #9110

Merged
merged 3 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Flatten => {
fn get_base_type(data_type: &DataType) -> Result<DataType> {
match data_type {
DataType::List(field) => match field.data_type() {
DataType::List(_) => get_base_type(field.data_type()),
_ => Ok(data_type.to_owned()),
},
_ => internal_err!("Not reachable, data_type should be List"),
DataType::List(field) if matches!(field.data_type(), DataType::List(_)) => get_base_type(field.data_type()),
DataType::LargeList(field) if matches!(field.data_type(), DataType::LargeList(_)) => get_base_type(field.data_type()),
DataType::Null | DataType::List(_) | DataType::LargeList(_) => Ok(data_type.to_owned()),
_ => internal_err!("Not reachable, data_type should be List or LargeList"),
}
}

Expand Down
52 changes: 36 additions & 16 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2246,38 +2246,41 @@ fn generic_list_cardinality<O: OffsetSizeTrait>(
}

// Create new offsets that are euqiavlent to `flatten` the array.
fn get_offsets_for_flatten(
offsets: OffsetBuffer<i32>,
indexes: OffsetBuffer<i32>,
) -> OffsetBuffer<i32> {
fn get_offsets_for_flatten<O: OffsetSizeTrait>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

offsets: OffsetBuffer<O>,
indexes: OffsetBuffer<O>,
) -> OffsetBuffer<O> {
let buffer = offsets.into_inner();
let offsets: Vec<i32> = indexes.iter().map(|i| buffer[*i as usize]).collect();
let offsets: Vec<O> = indexes
.iter()
.map(|i| buffer[i.to_usize().unwrap()])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering what is cheaper....
just cast as usize, or to_size and then unwrap

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot cast it as usize directly because *i is OffsetSizeTrait

.collect();
OffsetBuffer::new(offsets.into())
}

fn flatten_internal(
array: &dyn Array,
indexes: Option<OffsetBuffer<i32>>,
) -> Result<ListArray> {
let list_arr = as_list_array(array)?;
fn flatten_internal<O: OffsetSizeTrait>(
list_arr: GenericListArray<O>,
indexes: Option<OffsetBuffer<O>>,
) -> Result<GenericListArray<O>> {
let (field, offsets, values, _) = list_arr.clone().into_parts();
let data_type = field.data_type();

match data_type {
// Recursively get the base offsets for flattened array
DataType::List(_) => {
DataType::List(_) | DataType::LargeList(_) => {
let sub_list = as_generic_list_array::<O>(&values)?;
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
flatten_internal(&values, Some(offsets))
flatten_internal::<O>(sub_list.clone(), Some(offsets))
} else {
flatten_internal(&values, Some(offsets))
flatten_internal::<O>(sub_list.clone(), Some(offsets))
}
}
// Reach the base level, create a new list array
_ => {
if let Some(indexes) = indexes {
let offsets = get_offsets_for_flatten(offsets, indexes);
let list_arr = ListArray::new(field, offsets, values, None);
let list_arr = GenericListArray::<O>::new(field, offsets, values, None);
Ok(list_arr)
} else {
Ok(list_arr.clone())
Expand All @@ -2292,8 +2295,25 @@ pub fn flatten(args: &[ArrayRef]) -> Result<ArrayRef> {
return exec_err!("flatten expects one argument");
}

let flattened_array = flatten_internal(&args[0], None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
let array_type = args[0].data_type();
match array_type {
DataType::List(_) => {
let list_arr = as_list_array(&args[0])?;
let flattened_array = flatten_internal::<i32>(list_arr.clone(), None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
}
DataType::LargeList(_) => {
let list_arr = as_large_list_array(&args[0])?;
let flattened_array = flatten_internal::<i64>(list_arr.clone(), None)?;
Ok(Arc::new(flattened_array) as ArrayRef)
}
DataType::Null => Ok(args[0].clone()),
_ => {
exec_err!("flatten does not support type '{array_type:?}'")
}
}

// Ok(Arc::new(flattened_array) as ArrayRef)
}

/// Dispatch array length computation based on the offset type.
Expand Down
47 changes: 47 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ AS VALUES
(make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]), make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]))
;

statement ok
CREATE TABLE large_flatten_table
AS
SELECT
arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1,
arrow_cast(column2, 'LargeList(LargeList(LargeList(Int64)))') AS column2,
arrow_cast(column3, 'LargeList(LargeList(LargeList(LargeList(Int64))))') AS column3,
arrow_cast(column4, 'LargeList(LargeList(Float64))') AS column4
FROM flatten_table
;

statement ok
CREATE TABLE array_has_table_1D
AS VALUES
Expand Down Expand Up @@ -5345,19 +5356,41 @@ select array_concat(column1, [7]) from arrays_values_v2;
[7]

# flatten
# follow DuckDB
query ?
select flatten(NULL);
----
NULL

# flatten with scalar values #1
query ???
select flatten(make_array(1, 2, 1, 3, 2)),
flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))),
flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]));
----
[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]

query ???
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add test description

select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')),
flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'LargeList(LargeList(Int64))')),
flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]), 'LargeList(LargeList(LargeList(Float64)))'));
----
[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4]

# flatten with with column values
query ????
select column1, column2, column3, column4 from flatten_table;
Weijun-H marked this conversation as resolved.
Show resolved Hide resolved
----
[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]]
[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]

query ????
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add test description

select column1, column2, column3, column4 from large_flatten_table;
----
[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]]
[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]

# flatten with column values
query ????
select flatten(column1),
flatten(column2),
Expand All @@ -5368,6 +5401,17 @@ from flatten_table;
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]

query ????
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

select flatten(column1),
flatten(column2),
flatten(column3),
flatten(column4)
from large_flatten_table;
----
[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4]
[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]

## empty
# empty scalar function #1
query B
select empty(make_array(1));
Expand Down Expand Up @@ -5746,6 +5790,9 @@ drop table fixed_size_nested_arrays_with_repeating_elements;
statement ok
drop table flatten_table;

statement ok
drop table large_flatten_table;

statement ok
drop table arrays_values_without_nulls;

Expand Down
Loading