diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs index aa246ac95b8b..5d00f300e960 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -93,21 +93,10 @@ impl GroupColumn fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { // Perf: skip null check (by short circuit) if input is not nullable if NULLABLE { - // In nullable path, we should check if both `exist row` and `input row` - // are null/not null - let is_exist_null = self.nulls.is_null(lhs_row); - let null_match = is_exist_null == array.is_null(rhs_row); - if !null_match { - // If `is_null`s in `exist row` and `input row` don't match, return not equal to - return false; - } else if is_exist_null { - // If `is_null`s in `exist row` and `input row` match, and they are `null`s, - // return equal to - // - // NOTICE: we should not check their values when they are `null`s, because they are - // meaningless actually, and not ensured to be same - // - return true; + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; } // Otherwise, we need to check their values } @@ -224,9 +213,14 @@ where where B: ByteArrayType, { - let arr = array.as_bytes::(); - self.nulls.is_null(lhs_row) == arr.is_null(rhs_row) - && self.value(lhs_row) == (arr.value(rhs_row).as_ref() as &[u8]) + let array = array.as_bytes::(); + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + // Otherwise, we need to check their values + self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8]) } /// return the current value of the specified row irrespective of null @@ -382,6 +376,20 @@ where } } +/// Determines if the nullability of the existing and new input array can be used +/// to short-circuit the comparison of the two values. +/// +/// Returns `Some(result)` if the result of the comparison can be determined +/// from the nullness of the two values, and `None` if the comparison must be +/// done on the values themselves. +fn nulls_equal_to(lhs_null: bool, rhs_null: bool) -> Option { + match (lhs_null, rhs_null) { + (true, true) => Some(true), + (false, true) | (true, false) => Some(false), + _ => None, + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -468,13 +476,14 @@ mod tests { builder.append_val(&builder_array, 5); // Define input array - let (_, values, _) = + let (_nulls, values, _) = Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)]) .into_parts(); + // explicitly build a boolean buffer where one of the null values also happens to match let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); boolean_buffer_builder.append(true); - boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); // this sets Some(2) to null above boolean_buffer_builder.append(false); boolean_buffer_builder.append(false); boolean_buffer_builder.append(true); @@ -511,4 +520,63 @@ mod tests { assert!(builder.equal_to(0, &input_array, 0)); assert!(!builder.equal_to(1, &input_array, 1)); } + + #[test] + fn test_byte_array_equal_to() { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + let builder_array = Arc::new(StringArray::from(vec![ + None, + None, + None, + Some("foo"), + Some("bar"), + Some("baz"), + ])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + builder.append_val(&builder_array, 2); + builder.append_val(&builder_array, 3); + builder.append_val(&builder_array, 4); + builder.append_val(&builder_array, 5); + + // Define input array + let (offsets, buffer, _nulls) = StringArray::from(vec![ + Some("foo"), + Some("bar"), + None, + None, + Some("foo"), + Some("baz"), + ]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some("bar") to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = + Arc::new(StringArray::new(offsets, buffer, Some(nulls))) as ArrayRef; + + // Check + assert!(!builder.equal_to(0, &input_array, 0)); + assert!(builder.equal_to(1, &input_array, 1)); + assert!(builder.equal_to(2, &input_array, 2)); + assert!(!builder.equal_to(3, &input_array, 3)); + assert!(!builder.equal_to(4, &input_array, 4)); + assert!(builder.equal_to(5, &input_array, 5)); + } }