Skip to content

Commit

Permalink
Add PartitionEvaluatorArgs to WindowUDFImpl::partition_evaluator (#…
Browse files Browse the repository at this point in the history
…12804)

* Patched from `lead-lag` conversion tree

* Fixes unit tests in `row_number` udwf

* Add doc comments

* Updates doc comment

* Updates API to expose `input_exprs` directly

* Updates API to returns data types of input expressions
  • Loading branch information
jcsherin authored Oct 9, 2024
1 parent 3353c06 commit 30de35e
Show file tree
Hide file tree
Showing 15 changed files with 191 additions and 27 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ dashmap = { workspace = true }
datafusion = { workspace = true, default-features = true, features = ["avro"] }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
datafusion-functions-window-common = { workspace = true }
datafusion-optimizer = { workspace = true, default-features = true }
datafusion-physical-expr = { workspace = true, default-features = true }
datafusion-proto = { workspace = true }
Expand Down
6 changes: 5 additions & 1 deletion datafusion-examples/examples/advanced_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use datafusion_expr::function::WindowUDFFieldArgs;
use datafusion_expr::{
PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl,
};
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;

/// This example shows how to use the full WindowUDFImpl API to implement a user
/// defined window function. As in the `simple_udwf.rs` example, this struct implements
Expand Down Expand Up @@ -74,7 +75,10 @@ impl WindowUDFImpl for SmoothItUdf {

/// Create a `PartitionEvaluator` to evaluate this function on a new
/// partition.
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
fn partition_evaluator(
&self,
_partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(MyPartitionEvaluator::new()))
}

Expand Down
6 changes: 5 additions & 1 deletion datafusion-examples/examples/simplify_udwf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use datafusion_expr::{
expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature,
Volatility, WindowUDF, WindowUDFImpl,
};
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;

/// This UDWF will show how to use the WindowUDFImpl::simplify() API
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -60,7 +61,10 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
&self.signature
}

fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
fn partition_evaluator(
&self,
_partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
todo!()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use datafusion_expr::{
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;

/// A query with a window function evaluated over the entire partition
const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \
Expand Down Expand Up @@ -552,7 +553,10 @@ impl OddCounter {
&self.signature
}

fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
fn partition_evaluator(
&self,
_partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state))))
}

Expand Down
10 changes: 7 additions & 3 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ use crate::function::{
};
use crate::{
conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF,
Signature, Volatility,
AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator,
ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
};
use crate::{
AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
Expand All @@ -39,6 +39,7 @@ use arrow::compute::kernels::cast_utils::{
use arrow::datatypes::{DataType, Field};
use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::Debug;
Expand Down Expand Up @@ -658,7 +659,10 @@ impl WindowUDFImpl for SimpleWindowUDF {
&self.signature
}

fn partition_evaluator(&self) -> Result<Box<dyn crate::PartitionEvaluator>> {
fn partition_evaluator(
&self,
_partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
(self.partition_evaluator_factory)()
}

Expand Down
46 changes: 34 additions & 12 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ use std::{

use arrow::datatypes::{DataType, Field};

use datafusion_common::{not_impl_err, Result};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;

use crate::expr::WindowFunction;
use crate::{
function::WindowFunctionSimplification, Documentation, Expr, PartitionEvaluator,
Signature,
};
use datafusion_common::{not_impl_err, Result};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;

/// Logical representation of a user-defined window function (UDWF)
/// A UDWF is different from a UDF in that it is stateful across batches.
Expand Down Expand Up @@ -150,8 +150,11 @@ impl WindowUDF {
}

/// Return a `PartitionEvaluator` for evaluating this window function
pub fn partition_evaluator_factory(&self) -> Result<Box<dyn PartitionEvaluator>> {
self.inner.partition_evaluator()
pub fn partition_evaluator_factory(
&self,
partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
self.inner.partition_evaluator(partition_evaluator_args)
}

/// Returns the field of the final result of evaluating this window function.
Expand Down Expand Up @@ -218,8 +221,9 @@ where
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt, Documentation};
/// # use datafusion_expr::{WindowUDFImpl, WindowUDF};
/// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs;
/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
/// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL;
///
/// #[derive(Debug, Clone)]
/// struct SmoothIt {
Expand Down Expand Up @@ -254,7 +258,12 @@ where
/// fn name(&self) -> &str { "smooth_it" }
/// fn signature(&self) -> &Signature { &self.signature }
/// // The actual implementation would smooth the window
/// fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> { unimplemented!() }
/// fn partition_evaluator(
/// &self,
/// _partition_evaluator_args: PartitionEvaluatorArgs,
/// ) -> Result<Box<dyn PartitionEvaluator>> {
/// unimplemented!()
/// }
/// fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
/// if let Some(DataType::Int32) = field_args.get_input_type(0) {
/// Ok(Field::new(field_args.name(), DataType::Int32, false))
Expand Down Expand Up @@ -294,7 +303,10 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
fn signature(&self) -> &Signature;

/// Invoke the function, returning the [`PartitionEvaluator`] instance
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>>;
fn partition_evaluator(
&self,
partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>>;

/// Returns any aliases (alternate names) for this function.
///
Expand Down Expand Up @@ -468,8 +480,11 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
self.inner.signature()
}

fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
self.inner.partition_evaluator()
fn partition_evaluator(
&self,
partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
self.inner.partition_evaluator(partition_evaluator_args)
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -550,6 +565,7 @@ mod test {
use datafusion_common::Result;
use datafusion_expr_common::signature::{Signature, Volatility};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
use std::any::Any;
use std::cmp::Ordering;

Expand Down Expand Up @@ -581,7 +597,10 @@ mod test {
fn signature(&self) -> &Signature {
&self.signature
}
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
fn partition_evaluator(
&self,
_partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
unimplemented!()
}
fn field(&self, _field_args: WindowUDFFieldArgs) -> Result<Field> {
Expand Down Expand Up @@ -617,7 +636,10 @@ mod test {
fn signature(&self) -> &Signature {
&self.signature
}
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
fn partition_evaluator(
&self,
_partition_evaluator_args: PartitionEvaluatorArgs,
) -> Result<Box<dyn PartitionEvaluator>> {
unimplemented!()
}
fn field(&self, _field_args: WindowUDFFieldArgs) -> Result<Field> {
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-window-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ path = "src/lib.rs"

[dependencies]
datafusion-common = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
1 change: 1 addition & 0 deletions datafusion/functions-window-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
//!
//! [DataFusion]: <https://crates.io/crates/datafusion>
pub mod field;
pub mod partition;
89 changes: 89 additions & 0 deletions datafusion/functions-window-common/src/partition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion_common::arrow::datatypes::DataType;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::sync::Arc;

/// Arguments passed to created user-defined window function state
/// during physical execution.
#[derive(Debug, Default)]
pub struct PartitionEvaluatorArgs<'a> {
/// The expressions passed as arguments to the user-defined window
/// function.
input_exprs: &'a [Arc<dyn PhysicalExpr>],
/// The corresponding data types of expressions passed as arguments
/// to the user-defined window function.
input_types: &'a [DataType],
/// Set to `true` if the user-defined window function is reversed.
is_reversed: bool,
/// Set to `true` if `IGNORE NULLS` is specified.
ignore_nulls: bool,
}

impl<'a> PartitionEvaluatorArgs<'a> {
/// Create an instance of [`PartitionEvaluatorArgs`].
///
/// # Arguments
///
/// * `input_exprs` - The expressions passed as arguments
/// to the user-defined window function.
/// * `input_types` - The data types corresponding to the
/// arguments to the user-defined window function.
/// * `is_reversed` - Set to `true` if and only if the user-defined
/// window function is reversible and is reversed.
/// * `ignore_nulls` - Set to `true` when `IGNORE NULLS` is
/// specified.
///
pub fn new(
input_exprs: &'a [Arc<dyn PhysicalExpr>],
input_types: &'a [DataType],
is_reversed: bool,
ignore_nulls: bool,
) -> Self {
Self {
input_exprs,
input_types,
is_reversed,
ignore_nulls,
}
}

/// Returns the expressions passed as arguments to the user-defined
/// window function.
pub fn input_exprs(&self) -> &'a [Arc<dyn PhysicalExpr>] {
self.input_exprs
}

/// Returns the [`DataType`]s corresponding to the input expressions
/// to the user-defined window function.
pub fn input_types(&self) -> &'a [DataType] {
self.input_types
}

/// Returns `true` when the user-defined window function is
/// reversed, otherwise returns `false`.
pub fn is_reversed(&self) -> bool {
self.is_reversed
}

/// Returns `true` when `IGNORE NULLS` is specified, otherwise
/// returns `false`.
pub fn ignore_nulls(&self) -> bool {
self.ignore_nulls
}
}
Loading

0 comments on commit 30de35e

Please sign in to comment.