Skip to content

Commit

Permalink
Strip table qualifiers from schema in UNION ALL for unparser (#11082)
Browse files Browse the repository at this point in the history
* Add test cases for issue

* Remove test from logical_plan/builder.rs

* Remove table qualifiers for sorts following a Union

* Use transform_up and add more documentation on why this is needed
  • Loading branch information
phillipleblanc authored Jun 24, 2024
1 parent 528c4ab commit ed7c884
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 3 deletions.
1 change: 1 addition & 0 deletions datafusion/sql/src/unparser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
mod ast;
mod expr;
mod plan;
mod rewrite;
mod utils;

pub use expr::expr_to_sql;
Expand Down
7 changes: 5 additions & 2 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use super::{
BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder,
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
rewrite::normalize_union_schema,
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
Unparser,
};
Expand Down Expand Up @@ -63,6 +64,8 @@ pub fn plan_to_sql(plan: &LogicalPlan) -> Result<ast::Statement> {

impl Unparser<'_> {
pub fn plan_to_sql(&self, plan: &LogicalPlan) -> Result<ast::Statement> {
let plan = normalize_union_schema(plan)?;

match plan {
LogicalPlan::Projection(_)
| LogicalPlan::Filter(_)
Expand All @@ -80,8 +83,8 @@ impl Unparser<'_> {
| LogicalPlan::Limit(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Values(_)
| LogicalPlan::Distinct(_) => self.select_to_sql_statement(plan),
LogicalPlan::Dml(_) => self.dml_to_sql(plan),
| LogicalPlan::Distinct(_) => self.select_to_sql_statement(&plan),
LogicalPlan::Dml(_) => self.dml_to_sql(&plan),
LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Extension(_)
Expand Down
101 changes: 101 additions & 0 deletions datafusion/sql/src/unparser/rewrite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeIterator},
Result,
};
use datafusion_expr::{Expr, LogicalPlan, Sort};

/// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions.
///
/// DataFusion will return an error if two columns in the schema have the same name with no table qualifiers.
/// There are certain types of UNION queries that can result in having two columns with the same name, and the
/// solution was to add table qualifiers to the schema fields.
/// See <https://github.com/apache/datafusion/issues/5410> for more context on this decision.
///
/// However, this causes a problem when unparsing these queries back to SQL - as the table qualifier has
/// logically been erased and is no longer a valid reference.
///
/// The following input SQL:
/// ```sql
/// SELECT table1.foo FROM table1
/// UNION ALL
/// SELECT table2.foo FROM table2
/// ORDER BY foo
/// ```
///
/// Would be unparsed into the following invalid SQL without this transformation:
/// ```sql
/// SELECT table1.foo FROM table1
/// UNION ALL
/// SELECT table2.foo FROM table2
/// ORDER BY table1.foo
/// ```
///
/// Which would result in a SQL error, as `table1.foo` is not a valid reference in the context of the UNION.
pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result<LogicalPlan> {
let plan = plan.clone();

let transformed_plan = plan.transform_up(|plan| match plan {
LogicalPlan::Union(mut union) => {
let schema = match Arc::try_unwrap(union.schema) {
Ok(inner) => inner,
Err(schema) => (*schema).clone(),
};
let schema = schema.strip_qualifiers();

union.schema = Arc::new(schema);
Ok(Transformed::yes(LogicalPlan::Union(union)))
}
LogicalPlan::Sort(sort) => {
// Only rewrite Sort expressions that have a UNION as their input
if !matches!(&*sort.input, LogicalPlan::Union(_)) {
return Ok(Transformed::no(LogicalPlan::Sort(sort)));
}

Ok(Transformed::yes(LogicalPlan::Sort(Sort {
expr: rewrite_sort_expr_for_union(sort.expr)?,
input: sort.input,
fetch: sort.fetch,
})))
}
_ => Ok(Transformed::no(plan)),
});
transformed_plan.data()
}

/// Rewrite sort expressions that have a UNION plan as their input to remove the table reference.
fn rewrite_sort_expr_for_union(exprs: Vec<Expr>) -> Result<Vec<Expr>> {
let sort_exprs: Vec<Expr> = exprs
.into_iter()
.map_until_stop_and_collect(|expr| {
expr.transform_up(|expr| {
if let Expr::Column(mut col) = expr {
col.relation = None;
Ok(Transformed::yes(Expr::Column(col)))
} else {
Ok(Transformed::no(expr))
}
})
})
.data()?;

Ok(sort_exprs)
}
14 changes: 13 additions & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT j1_id FROM j1
UNION ALL
SELECT tb.j2_id as j1_id FROM j2 tb
ORDER BY j1_id
LIMIT 10;",
expected: r#"SELECT j1.j1_id FROM j1 UNION ALL SELECT tb.j2_id AS j1_id FROM j2 AS tb ORDER BY j1_id ASC NULLS LAST LIMIT 10"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
];

for query in tests {
Expand All @@ -239,7 +249,9 @@ fn roundtrip_statement_with_dialect() -> Result<()> {

let context = MockContextProvider::default();
let sql_to_rel = SqlToRel::new(&context);
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
let plan = sql_to_rel
.sql_statement_to_plan(statement)
.unwrap_or_else(|e| panic!("Failed to parse sql: {}\n{e}", query.sql));

let unparser = Unparser::new(&*query.unparser_dialect);
let roundtrip_statement = unparser.plan_to_sql(&plan)?;
Expand Down

0 comments on commit ed7c884

Please sign in to comment.