From a9d50749d59f11bb4e767d704e9f35d1087d644d Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Tue, 21 Nov 2023 10:49:45 +0100 Subject: [PATCH 1/8] support largelist for arrow_cast --- datafusion/common/src/scalar.rs | 80 +++++++++++++++---- datafusion/sql/src/expr/arrow_cast.rs | 14 ++++ .../sqllogictest/test_files/arrow_typeof.slt | 18 +++++ 3 files changed, 96 insertions(+), 16 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 21cd50dea8c7..ba5d758f0c3c 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -104,6 +104,8 @@ pub enum ScalarValue { /// /// The array must be a ListArray with length 1. List(ArrayRef), + /// The array must be a ListArray with length 1. + LargeList(ArrayRef), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -205,6 +207,8 @@ impl PartialEq for ScalarValue { (FixedSizeList(_), _) => false, (List(v1), List(v2)) => v1.eq(v2), (List(_), _) => false, + (LargeList(v1), LargeList(v2)) => v1.eq(v2), + (LargeList(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), @@ -343,7 +347,38 @@ impl PartialOrd for ScalarValue { None } } + (LargeList(arr1), LargeList(arr2)) => { + if arr1.data_type() == arr2.data_type() { + let list_arr1 = as_large_list_array(arr1); + let list_arr2 = as_large_list_array(arr2); + if list_arr1.len() != list_arr2.len() { + return None; + } + for i in 0..list_arr1.len() { + let arr1 = list_arr1.value(i); + let arr2 = list_arr2.value(i); + + let lt_res = + arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?; + let eq_res = + arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + } + Some(Ordering::Equal) + } else { + None + } + } (List(_), _) => None, + (LargeList(_), _) => None, (FixedSizeList(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, @@ -461,7 +496,7 @@ impl std::hash::Hash for ScalarValue { Binary(v) => v.hash(state), FixedSizeBinary(_, v) => v.hash(state), LargeBinary(v) => v.hash(state), - List(arr) | FixedSizeList(arr) => { + List(arr) | LargeList(arr) | FixedSizeList(arr) => { let arrays = vec![arr.to_owned()]; let hashes_buffer = &mut vec![0; arr.len()]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); @@ -872,9 +907,9 @@ impl ScalarValue { ScalarValue::Binary(_) => DataType::Binary, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { - arr.data_type().to_owned() - } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1065,9 +1100,9 @@ impl ScalarValue { ScalarValue::LargeBinary(v) => v.is_none(), // arr.len() should be 1 for a list scalar, but we don't seem to // enforce that anywhere, so we still check against array length. - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { - arr.len() == arr.null_count() - } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1889,7 +1924,9 @@ impl ScalarValue { .collect::(), ), }, - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { let arrays = std::iter::repeat(arr.as_ref()) .take(size) .collect::>(); @@ -2162,6 +2199,14 @@ impl ScalarValue { ScalarValue::List(arr) } + DataType::LargeList(_) => { + let list_array = as_large_list_array(array); + let nested_array = list_array.value(index); + // Produces a single element `ListArray` with the value at `index`. + let arr = Arc::new(array_into_list_array(nested_array)); + + ScalarValue::LargeList(arr) + } // TODO: There is no test for FixedSizeList now, add it later DataType::FixedSizeList(_, _) => { let list_array = as_fixed_size_list_array(array)?; @@ -2436,7 +2481,9 @@ impl ScalarValue { ScalarValue::LargeBinary(val) => { eq_array_primitive!(array, index, LargeBinaryArray, val)? } - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { let right = array.slice(index, 1); arr == &right } @@ -2562,9 +2609,9 @@ impl ScalarValue { | ScalarValue::LargeBinary(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { - arr.get_array_memory_size() - } + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(vals, fields) => { vals.as_ref() .map(|vals| { @@ -2932,7 +2979,9 @@ impl fmt::Display for ScalarValue { )?, None => write!(f, "NULL")?, }, - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { // ScalarValue List should always have a single element assert_eq!(arr.len(), 1); let options = FormatOptions::default().with_display_error(true); @@ -3015,9 +3064,8 @@ impl fmt::Debug for ScalarValue { ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), - ScalarValue::List(_) => { - write!(f, "List({self})") - } + ScalarValue::List(_) => write!(f, "List({self})"), + ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 8c0184b6d119..ade8b96b5cc2 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -149,6 +149,7 @@ impl<'a> Parser<'a> { Token::Decimal256 => self.parse_decimal_256(), Token::Dictionary => self.parse_dictionary(), Token::List => self.parse_list(), + Token::LargeList => self.parse_large_list(), tok => Err(make_error( self.val, &format!("finding next type, got unexpected '{tok}'"), @@ -166,6 +167,16 @@ impl<'a> Parser<'a> { )))) } + /// Parses the LargeList type + fn parse_large_list(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let data_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::LargeList(Arc::new(Field::new( + "item", data_type, true, + )))) + } + /// Parses the next timeunit fn parse_time_unit(&mut self, context: &str) -> Result { match self.next_token()? { @@ -496,6 +507,7 @@ impl<'a> Tokenizer<'a> { "Date64" => Token::SimpleType(DataType::Date64), "List" => Token::List, + "LargeList" => Token::LargeList, "Second" => Token::TimeUnit(TimeUnit::Second), "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), @@ -585,6 +597,7 @@ enum Token { Integer(i64), DoubleQuotedString(String), List, + LargeList, } impl Display for Token { @@ -592,6 +605,7 @@ impl Display for Token { match self { Token::SimpleType(t) => write!(f, "{t}"), Token::List => write!(f, "List"), + Token::LargeList => write!(f, "LargeList"), Token::Timestamp => write!(f, "Timestamp"), Token::Time32 => write!(f, "Time32"), Token::Time64 => write!(f, "Time64"), diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index e485251b7342..77b4de80023c 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -338,3 +338,21 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))'); + + +## List + + +query ? +select arrow_cast('1', 'List(Int64)'); +---- +[1] + + +## LargeList + + +query ? +select arrow_cast('1', 'LargeList(Int64)'); +---- +[1] From 774fd71fae4a2739dc501c07cae02dd3bd056a94 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Tue, 21 Nov 2023 11:48:56 +0100 Subject: [PATCH 2/8] fix cli --- datafusion/common/src/scalar.rs | 6 +++--- datafusion/common/src/utils.rs | 14 +++++++++++++- datafusion/proto/src/logical_plan/to_proto.rs | 9 ++++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index ba5d758f0c3c..42798aaba463 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -30,7 +30,7 @@ use crate::cast::{ }; use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; -use crate::utils::array_into_list_array; +use crate::utils::{array_into_large_list_array, array_into_list_array}; use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute::kernels::numeric::*; use arrow::datatypes::{i256, Fields, SchemaBuilder}; @@ -2202,8 +2202,8 @@ impl ScalarValue { DataType::LargeList(_) => { let list_array = as_large_list_array(array); let nested_array = list_array.value(index); - // Produces a single element `ListArray` with the value at `index`. - let arr = Arc::new(array_into_list_array(nested_array)); + // Produces a single element `LargeListArray` with the value at `index`. + let arr = Arc::new(array_into_large_list_array(nested_array)); ScalarValue::LargeList(arr) } diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index f031f7880436..12d4f516b4d0 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -25,7 +25,7 @@ use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, ListArray}; +use arrow_array::{Array, LargeListArray, ListArray}; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -349,6 +349,18 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray { ) } +/// Wrap an array into a single element `LargeListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { + let offsets = OffsetBuffer::from_lengths([arr.len()]); + LargeListArray::new( + Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + offsets, + arr, + None, + ) +} + /// Wrap arrays into a single element `ListArray`. /// /// Example: diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index cf66e3ddd5b5..e45402730590 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1140,7 +1140,9 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { } // ScalarValue::List and ScalarValue::FixedSizeList are serialized using // Arrow IPC messages as a single column RecordBatch - ScalarValue::List(arr) | ScalarValue::FixedSizeList(arr) => { + ScalarValue::List(arr) + | ScalarValue::LargeList(arr) + | ScalarValue::FixedSizeList(arr) => { // Wrap in a "field_name" column let batch = RecordBatch::try_from_iter(vec![( "field_name", @@ -1174,6 +1176,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { scalar_list_value, )), }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue( + scalar_list_value, + )), + }), ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { value: Some(protobuf::scalar_value::Value::FixedSizeListValue( scalar_list_value, From 5a5d096e6b1197f03a24873ba0cbc1c5742c507e Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Tue, 21 Nov 2023 15:18:47 +0100 Subject: [PATCH 3/8] update tests; --- .../sqllogictest/test_files/arrow_typeof.slt | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 77b4de80023c..3fad4d0f61b9 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -348,6 +348,16 @@ select arrow_cast('1', 'List(Int64)'); ---- [1] +query ? +select arrow_cast(make_array(1, 2, 3), 'List(Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'List(Int64)')); +---- +List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + ## LargeList @@ -356,3 +366,13 @@ query ? select arrow_cast('1', 'LargeList(Int64)'); ---- [1] + +query ? +select arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')); +---- +LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) \ No newline at end of file From 831d01258be793b294a20d5117a912de5a9a9621 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Tue, 21 Nov 2023 15:53:57 +0100 Subject: [PATCH 4/8] add new_large_list in ScalarValue --- datafusion/common/src/scalar.rs | 65 +++++++++++++++++++ .../tests/cases/roundtrip_logical_plan.rs | 30 +++++++++ 2 files changed, 95 insertions(+) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 42798aaba463..87931125c425 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1797,6 +1797,41 @@ impl ScalarValue { Arc::new(array_into_list_array(values)) } + /// Converts `Vec` where each element has type corresponding to + /// `data_type`, to a [`ListArray`]. + /// + /// Example + /// ``` + /// use datafusion_common::ScalarValue; + /// use arrow::array::{ListArray, Int32Array}; + /// use arrow::datatypes::{DataType, Int32Type}; + /// use datafusion_common::cast::as_large_list_array; + /// + /// let scalars = vec![ + /// ScalarValue::Int32(Some(1)), + /// ScalarValue::Int32(None), + /// ScalarValue::Int32(Some(2)) + /// ]; + /// + /// let array = ScalarValue::new_large_list(&scalars, &DataType::Int32); + /// let result = as_large_list_array(&array).unwrap(); + /// + /// let expected = ListArray::from_iter_primitive::( + /// vec![ + /// Some(vec![Some(1), None, Some(2)]) + /// ]); + /// + /// assert_eq!(result, &expected); + /// ``` + pub fn new_large_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef { + let values = if values.is_empty() { + new_empty_array(data_type) + } else { + Self::iter_to_array(values.iter().cloned()).unwrap() + }; + Arc::new(array_into_large_list_array(values)) + } + /// Converts a scalar value into an array of `size` rows. /// /// # Errors @@ -3692,6 +3727,15 @@ mod tests { assert_eq!(list_array.values().len(), 0); } + #[test] + fn scalar_large_list_null_to_array() { + let list_array_ref = ScalarValue::new_large_list(&[], &DataType::UInt64); + let list_array = as_large_list_array(&list_array_ref); + + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 0); + } + #[test] fn scalar_list_to_array() -> Result<()> { let values = vec![ @@ -3713,6 +3757,27 @@ mod tests { Ok(()) } + #[test] + fn scalar_large_list_to_array() -> Result<()> { + let values = vec![ + ScalarValue::UInt64(Some(100)), + ScalarValue::UInt64(None), + ScalarValue::UInt64(Some(101)), + ]; + let list_array_ref = ScalarValue::new_large_list(&values, &DataType::UInt64); + let list_array = as_large_list_array(&list_array_ref); + assert_eq!(list_array.len(), 1); + assert_eq!(list_array.values().len(), 3); + + let prim_array_ref = list_array.value(0); + let prim_array = as_uint64_array(&prim_array_ref)?; + assert_eq!(prim_array.len(), 3); + assert_eq!(prim_array.value(0), 100); + assert!(prim_array.is_null(1)); + assert_eq!(prim_array.value(2), 101); + Ok(()) + } + /// Creates array directly and via ScalarValue and ensures they are the same macro_rules! check_scalar_iter { ($SCALAR_T:ident, $ARRAYTYPE:ident, $INPUT:expr) => {{ diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2d56967ecffa..c532b58d902b 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -574,6 +574,7 @@ fn round_trip_scalar_values() { ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), ScalarValue::List(ScalarValue::new_list(&[], &DataType::Boolean)), + ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), @@ -674,6 +675,16 @@ fn round_trip_scalar_values() { ], &DataType::Float32, )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), ScalarValue::List(ScalarValue::new_list( &[ ScalarValue::List(ScalarValue::new_list(&[], &DataType::Float32)), @@ -690,6 +701,25 @@ fn round_trip_scalar_values() { ], &DataType::List(new_arc_field("item", DataType::Float32, true)), )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )), + ScalarValue::LargeList(ScalarValue::new_large_list( + &[ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ], + &DataType::Float32, + )), + ], + &DataType::List(new_arc_field("item", DataType::Float32, true)), + )), ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::< Int32Type, _, From 565318721bd6c61a0819597b6408db09413357c5 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Tue, 21 Nov 2023 16:02:58 +0100 Subject: [PATCH 5/8] fix ci --- datafusion/common/src/scalar.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 87931125c425..84ae7145f2a8 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1798,12 +1798,12 @@ impl ScalarValue { } /// Converts `Vec` where each element has type corresponding to - /// `data_type`, to a [`ListArray`]. + /// `data_type`, to a [`LargeListArray`]. /// /// Example /// ``` /// use datafusion_common::ScalarValue; - /// use arrow::array::{ListArray, Int32Array}; + /// use arrow::array::{LargeListArray, Int32Array}; /// use arrow::datatypes::{DataType, Int32Type}; /// use datafusion_common::cast::as_large_list_array; /// @@ -1816,7 +1816,7 @@ impl ScalarValue { /// let array = ScalarValue::new_large_list(&scalars, &DataType::Int32); /// let result = as_large_list_array(&array).unwrap(); /// - /// let expected = ListArray::from_iter_primitive::( + /// let expected = LargeListArray::from_iter_primitive::( /// vec![ /// Some(vec![Some(1), None, Some(2)]) /// ]); From 0f1f9caf57fb8b6e4edc19b1c4b56751dda6a8e0 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Tue, 21 Nov 2023 18:00:34 +0100 Subject: [PATCH 6/8] support LargeList in scalar --- datafusion/common/src/scalar.rs | 314 ++++++++++++++++-- .../tests/cases/roundtrip_logical_plan.rs | 2 +- 2 files changed, 292 insertions(+), 24 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 84ae7145f2a8..0f98371ef281 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1314,10 +1314,10 @@ impl ScalarValue { } macro_rules! build_array_list_primitive { - ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident) => {{ - Ok::(Arc::new(ListArray::from_iter_primitive::<$ARRAY_TY, _, _>( - scalars.into_iter().map(|x| match x { - ScalarValue::List(arr) => { + ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident, $LIST_TY:ident, $SCALAR_LIST:pat) => {{ + Ok::(Arc::new($LIST_TY::from_iter_primitive::<$ARRAY_TY, _, _>( + scalars.into_iter().map(|x| match x{ + ScalarValue::List(arr) if matches!(x, $SCALAR_LIST) => { // `ScalarValue::List` contains a single element `ListArray`. let list_arr = as_list_array(&arr); if list_arr.is_null(0) { @@ -1330,6 +1330,19 @@ impl ScalarValue { )) } } + ScalarValue::LargeList(arr) if matches!(x, $SCALAR_LIST) =>{ + // `ScalarValue::List` contains a single element `ListArray`. + let list_arr = as_large_list_array(&arr); + if list_arr.is_null(0) { + Ok(None) + } else { + let primitive_arr = + list_arr.values().as_primitive::<$ARRAY_TY>(); + Ok(Some( + primitive_arr.into_iter().collect::>>(), + )) + } + } sv => _internal_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", @@ -1342,11 +1355,11 @@ impl ScalarValue { } macro_rules! build_array_list_string { - ($BUILDER:ident, $STRING_ARRAY:ident) => {{ - let mut builder = ListBuilder::new($BUILDER::new()); + ($BUILDER:ident, $STRING_ARRAY:ident,$LIST_BUILDER:ident,$SCALAR_LIST:pat) => {{ + let mut builder = $LIST_BUILDER::new($BUILDER::new()); for scalar in scalars.into_iter() { match scalar { - ScalarValue::List(arr) => { + ScalarValue::List(arr) if matches!(scalar, $SCALAR_LIST) => { // `ScalarValue::List` contains a single element `ListArray`. let list_arr = as_list_array(&arr); @@ -1366,6 +1379,26 @@ impl ScalarValue { } builder.append(true); } + ScalarValue::LargeList(arr) if matches!(scalar, $SCALAR_LIST) => { + // `ScalarValue::List` contains a single element `ListArray`. + let list_arr = as_large_list_array(&arr); + + if list_arr.is_null(0) { + builder.append(false); + continue; + } + + let string_arr = $STRING_ARRAY(list_arr.values()); + + for v in string_arr.iter() { + if let Some(v) = v { + builder.values().append_value(v); + } else { + builder.values().append_null(); + } + } + builder.append(true); + } sv => { return _internal_err!( "Inconsistent types in ScalarValue::iter_to_array. \ @@ -1454,46 +1487,227 @@ impl ScalarValue { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!(Int8Type, Int8, i8)? + build_array_list_primitive!( + Int8Type, + Int8, + i8, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!(Int16Type, Int16, i16)? + build_array_list_primitive!( + Int16Type, + Int16, + i16, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!(Int32Type, Int32, i32)? + build_array_list_primitive!( + Int32Type, + Int32, + i32, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!(Int64Type, Int64, i64)? + build_array_list_primitive!( + Int64Type, + Int64, + i64, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!(UInt8Type, UInt8, u8)? + build_array_list_primitive!( + UInt8Type, + UInt8, + u8, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!(UInt16Type, UInt16, u16)? + build_array_list_primitive!( + UInt16Type, + UInt16, + u16, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!(UInt32Type, UInt32, u32)? + build_array_list_primitive!( + UInt32Type, + UInt32, + u32, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!(UInt64Type, UInt64, u64)? + build_array_list_primitive!( + UInt64Type, + UInt64, + u64, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!(Float32Type, Float32, f32)? + build_array_list_primitive!( + Float32Type, + Float32, + f32, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!(Float64Type, Float64, f64)? + build_array_list_primitive!( + Float64Type, + Float64, + f64, + ListArray, + ScalarValue::List(_) + )? } DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!(StringBuilder, as_string_array) + build_array_list_string!( + StringBuilder, + as_string_array, + ListBuilder, + ScalarValue::List(_) + ) } DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!(LargeStringBuilder, as_largestring_array) + build_array_list_string!( + LargeStringBuilder, + as_largestring_array, + ListBuilder, + ScalarValue::List(_) + ) } DataType::List(_) => { // Fallback case handling homogeneous lists with any ScalarValue element type let list_array = ScalarValue::iter_to_array_list(scalars)?; Arc::new(list_array) } + DataType::LargeList(fields) if fields.data_type() == &DataType::Int8 => { + build_array_list_primitive!( + Int8Type, + Int8, + i8, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Int16 => { + build_array_list_primitive!( + Int16Type, + Int16, + i16, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Int32 => { + build_array_list_primitive!( + Int32Type, + Int32, + i32, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Int64 => { + build_array_list_primitive!( + Int64Type, + Int64, + i64, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::UInt8 => { + build_array_list_primitive!( + UInt8Type, + UInt8, + u8, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::UInt16 => { + build_array_list_primitive!( + UInt16Type, + UInt16, + u16, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::UInt32 => { + build_array_list_primitive!( + UInt32Type, + UInt32, + u32, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::UInt64 => { + build_array_list_primitive!( + UInt64Type, + UInt64, + u64, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Float32 => { + build_array_list_primitive!( + Float32Type, + Float32, + f32, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Float64 => { + build_array_list_primitive!( + Float64Type, + Float64, + f64, + LargeListArray, + ScalarValue::LargeList(_) + )? + } + DataType::LargeList(fields) if fields.data_type() == &DataType::Utf8 => { + build_array_list_string!( + StringBuilder, + as_string_array, + LargeListBuilder, + ScalarValue::LargeList(_) + ) + } + DataType::LargeList(fields) if fields.data_type() == &DataType::LargeUtf8 => { + build_array_list_string!( + LargeStringBuilder, + as_largestring_array, + LargeListBuilder, + ScalarValue::LargeList(_) + ) + } + DataType::LargeList(_) => { + // Fallback case handling homogeneous lists with any ScalarValue element type + let list_array = ScalarValue::iter_to_large_array_list(scalars)?; + Arc::new(list_array) + } DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column let mut columns: Vec> = @@ -1606,7 +1820,6 @@ impl ScalarValue { | DataType::Time64(TimeUnit::Millisecond) | DataType::Duration(_) | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) | DataType::Union(_, _) | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) => { @@ -1674,10 +1887,10 @@ impl ScalarValue { Ok(array) } - /// This function build with nulls with nulls buffer. + /// This function build ListArray with nulls with nulls buffer. fn iter_to_array_list( scalars: impl IntoIterator, - ) -> Result> { + ) -> Result { let mut elements: Vec = vec![]; let mut valid = BooleanBufferBuilder::new(0); let mut offsets = vec![]; @@ -1721,7 +1934,62 @@ impl ScalarValue { let list_array = ListArray::new( Arc::new(Field::new("item", flat_array.data_type().clone(), true)), - OffsetBuffer::::from_lengths(offsets), + OffsetBuffer::from_lengths(offsets), + flat_array, + Some(NullBuffer::new(buffer)), + ); + + Ok(list_array) + } + + /// This function build LargeListArray with nulls with nulls buffer. + fn iter_to_large_array_list( + scalars: impl IntoIterator, + ) -> Result { + let mut elements: Vec = vec![]; + let mut valid = BooleanBufferBuilder::new(0); + let mut offsets = vec![]; + + for scalar in scalars { + if let ScalarValue::List(arr) = scalar { + // `ScalarValue::List` contains a single element `ListArray`. + let list_arr = as_list_array(&arr); + + if list_arr.is_null(0) { + // Repeat previous offset index + offsets.push(0); + + // Element is null + valid.append(false); + } else { + let arr = list_arr.values().to_owned(); + offsets.push(arr.len()); + elements.push(arr); + + // Element is valid + valid.append(true); + } + } else { + return _internal_err!( + "Expected ScalarValue::List element. Received {scalar:?}" + ); + } + } + + // Concatenate element arrays to create single flat array + let element_arrays: Vec<&dyn Array> = + elements.iter().map(|a| a.as_ref()).collect(); + + let flat_array = match arrow::compute::concat(&element_arrays) { + Ok(flat_array) => flat_array, + Err(err) => return Err(DataFusionError::ArrowError(err)), + }; + + let buffer = valid.finish(); + + let list_array = LargeListArray::new( + Arc::new(Field::new("item", flat_array.data_type().clone(), true)), + OffsetBuffer::from_lengths(offsets), flat_array, Some(NullBuffer::new(buffer)), ); diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index c532b58d902b..acc7f07bfa9f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -718,7 +718,7 @@ fn round_trip_scalar_values() { &DataType::Float32, )), ], - &DataType::List(new_arc_field("item", DataType::Float32, true)), + &DataType::LargeList(new_arc_field("item", DataType::Float32, true)), )), ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::from_iter_primitive::< Int32Type, From 21146808613d8c1087f01c2591685b39537af543 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Tue, 21 Nov 2023 18:01:25 +0100 Subject: [PATCH 7/8] modify comment --- datafusion/common/src/scalar.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 0f98371ef281..ffa8ab50f862 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -104,7 +104,7 @@ pub enum ScalarValue { /// /// The array must be a ListArray with length 1. List(ArrayRef), - /// The array must be a ListArray with length 1. + /// The array must be a LargeListArray with length 1. LargeList(ArrayRef), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), From 991f3af49002ed0a3e98bafd970cf7679db7d3bc Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Tue, 21 Nov 2023 19:16:10 +0100 Subject: [PATCH 8/8] support largelist for proto --- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 14 ++++++++++++++ datafusion/proto/src/generated/prost.rs | 4 +++- datafusion/proto/src/logical_plan/from_proto.rs | 5 ++++- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- 5 files changed, 23 insertions(+), 3 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9197343d749e..d43d19f85842 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -984,6 +984,7 @@ message ScalarValue{ // Literal Date32 value always has a unit of day int32 date_32_value = 14; ScalarTime32Value time32_value = 15; + ScalarListValue large_list_value = 16; ScalarListValue list_value = 17; ScalarListValue fixed_size_list_value = 18; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8a6360023794..133bbbee8920 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -21965,6 +21965,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::Time32Value(v) => { struct_ser.serialize_field("time32Value", v)?; } + scalar_value::Value::LargeListValue(v) => { + struct_ser.serialize_field("largeListValue", v)?; + } scalar_value::Value::ListValue(v) => { struct_ser.serialize_field("listValue", v)?; } @@ -22074,6 +22077,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "date32Value", "time32_value", "time32Value", + "large_list_value", + "largeListValue", "list_value", "listValue", "fixed_size_list_value", @@ -22132,6 +22137,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Float64Value, Date32Value, Time32Value, + LargeListValue, ListValue, FixedSizeListValue, Decimal128Value, @@ -22188,6 +22194,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "float64Value" | "float64_value" => Ok(GeneratedField::Float64Value), "date32Value" | "date_32_value" => Ok(GeneratedField::Date32Value), "time32Value" | "time32_value" => Ok(GeneratedField::Time32Value), + "largeListValue" | "large_list_value" => Ok(GeneratedField::LargeListValue), "listValue" | "list_value" => Ok(GeneratedField::ListValue), "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), @@ -22325,6 +22332,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("time32Value")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time32Value) +; + } + GeneratedField::LargeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("largeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeListValue) ; } GeneratedField::ListValue => { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 4fb8e1599e4b..503c4b6c73f1 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1200,7 +1200,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 32, 34" )] pub value: ::core::option::Option, } @@ -1244,6 +1244,8 @@ pub mod scalar_value { Date32Value(i32), #[prost(message, tag = "15")] Time32Value(super::ScalarTime32Value), + #[prost(message, tag = "16")] + LargeListValue(super::ScalarListValue), #[prost(message, tag = "17")] ListValue(super::ScalarListValue), #[prost(message, tag = "18")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 4ae45fa52162..8069e017f797 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -660,7 +660,9 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), // ScalarValue::List is serialized using arrow IPC format - Value::ListValue(scalar_list) | Value::FixedSizeListValue(scalar_list) => { + Value::ListValue(scalar_list) + | Value::FixedSizeListValue(scalar_list) + | Value::LargeListValue(scalar_list) => { let protobuf::ScalarListValue { ipc_message, arrow_data, @@ -703,6 +705,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let arr = record_batch.column(0); match value { Value::ListValue(_) => Self::List(arr.to_owned()), + Value::LargeListValue(_) => Self::LargeList(arr.to_owned()), Value::FixedSizeListValue(_) => Self::FixedSizeList(arr.to_owned()), _ => unreachable!(), } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e45402730590..750eb03e8347 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1177,7 +1177,7 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { )), }), ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( + value: Some(protobuf::scalar_value::Value::LargeListValue( scalar_list_value, )), }),