From e086a99c1ee7e0e14048f3d46e1f5e75723fc593 Mon Sep 17 00:00:00 2001 From: ritchie Date: Thu, 3 Oct 2024 11:01:16 +0200 Subject: [PATCH] fix: Only rewrite numeric ineq joins --- .../polars-plan/src/plans/conversion/join.rs | 21 +++++++++++++-- .../src/plans/optimizer/collapse_joins.rs | 26 ++++++++++++++++--- .../unit/operations/test_inequality_join.py | 19 ++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index b1577da41a33..60f7fc20f57e 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -316,8 +316,25 @@ fn resolve_join_where( )?; if let Some(ie_op_) = to_inequality_operator(&op) { - // We already have an IEjoin or an Inner join, push to remaining - if ie_op.len() >= 2 || !eq_right_on.is_empty() { + fn is_numeric(e: &Expr, schema: &Schema) -> bool { + expr_to_leaf_column_names_iter(e).any(|name| { + if let Some(dt) = schema.get(name.as_str()) { + dt.to_physical().is_numeric() + } else { + false + } + }) + } + + // We fallback to remaining if: + // - we already have an IEjoin or Inner join + // - we already have an Inner join + // - data is not numeric (our iejoin doesn't yet implement that) + if ie_op.len() >= 2 + || !eq_right_on.is_empty() + || !is_numeric(&left, &schema_left) + || !is_numeric(&right, &schema_right) + { remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix)) } else { ie_left_on.push(left); diff --git a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs index 5656734969fe..608c7122f2ec 100644 --- a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs +++ b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs @@ -5,7 +5,7 @@ use std::sync::Arc; -use polars_core::schema::SchemaRef; +use polars_core::schema::*; #[cfg(feature = "iejoin")] use polars_ops::frame::{IEJoinOptions, InequalityOperator}; use polars_ops::frame::{JoinCoalesce, JoinType}; @@ -304,8 +304,28 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &mut Arena= 2 || !eq_left_on.is_empty() { + fn is_numeric( + node: Node, + expr_arena: &Arena, + schema: &Schema, + ) -> bool { + aexpr_to_leaf_names_iter(node, expr_arena).any(|name| { + if let Some(dt) = schema.get(name.as_str()) { + dt.to_physical().is_numeric() + } else { + false + } + }) + } + + // We fallback to remaining if: + // - we already have an IEjoin or Inner join + // - we already have an Inner join + // - data is not numeric (our iejoin doesn't yet implement that) + if ie_op.len() >= 2 + || !eq_left_on.is_empty() + || !is_numeric(left, expr_arena, left_schema) + { remaining_predicates.push(node); } else { ie_left_on.push(ExprIR::from_node(left, expr_arena)); diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 18a24c1d8a7d..872361197a8d 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -575,3 +575,22 @@ def test_raise_invalid_predicate() -> None: with pytest.raises(pl.exceptions.InvalidOperationError): left.join_where(right, pl.col.index >= pl.col.a).collect() + + +def test_join_on_strings() -> None: + df = pl.LazyFrame( + { + "a": ["a", "b", "c"], + "b": ["b", "b", "b"], + } + ) + + q = df.join_where(df, pl.col("a").ge(pl.col("a_right"))) + + assert "CROSS JOIN" in q.explain() + assert q.collect().to_dict(as_series=False) == { + "a": ["a", "b", "b", "c", "c", "c"], + "b": ["b", "b", "b", "b", "b", "b"], + "a_right": ["a", "a", "b", "a", "b", "c"], + "b_right": ["b", "b", "b", "b", "b", "b"], + }