diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index e2d22a0d3328..c0f6c67263a7 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1748,70 +1748,27 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -macro_rules! non_list_contains { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let sub_array = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); - let mut boolean_builder = BooleanArray::builder($ARRAY.len()); - - for (arr, elem) in $ARRAY.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(elem)) = (arr, elem) { - let arr = downcast_arg!(arr, $ARRAY_TYPE); - let res = arr.iter().dedup().flatten().any(|x| x == elem); - boolean_builder.append_value(res); - } - } - Ok(Arc::new(boolean_builder.finish())) - }}; -} - /// Array_has SQL function pub fn array_has(args: &[ArrayRef]) -> Result { let array = as_list_array(&args[0])?; let element = &args[1]; check_datatypes("array_has", &[array.values(), element])?; - match element.data_type() { - DataType::List(_) => { - let sub_array = as_list_array(element)?; - let mut boolean_builder = BooleanArray::builder(array.len()); - - for (arr, elem) in array.iter().zip(sub_array.iter()) { - if let (Some(arr), Some(elem)) = (arr, elem) { - let list_arr = as_list_array(&arr)?; - let res = list_arr.iter().dedup().flatten().any(|x| *x == *elem); - boolean_builder.append_value(res); - } - } - Ok(Arc::new(boolean_builder.finish())) - } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - non_list_contains!(array, element, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) - } - } -} - -macro_rules! array_has_any_non_list_check { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); + let mut boolean_builder = BooleanArray::builder(array.len()); - let mut res = false; - for elem in sub_arr.iter().dedup() { - if let Some(elem) = elem { - res |= arr.iter().dedup().flatten().any(|x| x == elem); - } else { - return internal_err!( - "array_has_any does not support Null type for element in sub_array" - ); - } + let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; + let r_values = converter.convert_columns(&[element.clone()])?; + for (row_idx, arr) in array.iter().enumerate() { + if let Some(arr) = arr { + let arr_values = converter.convert_columns(&[arr])?; + let res = arr_values + .iter() + .dedup() + .any(|x| x == r_values.row(row_idx)); + boolean_builder.append_value(res); } - res - }}; + } + Ok(Arc::new(boolean_builder.finish())) } /// Array_has_any SQL function @@ -1820,55 +1777,27 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result { let array = as_list_array(&args[0])?; let sub_array = as_list_array(&args[1])?; - let mut boolean_builder = BooleanArray::builder(array.len()); + + let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let res = match arr.data_type() { - DataType::List(_) => { - let arr = downcast_arg!(arr, ListArray); - let sub_arr = downcast_arg!(sub_arr, ListArray); - - let mut res = false; - for elem in sub_arr.iter().dedup().flatten() { - res |= arr.iter().dedup().flatten().any(|x| *x == *elem); - } - res + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = converter.convert_columns(&[sub_arr])?; + + let mut res = false; + for elem in sub_arr_values.iter().dedup() { + res |= arr_values.iter().dedup().any(|x| x == elem); + if res { + break; } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - array_has_any_non_list_check!(arr, sub_arr, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) - } - }; + } boolean_builder.append_value(res); } } Ok(Arc::new(boolean_builder.finish())) } -macro_rules! array_has_all_non_list_check { - ($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE); - - let mut res = true; - for elem in sub_arr.iter().dedup() { - if let Some(elem) = elem { - res &= arr.iter().dedup().flatten().any(|x| x == elem); - } else { - return internal_err!( - "array_has_all does not support Null type for element in sub_array" - ); - } - } - res - }}; -} - /// Array_has_all SQL function pub fn array_has_all(args: &[ArrayRef]) -> Result { check_datatypes("array_has_all", &[&args[0], &args[1]])?; @@ -1877,28 +1806,20 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result { let sub_array = as_list_array(&args[1])?; let mut boolean_builder = BooleanArray::builder(array.len()); + + let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; for (arr, sub_arr) in array.iter().zip(sub_array.iter()) { if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let res = match arr.data_type() { - DataType::List(_) => { - let arr = downcast_arg!(arr, ListArray); - let sub_arr = downcast_arg!(sub_arr, ListArray); - - let mut res = true; - for elem in sub_arr.iter().dedup().flatten() { - res &= arr.iter().dedup().flatten().any(|x| *x == *elem); - } - res - } - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - array_has_all_non_list_check!(arr, sub_arr, $ARRAY_TYPE) - }; - } - call_array_function!(data_type, false) + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = converter.convert_columns(&[sub_arr])?; + + let mut res = true; + for elem in sub_arr_values.iter().dedup() { + res &= arr_values.iter().dedup().any(|x| x == elem); + if !res { + break; } - }; + } boolean_builder.append_value(res); } }