Skip to content

Commit

Permalink
Fix bug in TopK aggregates (#12766)
Browse files Browse the repository at this point in the history
Fix bug in TopK aggregates (#12766)
  • Loading branch information
avantgardnerio authored Oct 8, 2024
1 parent 7d36059 commit c412c74
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 20 deletions.
46 changes: 26 additions & 20 deletions datafusion/physical-optimizer/src/topk_aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
use std::sync::Arc;

use datafusion_physical_plan::aggregates::AggregateExec;
use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion_physical_plan::filter::FilterExec;
use datafusion_physical_plan::repartition::RepartitionExec;
use datafusion_physical_plan::sorts::sort::SortExec;
use datafusion_physical_plan::ExecutionPlan;

Expand All @@ -31,9 +28,10 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::PhysicalSortExpr;

use crate::PhysicalOptimizerRule;
use datafusion_physical_plan::execution_plan::CardinalityEffect;
use datafusion_physical_plan::projection::ProjectionExec;
use itertools::Itertools;

/// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed
Expand All @@ -48,12 +46,13 @@ impl TopKAggregation {

fn transform_agg(
aggr: &AggregateExec,
order: &PhysicalSortExpr,
order_by: &str,
order_desc: bool,
limit: usize,
) -> Option<Arc<dyn ExecutionPlan>> {
// ensure the sort direction matches aggregate function
let (field, desc) = aggr.get_minmax_desc()?;
if desc != order.options.descending {
if desc != order_desc {
return None;
}
let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
Expand All @@ -66,8 +65,7 @@ impl TopKAggregation {
}

// ensure the sort is on the same field as the aggregate output
let col = order.expr.as_any().downcast_ref::<Column>()?;
if col.name() != field.name() {
if order_by != field.name() {
return None;
}

Expand All @@ -92,31 +90,39 @@ impl TopKAggregation {
let child = children.into_iter().exactly_one().ok()?;
let order = sort.properties().output_ordering()?;
let order = order.iter().exactly_one().ok()?;
let order_desc = order.options.descending;
let order = order.expr.as_any().downcast_ref::<Column>()?;
let mut cur_col_name = order.name().to_string();
let limit = sort.fetch()?;

let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
plan.as_any()
.downcast_ref::<CoalesceBatchesExec>()
.is_some()
|| plan.as_any().downcast_ref::<RepartitionExec>().is_some()
|| plan.as_any().downcast_ref::<FilterExec>().is_some()
};

let mut cardinality_preserved = true;
let closure = |plan: Arc<dyn ExecutionPlan>| {
if !cardinality_preserved {
return Ok(Transformed::no(plan));
}
if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
// either we run into an Aggregate and transform it
match Self::transform_agg(aggr, order, limit) {
match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) {
None => cardinality_preserved = false,
Some(plan) => return Ok(Transformed::yes(plan)),
}
} else if let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() {
// track renames due to successive projections
for (src_expr, proj_name) in proj.expr() {
let Some(src_col) = src_expr.as_any().downcast_ref::<Column>() else {
continue;
};
if *proj_name == cur_col_name {
cur_col_name = src_col.name().to_string();
}
}
} else {
// or we continue down whitelisted nodes of other types
if !is_cardinality_preserving(Arc::clone(&plan)) {
cardinality_preserved = false;
// or we continue down through types that don't reduce cardinality
match plan.cardinality_effect() {
CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {}
CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => {
cardinality_preserved = false;
}
}
}
Ok(Transformed::no(plan))
Expand Down
5 changes: 5 additions & 0 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use datafusion_physical_expr::{
PhysicalExpr, PhysicalSortRequirement,
};

use crate::execution_plan::CardinalityEffect;
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use itertools::Itertools;

Expand Down Expand Up @@ -866,6 +867,10 @@ impl ExecutionPlan for AggregateExec {
}
}
}

fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::LowerEqual
}
}

fn create_schema(
Expand Down
5 changes: 5 additions & 0 deletions datafusion/physical-plan/src/coalesce_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use datafusion_common::Result;
use datafusion_execution::TaskContext;

use crate::coalesce::{BatchCoalescer, CoalescerState};
use crate::execution_plan::CardinalityEffect;
use futures::ready;
use futures::stream::{Stream, StreamExt};

Expand Down Expand Up @@ -199,6 +200,10 @@ impl ExecutionPlan for CoalesceBatchesExec {
fn fetch(&self) -> Option<usize> {
self.fetch
}

fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::Equal
}
}

/// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details.
Expand Down
5 changes: 5 additions & 0 deletions datafusion/physical-plan/src/coalesce_partitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use super::{

use crate::{DisplayFormatType, ExecutionPlan, Partitioning};

use crate::execution_plan::CardinalityEffect;
use datafusion_common::{internal_err, Result};
use datafusion_execution::TaskContext;

Expand Down Expand Up @@ -178,6 +179,10 @@ impl ExecutionPlan for CoalescePartitionsExec {
fn supports_limit_pushdown(&self) -> bool {
true
}

fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::Equal
}
}

#[cfg(test)]
Expand Down
19 changes: 19 additions & 0 deletions datafusion/physical-plan/src/execution_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,11 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync {
fn fetch(&self) -> Option<usize> {
None
}

/// Gets the effect on cardinality, if known
fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::Unknown
}
}

/// Extension trait provides an easy API to fetch various properties of
Expand Down Expand Up @@ -898,6 +903,20 @@ pub fn get_plan_string(plan: &Arc<dyn ExecutionPlan>) -> Vec<String> {
actual.iter().map(|elem| elem.to_string()).collect()
}

/// Indicates the effect an execution plan operator will have on the cardinality
/// of its input stream
pub enum CardinalityEffect {
/// Unknown effect. This is the default
Unknown,
/// The operator is guaranteed to produce exactly one row for
/// each input row
Equal,
/// The operator may produce fewer output rows than it receives input rows
LowerEqual,
/// The operator may produce more output rows than it receives input rows
GreaterEqual,
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
5 changes: 5 additions & 0 deletions datafusion/physical-plan/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use datafusion_physical_expr::{
analyze, split_conjunction, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr,
};

use crate::execution_plan::CardinalityEffect;
use futures::stream::{Stream, StreamExt};
use log::trace;

Expand Down Expand Up @@ -372,6 +373,10 @@ impl ExecutionPlan for FilterExec {
fn statistics(&self) -> Result<Statistics> {
Self::statistics_helper(&self.input, self.predicate(), self.default_selectivity)
}

fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::LowerEqual
}
}

/// This function ensures that all bounds in the `ExprBoundaries` vector are
Expand Down
5 changes: 5 additions & 0 deletions datafusion/physical-plan/src/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use arrow::record_batch::RecordBatch;
use datafusion_common::{internal_err, Result};
use datafusion_execution::TaskContext;

use crate::execution_plan::CardinalityEffect;
use futures::stream::{Stream, StreamExt};
use log::trace;

Expand Down Expand Up @@ -336,6 +337,10 @@ impl ExecutionPlan for LocalLimitExec {
fn supports_limit_pushdown(&self) -> bool {
true
}

fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::LowerEqual
}
}

/// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows.
Expand Down
5 changes: 5 additions & 0 deletions datafusion/physical-plan/src/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::Literal;

use crate::execution_plan::CardinalityEffect;
use futures::stream::{Stream, StreamExt};
use log::trace;

Expand Down Expand Up @@ -233,6 +234,10 @@ impl ExecutionPlan for ProjectionExec {
fn supports_limit_pushdown(&self) -> bool {
true
}

fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::Equal
}
}

/// If e is a direct column reference, returns the field level
Expand Down
5 changes: 5 additions & 0 deletions datafusion/physical-plan/src/repartition/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr};

use crate::execution_plan::CardinalityEffect;
use futures::stream::Stream;
use futures::{FutureExt, StreamExt, TryStreamExt};
use hashbrown::HashMap;
Expand Down Expand Up @@ -669,6 +670,10 @@ impl ExecutionPlan for RepartitionExec {
fn statistics(&self) -> Result<Statistics> {
self.input.statistics()
}

fn cardinality_effect(&self) -> CardinalityEffect {
CardinalityEffect::Equal
}
}

impl RepartitionExec {
Expand Down
9 changes: 9 additions & 0 deletions datafusion/physical-plan/src/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ use datafusion_execution::TaskContext;
use datafusion_physical_expr::LexOrdering;
use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement;

use crate::execution_plan::CardinalityEffect;
use futures::{StreamExt, TryStreamExt};
use log::{debug, trace};

Expand Down Expand Up @@ -972,6 +973,14 @@ impl ExecutionPlan for SortExec {
fn fetch(&self) -> Option<usize> {
self.fetch
}

fn cardinality_effect(&self) -> CardinalityEffect {
if self.fetch.is_none() {
CardinalityEffect::Equal
} else {
CardinalityEffect::LowerEqual
}
}
}

#[cfg(test)]
Expand Down
11 changes: 11 additions & 0 deletions datafusion/sqllogictest/test_files/aggregates_topk.slt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ physical_plan
07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)]
08)--------------MemoryExec: partitions=1, partition_sizes=[1]

query TI
select * from (select trace_id, MAX(timestamp) max_ts from traces t group by trace_id) where trace_id != 'b' order by max_ts desc limit 3;
----
c 4
a 1

query TI
select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4;
Expand Down Expand Up @@ -89,6 +94,12 @@ c 1 2
statement ok
set datafusion.optimizer.enable_topk_aggregation = true;

query TI
select * from (select trace_id, MAX(timestamp) max_ts from traces t group by trace_id) where max_ts != 3 order by max_ts desc limit 2;
----
c 4
a 1

query TT
explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4;
----
Expand Down

0 comments on commit c412c74

Please sign in to comment.