Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Correct results for grouping sets when columns contain nulls #12571

Merged
merged 7 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,26 @@ impl DataFrame {
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<DataFrame> {
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
let aggr_expr_len = aggr_expr.len();
let plan = LogicalPlanBuilder::from(self.plan)
.aggregate(group_expr, aggr_expr)?
.build()?;
let plan = if is_grouping_set {
let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len;
// For grouping sets we do a project to not expose the internal grouping id
let exprs = plan
.schema()
.columns()
.into_iter()
.enumerate()
.filter(|(idx, _)| *idx != grouping_id_pos)
.map(|(_, column)| Expr::Column(column))
.collect::<Vec<_>>();
LogicalPlanBuilder::from(plan).project(exprs)?.build()?
} else {
plan
};
Ok(DataFrame {
session_state: self.session_state,
plan,
Expand Down
14 changes: 2 additions & 12 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -692,10 +692,6 @@ impl DefaultPhysicalPlanner {
physical_input_schema.clone(),
)?);

// update group column indices based on partial aggregate plan evaluation
let final_group: Vec<Arc<dyn PhysicalExpr>> =
initial_aggr.output_group_expr();

let can_repartition = !groups.is_empty()
&& session_state.config().target_partitions() > 1
&& session_state.config().repartition_aggregations();
Expand All @@ -716,13 +712,7 @@ impl DefaultPhysicalPlanner {
AggregateMode::Final
};

let final_grouping_set = PhysicalGroupBy::new_single(
final_group
.iter()
.enumerate()
.map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone()))
.collect(),
);
let final_grouping_set = initial_aggr.group_expr().as_final();

Arc::new(AggregateExec::try_new(
next_partition_mode,
Expand Down Expand Up @@ -2345,7 +2335,7 @@ mod tests {
.expect("hash aggregate");
assert_eq!(
"sum(aggregate_test_100.c3)",
final_hash_agg.schema().field(2).name()
final_hash_agg.schema().field(3).name()
);
// we need access to the input to the partial aggregate so that other projects can
// implement serde
Expand Down
56 changes: 55 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use super::dml::CopyTo;
use super::DdlStatement;
Expand Down Expand Up @@ -2965,6 +2965,15 @@ impl Aggregate {
.into_iter()
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
.collect::<Vec<_>>();
qualified_fields.push((
None,
Field::new(
Self::INTERNAL_GROUPING_ID,
Self::grouping_id_type(qualified_fields.len()),
false,
)
.into(),
));
}

qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);
Expand Down Expand Up @@ -3016,9 +3025,19 @@ impl Aggregate {
})
}

fn is_grouping_set(&self) -> bool {
matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)])
}

/// Get the output expressions.
fn output_expressions(&self) -> Result<Vec<&Expr>> {
static INTERNAL_ID_EXPR: OnceLock<Expr> = OnceLock::new();
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
if self.is_grouping_set() {
exprs.push(INTERNAL_ID_EXPR.get_or_init(|| {
Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID))
}));
}
exprs.extend(self.aggr_expr.iter());
debug_assert!(exprs.len() == self.schema.fields().len());
Ok(exprs)
Expand All @@ -3030,6 +3049,41 @@ impl Aggregate {
pub fn group_expr_len(&self) -> Result<usize> {
grouping_set_expr_count(&self.group_expr)
}

/// Returns the data type of the grouping id.
/// The grouping ID value is a bitmask where each set bit
/// indicates that the corresponding grouping expression is
/// null
pub fn grouping_id_type(group_exprs: usize) -> DataType {
if group_exprs <= 8 {
DataType::UInt8
} else if group_exprs <= 16 {
DataType::UInt16
} else if group_exprs <= 32 {
DataType::UInt32
} else {
DataType::UInt64
}
}

/// Internal column used when the aggregation is a grouping set.
///
/// This column contains a bitmask where each bit represents a grouping
/// expression. The least significant bit corresponds to the rightmost
/// grouping expression. A bit value of 0 indicates that the corresponding
/// column is included in the grouping set, while a value of 1 means it is excluded.
///
/// For example, for the grouping expressions CUBE(a, b), the grouping ID
/// column will have the following values:
/// 0b00: Both `a` and `b` are included
/// 0b01: `b` is excluded
/// 0b10: `a` is excluded
/// 0b11: Both `a` and `b` are excluded
///
/// This internal column is necessary because excluded columns are replaced
/// with `NULL` values. To handle these cases correctly, we must distinguish
/// between an actual `NULL` value in a column and a column being excluded from the set.
pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
alamb marked this conversation as resolved.
Show resolved Hide resolved
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
Expand Down
12 changes: 11 additions & 1 deletion datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,17 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result
/// Count the number of distinct exprs in a list of group by expressions. If the
/// first element is a `GroupingSet` expression then it must be the only expr.
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
if group_expr.len() > 1 {
return plan_err!(
"Invalid group by expressions, GroupingSet must be the only expression"
);
}
// Groupings sets have an additional interal column for the grouping id
Ok(grouping_set.distinct_expr().len() + 1)
} else {
grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
}
}

/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
Expand Down
6 changes: 3 additions & 3 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -373,7 +373,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -392,7 +392,7 @@ mod tests {
.build()?;

// Should not be optimized
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand Down
Loading
Loading