From bf145c255581fb3079496c8f41fc0ca6f4435bb4 Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Fri, 2 Feb 2024 18:54:31 +0800 Subject: [PATCH 1/3] support FixedSizeList in flatten --- datafusion/expr/src/built_in_function.rs | 7 ++- .../physical-expr/src/array_expressions.rs | 52 +++++++++++++------ datafusion/sqllogictest/test_files/array.slt | 43 +++++++++++++++ 3 files changed, 85 insertions(+), 17 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 20b7df46e387..bcc4f8d3c864 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -551,7 +551,12 @@ impl BuiltinScalarFunction { DataType::List(_) => get_base_type(field.data_type()), _ => Ok(data_type.to_owned()), }, - _ => internal_err!("Not reachable, data_type should be List"), + DataType::LargeList(field) => match field.data_type() { + DataType::LargeList(_) => get_base_type(field.data_type()), + _ => Ok(data_type.to_owned()), + }, + DataType::Null => Ok(data_type.to_owned()), + _ => internal_err!("Not reachable, data_type should be List or LargeList"), } } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 844dae0917c7..0709e66a35c9 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -2246,38 +2246,41 @@ fn generic_list_cardinality( } // Create new offsets that are euqiavlent to `flatten` the array. -fn get_offsets_for_flatten( - offsets: OffsetBuffer, - indexes: OffsetBuffer, -) -> OffsetBuffer { +fn get_offsets_for_flatten( + offsets: OffsetBuffer, + indexes: OffsetBuffer, +) -> OffsetBuffer { let buffer = offsets.into_inner(); - let offsets: Vec = indexes.iter().map(|i| buffer[*i as usize]).collect(); + let offsets: Vec = indexes + .iter() + .map(|i| buffer[i.to_usize().unwrap()]) + .collect(); OffsetBuffer::new(offsets.into()) } -fn flatten_internal( - array: &dyn Array, - indexes: Option>, -) -> Result { - let list_arr = as_list_array(array)?; +fn flatten_internal( + list_arr: GenericListArray, + indexes: Option>, +) -> Result> { let (field, offsets, values, _) = list_arr.clone().into_parts(); let data_type = field.data_type(); match data_type { // Recursively get the base offsets for flattened array - DataType::List(_) => { + DataType::List(_) | DataType::LargeList(_) => { + let sub_list = as_generic_list_array::(&values)?; if let Some(indexes) = indexes { let offsets = get_offsets_for_flatten(offsets, indexes); - flatten_internal(&values, Some(offsets)) + flatten_internal::(sub_list.clone(), Some(offsets)) } else { - flatten_internal(&values, Some(offsets)) + flatten_internal::(sub_list.clone(), Some(offsets)) } } // Reach the base level, create a new list array _ => { if let Some(indexes) = indexes { let offsets = get_offsets_for_flatten(offsets, indexes); - let list_arr = ListArray::new(field, offsets, values, None); + let list_arr = GenericListArray::::new(field, offsets, values, None); Ok(list_arr) } else { Ok(list_arr.clone()) @@ -2292,8 +2295,25 @@ pub fn flatten(args: &[ArrayRef]) -> Result { return exec_err!("flatten expects one argument"); } - let flattened_array = flatten_internal(&args[0], None)?; - Ok(Arc::new(flattened_array) as ArrayRef) + let array_type = args[0].data_type(); + match array_type { + DataType::List(_) => { + let list_arr = as_list_array(&args[0])?; + let flattened_array = flatten_internal::(list_arr.clone(), None)?; + Ok(Arc::new(flattened_array) as ArrayRef) + } + DataType::LargeList(_) => { + let list_arr = as_large_list_array(&args[0])?; + let flattened_array = flatten_internal::(list_arr.clone(), None)?; + Ok(Arc::new(flattened_array) as ArrayRef) + } + DataType::Null => Ok(args[0].clone()), + _ => { + exec_err!("flatten does not support type '{array_type:?}'") + } + } + + // Ok(Arc::new(flattened_array) as ArrayRef) } /// Dispatch array length computation based on the offset type. diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 4fdc428d7a9c..69a1ae8cf47d 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -202,6 +202,17 @@ AS VALUES (make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]), make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0])) ; +statement ok +CREATE TABLE large_flatten_table +AS + SELECT + arrow_cast(column1, 'LargeList(LargeList(Int64))') AS column1, + arrow_cast(column2, 'LargeList(LargeList(LargeList(Int64)))') AS column2, + arrow_cast(column3, 'LargeList(LargeList(LargeList(LargeList(Int64))))') AS column3, + arrow_cast(column4, 'LargeList(LargeList(Float64))') AS column4 + FROM flatten_table +; + statement ok CREATE TABLE array_has_table_1D AS VALUES @@ -5345,6 +5356,12 @@ select array_concat(column1, [7]) from arrays_values_v2; [7] # flatten +# follow DuckDB +query ? +select flatten(NULL); +---- +NULL + query ??? select flatten(make_array(1, 2, 1, 3, 2)), flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))), @@ -5352,12 +5369,25 @@ select flatten(make_array(1, 2, 1, 3, 2)), ---- [1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4] +query ??? +select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')), + flatten(arrow_cast(make_array([1], [2, 3], [null], make_array(4, null, 5)), 'LargeList(LargeList(Int64))')), + flatten(arrow_cast(make_array([[1.1]], [[2.2]], [[3.3], [4.4]]), 'LargeList(LargeList(LargeList(Float64)))')); +---- +[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4] + query ???? select column1, column2, column3, column4 from flatten_table; ---- [[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]] [[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] +query ???? +select column1, column2, column3, column4 from large_flatten_table; +---- +[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]] +[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + query ???? select flatten(column1), flatten(column2), @@ -5368,6 +5398,16 @@ from flatten_table; [1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] [1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +query ???? +select flatten(column1), + flatten(column2), + flatten(column3), + flatten(column4) +from large_flatten_table; +---- +[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + # empty scalar function #1 query B select empty(make_array(1)); @@ -5746,6 +5786,9 @@ drop table fixed_size_nested_arrays_with_repeating_elements; statement ok drop table flatten_table; +statement ok +drop table large_flatten_table; + statement ok drop table arrays_values_without_nulls; From c8fbfc6fbb316c0d2cd6fc946f94a8786d0a2eed Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Sat, 3 Feb 2024 10:24:24 +0800 Subject: [PATCH 2/3] Refactor flatten function and add test cases --- datafusion/expr/src/built_in_function.rs | 12 +++--------- datafusion/sqllogictest/test_files/array.slt | 4 ++++ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index bcc4f8d3c864..44ef86baf994 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -547,15 +547,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Flatten => { fn get_base_type(data_type: &DataType) -> Result { match data_type { - DataType::List(field) => match field.data_type() { - DataType::List(_) => get_base_type(field.data_type()), - _ => Ok(data_type.to_owned()), - }, - DataType::LargeList(field) => match field.data_type() { - DataType::LargeList(_) => get_base_type(field.data_type()), - _ => Ok(data_type.to_owned()), - }, - DataType::Null => Ok(data_type.to_owned()), + DataType::List(field) if matches!(field.data_type(), DataType::List(_)) => get_base_type(field.data_type()), + DataType::LargeList(field) if matches!(field.data_type(), DataType::LargeList(_)) => get_base_type(field.data_type()), + DataType::Null | DataType::List(_) | DataType::LargeList(_) => Ok(data_type.to_owned()), _ => internal_err!("Not reachable, data_type should be List or LargeList"), } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 69a1ae8cf47d..c9303489e916 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5362,6 +5362,7 @@ select flatten(NULL); ---- NULL +# flatten with scalar values #1 query ??? select flatten(make_array(1, 2, 1, 3, 2)), flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))), @@ -5376,6 +5377,7 @@ select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')), ---- [1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4] +# flatten with with column values query ???? select column1, column2, column3, column4 from flatten_table; ---- @@ -5388,6 +5390,7 @@ select column1, column2, column3, column4 from large_flatten_table; [[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]] [[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] +# flatten with column values query ???? select flatten(column1), flatten(column2), @@ -5408,6 +5411,7 @@ from large_flatten_table; [1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] [1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] +## empty # empty scalar function #1 query B select empty(make_array(1)); From dfce271a444fe042ba17cf54baf4d2d6671999db Mon Sep 17 00:00:00 2001 From: Weijun-H Date: Sun, 4 Feb 2024 10:37:14 +0800 Subject: [PATCH 3/3] remove redundant tests --- datafusion/sqllogictest/test_files/array.slt | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c9303489e916..36a656eb7f9e 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5377,19 +5377,6 @@ select flatten(arrow_cast(make_array(1, 2, 1, 3, 2), 'LargeList(Int64)')), ---- [1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4] -# flatten with with column values -query ???? -select column1, column2, column3, column4 from flatten_table; ----- -[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]] -[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] - -query ???? -select column1, column2, column3, column4 from large_flatten_table; ----- -[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]] -[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] - # flatten with column values query ???? select flatten(column1),