Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ScalarUDF with zero arguments should be provided with one null array as parameter #9031

Merged
merged 6 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,7 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1336,6 +1337,7 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 3))),
Expand Down Expand Up @@ -1405,6 +1407,7 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1471,6 +1474,7 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d_new", 3))),
Expand Down
114 changes: 108 additions & 6 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@
// under the License.

use arrow::compute::kernels::numeric::add;
use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch};
use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, Volatility,
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, ScalarUDF,
ScalarUDFImpl, Signature, Volatility,
};
use rand::{thread_rng, Rng};
use std::iter;
use std::sync::Arc;

/// test that casting happens on udfs.
Expand Down Expand Up @@ -166,10 +170,7 @@ async fn scalar_udf_zero_params() -> Result<()> {

ctx.register_batch("t", batch)?;
// create function just returns 100 regardless of inp
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Scalar(_) = &args[0] else {
panic!("expect scalar")
};
let myfunc = Arc::new(|_args: &[ColumnarValue]| {
Ok(ColumnarValue::Array(
Arc::new((0..1).map(|_| 100).collect::<Int32Array>()) as ArrayRef,
))
Expand Down Expand Up @@ -392,6 +393,107 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
Ok(())
}

#[derive(Debug)]
pub struct RandomUDF {
signature: Signature,
}

impl RandomUDF {
pub fn new() -> Self {
Self {
signature: Signature::any(0, Volatility::Volatile),
}
}
}

impl ScalarUDFImpl for RandomUDF {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"random_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let len: usize = match &args[0] {
// This udf is always invoked with zero argument so its argument
// is a null array indicating the batch size.
ColumnarValue::Array(array) if array.data_type().is_null() => array.len(),
_ => {
return Err(datafusion::error::DataFusionError::Internal(
"Invalid argument type".to_string(),
))
}
};
let mut rng = thread_rng();
let values = iter::repeat_with(|| rng.gen_range(0.1..1.0)).take(len);
let array = Float64Array::from_iter_values(values);
Ok(ColumnarValue::Array(Arc::new(array)))
}
}

/// Ensure that a user defined function with zero argument will be invoked
/// with a null array indicating the batch size.
#[tokio::test]
alamb marked this conversation as resolved.
Show resolved Hide resolved
async fn test_user_defined_functions_zero_argument() -> Result<()> {
let ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![Field::new(
"index",
DataType::UInt8,
false,
)]));

let batch = RecordBatch::try_new(
schema,
vec![Arc::new(UInt8Array::from_iter_values([1, 2, 3]))],
)?;

ctx.register_batch("data_table", batch)?;

let random_normal_udf = ScalarUDF::from(RandomUDF::new());
ctx.register_udf(random_normal_udf);

let result = plan_and_collect(
&ctx,
"SELECT random_udf() AS random_udf, random() AS native_random FROM data_table",
)
.await?;

assert_eq!(result.len(), 1);
let batch = &result[0];
let random_udf = batch
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let native_random = batch
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();

assert_eq!(random_udf.len(), native_random.len());

let mut previous = -1.0;
for i in 0..random_udf.len() {
assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0);
assert!(random_udf.value(i) != previous);
previous = random_udf.value(i);
}

Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ pub fn create_physical_expr(
input_phy_exprs.to_vec(),
data_type,
monotonicity,
fun.signature().type_signature.supports_zero_argument(),
)))
}

Expand Down
18 changes: 6 additions & 12 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ pub fn create_physical_expr(
}

Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
let mut physical_args = args
let physical_args = args
.iter()
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
.collect::<Result<Vec<_>>>()?;
Expand All @@ -272,17 +272,11 @@ pub fn create_physical_expr(
execution_props,
)
}
ScalarFunctionDefinition::UDF(fun) => {
// udfs with zero params expect null array as input
if args.is_empty() {
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
}
udf::create_physical_expr(
fun.clone().as_ref(),
&physical_args,
input_schema,
)
}
ScalarFunctionDefinition::UDF(fun) => udf::create_physical_expr(
fun.clone().as_ref(),
&physical_args,
input_schema,
),
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
Expand Down
17 changes: 15 additions & 2 deletions datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub struct ScalarFunctionExpr {
// and it specifies the effect of an increase or decrease in
// the corresponding `arg` to the function value.
monotonicity: Option<FuncMonotonicity>,
// Whether this function can be invoked with zero arguments
supports_zero_argument: bool,
}

impl Debug for ScalarFunctionExpr {
Expand All @@ -79,13 +81,15 @@ impl ScalarFunctionExpr {
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
monotonicity: Option<FuncMonotonicity>,
supports_zero_argument: bool,
) -> Self {
Self {
fun,
name: name.to_owned(),
args,
return_type,
monotonicity,
supports_zero_argument,
}
}

Expand Down Expand Up @@ -138,9 +142,12 @@ impl PhysicalExpr for ScalarFunctionExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
// evaluate the arguments, if there are no arguments we'll instead pass in a null array
// indicating the batch size (as a convention)
let inputs = match (self.args.len(), self.name.parse::<BuiltinScalarFunction>()) {
let inputs = match (
self.args.is_empty(),
self.name.parse::<BuiltinScalarFunction>(),
) {
// MakeArray support zero argument but has the different behavior from the array with one null.
(0, Ok(scalar_fun))
(true, Ok(scalar_fun))
if scalar_fun
.signature()
.type_signature
Expand All @@ -149,6 +156,11 @@ impl PhysicalExpr for ScalarFunctionExpr {
{
vec![ColumnarValue::create_null_array(batch.num_rows())]
}
// If the function supports zero argument, we pass in a null array indicating the batch size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never fully understood why this didn't just check self.args.is_empty() 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Changed to self.args.is_empty().

// This is for user-defined functions.
(true, Err(_)) if self.supports_zero_argument => {
vec![ColumnarValue::create_null_array(batch.num_rows())]
}
_ => self
.args
.iter()
Expand All @@ -175,6 +187,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
children,
self.return_type().clone(),
self.monotonicity.clone(),
self.supports_zero_argument,
)))
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub fn create_physical_expr(
input_phy_exprs.to_vec(),
fun.return_type(&input_exprs_types)?,
fun.monotonicity()?,
fun.signature().type_signature.supports_zero_argument(),
)))
}

Expand Down
19 changes: 8 additions & 11 deletions datafusion/proto/src/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,21 +340,17 @@ pub fn parse_physical_expr(
// TODO Do not create new the ExecutionProps
let execution_props = ExecutionProps::new();

let fun_expr = functions::create_physical_fun(
functions::create_physical_expr(
&(&scalar_function).into(),
&args,
input_schema,
&execution_props,
)?;

Arc::new(ScalarFunctionExpr::new(
&e.name,
fun_expr,
args,
convert_required!(e.return_type)?,
None,
))
)?
}
ExprType::ScalarUdf(e) => {
let scalar_fun = registry.udf(e.name.as_str())?.fun().clone();
let udf = registry.udf(e.name.as_str())?;
let signature = udf.signature();
let scalar_fun = udf.fun().clone();

let args = e
.args
Expand All @@ -368,6 +364,7 @@ pub fn parse_physical_expr(
args,
convert_required!(e.return_type)?,
None,
signature.type_signature.supports_zero_argument(),
))
}
ExprType::LikeExpr(like_expr) => Arc::new(LikeExpr::new(
Expand Down
4 changes: 3 additions & 1 deletion datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,9 @@ fn roundtrip_builtin_scalar_function() -> Result<()> {
"acos",
fun_expr,
vec![col("a", &schema)?],
DataType::Int64,
DataType::Float64,
Copy link
Member Author

@viirya viirya Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing test is not correct at all. acos built-in scalar function's return type should be Float64.

Previously the roundtrip test passes because from_proto simply takes serde return type and uses it as parameter to ScalarFunctionExpr.

But in this PR, from_proto calls create_physical_expr which gets return type directly from BuiltinScalarFunction. So with the PR, this test issue is found.

None,
false,
);

let project =
Expand Down Expand Up @@ -617,6 +618,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
vec![col("a", &schema)?],
DataType::Int64,
None,
false,
);

let project =
Expand Down
Loading