Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add helper function for processing scalar function input #8962

Merged
merged 5 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions datafusion-examples/examples/simple_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::cast::as_float64_array;
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr::functions::columnar_values_to_array;
use std::sync::Arc;

/// create local execution context with an in-memory table:
Expand Down Expand Up @@ -70,22 +71,11 @@ async fn main() -> Result<()> {
// this is guaranteed by DataFusion based on the function's signature.
assert_eq!(args.len(), 2);

// Try to obtain row number
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let inferred_length = len.unwrap_or(1);

let arg0 = args[0].clone().into_array(inferred_length)?;
let arg1 = args[1].clone().into_array(inferred_length)?;
let args = columnar_values_to_array(args)?;

// 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics!
let base = as_float64_array(&arg0).expect("cast failed");
let exponent = as_float64_array(&arg1).expect("cast failed");
let base = as_float64_array(&args[0]).expect("cast failed");
let exponent = as_float64_array(&args[1]).expect("cast failed");

// this is guaranteed by DataFusion. We place it just to make it obvious.
assert_eq!(exponent.len(), base.len());
Expand Down
27 changes: 4 additions & 23 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,7 @@ mod tests {
use datafusion_physical_expr::execution_props::ExecutionProps;

use chrono::{DateTime, TimeZone, Utc};
use datafusion_physical_expr::functions::columnar_values_to_array;

// ------------------------------
// --- ExprSimplifier tests -----
Expand Down Expand Up @@ -1437,30 +1438,10 @@ mod tests {
let return_type = Arc::new(DataType::Int32);

let fun = Arc::new(|args: &[ColumnarValue]| {
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let inferred_length = len.unwrap_or(1);

let arg0 = match &args[0] {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.to_array_of_size(inferred_length).unwrap()
}
};
let arg1 = match &args[1] {
ColumnarValue::Array(array) => array.clone(),
ColumnarValue::Scalar(scalar) => {
scalar.to_array_of_size(inferred_length).unwrap()
}
};
let args = columnar_values_to_array(args)?;

let arg0 = as_int32_array(&arg0)?;
let arg1 = as_int32_array(&arg1)?;
let arg0 = as_int32_array(&args[0])?;
let arg1 = as_int32_array(&args[1])?;

// 2. perform the computation
let array = arg0
Expand Down
54 changes: 50 additions & 4 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use arrow::{
compute::kernels::length::{bit_length, length},
datatypes::{DataType, Int32Type, Int64Type, Schema},
};
use arrow_array::Array;
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
pub use datafusion_expr::FuncMonotonicity;
use datafusion_expr::{
Expand Down Expand Up @@ -191,6 +192,51 @@ pub(crate) enum Hint {
AcceptsSingular,
}

/// A helper function used to infer the length of arguments of Scalar functions and convert
/// [`ColumnarValue`]s to [`ArrayRef`]s with the inferred length. Note that this function
/// only works for functions that accept either that all arguments are scalars or all arguments
/// are arrays with same length. Otherwise, it will return an error.
pub fn columnar_values_to_array(args: &[ColumnarValue]) -> Result<Vec<ArrayRef>> {
if args.is_empty() {
return Ok(vec![]);
}

let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) if acc.is_none() => Some(1),
ColumnarValue::Scalar(_) => {
if let Some(1) = acc {
acc
} else {
None
}
}
ColumnarValue::Array(a) => {
if let Some(l) = acc {
if l == a.len() {
acc
} else {
None
}
} else {
Some(a.len())
}
}
});
viirya marked this conversation as resolved.
Show resolved Hide resolved

let inferred_length = len.ok_or(DataFusionError::Internal(
alamb marked this conversation as resolved.
Show resolved Hide resolved
"Arguments has mixed length".to_string(),
))?;

let args = args
.iter()
.map(|arg| arg.clone().into_array(inferred_length))
.collect::<Result<Vec<_>>>()?;

Ok(args)
}

/// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function
/// and vice-versa after evaluation.
/// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar.
Expand Down Expand Up @@ -559,10 +605,10 @@ pub fn create_physical_fun(
}),
BuiltinScalarFunction::InStr => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::instr::<i32>)(args)
make_scalar_function_inner(string_expressions::instr::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::instr::<i64>)(args)
make_scalar_function_inner(string_expressions::instr::<i64>)(args)
}
other => internal_err!("Unsupported data type {other:?} for function instr"),
}),
Expand Down Expand Up @@ -790,10 +836,10 @@ pub fn create_physical_fun(
}),
BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() {
DataType::Utf8 => {
make_scalar_function(string_expressions::ends_with::<i32>)(args)
make_scalar_function_inner(string_expressions::ends_with::<i32>)(args)
}
DataType::LargeUtf8 => {
make_scalar_function(string_expressions::ends_with::<i64>)(args)
make_scalar_function_inner(string_expressions::ends_with::<i64>)(args)
}
other => {
internal_err!("Unsupported data type {other:?} for function ends_with")
Expand Down
11 changes: 5 additions & 6 deletions docs/source/library-user-guide/adding-udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ use std::sync::Arc;

use datafusion::arrow::array::{ArrayRef, Int64Array};
use datafusion::common::Result;

use datafusion::common::cast::as_int64_array;
use datafusion::physical_plan::functions::columnar_values_to_array;

pub fn add_one(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn add_one(args: &[ColumnarValue]) -> Result<ArrayRef> {
// Error handling omitted for brevity

let args = columnar_values_to_array(args)?;
let i64s = as_int64_array(&args[0])?;

let new_array = i64s
Expand Down Expand Up @@ -82,7 +82,6 @@ There is a lower level API with more functionality but is more complex, that is

```rust
use datafusion::logical_expr::{Volatility, create_udf};
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::arrow::datatypes::DataType;
use std::sync::Arc;

Expand All @@ -91,13 +90,13 @@ let udf = create_udf(
vec![DataType::Int64],
Arc::new(DataType::Int64),
Volatility::Immutable,
make_scalar_function(add_one),
Arc::new(add_one),
);
```

[`scalarudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html
[`create_udf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html
[`make_scalar_function`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html
[`process_scalar_func_inputs`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html
[`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs

A few things to note:
Expand Down
Loading