diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 674e85a55c92..35450b1f32ff 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -19,8 +19,10 @@ use std::ops::Not; -use super::or_in_list_simplifier::OrInListSimplifier; use super::utils::*; +use super::{ + inlist_simplifier::InListSimplifier, or_in_list_simplifier::OrInListSimplifier, +}; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; use crate::simplify_expressions::regex::simplify_regex_expr; @@ -133,6 +135,7 @@ impl ExprSimplifier { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut or_in_list_simplifier = OrInListSimplifier::new(); + let mut inlist_simplifier = InListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); // TODO iterate until no changes are made during rewrite @@ -142,6 +145,7 @@ impl ExprSimplifier { expr.rewrite(&mut const_evaluator)? .rewrite(&mut simplifier)? .rewrite(&mut or_in_list_simplifier)? + .rewrite(&mut inlist_simplifier)? .rewrite(&mut guarantee_rewriter)? // run both passes twice to try an minimize simplifications that we missed .rewrite(&mut const_evaluator)? @@ -3201,11 +3205,118 @@ mod tests { col("c1").eq(subquery1).or(col("c1").eq(subquery2)) ); - // c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> - // c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) + // 1. c1 IN (1,2,3,4) AND c1 IN (5,6,7,8) -> false + let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and( + in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false), + ); + assert_eq!(simplify(expr.clone()), lit(false)); + + // 2. c1 IN (1,2,3,4) AND c1 IN (4,5,6,7) -> c1 = 4 + let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and( + in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], false), + ); + assert_eq!(simplify(expr.clone()), col("c1").eq(lit(4))); + + // 3. c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> true let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or( in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true), ); + assert_eq!(simplify(expr.clone()), lit(true)); + + // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7) + let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and( + in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true), + ); + assert_eq!( + simplify(expr.clone()), + in_list( + col("c1"), + vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6), lit(7)], + true + ) + ); + + // 5. c1 IN (1,2,3,4) OR c1 IN (2,3,4,5) -> c1 IN (1,2,3,4,5) + let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).or( + in_list(col("c1"), vec![lit(2), lit(3), lit(4), lit(5)], false), + ); + assert_eq!( + simplify(expr.clone()), + in_list( + col("c1"), + vec![lit(1), lit(2), lit(3), lit(4), lit(5)], + false + ) + ); + + // 6. c1 IN (1,2,3) AND c1 NOT INT (1,2,3,4,5) -> false + let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3)], false).and(in_list( + col("c1"), + vec![lit(1), lit(2), lit(3), lit(4), lit(5)], + true, + )); + assert_eq!(simplify(expr.clone()), lit(false)); + + // 7. c1 NOT IN (1,2,3,4) AND c1 IN (1,2,3,4,5) -> c1 = 5 + let expr = + in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and(in_list( + col("c1"), + vec![lit(1), lit(2), lit(3), lit(4), lit(5)], + false, + )); + assert_eq!(simplify(expr.clone()), col("c1").eq(lit(5))); + + // 8. c1 IN (1,2,3,4) AND c1 NOT IN (5,6,7,8) -> c1 IN (1,2,3,4) + let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and( + in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true), + ); + assert_eq!( + simplify(expr.clone()), + in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false) + ); + + // inlist with more than two expressions + // c1 IN (1,2,3,4,5,6) AND c1 IN (1,3,5,6) AND c1 IN (3,6) -> c1 = 3 OR c1 = 6 + let expr = in_list( + col("c1"), + vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6)], + false, + ) + .and(in_list( + col("c1"), + vec![lit(1), lit(3), lit(5), lit(6)], + false, + )) + .and(in_list(col("c1"), vec![lit(3), lit(6)], false)); + assert_eq!( + simplify(expr.clone()), + col("c1").eq(lit(3)).or(col("c1").eq(lit(6))) + ); + + // c1 NOT IN (1,2,3,4) AND c1 IN (5,6,7,8) AND c1 NOT IN (3,4,5,6) AND c1 IN (8,9,10) -> c1 = 8 + let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and( + in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false) + .and(in_list( + col("c1"), + vec![lit(3), lit(4), lit(5), lit(6)], + true, + )) + .and(in_list(col("c1"), vec![lit(8), lit(9), lit(10)], false)), + ); + assert_eq!(simplify(expr.clone()), col("c1").eq(lit(8))); + + // Contains non-InList expression + // c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) -> c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) + let expr = + in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or(col("c1") + .not_eq(lit(5)) + .or(in_list( + col("c1"), + vec![lit(6), lit(7), lit(8), lit(9)], + true, + ))); + // TODO: Further simplify this expression + // assert_eq!(simplify(expr.clone()), lit(true)); assert_eq!(simplify(expr.clone()), expr); } diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs new file mode 100644 index 000000000000..fa95f1688e6f --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -0,0 +1,136 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module implements a rule that simplifies the values for `InList`s + +use std::collections::HashSet; + +use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::Result; +use datafusion_expr::expr::InList; +use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; + +/// Simplify expressions that is guaranteed to be true or false to a literal boolean expression +/// +/// Rules: +/// If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists +/// Intersection: +/// 1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false` +/// 2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)` +/// 3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)` +/// Union: +/// 4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)` +/// # This rule is handled by `or_in_list_simplifier.rs` +/// 5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)` +/// If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression +/// 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` +/// 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` +/// 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` +pub(super) struct InListSimplifier {} + +impl InListSimplifier { + pub(super) fn new() -> Self { + Self {} + } +} + +impl TreeNodeRewriter for InListSimplifier { + type N = Expr; + + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { + if let (Expr::InList(l1), Operator::And, Expr::InList(l2)) = + (left.as_ref(), op, right.as_ref()) + { + if l1.expr == l2.expr && !l1.negated && !l2.negated { + return inlist_intersection(l1, l2, false); + } else if l1.expr == l2.expr && l1.negated && l2.negated { + return inlist_union(l1, l2, true); + } else if l1.expr == l2.expr && !l1.negated && l2.negated { + return inlist_except(l1, l2); + } else if l1.expr == l2.expr && l1.negated && !l2.negated { + return inlist_except(l2, l1); + } + } else if let (Expr::InList(l1), Operator::Or, Expr::InList(l2)) = + (left.as_ref(), op, right.as_ref()) + { + if l1.expr == l2.expr && l1.negated && l2.negated { + return inlist_intersection(l1, l2, true); + } + } + } + + Ok(expr) + } +} + +fn inlist_union(l1: &InList, l2: &InList, negated: bool) -> Result { + let mut seen: HashSet = HashSet::new(); + let list = l1 + .list + .iter() + .chain(l2.list.iter()) + .filter(|&e| seen.insert(e.to_owned())) + .cloned() + .collect::>(); + let merged_inlist = InList { + expr: l1.expr.clone(), + list, + negated, + }; + Ok(Expr::InList(merged_inlist)) +} + +fn inlist_intersection(l1: &InList, l2: &InList, negated: bool) -> Result { + let l1_set: HashSet = l1.list.iter().cloned().collect(); + let intersect_list: Vec = l2 + .list + .iter() + .filter(|x| l1_set.contains(x)) + .cloned() + .collect(); + // e in () is always false + // e not in () is always true + if intersect_list.is_empty() { + return Ok(lit(negated)); + } + let merged_inlist = InList { + expr: l1.expr.clone(), + list: intersect_list, + negated, + }; + Ok(Expr::InList(merged_inlist)) +} + +fn inlist_except(l1: &InList, l2: &InList) -> Result { + let l2_set: HashSet = l2.list.iter().cloned().collect(); + let except_list: Vec = l1 + .list + .iter() + .filter(|x| !l2_set.contains(x)) + .cloned() + .collect(); + if except_list.is_empty() { + return Ok(lit(false)); + } + let merged_inlist = InList { + expr: l1.expr.clone(), + list: except_list, + negated: false, + }; + Ok(Expr::InList(merged_inlist)) +} diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index 2cf6ed166cdd..44ba5b3e3b84 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -18,6 +18,7 @@ pub mod context; pub mod expr_simplifier; mod guarantees; +mod inlist_simplifier; mod or_in_list_simplifier; mod regex; pub mod simplify_exprs; diff --git a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs index cebaaccc41c7..fd5c9ecaf82c 100644 --- a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs @@ -18,6 +18,7 @@ //! This module implements a rule that simplifies OR expressions into IN list expressions use std::borrow::Cow; +use std::collections::HashSet; use datafusion_common::tree_node::TreeNodeRewriter; use datafusion_common::Result; @@ -52,9 +53,14 @@ impl TreeNodeRewriter for OrInListSimplifier { { let lhs = lhs.into_owned(); let rhs = rhs.into_owned(); - let mut list = vec![]; - list.extend(lhs.list); - list.extend(rhs.list); + let mut seen: HashSet = HashSet::new(); + let list = lhs + .list + .into_iter() + .chain(rhs.list) + .filter(|e| seen.insert(e.to_owned())) + .collect::>(); + let merged_inlist = InList { expr: lhs.expr, list, diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index e32e415338a7..b5347f997a5a 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -725,3 +725,40 @@ AggregateExec: mode=SinglePartitioned, gby=[p_partkey@2 as p_partkey], aggr=[SUM --------CoalesceBatchesExec: target_batch_size=8192 ----------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] + +# Inlist simplification + +statement ok +create table t(x int) as values (1), (2), (3); + +query TT +explain select x from t where x IN (1,2,3) AND x IN (4,5); +---- +logical_plan EmptyRelation +physical_plan EmptyExec + +query TT +explain select x from t where x NOT IN (1,2,3,4) OR x NOT IN (5,6,7,8); +---- +logical_plan TableScan: t projection=[x] +physical_plan MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select x from t where x IN (1,2,3,4,5) AND x NOT IN (1,2,3,4); +---- +logical_plan +Filter: t.x = Int32(5) +--TableScan: t projection=[x] +physical_plan +CoalesceBatchesExec: target_batch_size=8192 +--FilterExec: x@0 = 5 +----MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select x from t where x NOT IN (1,2,3,4,5) AND x IN (1,2,3); +---- +logical_plan EmptyRelation +physical_plan EmptyExec + +statement ok +drop table t;