Skip to content

Commit

Permalink
Make FirstValue an UDAF, Change AggregateUDFImpl::accumulator signa…
Browse files Browse the repository at this point in the history
…ture, support ORDER BY for UDAFs (#9874)

* first draft

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* clippy fix

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* use one vector for ordering req

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add sort exprs to accumulator

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* clippy

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix doc test

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* change to ref

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix typo

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix doc

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* move schema and logical ordering exprs

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* remove redudant info

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rename

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add ignore nulls

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix conflict

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* backup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* complete return_type

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* complete replace

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* split to first value udf

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* replace accumulator

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* small fix

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* remove ordering types

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* make state fields more flexible

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* replace done

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm comments

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm test1

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix state fields

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* args struct for accumulator

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* simplify

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add sig

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add comments

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fmt

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix docs

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* use exprs utils

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm state type

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add comment

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 authored Apr 3, 2024
1 parent d2ba901 commit dfd4442
Show file tree
Hide file tree
Showing 24 changed files with 450 additions and 134 deletions.
19 changes: 14 additions & 5 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_physical_expr::NullState;
use std::{any::Any, sync::Arc};
Expand All @@ -30,7 +31,8 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
GroupsAccumulator, Signature,
};

/// This example shows how to use the full AggregateUDFImpl API to implement a user
Expand Down Expand Up @@ -85,13 +87,21 @@ impl AggregateUDFImpl for GeoMeanUdaf {
/// is supported, DataFusion will use this row oriented
/// accumulator when the aggregate function is used as a window function
/// or when there are only aggregates (no GROUP BY columns) in the plan.
fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(GeometricMean::new()))
}

/// This is the description of the state. accumulator's state() must match the types here.
fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
Ok(vec![DataType::Float64, DataType::UInt32])
fn state_fields(
&self,
_name: &str,
value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
Ok(vec![
Field::new("prod", value_type, true),
Field::new("n", DataType::UInt32, true),
])
}

/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator`
Expand Down Expand Up @@ -191,7 +201,6 @@ impl Accumulator for GeometricMean {

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Expand Down
20 changes: 20 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@ use datafusion_common::{
OwnedTableReference, SchemaReference,
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::{create_first_value, Signature, Volatility};
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
var_provider::is_system_variables,
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
use datafusion_physical_expr::create_first_value_accumulator;
use datafusion_sql::{
parser::{CopyToSource, CopyToStatement, DFParser},
planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel},
Expand All @@ -82,6 +85,7 @@ use datafusion_sql::{

use async_trait::async_trait;
use chrono::{DateTime, Utc};
use log::debug;
use parking_lot::RwLock;
use sqlparser::dialect::dialect_from_str;
use url::Url;
Expand Down Expand Up @@ -1451,6 +1455,22 @@ impl SessionState {
datafusion_functions_array::register_all(&mut new_self)
.expect("can not register array expressions");

let first_value = create_first_value(
"FIRST_VALUE",
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
Arc::new(create_first_value_accumulator),
);

match new_self.register_udaf(Arc::new(first_value)) {
Ok(Some(existing_udaf)) => {
debug!("Overwrite existing UDAF: {}", existing_udaf.name());
}
Ok(None) => {}
Err(err) => {
panic!("Failed to register UDAF: {}", err);
}
}

new_self
}
/// Returns new [`SessionState`] using the provided
Expand Down
50 changes: 31 additions & 19 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,24 +247,20 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
distinct,
args,
filter,
order_by,
order_by: _,
null_treatment: _,
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(..) => {
create_function_physical_name(func_def.name(), *distinct, args)
}
AggregateFunctionDefinition::UDF(fun) => {
// TODO: Add support for filter and order by in AggregateUDF
// TODO: Add support for filter by in AggregateUDF
if filter.is_some() {
return exec_err!(
"aggregate expression with filter is not supported"
);
}
if order_by.is_some() {
return exec_err!(
"aggregate expression with order_by is not supported"
);
}

let names = args
.iter()
.map(|e| create_physical_name(e, false))
Expand Down Expand Up @@ -1667,20 +1663,22 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
)?),
None => None,
};
let order_by = match order_by {
Some(e) => Some(create_physical_sort_exprs(
e,
logical_input_schema,
execution_props,
)?),
None => None,
};

let ignore_nulls = null_treatment
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;
let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let ordering_reqs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
logical_input_schema,
execution_props,
)?),
None => None,
};
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = aggregates::create_aggregate_expr(
fun,
*distinct,
Expand All @@ -1690,16 +1688,30 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
name,
ignore_nulls,
)?;
(agg_expr, filter, order_by)
(agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::UDF(fun) => {
let sort_exprs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
logical_input_schema,
execution_props,
)?),
None => None,
};
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = udaf::create_aggregate_expr(
fun,
&args,
&sort_exprs,
&ordering_reqs,
physical_input_schema,
name,
);
(agg_expr?, filter, order_by)
ignore_nulls,
)?;
(agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::Name(_) => {
return internal_err!(
Expand Down
20 changes: 11 additions & 9 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;

Expand Down Expand Up @@ -491,7 +492,7 @@ impl TimeSum {
// Returns the same type as its input
let return_type = timestamp_type.clone();

let state_type = vec![timestamp_type.clone()];
let state_fields = vec![Field::new("sum", timestamp_type, true)];

let volatility = Volatility::Immutable;

Expand All @@ -505,7 +506,7 @@ impl TimeSum {
return_type,
volatility,
accumulator,
state_type,
state_fields,
));

// register the selector as "time_sum"
Expand Down Expand Up @@ -591,6 +592,11 @@ impl FirstSelector {
fn register(ctx: &mut SessionContext) {
let return_type = Self::output_datatype();
let state_type = Self::state_datatypes();
let state_fields = state_type
.into_iter()
.enumerate()
.map(|(i, t)| Field::new(format!("{i}"), t, true))
.collect::<Vec<_>>();

// Possible input signatures
let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];
Expand All @@ -607,7 +613,7 @@ impl FirstSelector {
Signature::one_of(signatures, volatility),
return_type,
accumulator,
state_type,
state_fields,
));

// register the selector as "first"
Expand Down Expand Up @@ -717,15 +723,11 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
Ok(DataType::UInt64)
}

fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
// should use groups accumulator
panic!("accumulator shouldn't invoke");
}

fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
Ok(vec![DataType::UInt64])
}

fn groups_accumulator_supported(&self) -> bool {
true
}
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,14 +577,15 @@ impl AggregateFunction {
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
null_treatment: Option<NullTreatment>,
) -> Self {
Self {
func_def: AggregateFunctionDefinition::UDF(udf),
args,
distinct,
filter,
order_by,
null_treatment: None,
null_treatment,
}
}
}
Expand Down
Loading

0 comments on commit dfd4442

Please sign in to comment.