Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FirstValue an UDAF, Change AggregateUDFImpl::accumulator signature, support ORDER BY for UDAFs #9874

Merged
merged 49 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
b94f70f
first draft
jayzhan211 Feb 16, 2024
c743d13
clippy fix
jayzhan211 Feb 18, 2024
3a7e965
cleanup
jayzhan211 Feb 18, 2024
4917f56
use one vector for ordering req
jayzhan211 Feb 21, 2024
c9e8641
add sort exprs to accumulator
jayzhan211 Feb 21, 2024
3a5f0d1
clippy
jayzhan211 Feb 21, 2024
a3ea00a
cleanup
jayzhan211 Feb 21, 2024
f349f21
fix doc test
jayzhan211 Feb 21, 2024
6fcdaac
change to ref
jayzhan211 Feb 27, 2024
c3512a6
fix typo
jayzhan211 Feb 27, 2024
092d46e
fix doc
jayzhan211 Feb 27, 2024
8592e6b
fmt
jayzhan211 Mar 1, 2024
0f8fc24
move schema and logical ordering exprs
jayzhan211 Mar 1, 2024
3185f9f
remove redudant info
jayzhan211 Mar 1, 2024
3ecc772
rename
jayzhan211 Mar 1, 2024
faadc63
cleanup
jayzhan211 Mar 1, 2024
7e33910
add ignore nulls
jayzhan211 Mar 7, 2024
cfffcbf
Merge remote-tracking branch 'upstream/main' into udf-order-2
jayzhan211 Mar 25, 2024
6aaa15c
fix conflict
jayzhan211 Mar 25, 2024
b74b7d2
backup
jayzhan211 Mar 26, 2024
263e6cb
complete return_type
jayzhan211 Mar 26, 2024
0a77e4f
complete replace
jayzhan211 Mar 30, 2024
7b26377
split to first value udf
jayzhan211 Mar 30, 2024
4bfd91d
replace accumulator
jayzhan211 Mar 30, 2024
7f54141
fmt
jayzhan211 Mar 30, 2024
6339535
cleanup
jayzhan211 Mar 30, 2024
33ae6ee
small fix
jayzhan211 Mar 30, 2024
b4eb865
remove ordering types
jayzhan211 Mar 30, 2024
d8ab6c5
make state fields more flexible
jayzhan211 Mar 30, 2024
a3bff42
cleanup
jayzhan211 Mar 30, 2024
53465fd
replace done
jayzhan211 Mar 30, 2024
cc21496
cleanup
jayzhan211 Mar 30, 2024
b62544f
cleanup
jayzhan211 Mar 30, 2024
ddfabad
Merge remote-tracking branch 'upstream/main' into first-value-udf
jayzhan211 Mar 30, 2024
4b809b0
rm comments
jayzhan211 Mar 30, 2024
2534727
cleanup
jayzhan211 Mar 30, 2024
17378dd
rm test1
jayzhan211 Mar 30, 2024
dd1c4ba
fix state fields
jayzhan211 Mar 31, 2024
5d5d310
fmt
jayzhan211 Mar 31, 2024
23f20f9
args struct for accumulator
jayzhan211 Mar 31, 2024
b2ba8c3
simplify
jayzhan211 Mar 31, 2024
75aa2fe
add sig
jayzhan211 Mar 31, 2024
5b9625f
add comments
jayzhan211 Mar 31, 2024
d5c3f6f
fmt
jayzhan211 Mar 31, 2024
dc9549a
fix docs
jayzhan211 Apr 1, 2024
7ce3d41
Merge remote-tracking branch 'upstream/main' into first-value-udf
jayzhan211 Apr 1, 2024
49b4a76
use exprs utils
jayzhan211 Apr 1, 2024
d70cce5
rm state type
jayzhan211 Apr 2, 2024
29c4018
add comment
jayzhan211 Apr 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::Schema;
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_physical_expr::NullState;
use std::{any::Any, sync::Arc};
Expand Down Expand Up @@ -85,7 +86,14 @@ impl AggregateUDFImpl for GeoMeanUdaf {
/// is supported, DataFusion will use this row oriented
/// accumulator when the aggregate function is used as a window function
/// or when there are only aggregates (no GROUP BY columns) in the plan.
fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
fn accumulator(
&self,
_arg: &DataType,
_sort_exprs: &[Expr],
_schema: &Schema,
_ignore_nulls: bool,
_requirement_satisfied: bool,
) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(GeometricMean::new()))
}

Expand Down Expand Up @@ -191,7 +199,7 @@ impl Accumulator for GeometricMean {

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::arrow::datatypes::Field;
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async fn main() -> Result<()> {
Arc::new(DataType::Float64),
Volatility::Immutable,
// This is the accumulator factory; DataFusion uses it to create new accumulators.
Arc::new(|_| Ok(Box::new(GeometricMean::new()))),
Arc::new(|_, _, _| Ok(Box::new(GeometricMean::new()))),
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);
Expand Down
10 changes: 10 additions & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@ use datafusion_common::{
OwnedTableReference, SchemaReference,
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::{create_first_value, Signature, Volatility};
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
var_provider::is_system_variables,
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
use datafusion_physical_expr::create_first_value_accumulator;
use datafusion_sql::{
parser::{CopyToSource, CopyToStatement, DFParser},
planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel},
Expand Down Expand Up @@ -1457,6 +1460,13 @@ impl SessionState {
datafusion_functions_array::register_all(&mut new_self)
.expect("can not register array expressions");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also panic if register fails here


let first_value = create_first_value(
"FIRST_VALUE",
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
Arc::new(create_first_value_accumulator),
);
let _ = new_self.register_udaf(Arc::new(first_value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the return value ignored? If it returns a Result maybe it could be new_self.register_udaf(Arc::new(first_value))?;. If it returns an option, leaving it new_self.register_udaf(Arc::new(first_value)); might make the intent clearer than let _


new_self
}
/// Returns new [`SessionState`] using the provided
Expand Down
70 changes: 49 additions & 21 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,24 +247,20 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
distinct,
args,
filter,
order_by,
order_by: _,
null_treatment: _,
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(..) => {
create_function_physical_name(func_def.name(), *distinct, args)
}
AggregateFunctionDefinition::UDF(fun) => {
// TODO: Add support for filter and order by in AggregateUDF
// TODO: Add support for filter by in AggregateUDF
if filter.is_some() {
return exec_err!(
"aggregate expression with filter is not supported"
);
}
if order_by.is_some() {
return exec_err!(
"aggregate expression with order_by is not supported"
);
}

let names = args
.iter()
.map(|e| create_physical_name(e, false))
Expand Down Expand Up @@ -1657,8 +1653,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
order_by,
null_treatment,
}) => {
let args =
create_physical_exprs(args, logical_input_schema, execution_props)?;
let args = args
.iter()
.map(|e| create_physical_expr(e, logical_input_schema, execution_props))
.collect::<Result<Vec<_>>>()?;

let filter = match filter {
Some(e) => Some(create_physical_expr(
e,
Expand All @@ -1667,20 +1666,28 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
)?),
None => None,
};
let order_by = match order_by {
Some(e) => Some(create_physical_sort_exprs(
e,
logical_input_schema,
execution_props,
)?),
None => None,
};
let ignore_nulls = null_treatment
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;
let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let ordering_reqs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Some(e) => Some(
e.iter()
.map(|expr| {
create_physical_sort_expr(
expr,
logical_input_schema,
execution_props,
)
})
.collect::<Result<Vec<_>>>()?,
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved
),
None => None,
};
jayzhan211 marked this conversation as resolved.
Show resolved Hide resolved

let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = aggregates::create_aggregate_expr(
fun,
*distinct,
Expand All @@ -1690,16 +1697,37 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
name,
ignore_nulls,
)?;
(agg_expr, filter, order_by)
(agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::UDF(fun) => {
let sort_exprs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Some(e) => Some(
e.iter()
.map(|expr| {
create_physical_sort_expr(
expr,
logical_input_schema,
execution_props,
)
})
.collect::<Result<Vec<_>>>()?,
),
None => None,
};

let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = udaf::create_aggregate_expr(
fun,
&args,
&sort_exprs,
&ordering_reqs,
physical_input_schema,
name,
);
(agg_expr?, filter, order_by)
ignore_nulls,
)?;
(agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::Name(_) => {
return internal_err!(
Expand Down
23 changes: 15 additions & 8 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
create_udaf, AggregateUDFImpl, Expr, GroupsAccumulator, SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;

Expand Down Expand Up @@ -234,7 +234,7 @@ async fn simple_udaf() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down Expand Up @@ -262,7 +262,7 @@ async fn deregister_udaf() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down Expand Up @@ -290,7 +290,7 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down Expand Up @@ -333,7 +333,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
vec![DataType::Float64],
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
Arc::new(|_, _, _| Ok(Box::<AvgAccumulator>::default())),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
)
.with_aliases(vec!["dummy_alias"]);
Expand Down Expand Up @@ -497,7 +497,7 @@ impl TimeSum {

let captured_state = Arc::clone(&test_state);
let accumulator: AccumulatorFactoryFunction =
Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));
Arc::new(move |_, _, _| Ok(Box::new(Self::new(Arc::clone(&captured_state)))));

let time_sum = AggregateUDF::from(SimpleAggregateUDF::new(
name,
Expand Down Expand Up @@ -596,7 +596,7 @@ impl FirstSelector {
let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];

let accumulator: AccumulatorFactoryFunction =
Arc::new(|_| Ok(Box::new(Self::new())));
Arc::new(|_, _, _| Ok(Box::new(Self::new())));

let volatility = Volatility::Immutable;

Expand Down Expand Up @@ -717,7 +717,14 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
Ok(DataType::UInt64)
}

fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
fn accumulator(
&self,
_arg: &DataType,
_sort_exprs: &[Expr],
_schema: &Schema,
_ignore_nulls: bool,
_requirement_satisfied: bool,
) -> Result<Box<dyn Accumulator>> {
// should use groups accumulator
panic!("accumulator shouldn't invoke");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ async fn udaf_as_window_func() -> Result<()> {
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(MyAccumulator))),
Arc::new(|_, _, _| Ok(Box::new(MyAccumulator))),
Arc::new(vec![DataType::Int32]),
);

Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,14 +577,15 @@ impl AggregateFunction {
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
null_treatment: Option<NullTreatment>,
) -> Self {
Self {
func_def: AggregateFunctionDefinition::UDF(udf),
args,
distinct,
filter,
order_by,
null_treatment: None,
null_treatment,
}
}
}
Expand Down
Loading
Loading