Skip to content

Commit

Permalink
Fix equal_to in ByteGroupValueBuilder (#12770)
Browse files Browse the repository at this point in the history
* Fix `equal_to` in `ByteGroupValueBuilder`

* refactor null_equal_to

* Update datafusion/physical-plan/src/aggregates/group_values/group_column.rs
  • Loading branch information
alamb authored Oct 6, 2024
1 parent 6f8c74c commit 18f9201
Showing 1 changed file with 88 additions and 20 deletions.
108 changes: 88 additions & 20 deletions datafusion/physical-plan/src/aggregates/group_values/group_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,10 @@ impl<T: ArrowPrimitiveType, const NULLABLE: bool> GroupColumn
fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool {
// Perf: skip null check (by short circuit) if input is not nullable
if NULLABLE {
// In nullable path, we should check if both `exist row` and `input row`
// are null/not null
let is_exist_null = self.nulls.is_null(lhs_row);
let null_match = is_exist_null == array.is_null(rhs_row);
if !null_match {
// If `is_null`s in `exist row` and `input row` don't match, return not equal to
return false;
} else if is_exist_null {
// If `is_null`s in `exist row` and `input row` match, and they are `null`s,
// return equal to
//
// NOTICE: we should not check their values when they are `null`s, because they are
// meaningless actually, and not ensured to be same
//
return true;
let exist_null = self.nulls.is_null(lhs_row);
let input_null = array.is_null(rhs_row);
if let Some(result) = nulls_equal_to(exist_null, input_null) {
return result;
}
// Otherwise, we need to check their values
}
Expand Down Expand Up @@ -224,9 +213,14 @@ where
where
B: ByteArrayType,
{
let arr = array.as_bytes::<B>();
self.nulls.is_null(lhs_row) == arr.is_null(rhs_row)
&& self.value(lhs_row) == (arr.value(rhs_row).as_ref() as &[u8])
let array = array.as_bytes::<B>();
let exist_null = self.nulls.is_null(lhs_row);
let input_null = array.is_null(rhs_row);
if let Some(result) = nulls_equal_to(exist_null, input_null) {
return result;
}
// Otherwise, we need to check their values
self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8])
}

/// return the current value of the specified row irrespective of null
Expand Down Expand Up @@ -382,6 +376,20 @@ where
}
}

/// Determines if the nullability of the existing and new input array can be used
/// to short-circuit the comparison of the two values.
///
/// Returns `Some(result)` if the result of the comparison can be determined
/// from the nullness of the two values, and `None` if the comparison must be
/// done on the values themselves.
fn nulls_equal_to(lhs_null: bool, rhs_null: bool) -> Option<bool> {
match (lhs_null, rhs_null) {
(true, true) => Some(true),
(false, true) | (true, false) => Some(false),
_ => None,
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
Expand Down Expand Up @@ -468,13 +476,14 @@ mod tests {
builder.append_val(&builder_array, 5);

// Define input array
let (_, values, _) =
let (_nulls, values, _) =
Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)])
.into_parts();

// explicitly build a boolean buffer where one of the null values also happens to match
let mut boolean_buffer_builder = BooleanBufferBuilder::new(6);
boolean_buffer_builder.append(true);
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(false); // this sets Some(2) to null above
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(true);
Expand Down Expand Up @@ -511,4 +520,63 @@ mod tests {
assert!(builder.equal_to(0, &input_array, 0));
assert!(!builder.equal_to(1, &input_array, 1));
}

#[test]
fn test_byte_array_equal_to() {
// Will cover such cases:
// - exist null, input not null
// - exist null, input null; values not equal
// - exist null, input null; values equal
// - exist not null, input null
// - exist not null, input not null; values not equal
// - exist not null, input not null; values equal

// Define PrimitiveGroupValueBuilder
let mut builder = ByteGroupValueBuilder::<i32>::new(OutputType::Utf8);
let builder_array = Arc::new(StringArray::from(vec![
None,
None,
None,
Some("foo"),
Some("bar"),
Some("baz"),
])) as ArrayRef;
builder.append_val(&builder_array, 0);
builder.append_val(&builder_array, 1);
builder.append_val(&builder_array, 2);
builder.append_val(&builder_array, 3);
builder.append_val(&builder_array, 4);
builder.append_val(&builder_array, 5);

// Define input array
let (offsets, buffer, _nulls) = StringArray::from(vec![
Some("foo"),
Some("bar"),
None,
None,
Some("foo"),
Some("baz"),
])
.into_parts();

// explicitly build a boolean buffer where one of the null values also happens to match
let mut boolean_buffer_builder = BooleanBufferBuilder::new(6);
boolean_buffer_builder.append(true);
boolean_buffer_builder.append(false); // this sets Some("bar") to null above
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(false);
boolean_buffer_builder.append(true);
boolean_buffer_builder.append(true);
let nulls = NullBuffer::new(boolean_buffer_builder.finish());
let input_array =
Arc::new(StringArray::new(offsets, buffer, Some(nulls))) as ArrayRef;

// Check
assert!(!builder.equal_to(0, &input_array, 0));
assert!(builder.equal_to(1, &input_array, 1));
assert!(builder.equal_to(2, &input_array, 2));
assert!(!builder.equal_to(3, &input_array, 3));
assert!(!builder.equal_to(4, &input_array, 4));
assert!(builder.equal_to(5, &input_array, 5));
}
}

0 comments on commit 18f9201

Please sign in to comment.