Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't error in simplify_expressions rule #8957

Merged
merged 3 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 85 additions & 68 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ use arrow::{
};
use datafusion_common::{
cast::{as_large_list_array, as_list_array},
plan_err,
tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{
exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{
and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr, Like,
Expand Down Expand Up @@ -249,6 +248,14 @@ struct ConstEvaluator<'a> {
input_batch: RecordBatch,
}

/// The simplify result of ConstEvaluator
enum ConstSimplifyResult {
// Expr was simplifed and contains the new expression
Simplified(ScalarValue),
// Evalaution encountered an error, contains the original expression
SimplifyRuntimeError(DataFusionError, Expr),
}

impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {
type N = Expr;

Expand Down Expand Up @@ -281,7 +288,17 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
match self.can_evaluate.pop() {
Some(true) => Ok(Expr::Literal(self.evaluate_to_scalar(expr)?)),
// Certain expressions such as `CASE` and `COALESCE` are short circuiting
// and may not evalute all their sub expressions. Thus if
// if any error is countered during simplification, return the original
// so that normal evaluation can occur
Comment on lines +291 to +294
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Some(true) => {
let result = self.evaluate_to_scalar(expr);
match result {
ConstSimplifyResult::Simplified(s) => Ok(Expr::Literal(s)),
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => Ok(expr),
}
}
Some(false) => Ok(expr),
_ => internal_err!("Failed to pop can_evaluate"),
}
Expand Down Expand Up @@ -376,29 +393,40 @@ impl<'a> ConstEvaluator<'a> {
}

/// Internal helper to evaluates an Expr
pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> Result<ScalarValue> {
pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult {
if let Expr::Literal(s) = expr {
return Ok(s);
return ConstSimplifyResult::Simplified(s);
}

let phys_expr =
create_physical_expr(&expr, &self.input_schema, self.execution_props)?;
let col_val = phys_expr.evaluate(&self.input_batch)?;
match create_physical_expr(&expr, &self.input_schema, self.execution_props) {
Ok(e) => e,
Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
};
let col_val = match phys_expr.evaluate(&self.input_batch) {
Ok(v) => v,
Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr),
};
match col_val {
ColumnarValue::Array(a) => {
if a.len() != 1 {
exec_err!(
"Could not evaluate the expression, found a result of length {}",
a.len()
ConstSimplifyResult::SimplifyRuntimeError(
DataFusionError::Execution(format!("Could not evaluate the expression, found a result of length {}", a.len())),
expr,
)
} else if as_list_array(&a).is_ok() || as_large_list_array(&a).is_ok() {
Ok(ScalarValue::List(a.as_list().to_owned().into()))
ConstSimplifyResult::Simplified(ScalarValue::List(
a.as_list().to_owned().into(),
))
} else {
// Non-ListArray
ScalarValue::try_from_array(&a, 0)
match ScalarValue::try_from_array(&a, 0) {
Ok(s) => ConstSimplifyResult::Simplified(s),
Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr),
}
}
}
ColumnarValue::Scalar(s) => Ok(s),
ColumnarValue::Scalar(s) => ConstSimplifyResult::Simplified(s),
}
}
}
Expand Down Expand Up @@ -796,18 +824,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
op: Divide,
right,
}) if is_null(&right) => *right,
// A / 0 -> Divide by zero error if A is not null and not floating
// (float / 0 -> inf | -inf | NAN)
Expr::BinaryExpr(BinaryExpr {
haohuaijin marked this conversation as resolved.
Show resolved Hide resolved
left,
op: Divide,
right,
}) if !info.nullable(&left)?
&& !info.get_data_type(&left)?.is_floating()
&& is_zero(&right) =>
{
return plan_err!("Divide by zero");
}

//
// Rules for Modulo
Expand Down Expand Up @@ -836,21 +852,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
{
lit(0)
}
// A % 0 --> Divide by zero Error (if A is not floating and not null)
// A % 0 --> NAN (if A is floating and not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: Modulo,
right,
}) if !info.nullable(&left)? && is_zero(&right) => {
match info.get_data_type(&left)? {
DataType::Float32 => lit(f32::NAN),
DataType::Float64 => lit(f64::NAN),
_ => {
return plan_err!("Divide by zero");
}
}
}

//
// Rules for BitwiseAnd
Expand Down Expand Up @@ -1317,9 +1318,7 @@ mod tests {
array::{ArrayRef, Int32Array},
datatypes::{DataType, Field, Schema},
};
use datafusion_common::{
assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema,
};
use datafusion_common::{assert_contains, cast::as_int32_array, DFField, ToDFSchema};
use datafusion_expr::{interval_arithmetic::Interval, *};
use datafusion_physical_expr::execution_props::ExecutionProps;

Expand Down Expand Up @@ -1792,27 +1791,6 @@ mod tests {
assert_eq!(simplify(expr), expected);
}

#[test]
fn test_simplify_divide_zero_by_zero() {
haohuaijin marked this conversation as resolved.
Show resolved Hide resolved
// 0 / 0 -> Divide by zero
let expr = lit(0) / lit(0);
let err = try_simplify(expr).unwrap_err();

let _expected = plan_datafusion_err!("Divide by zero");

assert!(matches!(err, ref _expected), "{err}");
}

#[test]
fn test_simplify_divide_by_zero() {
// A / 0 -> DivideByZeroError
let expr = col("c2_non_null") / lit(0);
assert_eq!(
try_simplify(expr).unwrap_err().strip_backtrace(),
"Error during planning: Divide by zero"
);
}

#[test]
fn test_simplify_modulo_by_null() {
let null = lit(ScalarValue::Null);
Expand All @@ -1837,6 +1815,26 @@ mod tests {
assert_eq!(simplify(expr), expected);
}

#[test]
fn test_simplify_divide_zero_by_zero() {
// because divide by 0 maybe occur in short-circuit expression
// so we should not simplify this, and throw error in runtime
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let expr = lit(0) / lit(0);
let expected = expr.clone();

assert_eq!(simplify(expr), expected);
}

#[test]
fn test_simplify_divide_by_zero() {
// because divide by 0 maybe occur in short-circuit expression
// so we should not simplify this, and throw error in runtime
let expr = col("c2_non_null") / lit(0);
let expected = expr.clone();

assert_eq!(simplify(expr), expected);
}

#[test]
fn test_simplify_modulo_by_one_non_null() {
let expr = col("c2_non_null") % lit(1);
Expand Down Expand Up @@ -2231,11 +2229,12 @@ mod tests {

#[test]
fn test_simplify_modulo_by_zero_non_null() {
// because modulo by 0 maybe occur in short-circuit expression
// so we should not simplify this, and throw error in runtime.
let expr = col("c2_non_null") % lit(0);
assert_eq!(
try_simplify(expr).unwrap_err().strip_backtrace(),
"Error during planning: Divide by zero"
);
let expected = expr.clone();

assert_eq!(simplify(expr), expected);
}

#[test]
Expand Down Expand Up @@ -3385,4 +3384,22 @@ mod tests {
let output = simplify_with_guarantee(expr.clone(), guarantees);
assert_eq!(&output, &expr_x);
}

#[test]
fn test_expression_partial_simplify_1() {
// (1 + 2) + (4 / 0) -> 3 + (4 / 0)
let expr = (lit(1) + lit(2)) + (lit(4) / lit(0));
let expected = (lit(3)) + (lit(4) / lit(0));

assert_eq!(simplify(expr), expected);
}

#[test]
fn test_expression_partial_simplify_2() {
// (1 > 2) and (4 / 0) -> false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0));
let expected = lit(false);

assert_eq!(simplify(expr), expected);
}
}
63 changes: 0 additions & 63 deletions datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,28 +138,6 @@ mod tests {
ExprSchemable, JoinType,
};

/// A macro to assert that one string is contained within another with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯 for reusing the existing one

/// a nice error message if they are not.
///
/// Usage: `assert_contains!(actual, expected)`
///
/// Is a macro so test error
/// messages are on the same line as the failure;
///
/// Both arguments must be convertable into Strings (Into<String>)
macro_rules! assert_contains {
($ACTUAL: expr, $EXPECTED: expr) => {
let actual_value: String = $ACTUAL.into();
let expected_value: String = $EXPECTED.into();
assert!(
actual_value.contains(&expected_value),
"Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}",
expected_value,
actual_value
);
};
}

fn test_table_scan() -> LogicalPlan {
let schema = Schema::new(vec![
Field::new("a", DataType::Boolean, false),
Expand Down Expand Up @@ -425,18 +403,6 @@ mod tests {
assert_optimized_plan_eq(&plan, expected)
}

// expect optimizing will result in an error, returning the error string
fn get_optimized_plan_err(plan: &LogicalPlan, date_time: &DateTime<Utc>) -> String {
let config = OptimizerContext::new().with_query_execution_start_time(*date_time);
let rule = SimplifyExpressions::new();

let err = rule
.try_optimize(plan, &config)
.expect_err("expected optimization to fail");

err.to_string()
}

fn get_optimized_plan_formatted(
plan: &LogicalPlan,
date_time: &DateTime<Utc>,
Expand Down Expand Up @@ -468,21 +434,6 @@ mod tests {
Ok(())
}

#[test]
fn to_timestamp_expr_wrong_arg() -> Result<()> {
let table_scan = test_table_scan();
let proj = vec![to_timestamp_expr("I'M NOT A TIMESTAMP")];
let plan = LogicalPlanBuilder::from(table_scan)
.project(proj)?
.build()?;

let expected =
"Error parsing timestamp from 'I'M NOT A TIMESTAMP': error parsing date";
let actual = get_optimized_plan_err(&plan, &Utc::now());
assert_contains!(actual, expected);
Ok(())
}

#[test]
fn cast_expr() -> Result<()> {
let table_scan = test_table_scan();
Expand All @@ -498,20 +449,6 @@ mod tests {
Ok(())
}

#[test]
fn cast_expr_wrong_arg() -> Result<()> {
let table_scan = test_table_scan();
let proj = vec![Expr::Cast(Cast::new(Box::new(lit("")), DataType::Int32))];
let plan = LogicalPlanBuilder::from(table_scan)
.project(proj)?
.build()?;

let expected = "Cannot cast string '' to value of Int32 type";
let actual = get_optimized_plan_err(&plan, &Utc::now());
assert_contains!(actual, expected);
Ok(())
}

#[test]
fn multiple_now_expr() -> Result<()> {
let table_scan = test_table_scan();
Expand Down
10 changes: 10 additions & 0 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ impl CaseExpr {
// Make sure we only consider rows that have not been matched yet
let when_match = and(&when_match, &remainder)?;

// When no rows available for when clause, skip then clause
if when_match.true_count() == 0 {
continue;
}

haohuaijin marked this conversation as resolved.
Show resolved Hide resolved
let then_value = self.when_then_expr[i]
.1
.evaluate_selection(batch, &when_match)?;
Expand Down Expand Up @@ -214,6 +219,11 @@ impl CaseExpr {
// Make sure we only consider rows that have not been matched yet
let when_value = and(&when_value, &remainder)?;

// When no rows available for when clause, skip then clause
if when_value.true_count() == 0 {
continue;
}

let then_value = self.when_then_expr[i]
.1
.evaluate_selection(batch, &when_value)?;
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/arrow_typeof.slt
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ select arrow_cast([1], 'FixedSizeList(1, Int64)');
----
[1]

query error DataFusion error: Optimizer rule 'simplify_expressions' failed
query error DataFusion error: Arrow error: Cast error: Cannot cast to FixedSizeList\(4\): value at index 0 has length 3
select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(4, Int64)');

query ?
Expand All @@ -421,4 +421,4 @@ FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0
query ?
select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)');
----
[1, 2, 3]
[1, 2, 3]
Loading