Skip to content

Commit

Permalink
support FixedSizeList type coercion
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H committed Jan 18, 2024
1 parent 78e7d2f commit f169af2
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 21 deletions.
10 changes: 5 additions & 5 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,9 @@ pub fn arrays_into_list_array(
/// ```
pub fn base_type(data_type: &DataType) -> DataType {
match data_type {
DataType::List(field) | DataType::LargeList(field) => {
base_type(field.data_type())
}
DataType::List(field)
| DataType::LargeList(field)
| DataType::FixedSizeList(field, _) => base_type(field.data_type()),
_ => data_type.to_owned(),
}
}
Expand All @@ -464,9 +464,9 @@ pub fn coerced_type_with_base_type_only(
base_type: &DataType,
) -> DataType {
match data_type {
DataType::List(field) => {
DataType::List(field) | DataType::FixedSizeList(field, _) => {
let data_type = match field.data_type() {
DataType::List(_) => {
DataType::List(_) | DataType::FixedSizeList(_, _) => {
coerced_type_with_base_type_only(field.data_type(), base_type)
}
_ => base_type.to_owned(),
Expand Down
43 changes: 30 additions & 13 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -949,12 +949,18 @@ impl BuiltinScalarFunction {
}
BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayEmpty => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayElement => {
Signature::array_and_element(self.volatility())
}
BuiltinScalarFunction::ArrayExcept => {
Signature::array_and_element(self.volatility())
}
BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayHasAll
| BuiltinScalarFunction::ArrayHasAny
| BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()),
| BuiltinScalarFunction::ArrayHas => {
Signature::array_and_element(self.volatility())
}
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
}
Expand All @@ -963,15 +969,22 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayPosition => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayPrepend => Signature {
type_signature: ElementAndArray,
volatility: self.volatility(),
},
BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayPositions => {
Signature::array_and_element(self.volatility())
}
BuiltinScalarFunction::ArrayPrepend => {
Signature::element_and_array(self.volatility())
}
BuiltinScalarFunction::ArrayRepeat => {
Signature::array_and_element(self.volatility())
}
BuiltinScalarFunction::ArrayRemove => {
Signature::array_and_element(self.volatility())
}
BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()),
BuiltinScalarFunction::ArrayRemoveAll => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayRemoveAll => {
Signature::array_and_element(self.volatility())
}
BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()),
BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()),
BuiltinScalarFunction::ArrayReplaceAll => {
Expand All @@ -981,8 +994,12 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayToString => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayIntersect => {
Signature::array_and_element(self.volatility())
}
BuiltinScalarFunction::ArrayUnion => {
Signature::array_and_element(self.volatility())
}
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayResize => {
Signature::variadic_any(self.volatility())
Expand Down
14 changes: 14 additions & 0 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,20 @@ impl Signature {
volatility,
}
}
/// Specialized Signature for ArrayAppend and similar functions
pub fn array_and_element(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArrayAndElement,
volatility,
}
}
/// Specialized Signature for ArrayPrepend and similar functions
pub fn element_and_array(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ElementAndArray,
volatility,
}
}
}

/// Monotonicity of the `ScalarFunctionExpr` with respect to its arguments.
Expand Down
16 changes: 13 additions & 3 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ fn get_valid_types(

// We need to find the coerced base type, mainly for cases like:
// `array_append(List(null), i64)` -> `List(i64)`
let array_base_type = datafusion_common::utils::base_type(array_type);
let elem_base_type = datafusion_common::utils::base_type(elem_type);
let array_base_type = dbg!(datafusion_common::utils::base_type(array_type));
let elem_base_type = dbg!(datafusion_common::utils::base_type(elem_type));
let new_base_type = comparison_coercion(&array_base_type, &elem_base_type);

if new_base_type.is_none() {
Expand All @@ -125,6 +125,14 @@ fn get_valid_types(
Ok(vec![vec![elem_type.to_owned(), array_type.clone()]])
}
}
DataType::FixedSizeList(ref field, _) => {
let elem_type = field.data_type();
if is_append {
Ok(vec![vec![array_type.clone(), elem_type.to_owned()]])
} else {
Ok(vec![vec![elem_type.to_owned(), array_type.clone()]])
}
}
_ => Ok(vec![vec![]]),
}
}
Expand Down Expand Up @@ -161,7 +169,7 @@ fn get_valid_types(

TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::ArrayAndElement => {
return array_append_or_prepend_valid_types(current_types, true)
return dbg!(array_append_or_prepend_valid_types(current_types, true))
}
TypeSignature::ElementAndArray => {
return array_append_or_prepend_valid_types(current_types, false)
Expand Down Expand Up @@ -311,6 +319,8 @@ fn coerced_from<'a>(
Utf8 | LargeUtf8 => Some(type_into.clone()),
Null if can_cast_types(type_from, type_into) => Some(type_into.clone()),

List(_) if matches!(type_from, FixedSizeList(_, _)) => Some(type_into.clone()),

// Only accept list and largelist with the same number of dimensions unless the type is Null.
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this.
List(_) | LargeList(_)
Expand Down

0 comments on commit f169af2

Please sign in to comment.