Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve numerice_coercion and reuse in both compare op and math op #8385

Closed
wants to merge 11 commits into from
258 changes: 145 additions & 113 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ use datafusion_common::{
exec_datafusion_err, plan_datafusion_err, plan_err, DataFusionError, Result,
};

/// Returns true if this type is Decimal.
fn is_decimal(data_type: &DataType) -> bool {
use DataType::*;
matches!(data_type, Decimal128(_, _) | Decimal256(_, _))
}

/// The type signature of an instantiation of binary operator expression such as
/// `lhs + rhs`
///
Expand Down Expand Up @@ -290,7 +296,8 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
// same type => equality is possible
return Some(lhs_type.clone());
}
comparison_binary_numeric_coercion(lhs_type, rhs_type)

numeric_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
Expand Down Expand Up @@ -354,73 +361,144 @@ fn string_temporal_coercion(
match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type))
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where one both are numeric
fn comparison_binary_numeric_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
/// Decimal coercion rules for comparison operations, including comparison between decimal and non-decimal types.
fn binary_decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
if !lhs_type.is_numeric() || !rhs_type.is_numeric() {

// At least on should be decimal
if !is_decimal(lhs_type) && !is_decimal(rhs_type) {
return None;
};

// same type => all good
if lhs_type == rhs_type {
return Some(lhs_type.clone());
match (lhs_type, rhs_type) {
// Prefer decimal data type over floating point for comparison operation
(Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _), Decimal256(_, _)) => {
get_wider_decimal_type(lhs_type, rhs_type)
}
(decimal_type @ Decimal128(_, _), other_type)
| (other_type, decimal_type @ Decimal128(_, _))
| (decimal_type @ Decimal256(_, _), other_type)
| (other_type, decimal_type @ Decimal256(_, _)) => {
get_comparison_common_decimal_type(decimal_type, other_type)
}
_ => None,
}
}

/// Coerce non decimal numeric types to a common type for the purposes of a comparison operation and math operation
///
/// We tend to find the narrowest type that can represent both inputs if possible,
/// so the return type MAY not be the same as either input type.
///
/// For example, `Int64` and `Float32` will coerce to `Float64`.
///
/// Also, since there might not be perfect type for both inputs, so data lossy is expected.
/// For example, `UInt64` and `Float64` will coerce to `Float64`, so casting `UInt64` to `Float64` will lose data.
fn non_decimal_numeric_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;

// these are ordered from most informative to least informative so
// that the coercion does not lose information via truncation
match (lhs_type, rhs_type) {
// Prefer decimal data type over floating point for comparison operation
(Decimal128(_, _), Decimal128(_, _)) => {
get_wider_decimal_type(lhs_type, rhs_type)
// f64
// Prefer f64 over u64 and i64, data lossy is expected
(Float64, _) | (_, Float64) => Some(Float64),

// u64
// Prefer f64 over u64, data lossy is expected
(UInt64, Float32) | (Float32, UInt64) | (UInt64, Float16) | (Float16, UInt64) => {
Some(Float64)
}
(Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
(_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
(Decimal256(_, _), Decimal256(_, _)) => {
get_wider_decimal_type(lhs_type, rhs_type)
// Prefer i64 over u64, data lossy is expected
(UInt64, data_type) | (data_type, UInt64) => {
if data_type.is_signed_integer() {
Some(Int64)
} else {
Some(UInt64)
}
}
(Decimal256(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
(_, Decimal256(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
// The following match arms encode the following logic: Given the two
// integral types, we choose the narrowest possible integral type that
// accommodates all values of both types. Note that some information
// loss is inevitable when we have a signed type and a `UInt64`, in
// which case we use `Int64`;i.e. the widest signed integral type.
(Int64, _)
| (_, Int64)
| (UInt64, Int8)
| (Int8, UInt64)
| (UInt64, Int16)
| (Int16, UInt64)
| (UInt64, Int32)
| (Int32, UInt64)
| (UInt32, Int8)
| (Int8, UInt32)
| (UInt32, Int16)
| (Int16, UInt32)
| (UInt32, Int32)
| (Int32, UInt32) => Some(Int64),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(Int32, _)
| (_, Int32)
| (UInt16, Int16)
| (Int16, UInt16)
| (UInt16, Int8)
| (Int8, UInt16) => Some(Int32),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt8, _) | (_, UInt8) => Some(UInt8),

// i64
// Prefer f64 over i64, data lossy is expected
(Int64, Float32) | (Float32, Int64) | (Int64, Float16) | (Float16, Int64) => {
Some(Float64)
}
(Int64, _) | (_, Int64) => Some(Int64),

// f32
// f32 is not guaranteed to be able to represent all i32 values
(Float32, UInt32) | (UInt32, Float32) | (Float32, Int32) | (Int32, Float32) => {
Some(Float64)
}
(Float32, _) | (_, Float32) => Some(Float32),

// u32
(UInt32, Float16) | (Float16, UInt32) => Some(Float64),
(UInt32, data_type) | (data_type, UInt32) => {
if data_type.is_signed_integer() {
Some(Int64)
} else {
Some(UInt32)
}
}

// i32
// f32 is not guaranteed to be able to represent all i32 values, so f64 is preferred
(Int32, Float16) | (Float16, Int32) => Some(Float64),
(Int32, _) | (_, Int32) => Some(Int32),

// f16
(Float16, UInt16) | (UInt16, Float16) | (Float16, Int16) | (Int16, Float16) => {
Some(Float32)
}
(Float16, _) | (_, Float16) => Some(Float16),

// u16
(UInt16, data_type) | (data_type, UInt16) => {
if data_type.is_signed_integer() {
Some(Int32)
} else {
Some(UInt16)
}
}

// i16
(Int16, _) | (_, Int16) => Some(Int16),

// u8
(UInt8, UInt8) => Some(UInt8),
(UInt8, Int8) | (Int8, UInt8) => Some(Int16),

// i8
(Int8, Int8) => Some(Int8),

_ => None,
}
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where both are numeric and the coerced type MAY not be the same as either input type.
fn numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
if !lhs_type.is_numeric() || !rhs_type.is_numeric() {
return None;
};

if is_decimal(lhs_type) || is_decimal(rhs_type) {
return binary_decimal_coercion(lhs_type, rhs_type);
};

non_decimal_numeric_coercion(lhs_type, rhs_type)
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where both are numeric and the coerced type SHOULD be one of the input types.
pub fn exact_numeric_coercion(_: &DataType, _: &DataType) -> Option<DataType> {
todo!("Implement this when we have a use case for it")
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of
/// a comparison operation where one is a decimal
fn get_comparison_common_decimal_type(
Expand Down Expand Up @@ -575,17 +653,19 @@ fn mathematics_numerical_coercion(
(_, Dictionary(_, value_type)) => {
mathematics_numerical_coercion(lhs_type, value_type)
}
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
(Int64, _) | (_, Int64) => Some(Int64),
(Int32, _) | (_, Int32) => Some(Int32),
(Int16, _) | (_, Int16) => Some(Int16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
_ => {
// `math_decimal_coercion` does not handle coercion between Decimal and Float and Uint.
if is_decimal(lhs_type) && is_decimal(rhs_type) {
unreachable!("Should be handled in `math_decimal_coercion`")
} else if is_decimal(lhs_type) {
Some(rhs_type.to_owned())
} else if is_decimal(rhs_type) {
Some(lhs_type.to_owned())
} else {
// Both are non decimal numeric type
non_decimal_numeric_coercion(lhs_type, rhs_type)
}
}
}
}

Expand Down Expand Up @@ -854,16 +934,6 @@ mod tests {
use arrow::datatypes::DataType;
use datafusion_common::{assert_contains, Result};

#[test]
fn test_coercion_error() -> Result<()> {
let result_type =
get_input_types(&DataType::Float32, &Operator::Plus, &DataType::Utf8);

let e = result_type.unwrap_err();
assert_eq!(e.strip_backtrace(), "Error during planning: Cannot coerce arithmetic expression Float32 + Utf8 to valid types");
Ok(())
}

#[test]
fn test_decimal_binary_comparison_coercion() -> Result<()> {
let input_decimal = DataType::Decimal128(20, 3);
Expand Down Expand Up @@ -1222,44 +1292,6 @@ mod tests {
Ok(())
}

#[test]
fn test_type_coercion_arithmetic() -> Result<()> {
// integer
test_coercion_binary_rule!(
DataType::Int32,
DataType::UInt32,
Operator::Plus,
DataType::Int32
);
test_coercion_binary_rule!(
DataType::Int32,
DataType::UInt16,
Operator::Minus,
DataType::Int32
);
test_coercion_binary_rule!(
DataType::Int8,
DataType::Int64,
Operator::Multiply,
DataType::Int64
);
// float
test_coercion_binary_rule!(
DataType::Float32,
DataType::Int32,
Operator::Plus,
DataType::Float32
);
test_coercion_binary_rule!(
DataType::Float32,
DataType::Float64,
Operator::Multiply,
DataType::Float64
);
// TODO add other data type
Ok(())
}

fn test_math_decimal_coercion_rule(
lhs_type: DataType,
rhs_type: DataType,
Expand Down Expand Up @@ -1333,7 +1365,7 @@ mod tests {
DataType::Float32,
DataType::Int64,
Operator::Eq,
DataType::Float32
DataType::Float64
);
test_coercion_binary_rule!(
DataType::Float32,
Expand Down
8 changes: 4 additions & 4 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ mod tests {

let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\
\n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
Expand Down Expand Up @@ -511,9 +511,9 @@ mod tests {
.build()?;

// Should work
let expected = "Projection: group_alias_0 AS test.a + Int32(1), COUNT(alias1) AS COUNT(DISTINCT test.c) [test.a + Int32(1):Int32, COUNT(DISTINCT test.c):Int64;N]\
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\
let expected = "Projection: group_alias_0 AS test.a + Int32(1), COUNT(alias1) AS COUNT(DISTINCT test.c) [test.a + Int32(1):Int64, COUNT(DISTINCT test.c):Int64;N]\
\n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int64, COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int64, alias1:UInt32]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
Expand Down
4 changes: 2 additions & 2 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ fn push_down_filter_groupby_expr_contains_alias() {
let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3";
let plan = test_sql(sql).unwrap();
let expected = "Projection: test.col_int32 + test.col_uint32 AS c, COUNT(*)\
\n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]\
\n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\
\n Aggregate: groupBy=[[CAST(test.col_int32 AS Int64) + CAST(test.col_uint32 AS Int64)]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]\
\n Filter: CAST(test.col_int32 AS Int64) + CAST(test.col_uint32 AS Int64) > Int64(3)\
\n TableScan: test projection=[col_int32, col_uint32]";
assert_eq!(expected, format!("{plan:?}"));
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,9 +794,9 @@ mod tests {
DataType::UInt32,
vec![1u32, 2u32],
Operator::Plus,
Int32Array,
DataType::Int32,
[2i32, 4i32],
Int64Array,
DataType::Int64,
[2i64, 4i64],
);
test_coercion!(
Int32Array,
Expand Down
Loading