Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace array_has/array_has_all/array_has_any macro to remove duplicate code #8263

Merged
merged 3 commits into from
Nov 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 35 additions & 114 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1748,70 +1748,27 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}

macro_rules! non_list_contains {
($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{
let sub_array = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE);
let mut boolean_builder = BooleanArray::builder($ARRAY.len());

for (arr, elem) in $ARRAY.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(elem)) = (arr, elem) {
let arr = downcast_arg!(arr, $ARRAY_TYPE);
let res = arr.iter().dedup().flatten().any(|x| x == elem);
boolean_builder.append_value(res);
}
}
Ok(Arc::new(boolean_builder.finish()))
}};
}

/// Array_has SQL function
pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
let array = as_list_array(&args[0])?;
let element = &args[1];

check_datatypes("array_has", &[array.values(), element])?;
match element.data_type() {
DataType::List(_) => {
let sub_array = as_list_array(element)?;
let mut boolean_builder = BooleanArray::builder(array.len());

for (arr, elem) in array.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(elem)) = (arr, elem) {
let list_arr = as_list_array(&arr)?;
let res = list_arr.iter().dedup().flatten().any(|x| *x == *elem);
boolean_builder.append_value(res);
}
}
Ok(Arc::new(boolean_builder.finish()))
}
data_type => {
macro_rules! array_function {
($ARRAY_TYPE:ident) => {
non_list_contains!(array, element, $ARRAY_TYPE)
};
}
call_array_function!(data_type, false)
}
}
}

macro_rules! array_has_any_non_list_check {
($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{
let arr = downcast_arg!($ARRAY, $ARRAY_TYPE);
let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE);
let mut boolean_builder = BooleanArray::builder(array.len());

let mut res = false;
for elem in sub_arr.iter().dedup() {
if let Some(elem) = elem {
res |= arr.iter().dedup().flatten().any(|x| x == elem);
} else {
return internal_err!(
"array_has_any does not support Null type for element in sub_array"
);
}
let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
let r_values = converter.convert_columns(&[element.clone()])?;
for (row_idx, arr) in array.iter().enumerate() {
if let Some(arr) = arr {
let arr_values = converter.convert_columns(&[arr])?;
let res = arr_values
.iter()
.dedup()
.any(|x| x == r_values.row(row_idx));
boolean_builder.append_value(res);
}
res
}};
}
Ok(Arc::new(boolean_builder.finish()))
}

/// Array_has_any SQL function
Expand All @@ -1820,55 +1777,27 @@ pub fn array_has_any(args: &[ArrayRef]) -> Result<ArrayRef> {

let array = as_list_array(&args[0])?;
let sub_array = as_list_array(&args[1])?;

let mut boolean_builder = BooleanArray::builder(array.len());

let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
for (arr, sub_arr) in array.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let res = match arr.data_type() {
DataType::List(_) => {
let arr = downcast_arg!(arr, ListArray);
let sub_arr = downcast_arg!(sub_arr, ListArray);

let mut res = false;
for elem in sub_arr.iter().dedup().flatten() {
res |= arr.iter().dedup().flatten().any(|x| *x == *elem);
}
res
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = converter.convert_columns(&[sub_arr])?;

let mut res = false;
for elem in sub_arr_values.iter().dedup() {
res |= arr_values.iter().dedup().any(|x| x == elem);
if res {
break;
}
data_type => {
macro_rules! array_function {
($ARRAY_TYPE:ident) => {
array_has_any_non_list_check!(arr, sub_arr, $ARRAY_TYPE)
};
}
call_array_function!(data_type, false)
}
};
}
boolean_builder.append_value(res);
}
}
Ok(Arc::new(boolean_builder.finish()))
}

macro_rules! array_has_all_non_list_check {
($ARRAY:expr, $SUB_ARRAY:expr, $ARRAY_TYPE:ident) => {{
let arr = downcast_arg!($ARRAY, $ARRAY_TYPE);
let sub_arr = downcast_arg!($SUB_ARRAY, $ARRAY_TYPE);

let mut res = true;
for elem in sub_arr.iter().dedup() {
if let Some(elem) = elem {
res &= arr.iter().dedup().flatten().any(|x| x == elem);
} else {
return internal_err!(
"array_has_all does not support Null type for element in sub_array"
);
}
}
res
}};
}

/// Array_has_all SQL function
pub fn array_has_all(args: &[ArrayRef]) -> Result<ArrayRef> {
check_datatypes("array_has_all", &[&args[0], &args[1]])?;
Expand All @@ -1877,28 +1806,20 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result<ArrayRef> {
let sub_array = as_list_array(&args[1])?;

let mut boolean_builder = BooleanArray::builder(array.len());

let converter = RowConverter::new(vec![SortField::new(array.value_type())])?;
for (arr, sub_arr) in array.iter().zip(sub_array.iter()) {
if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) {
let res = match arr.data_type() {
DataType::List(_) => {
let arr = downcast_arg!(arr, ListArray);
let sub_arr = downcast_arg!(sub_arr, ListArray);

let mut res = true;
for elem in sub_arr.iter().dedup().flatten() {
res &= arr.iter().dedup().flatten().any(|x| *x == *elem);
}
res
}
data_type => {
macro_rules! array_function {
($ARRAY_TYPE:ident) => {
array_has_all_non_list_check!(arr, sub_arr, $ARRAY_TYPE)
};
}
call_array_function!(data_type, false)
let arr_values = converter.convert_columns(&[arr])?;
let sub_arr_values = converter.convert_columns(&[sub_arr])?;

let mut res = true;
for elem in sub_arr_values.iter().dedup() {
res &= arr_values.iter().dedup().any(|x| x == elem);
if !res {
break;
}
};
}
boolean_builder.append_value(res);
}
}
Expand Down