diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index e86d6172cecd..57a5b200b7e8 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -241,6 +241,8 @@ pub enum BuiltinScalarFunction { NullIf, /// octet_length OctetLength, + /// position + Position, /// random Random, /// regexp_replace @@ -460,6 +462,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::MD5 => Volatility::Immutable, BuiltinScalarFunction::NullIf => Volatility::Immutable, BuiltinScalarFunction::OctetLength => Volatility::Immutable, + BuiltinScalarFunction::Position => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, BuiltinScalarFunction::RegexpReplace => Volatility::Immutable, BuiltinScalarFunction::Repeat => Volatility::Immutable, @@ -735,6 +738,9 @@ impl BuiltinScalarFunction { utf8_to_int_type(&input_expr_types[0], "octet_length") } BuiltinScalarFunction::Pi => Ok(Float64), + BuiltinScalarFunction::Position => { + utf8_to_int_type(&input_expr_types[0], "position") + } BuiltinScalarFunction::Random => Ok(Float64), BuiltinScalarFunction::Uuid => Ok(Utf8), BuiltinScalarFunction::RegexpReplace => { @@ -1225,7 +1231,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::InStr | BuiltinScalarFunction::Strpos - | BuiltinScalarFunction::StartsWith => Signature::one_of( + | BuiltinScalarFunction::StartsWith + | BuiltinScalarFunction::Position => Signature::one_of( vec![ Exact(vec![Utf8, Utf8]), Exact(vec![Utf8, LargeUtf8]), @@ -1498,6 +1505,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Reverse => &["reverse"], BuiltinScalarFunction::Right => &["right"], BuiltinScalarFunction::Rpad => &["rpad"], + BuiltinScalarFunction::Position => &["position"], BuiltinScalarFunction::Rtrim => &["rtrim"], BuiltinScalarFunction::SplitPart => &["split_part"], BuiltinScalarFunction::StringToArray => { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 006b5f10f10d..9294d6e4528d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -814,6 +814,12 @@ scalar_expr!( string, "returns the number of bytes of a string" ); +scalar_expr!( + Position, + position, + substring string, + "return the position of the appearence of `substring` in `string`" +); scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `from` with `to` in the `string`"); scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times"); scalar_expr!(Reverse, reverse, string, "reverses the `string`"); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 2bfdf499123b..ad455531c6de 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -938,6 +938,17 @@ pub fn create_physical_fun( "Unsupported data type {other:?} for function overlay", ))), }), + BuiltinScalarFunction::Position => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function_inner(string_expressions::position::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function_inner(string_expressions::position::)(args) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function position" + ))), + }), BuiltinScalarFunction::Levenshtein => { Arc::new(|args| match args[0].data_type() { DataType::Utf8 => make_scalar_function_inner( diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index d5344773cfbc..be8d769283de 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -570,6 +570,48 @@ pub fn uuid(args: &[ColumnarValue]) -> Result { let array = GenericStringArray::::from_iter_values(values); Ok(ColumnarValue::Array(Arc::new(array))) } +/// position function, similar logic as instr +/// position('world' in 'Helloworld') = 6 +pub fn position(args: &[ArrayRef]) -> Result { + let substr_arr = as_generic_string_array::(&args[0])?; + let str_arr = as_generic_string_array::(&args[1])?; + + match args[0].data_type() { + DataType::Utf8 => { + let result = str_arr + .iter() + .zip(substr_arr.iter()) + .map(|(string, substr)| match (string, substr) { + (Some(string), Some(substr)) => string + .find(substr) + .map_or(Some(0), |index| Some((index + 1) as i32)), + _ => None, + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) + } + DataType::LargeUtf8 => { + let result = str_arr + .iter() + .zip(substr_arr.iter()) + .map(|(string, substr)| match (string, substr) { + (Some(string), Some(substr)) => string + .find(substr) + .map_or(Some(0), |index| Some((index + 1) as i64)), + _ => None, + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "position was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + ) + } + } +} /// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) /// Replaces a substring of string1 with string2 starting at the integer bit @@ -787,4 +829,15 @@ mod tests { Ok(()) } + #[test] + fn to_position() -> Result<()> { + let substr_arr = Arc::new(StringArray::from(vec!["world"])); + let str_arr = Arc::new(StringArray::from(vec!["Hello, world"])); + let res = position::(&[substr_arr, str_arr]).unwrap(); + let result = + as_int32_array(&res).expect("failed to initialized function position"); + let expected = Int32Array::from(vec![8]); + assert_eq!(&expected, result); + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 66c1271e65c1..54fd54a01d9a 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -671,6 +671,7 @@ enum ScalarFunction { ArrayResize = 130; EndsWith = 131; InStr = 132; + Position = 133; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 39a8678ef250..b27f2c8c0f49 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22425,6 +22425,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayResize => "ArrayResize", Self::EndsWith => "EndsWith", Self::InStr => "InStr", + Self::Position => "Position", }; serializer.serialize_str(variant) } @@ -22569,6 +22570,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayResize", "EndsWith", "InStr", + "Position", ]; struct GeneratedVisitor; @@ -22742,6 +22744,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayResize" => Ok(ScalarFunction::ArrayResize), "EndsWith" => Ok(ScalarFunction::EndsWith), "InStr" => Ok(ScalarFunction::InStr), + "Position" => Ok(ScalarFunction::Position), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 7bf1d8ed0450..7832f47b6e10 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2766,6 +2766,7 @@ pub enum ScalarFunction { ArrayResize = 130, EndsWith = 131, InStr = 132, + Position = 133, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2907,6 +2908,7 @@ impl ScalarFunction { ScalarFunction::ArrayResize => "ArrayResize", ScalarFunction::EndsWith => "EndsWith", ScalarFunction::InStr => "InStr", + ScalarFunction::Position => "Position", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3045,6 +3047,7 @@ impl ScalarFunction { "ArrayResize" => Some(Self::ArrayResize), "EndsWith" => Some(Self::EndsWith), "InStr" => Some(Self::InStr), + "Position" => Some(Self::Position), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 42d39b5c5139..7c5851b84abe 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -61,13 +61,13 @@ use datafusion_expr::{ factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, initcap, instr, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, - radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, - round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, - sqrt, starts_with, string_to_array, strpos, struct_fun, substr, substr_index, - substring, tan, tanh, to_hex, translate, trim, trunc, upper, uuid, AggregateFunction, - Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, - GetFieldAccess, GetIndexedField, GroupingSet, + lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, position, + power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, + right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, + split_part, sqrt, starts_with, string_to_array, strpos, struct_fun, substr, + substr_index, substring, tan, tanh, to_hex, translate, trim, trunc, upper, uuid, + AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, + Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -534,6 +534,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::InStr => Self::InStr, ScalarFunction::Left => Self::Left, ScalarFunction::Lpad => Self::Lpad, + ScalarFunction::Position => Self::Position, ScalarFunction::Random => Self::Random, ScalarFunction::RegexpReplace => Self::RegexpReplace, ScalarFunction::Repeat => Self::Repeat, @@ -1592,6 +1593,10 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::Position => Ok(position( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::Gcd => Ok(gcd( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index dbb52eced36c..64fb145dd213 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1530,6 +1530,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::InStr => Self::InStr, BuiltinScalarFunction::Left => Self::Left, BuiltinScalarFunction::Lpad => Self::Lpad, + BuiltinScalarFunction::Position => Self::Position, BuiltinScalarFunction::Random => Self::Random, BuiltinScalarFunction::Uuid => Self::Uuid, BuiltinScalarFunction::RegexpReplace => Self::RegexpReplace, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 9fded63af3fc..7ca9ab762832 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -514,7 +514,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Struct { values, fields } => { self.parse_struct(values, fields, schema, planner_context) } - + SQLExpr::Position { expr, r#in } => { + self.sql_position_to_expr(*expr, *r#in, schema, planner_context) + } _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), } } @@ -704,7 +706,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) } - + fn sql_position_to_expr( + &self, + substr_expr: SQLExpr, + str_expr: SQLExpr, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let fun = BuiltinScalarFunction::Position; + let substr = + self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; + let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; + let args = vec![substr, fullstr]; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + } fn sql_agg_with_filter_to_expr( &self, expr: SQLExpr, diff --git a/datafusion/sqllogictest/test_files/position.slt b/datafusion/sqllogictest/test_files/position.slt new file mode 100644 index 000000000000..c1141ab06943 --- /dev/null +++ b/datafusion/sqllogictest/test_files/position.slt @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# test position in select +query I +select position('world' in 'hello world'); +---- +7 + + + +# test in expression +query I +select 1000 where position('world' in 'hello world') != 100; +---- +1000 + + +# test in expression +query I +select 100000 where position('legend' in 'league of legend') = 11; +---- +100000 + + +# test in expression +query I +select 100000 where position('legend' in 'league of legend') != 11; +---- diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 6c526e3ada75..ad5a0f39378c 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -641,6 +641,7 @@ nullif(expression1, expression2) - [levenshtein](#levenshtein) - [substr_index](#substr_index) - [find_in_set](#find_in_set) +- [position](#position) ### `ascii` @@ -1300,6 +1301,19 @@ regexp_replace(str, regexp, replacement, flags) - **g**: (global) Search globally and don't return after the first match. - **i**: (insensitive) Ignore case when matching. +### `position` + +Returns the position of substr in orig_str + +``` +position(substr in origstr) +``` + +#### Arguments + +- **substr**: he pattern string. +- **origstr**: The model string. + ## Time and Date Functions - [now](#now)