From ed7c884d64b2750529ec6b256cdbab582341de07 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Tue, 25 Jun 2024 05:33:47 +0900 Subject: [PATCH] Strip table qualifiers from schema in `UNION ALL` for unparser (#11082) * Add test cases for issue * Remove test from logical_plan/builder.rs * Remove table qualifiers for sorts following a Union * Use transform_up and add more documentation on why this is needed --- datafusion/sql/src/unparser/mod.rs | 1 + datafusion/sql/src/unparser/plan.rs | 7 +- datafusion/sql/src/unparser/rewrite.rs | 101 ++++++++++++++++++++++ datafusion/sql/tests/cases/plan_to_sql.rs | 14 ++- 4 files changed, 120 insertions(+), 3 deletions(-) create mode 100644 datafusion/sql/src/unparser/rewrite.rs diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index fb0285901c3f..fbbed4972b17 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -18,6 +18,7 @@ mod ast; mod expr; mod plan; +mod rewrite; mod utils; pub use expr::expr_to_sql; diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index a4a457f51dc9..15137403c582 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -28,6 +28,7 @@ use super::{ BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder, }, + rewrite::normalize_union_schema, utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant}, Unparser, }; @@ -63,6 +64,8 @@ pub fn plan_to_sql(plan: &LogicalPlan) -> Result { impl Unparser<'_> { pub fn plan_to_sql(&self, plan: &LogicalPlan) -> Result { + let plan = normalize_union_schema(plan)?; + match plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) @@ -80,8 +83,8 @@ impl Unparser<'_> { | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) - | LogicalPlan::Distinct(_) => self.select_to_sql_statement(plan), - LogicalPlan::Dml(_) => self.dml_to_sql(plan), + | LogicalPlan::Distinct(_) => self.select_to_sql_statement(&plan), + LogicalPlan::Dml(_) => self.dml_to_sql(&plan), LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) | LogicalPlan::Extension(_) diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs new file mode 100644 index 000000000000..a73fce30ced3 --- /dev/null +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_common::{ + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator}, + Result, +}; +use datafusion_expr::{Expr, LogicalPlan, Sort}; + +/// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. +/// +/// DataFusion will return an error if two columns in the schema have the same name with no table qualifiers. +/// There are certain types of UNION queries that can result in having two columns with the same name, and the +/// solution was to add table qualifiers to the schema fields. +/// See for more context on this decision. +/// +/// However, this causes a problem when unparsing these queries back to SQL - as the table qualifier has +/// logically been erased and is no longer a valid reference. +/// +/// The following input SQL: +/// ```sql +/// SELECT table1.foo FROM table1 +/// UNION ALL +/// SELECT table2.foo FROM table2 +/// ORDER BY foo +/// ``` +/// +/// Would be unparsed into the following invalid SQL without this transformation: +/// ```sql +/// SELECT table1.foo FROM table1 +/// UNION ALL +/// SELECT table2.foo FROM table2 +/// ORDER BY table1.foo +/// ``` +/// +/// Which would result in a SQL error, as `table1.foo` is not a valid reference in the context of the UNION. +pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result { + let plan = plan.clone(); + + let transformed_plan = plan.transform_up(|plan| match plan { + LogicalPlan::Union(mut union) => { + let schema = match Arc::try_unwrap(union.schema) { + Ok(inner) => inner, + Err(schema) => (*schema).clone(), + }; + let schema = schema.strip_qualifiers(); + + union.schema = Arc::new(schema); + Ok(Transformed::yes(LogicalPlan::Union(union))) + } + LogicalPlan::Sort(sort) => { + // Only rewrite Sort expressions that have a UNION as their input + if !matches!(&*sort.input, LogicalPlan::Union(_)) { + return Ok(Transformed::no(LogicalPlan::Sort(sort))); + } + + Ok(Transformed::yes(LogicalPlan::Sort(Sort { + expr: rewrite_sort_expr_for_union(sort.expr)?, + input: sort.input, + fetch: sort.fetch, + }))) + } + _ => Ok(Transformed::no(plan)), + }); + transformed_plan.data() +} + +/// Rewrite sort expressions that have a UNION plan as their input to remove the table reference. +fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { + let sort_exprs: Vec = exprs + .into_iter() + .map_until_stop_and_collect(|expr| { + expr.transform_up(|expr| { + if let Expr::Column(mut col) = expr { + col.relation = None; + Ok(Transformed::yes(Expr::Column(col))) + } else { + Ok(Transformed::no(expr)) + } + }) + }) + .data()?; + + Ok(sort_exprs) +} diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 33e28e7056b9..374403d853f9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -230,6 +230,16 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), }, + TestStatementWithDialect { + sql: "SELECT j1_id FROM j1 + UNION ALL + SELECT tb.j2_id as j1_id FROM j2 tb + ORDER BY j1_id + LIMIT 10;", + expected: r#"SELECT j1.j1_id FROM j1 UNION ALL SELECT tb.j2_id AS j1_id FROM j2 AS tb ORDER BY j1_id ASC NULLS LAST LIMIT 10"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, ]; for query in tests { @@ -239,7 +249,9 @@ fn roundtrip_statement_with_dialect() -> Result<()> { let context = MockContextProvider::default(); let sql_to_rel = SqlToRel::new(&context); - let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + let plan = sql_to_rel + .sql_statement_to_plan(statement) + .unwrap_or_else(|e| panic!("Failed to parse sql: {}\n{e}", query.sql)); let unparser = Unparser::new(&*query.unparser_dialect); let roundtrip_statement = unparser.plan_to_sql(&plan)?;