-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't error in simplify_expressions rule #8957
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,11 +33,10 @@ use arrow::{ | |
}; | ||
use datafusion_common::{ | ||
cast::{as_large_list_array, as_list_array}, | ||
plan_err, | ||
tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, | ||
}; | ||
use datafusion_common::{ | ||
exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, | ||
internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, | ||
}; | ||
use datafusion_expr::{ | ||
and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like, | ||
|
@@ -249,6 +248,14 @@ struct ConstEvaluator<'a> { | |
input_batch: RecordBatch, | ||
} | ||
|
||
/// The simplify result of ConstEvaluator | ||
enum ConstSimplifyResult { | ||
// Expr was simplifed and contains the new expression | ||
Simplified(ScalarValue), | ||
// Evalaution encountered an error, contains the original expression | ||
SimplifyRuntimeError(DataFusionError, Expr), | ||
} | ||
|
||
impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { | ||
type N = Expr; | ||
|
||
|
@@ -281,7 +288,17 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { | |
|
||
fn mutate(&mut self, expr: Expr) -> Result<Expr> { | ||
match self.can_evaluate.pop() { | ||
Some(true) => Ok(Expr::Literal(self.evaluate_to_scalar(expr)?)), | ||
// Certain expressions such as `CASE` and `COALESCE` are short circuiting | ||
// and may not evalute all their sub expressions. Thus if | ||
// if any error is countered during simplification, return the original | ||
// so that normal evaluation can occur | ||
Some(true) => { | ||
let result = self.evaluate_to_scalar(expr); | ||
match result { | ||
ConstSimplifyResult::Simplified(s) => Ok(Expr::Literal(s)), | ||
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => Ok(expr), | ||
} | ||
} | ||
Some(false) => Ok(expr), | ||
_ => internal_err!("Failed to pop can_evaluate"), | ||
} | ||
|
@@ -376,29 +393,40 @@ impl<'a> ConstEvaluator<'a> { | |
} | ||
|
||
/// Internal helper to evaluates an Expr | ||
pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> Result<ScalarValue> { | ||
pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult { | ||
if let Expr::Literal(s) = expr { | ||
return Ok(s); | ||
return ConstSimplifyResult::Simplified(s); | ||
} | ||
|
||
let phys_expr = | ||
create_physical_expr(&expr, &self.input_schema, self.execution_props)?; | ||
let col_val = phys_expr.evaluate(&self.input_batch)?; | ||
match create_physical_expr(&expr, &self.input_schema, self.execution_props) { | ||
Ok(e) => e, | ||
Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), | ||
}; | ||
let col_val = match phys_expr.evaluate(&self.input_batch) { | ||
Ok(v) => v, | ||
Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), | ||
}; | ||
match col_val { | ||
ColumnarValue::Array(a) => { | ||
if a.len() != 1 { | ||
exec_err!( | ||
"Could not evaluate the expression, found a result of length {}", | ||
a.len() | ||
ConstSimplifyResult::SimplifyRuntimeError( | ||
DataFusionError::Execution(format!("Could not evaluate the expression, found a result of length {}", a.len())), | ||
expr, | ||
) | ||
} else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() { | ||
Ok(ScalarValue::List(a.as_list().to_owned().into())) | ||
ConstSimplifyResult::Simplified(ScalarValue::List( | ||
a.as_list().to_owned().into(), | ||
)) | ||
} else { | ||
// Non-ListArray | ||
ScalarValue::try_from_array(&a, 0) | ||
match ScalarValue::try_from_array(&a, 0) { | ||
Ok(s) => ConstSimplifyResult::Simplified(s), | ||
Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr), | ||
} | ||
} | ||
} | ||
ColumnarValue::Scalar(s) => Ok(s), | ||
ColumnarValue::Scalar(s) => ConstSimplifyResult::Simplified(s), | ||
} | ||
} | ||
} | ||
|
@@ -796,18 +824,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { | |
op: Divide, | ||
right, | ||
}) if is_null(&right) => *right, | ||
// A / 0 -> Divide by zero error if A is not null and not floating | ||
// (float / 0 -> inf | -inf | NAN) | ||
Expr::BinaryExpr(BinaryExpr { | ||
haohuaijin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
left, | ||
op: Divide, | ||
right, | ||
}) if !info.nullable(&left)? | ||
&& !info.get_data_type(&left)?.is_floating() | ||
&& is_zero(&right) => | ||
{ | ||
return plan_err!("Divide by zero"); | ||
} | ||
|
||
// | ||
// Rules for Modulo | ||
|
@@ -836,21 +852,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { | |
{ | ||
lit(0) | ||
} | ||
// A % 0 --> Divide by zero Error (if A is not floating and not null) | ||
// A % 0 --> NAN (if A is floating and not null) | ||
Expr::BinaryExpr(BinaryExpr { | ||
left, | ||
op: Modulo, | ||
right, | ||
}) if !info.nullable(&left)? && is_zero(&right) => { | ||
match info.get_data_type(&left)? { | ||
DataType::Float32 => lit(f32::NAN), | ||
DataType::Float64 => lit(f64::NAN), | ||
_ => { | ||
return plan_err!("Divide by zero"); | ||
} | ||
} | ||
} | ||
|
||
// | ||
// Rules for BitwiseAnd | ||
|
@@ -1317,9 +1318,7 @@ mod tests { | |
array::{ArrayRef, Int32Array}, | ||
datatypes::{DataType, Field, Schema}, | ||
}; | ||
use datafusion_common::{ | ||
assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema, | ||
}; | ||
use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema}; | ||
use datafusion_expr::{interval_arithmetic::Interval, *}; | ||
use datafusion_physical_expr::execution_props::ExecutionProps; | ||
|
||
|
@@ -1792,27 +1791,6 @@ mod tests { | |
assert_eq!(simplify(expr), expected); | ||
} | ||
|
||
#[test] | ||
fn test_simplify_divide_zero_by_zero() { | ||
haohuaijin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// 0 / 0 -> Divide by zero | ||
let expr = lit(0) / lit(0); | ||
let err = try_simplify(expr).unwrap_err(); | ||
|
||
let _expected = plan_datafusion_err!("Divide by zero"); | ||
|
||
assert!(matches!(err, ref _expected), "{err}"); | ||
} | ||
|
||
#[test] | ||
fn test_simplify_divide_by_zero() { | ||
// A / 0 -> DivideByZeroError | ||
let expr = col("c2_non_null") / lit(0); | ||
assert_eq!( | ||
try_simplify(expr).unwrap_err().strip_backtrace(), | ||
"Error during planning: Divide by zero" | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_simplify_modulo_by_null() { | ||
let null = lit(ScalarValue::Null); | ||
|
@@ -1837,6 +1815,26 @@ mod tests { | |
assert_eq!(simplify(expr), expected); | ||
} | ||
|
||
#[test] | ||
fn test_simplify_divide_zero_by_zero() { | ||
// because divide by 0 maybe occur in short-circuit expression | ||
// so we should not simplify this, and throw error in runtime | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
let expr = lit(0) / lit(0); | ||
let expected = expr.clone(); | ||
|
||
assert_eq!(simplify(expr), expected); | ||
} | ||
|
||
#[test] | ||
fn test_simplify_divide_by_zero() { | ||
// because divide by 0 maybe occur in short-circuit expression | ||
// so we should not simplify this, and throw error in runtime | ||
let expr = col("c2_non_null") / lit(0); | ||
let expected = expr.clone(); | ||
|
||
assert_eq!(simplify(expr), expected); | ||
} | ||
|
||
#[test] | ||
fn test_simplify_modulo_by_one_non_null() { | ||
let expr = col("c2_non_null") % lit(1); | ||
|
@@ -2231,11 +2229,12 @@ mod tests { | |
|
||
#[test] | ||
fn test_simplify_modulo_by_zero_non_null() { | ||
// because modulo by 0 maybe occur in short-circuit expression | ||
// so we should not simplify this, and throw error in runtime. | ||
let expr = col("c2_non_null") % lit(0); | ||
assert_eq!( | ||
try_simplify(expr).unwrap_err().strip_backtrace(), | ||
"Error during planning: Divide by zero" | ||
); | ||
let expected = expr.clone(); | ||
|
||
assert_eq!(simplify(expr), expected); | ||
} | ||
|
||
#[test] | ||
|
@@ -3385,4 +3384,22 @@ mod tests { | |
let output = simplify_with_guarantee(expr.clone(), guarantees); | ||
assert_eq!(&output, &expr_x); | ||
} | ||
|
||
#[test] | ||
fn test_expression_partial_simplify_1() { | ||
// (1 + 2) + (4 / 0) -> 3 + (4 / 0) | ||
let expr = (lit(1) + lit(2)) + (lit(4) / lit(0)); | ||
let expected = (lit(3)) + (lit(4) / lit(0)); | ||
|
||
assert_eq!(simplify(expr), expected); | ||
} | ||
|
||
#[test] | ||
fn test_expression_partial_simplify_2() { | ||
// (1 > 2) and (4 / 0) -> false | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0)); | ||
let expected = lit(false); | ||
|
||
assert_eq!(simplify(expr), expected); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,28 +138,6 @@ mod tests { | |
ExprSchemable, JoinType, | ||
}; | ||
|
||
/// A macro to assert that one string is contained within another with | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💯 for reusing the existing one |
||
/// a nice error message if they are not. | ||
/// | ||
/// Usage: `assert_contains!(actual, expected)` | ||
/// | ||
/// Is a macro so test error | ||
/// messages are on the same line as the failure; | ||
/// | ||
/// Both arguments must be convertable into Strings (Into<String>) | ||
macro_rules! assert_contains { | ||
($ACTUAL: expr, $EXPECTED: expr) => { | ||
let actual_value: String = $ACTUAL.into(); | ||
let expected_value: String = $EXPECTED.into(); | ||
assert!( | ||
actual_value.contains(&expected_value), | ||
"Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", | ||
expected_value, | ||
actual_value | ||
); | ||
}; | ||
} | ||
|
||
fn test_table_scan() -> LogicalPlan { | ||
let schema = Schema::new(vec![ | ||
Field::new("a", DataType::Boolean, false), | ||
|
@@ -425,18 +403,6 @@ mod tests { | |
assert_optimized_plan_eq(&plan, expected) | ||
} | ||
|
||
// expect optimizing will result in an error, returning the error string | ||
fn get_optimized_plan_err(plan: &LogicalPlan, date_time: &DateTime<Utc>) -> String { | ||
let config = OptimizerContext::new().with_query_execution_start_time(*date_time); | ||
let rule = SimplifyExpressions::new(); | ||
|
||
let err = rule | ||
.try_optimize(plan, &config) | ||
.expect_err("expected optimization to fail"); | ||
|
||
err.to_string() | ||
} | ||
|
||
fn get_optimized_plan_formatted( | ||
plan: &LogicalPlan, | ||
date_time: &DateTime<Utc>, | ||
|
@@ -468,21 +434,6 @@ mod tests { | |
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn to_timestamp_expr_wrong_arg() -> Result<()> { | ||
let table_scan = test_table_scan(); | ||
let proj = vec![to_timestamp_expr("I'M NOT A TIMESTAMP")]; | ||
let plan = LogicalPlanBuilder::from(table_scan) | ||
.project(proj)? | ||
.build()?; | ||
|
||
let expected = | ||
"Error parsing timestamp from 'I'M NOT A TIMESTAMP': error parsing date"; | ||
let actual = get_optimized_plan_err(&plan, &Utc::now()); | ||
assert_contains!(actual, expected); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn cast_expr() -> Result<()> { | ||
let table_scan = test_table_scan(); | ||
|
@@ -498,20 +449,6 @@ mod tests { | |
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn cast_expr_wrong_arg() -> Result<()> { | ||
let table_scan = test_table_scan(); | ||
let proj = vec![Expr::Cast(Cast::new(Box::new(lit("")), DataType::Int32))]; | ||
let plan = LogicalPlanBuilder::from(table_scan) | ||
.project(proj)? | ||
.build()?; | ||
|
||
let expected = "Cannot cast string '' to value of Int32 type"; | ||
let actual = get_optimized_plan_err(&plan, &Utc::now()); | ||
assert_contains!(actual, expected); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn multiple_now_expr() -> Result<()> { | ||
let table_scan = test_table_scan(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍