Skip to content

Commit

Permalink
Introduce Signature::String and return error if input of strpos is …
Browse files Browse the repository at this point in the history
…integer (#12751)

* fix sig

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix error

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix all signature

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix all signature

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* change default type

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* clippy

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix docs

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm deadcode

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* cleanup

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm test

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 authored Oct 8, 2024
1 parent d8405ba commit 47664df
Show file tree
Hide file tree
Showing 31 changed files with 184 additions and 384 deletions.
8 changes: 4 additions & 4 deletions datafusion/core/tests/expr_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ mod simplification;
fn test_octet_length() {
#[rustfmt::skip]
evaluate_expr_test(
octet_length(col("list")),
octet_length(col("id")),
vec![
"+------+",
"| expr |",
"+------+",
"| 5 |",
"| 18 |",
"| 6 |",
"| 1 |",
"| 1 |",
"| 1 |",
"+------+",
],
);
Expand Down
18 changes: 17 additions & 1 deletion datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ pub enum TypeSignature {
/// Fixed number of arguments of numeric types.
/// See <https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric> to know which type is considered numeric
Numeric(usize),
/// Fixed number of arguments of all the same string types.
/// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8.
/// Null is considerd as Utf8 by default
/// Dictionary with string value type is also handled.
String(usize),
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
Expand Down Expand Up @@ -190,8 +195,11 @@ impl TypeSignature {
.collect::<Vec<String>>()
.join(", ")]
}
TypeSignature::String(num) => {
vec![format!("String({num})")]
}
TypeSignature::Numeric(num) => {
vec![format!("Numeric({})", num)]
vec![format!("Numeric({num})")]
}
TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
vec![Self::join_types(types, ", ")]
Expand Down Expand Up @@ -280,6 +288,14 @@ impl Signature {
}
}

/// A specified number of numeric arguments
pub fn string(arg_count: usize, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::String(arg_count),
volatility,
}
}

/// An arbitrary number of arguments of any type.
pub fn variadic_any(volatility: Volatility) -> Self {
Self {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ fn string_concat_internal_coercion(
/// based on the observation that StringArray to StringViewArray is cheap but not vice versa.
///
/// Between Utf8 and LargeUtf8, we coerce to LargeUtf8.
fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
// If Utf8View is in any side, we coerce to Utf8View.
Expand Down
67 changes: 65 additions & 2 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ use datafusion_common::{
utils::{coerced_fixed_size_list_to_list, list_ndims},
Result,
};
use datafusion_expr_common::signature::{
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
use datafusion_expr_common::{
signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
type_coercion::binary::string_coercion,
};
use std::sync::Arc;

Expand Down Expand Up @@ -176,6 +177,7 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
type_signature,
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::String(_)
| TypeSignature::Coercible(_)
| TypeSignature::Any(_)
)
Expand Down Expand Up @@ -381,6 +383,67 @@ fn get_valid_types(
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::String(number) => {
if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if *number != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
number,
current_types.len()
);
}

fn coercion_rule(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Result<DataType> {
match (lhs_type, rhs_type) {
(DataType::Null, DataType::Null) => Ok(DataType::Utf8),
(DataType::Null, data_type) | (data_type, DataType::Null) => {
coercion_rule(data_type, &DataType::Utf8)
}
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
coercion_rule(lhs, rhs)
}
(DataType::Dictionary(_, v), other)
| (other, DataType::Dictionary(_, v)) => coercion_rule(v, other),
_ => {
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
Ok(coerced_type)
} else {
plan_err!(
"{} and {} are not coercible to a common string type",
lhs_type,
rhs_type
)
}
}
}
}

// Length checked above, safe to unwrap
let mut coerced_type = current_types.first().unwrap().to_owned();
for t in current_types.iter().skip(1) {
coerced_type = coercion_rule(&coerced_type, t)?;
}

fn base_type_or_default_type(data_type: &DataType) -> DataType {
if data_type.is_null() {
DataType::Utf8
} else if let DataType::Dictionary(_, v) = data_type {
base_type_or_default_type(v)
} else {
data_type.to_owned()
}
}

vec![vec![base_type_or_default_type(&coerced_type); *number]]
}
TypeSignature::Numeric(number) => {
if *number < 1 {
return plan_err!(
Expand Down
6 changes: 3 additions & 3 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ macro_rules! make_math_binary_udf {
use arrow::datatypes::DataType;
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::TypeSignature;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

#[derive(Debug)]
Expand All @@ -298,8 +298,8 @@ macro_rules! make_math_binary_udf {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Float32, Float32]),
Exact(vec![Float64, Float64]),
TypeSignature::Exact(vec![Float32, Float32]),
TypeSignature::Exact(vec![Float64, Float64]),
],
Volatility::Immutable,
),
Expand Down
8 changes: 5 additions & 3 deletions datafusion/functions/src/math/nans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@

use arrow::datatypes::DataType;
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ColumnarValue, TypeSignature};

use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;
Expand All @@ -43,7 +42,10 @@ impl IsNanFunc {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Float32]), Exact(vec![Float64])],
vec![
TypeSignature::Exact(vec![Float32]),
TypeSignature::Exact(vec![Float64]),
],
Volatility::Immutable,
),
}
Expand Down
8 changes: 5 additions & 3 deletions datafusion/functions/src/math/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ use datafusion_common::{
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{ColumnarValue, Expr, ScalarUDF};
use datafusion_expr::{ColumnarValue, Expr, ScalarUDF, TypeSignature};

use arrow::array::{ArrayRef, Float64Array, Int64Array};
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;
Expand All @@ -52,7 +51,10 @@ impl PowerFunc {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])],
vec![
TypeSignature::Exact(vec![Int64, Int64]),
TypeSignature::Exact(vec![Float64, Float64]),
],
Volatility::Immutable,
),
aliases: vec![String::from("pow")],
Expand Down
11 changes: 5 additions & 6 deletions datafusion/functions/src/regex/regexplike.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ use datafusion_common::{
cast::as_generic_string_array, internal_err, DataFusionError, Result,
};
use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX;
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Documentation};
use datafusion_expr::{ColumnarValue, Documentation, TypeSignature};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::{Arc, OnceLock};
Expand Down Expand Up @@ -87,10 +86,10 @@ impl RegexpLikeFunc {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
Exact(vec![Utf8, Utf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
TypeSignature::Exact(vec![Utf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
],
Volatility::Immutable,
),
Expand Down
11 changes: 5 additions & 6 deletions datafusion/functions/src/regex/regexpmatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ use datafusion_common::{arrow_datafusion_err, plan_err};
use datafusion_common::{
cast::as_generic_string_array, internal_err, DataFusionError, Result,
};
use datafusion_expr::ColumnarValue;
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, TypeSignature};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;
Expand All @@ -53,10 +52,10 @@ impl RegexpMatchFunc {
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`.
// If that fails, it proceeds to `(LargeUtf8, Utf8)`.
// TODO: Native support Utf8View for regexp_match.
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
Exact(vec![Utf8, Utf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
TypeSignature::Exact(vec![Utf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
],
Volatility::Immutable,
),
Expand Down
10 changes: 5 additions & 5 deletions datafusion/functions/src/regex/regexpreplace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use datafusion_common::{
};
use datafusion_expr::function::Hint;
use datafusion_expr::ColumnarValue;
use datafusion_expr::TypeSignature::*;
use datafusion_expr::TypeSignature;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use regex::Regex;
use std::any::Any;
Expand All @@ -56,10 +56,10 @@ impl RegexpReplaceFunc {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8, Utf8]),
Exact(vec![Utf8View, Utf8, Utf8]),
Exact(vec![Utf8, Utf8, Utf8, Utf8]),
Exact(vec![Utf8View, Utf8, Utf8, Utf8]),
TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]),
TypeSignature::Exact(vec![Utf8, Utf8, Utf8, Utf8]),
TypeSignature::Exact(vec![Utf8View, Utf8, Utf8, Utf8]),
],
Volatility::Immutable,
),
Expand Down
7 changes: 1 addition & 6 deletions datafusion/functions/src/string/ascii.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,8 @@ impl Default for AsciiFunc {

impl AsciiFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Utf8View],
Volatility::Immutable,
),
signature: Signature::string(1, Volatility::Immutable),
}
}
}
Expand Down
7 changes: 1 addition & 6 deletions datafusion/functions/src/string/bit_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,8 @@ impl Default for BitLengthFunc {

impl BitLengthFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8],
Volatility::Immutable,
),
signature: Signature::string(1, Volatility::Immutable),
}
}
}
Expand Down
17 changes: 4 additions & 13 deletions datafusion/functions/src/string/btrim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ use arrow::datatypes::DataType;
use datafusion_common::{exec_err, Result};
use datafusion_expr::function::Hint;
use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING;
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use std::any::Any;
use std::sync::OnceLock;

Expand All @@ -49,18 +49,9 @@ impl Default for BTrimFunc {

impl BTrimFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
// Planner attempts coercion to the target type starting with the most preferred candidate.
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`.
// If that fails, it proceeds to `(Utf8, Utf8)`.
Exact(vec![Utf8View, Utf8View]),
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8View]),
Exact(vec![Utf8]),
],
vec![TypeSignature::String(2), TypeSignature::String(1)],
Volatility::Immutable,
),
aliases: vec![String::from("trim")],
Expand Down
Loading

0 comments on commit 47664df

Please sign in to comment.