Skip to content

Commit

Permalink
feat(substrait): set ProjectRel output_mapping in producer (#12495)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarua authored Sep 18, 2024
1 parent c763fda commit ec10c04
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 26 deletions.
79 changes: 53 additions & 26 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ use substrait::proto::expression::literal::{
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::read_rel::VirtualTable;
use substrait::proto::{CrossRel, ExchangeRel};
use substrait::proto::rel_common::EmitKind;
use substrait::proto::rel_common::EmitKind::Emit;
use substrait::proto::{rel_common, CrossRel, ExchangeRel, RelCommon};
use substrait::{
proto::{
aggregate_function::AggregationInvocation,
Expand Down Expand Up @@ -219,9 +221,20 @@ pub fn to_substrait_rel(
.iter()
.map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions))
.collect::<Result<Vec<_>>>()?;

let emit_kind = create_project_remapping(
expressions.len(),
p.input.as_ref().schema().fields().len(),
);
let common = RelCommon {
emit_kind: Some(emit_kind),
hint: None,
advanced_extension: None,
};

Ok(Box::new(Rel {
rel_type: Some(RelType::Project(Box::new(ProjectRel {
common: None,
common: Some(common),
input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?),
expressions,
advanced_extension: None,
Expand Down Expand Up @@ -432,38 +445,39 @@ pub fn to_substrait_rel(
}
LogicalPlan::Window(window) => {
let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?;
// If the input is a Project relation, we can just append the WindowFunction expressions
// before returning
// Otherwise, wrap the input in a Project relation before appending the WindowFunction
// expressions
let mut project_rel: Box<ProjectRel> = match &input.as_ref().rel_type {
Some(RelType::Project(p)) => Box::new(*p.clone()),
_ => {
// Create Projection with field referencing all output fields in the input relation
let expressions = (0..window.input.schema().fields().len())
.map(substrait_field_ref)
.collect::<Result<Vec<_>>>()?;
Box::new(ProjectRel {
common: None,
input: Some(input),
expressions,
advanced_extension: None,
})
}
};
// Parse WindowFunction expression
let mut window_exprs = vec![];

// create a field reference for each input field
let mut expressions = (0..window.input.schema().fields().len())
.map(substrait_field_ref)
.collect::<Result<Vec<_>>>()?;

// process and add each window function expression
for expr in &window.window_expr {
window_exprs.push(to_substrait_rex(
expressions.push(to_substrait_rex(
ctx,
expr,
window.input.schema(),
0,
extensions,
)?);
}
// Append parsed WindowFunction expressions
project_rel.expressions.extend(window_exprs);

let emit_kind = create_project_remapping(
expressions.len(),
window.input.schema().fields().len(),
);
let common = RelCommon {
emit_kind: Some(emit_kind),
hint: None,
advanced_extension: None,
};
let project_rel = Box::new(ProjectRel {
common: Some(common),
input: Some(input),
expressions,
advanced_extension: None,
});

Ok(Box::new(Rel {
rel_type: Some(RelType::Project(project_rel)),
}))
Expand Down Expand Up @@ -553,6 +567,19 @@ pub fn to_substrait_rel(
}
}

/// By default, a Substrait Project outputs all input fields followed by all expressions.
/// A DataFusion Projection only outputs expressions. In order to keep the Substrait
/// plan consistent with DataFusion, we must apply an output mapping that skips the input
/// fields so that the Substrait Project will only output the expression fields.
fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind {
let expression_field_start = input_field_count;
let expression_field_end = expression_field_start + expr_count;
let output_mapping = (expression_field_start..expression_field_end)
.map(|i| i as i32)
.collect();
Emit(rel_common::Emit { output_mapping })
}

fn to_substrait_named_struct(
schema: &DFSchemaRef,
extensions: &mut Extensions,
Expand Down
101 changes: 101 additions & 0 deletions datafusion/substrait/tests/cases/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ mod tests {
use datafusion::error::Result;
use datafusion::prelude::*;

use datafusion_substrait::logical_plan::producer::to_substrait_plan;
use std::fs;
use substrait::proto::plan_rel::RelType;
use substrait::proto::rel_common::{Emit, EmitKind};
use substrait::proto::{rel, RelCommon};

#[tokio::test]
async fn serialize_simple_select() -> Result<()> {
Expand Down Expand Up @@ -63,6 +67,103 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn include_remaps_for_projects() -> Result<()> {
let ctx = create_context().await?;
let df = ctx.sql("SELECT b, a + a, a FROM data").await?;
let datafusion_plan = df.into_optimized_plan()?;

assert_eq!(
format!("{}", datafusion_plan),
"Projection: data.b, data.a + data.a, data.a\
\n TableScan: data projection=[a, b]",
);

let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();

let relation = plan.relations.first().unwrap().rel_type.as_ref();
let root_rel = match relation {
Some(RelType::Root(root)) => root.input.as_ref().unwrap(),
_ => panic!("expected Root"),
};
if let Some(rel::RelType::Project(p)) = root_rel.rel_type.as_ref() {
// The input has 2 columns [a, b], the Projection has 3 expressions [b, a + a, a]
// The required output mapping is [2,3,4], which skips the 2 input columns.
assert_emit(p.common.as_ref(), vec![2, 3, 4]);

if let Some(rel::RelType::Read(r)) =
p.input.as_ref().unwrap().rel_type.as_ref()
{
let mask_expression = r.projection.as_ref().unwrap();
let select = mask_expression.select.as_ref().unwrap();
assert_eq!(
2,
select.struct_items.len(),
"Read outputs two columns: a, b"
);
return Ok(());
}
}
panic!("plan did not match expected structure")
}

#[tokio::test]
async fn include_remaps_for_windows() -> Result<()> {
let ctx = create_context().await?;
// let df = ctx.sql("SELECT a, b, lead(b) OVER (PARTITION BY a) FROM data").await?;
let df = ctx
.sql("SELECT b, RANK() OVER (PARTITION BY a), c FROM data;")
.await?;
let datafusion_plan = df.into_optimized_plan()?;
assert_eq!(
format!("{}", datafusion_plan),
"Projection: data.b, RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\
\n WindowAggr: windowExpr=[[RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: data projection=[a, b, c]",
);

let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();

let relation = plan.relations.first().unwrap().rel_type.as_ref();
let root_rel = match relation {
Some(RelType::Root(root)) => root.input.as_ref().unwrap(),
_ => panic!("expected Root"),
};

if let Some(rel::RelType::Project(p1)) = root_rel.rel_type.as_ref() {
// The WindowAggr outputs 4 columns, the Projection has 4 columns
assert_emit(p1.common.as_ref(), vec![4, 5, 6]);

if let Some(rel::RelType::Project(p2)) =
p1.input.as_ref().unwrap().rel_type.as_ref()
{
// The input has 3 columns, the WindowAggr has 4 expression
assert_emit(p2.common.as_ref(), vec![3, 4, 5, 6]);

if let Some(rel::RelType::Read(r)) =
p2.input.as_ref().unwrap().rel_type.as_ref()
{
let mask_expression = r.projection.as_ref().unwrap();
let select = mask_expression.select.as_ref().unwrap();
assert_eq!(
3,
select.struct_items.len(),
"Read outputs three columns: a, b, c"
);
return Ok(());
}
}
}
panic!("plan did not match expected structure")
}

fn assert_emit(rel_common: Option<&RelCommon>, output_mapping: Vec<i32>) {
assert_eq!(
rel_common.unwrap().emit_kind.clone(),
Some(EmitKind::Emit(Emit { output_mapping }))
);
}

async fn create_context() -> Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new())
Expand Down

0 comments on commit ec10c04

Please sign in to comment.