Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Make fields of ScalarUDF non pub #8039

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
}
}
Expr::ScalarUDF(ScalarUDF { fun, .. }) => {
match fun.signature.volatility {
match fun.signature().volatility {
Volatility::Immutable => VisitRecursion::Continue,
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ impl SessionContext {
self.state
.write()
.scalar_functions
.insert(f.name.clone(), Arc::new(f));
.insert(f.name().to_string(), Arc::new(f));
}

/// Registers an aggregate UDF within this context.
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
create_function_physical_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_physical_name(&fun.name, false, args)
create_function_physical_name(fun.name(), false, args)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
create_function_physical_name(&fun.to_string(), false, args)
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ impl fmt::Display for Expr {
fmt_function(f, &func.fun.to_string(), false, &func.args, true)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
fmt_function(f, &fun.name, false, args, true)
fmt_function(f, fun.name(), false, args, true)
}
Expr::WindowFunction(WindowFunction {
fun,
Expand Down Expand Up @@ -1512,7 +1512,7 @@ fn create_name(e: &Expr) -> Result<String> {
create_function_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_name(&fun.name, false, args)
create_function_name(fun.name(), false, args)
}
Expr::WindowFunction(WindowFunction {
fun,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
Ok(fun.return_type(&data_types)?)
}
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let data_types = args
Expand Down
45 changes: 36 additions & 9 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,31 @@
// specific language governing permissions and limitations
// under the License.

//! Udf module contains foundational types that are used to represent UDFs in DataFusion.
//! [`ScalarUDF`]: Scalar User Defined Functions

use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::fmt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::sync::Arc;

/// Logical representation of a UDF.
/// Logical representation of a Scalar User Defined Function.
///
/// A scalar function produces a single row output for each row of input.
///
/// This struct contains the information DataFusion needs to plan and invoke
/// functions such name, type signature, return type, and actual implementation.
///
#[derive(Clone)]
pub struct ScalarUDF {
/// name
pub name: String,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the key change in this PR is to remove pub

/// signature
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
/// The name of the function
name: String,
/// The signature (the types of arguments that are supported)
signature: Signature,
/// Function that returns the return type given the argument types
return_type: ReturnTypeFunction,
/// actual implementation
///
/// The fn param is the wrapped function but be aware that the function will
Expand All @@ -40,7 +48,7 @@ pub struct ScalarUDF {
/// will be passed. In that case the single element is a null array to indicate
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
pub fun: ScalarFunctionImplementation,
fun: ScalarFunctionImplementation,
}

impl Debug for ScalarUDF {
Expand Down Expand Up @@ -89,4 +97,23 @@ impl ScalarUDF {
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args))
}

/// Returns this function's name
pub fn name(&self) -> &str {
&self.name
}
/// Returns this function's signature
pub fn signature(&self) -> &Signature {
&self.signature
}
/// return the return type of this function given the types of the arguments
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than expose the underlying function pointer directly, I opted to handle the nuance of calling it here.

// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
}
/// return the implementation of this function
pub fn fun(&self) -> &ScalarFunctionImplementation {
&self.fun
}
}
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
fun.signature(),
)?;
Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ impl<'a> ConstEvaluator<'a> {
Self::volatility_ok(fun.volatility())
}
Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => {
Self::volatility_ok(fun.signature.volatility)
Self::volatility_ok(fun.signature().volatility)
}
Expr::Literal(_)
| Expr::BinaryExpr { .. }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ pub fn create_physical_expr(
&format!("{fun}"),
fun_expr,
input_phy_exprs.to_vec(),
&data_type,
data_type,
monotonicity,
)))
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ impl ScalarFunctionExpr {
name: &str,
fun: ScalarFunctionImplementation,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: &DataType,
return_type: DataType,
monotonicity: Option<FuncMonotonicity>,
) -> Self {
Self {
fun,
name: name.to_owned(),
args,
return_type: return_type.clone(),
return_type,
monotonicity,
}
}
Expand Down Expand Up @@ -165,7 +165,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
&self.name,
self.fun.clone(),
children,
self.return_type(),
self.return_type().clone(),
self.monotonicity.clone(),
)))
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ pub fn create_physical_expr(
.collect::<Result<Vec<_>>>()?;

Ok(Arc::new(ScalarFunctionExpr::new(
&fun.name,
fun.fun.clone(),
fun.name(),
fun.fun().clone(),
input_phy_exprs.to_vec(),
(fun.return_type)(&input_exprs_types)?.as_ref(),
fun.return_type(&input_exprs_types)?,
None,
)))
}
2 changes: 1 addition & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => Self {
expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
fun_name: fun.name.clone(),
fun_name: fun.name().to_string(),
args: args
.iter()
.map(|expr| expr.try_into())
Expand Down
7 changes: 3 additions & 4 deletions datafusion/proto/src/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
//! Serde code to convert from protocol buffers to Rust data structures.

use std::convert::{TryFrom, TryInto};
use std::ops::Deref;
use std::sync::Arc;

use arrow::compute::SortOptions;
Expand Down Expand Up @@ -308,12 +307,12 @@ pub fn parse_physical_expr(
&e.name,
fun_expr,
args,
&convert_required!(e.return_type)?,
convert_required!(e.return_type)?,
None,
))
}
ExprType::ScalarUdf(e) => {
let scalar_fun = registry.udf(e.name.as_str())?.deref().clone().fun;
let scalar_fun = registry.udf(e.name.as_str())?.fun().clone();

let args = e
.args
Expand All @@ -325,7 +324,7 @@ pub fn parse_physical_expr(
e.name.as_str(),
scalar_fun,
args,
&convert_required!(e.return_type)?,
convert_required!(e.return_type)?,
None,
))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> {
"acos",
fun_expr,
vec![col("a", &schema)?],
&DataType::Int64,
DataType::Int64,
None,
);

Expand Down Expand Up @@ -549,7 +549,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
"dummy",
scalar_fn,
vec![col("a", &schema)?],
&DataType::Int64,
DataType::Int64,
None,
);

Expand Down