Skip to content

Commit

Permalink
Improve UserDefinedLogicalNode::from_template API to return Result (
Browse files Browse the repository at this point in the history
apache#10575)

* UserDefinedLogicalNode::from_template return Result

* Rename from_template to with_exprs_and_inputs

* Resolve review comments
  • Loading branch information
lewiszlw authored and findepi committed Jul 16, 2024
1 parent 7f71e4c commit 77ba1d8
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 47 deletions.
44 changes: 24 additions & 20 deletions datafusion/expr/src/logical_plan/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! This module defines the interface for logical nodes
use crate::{Expr, LogicalPlan};
use datafusion_common::{DFSchema, DFSchemaRef};
use datafusion_common::{DFSchema, DFSchemaRef, Result};
use std::hash::{Hash, Hasher};
use std::{any::Any, collections::HashSet, fmt, sync::Arc};

Expand Down Expand Up @@ -76,27 +76,31 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
/// For example: `TopK: k=10`
fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;

/// Create a new `ExtensionPlanNode` with the specified children
#[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")]
#[allow(clippy::wrong_self_convention)]
fn from_template(
&self,
exprs: &[Expr],
inputs: &[LogicalPlan],
) -> Arc<dyn UserDefinedLogicalNode> {
self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec())
.unwrap()
}

/// Create a new `UserDefinedLogicalNode` with the specified children
/// and expressions. This function is used during optimization
/// when the plan is being rewritten and a new instance of the
/// `ExtensionPlanNode` must be created.
/// `UserDefinedLogicalNode` must be created.
///
/// Note that exprs and inputs are in the same order as the result
/// of self.inputs and self.exprs.
///
/// So, `self.from_template(exprs, ..).expressions() == exprs
//
// TODO(clippy): This should probably be renamed to use a `with_*` prefix. Something
// like `with_template`, or `with_exprs_and_inputs`.
//
// Also, I think `ExtensionPlanNode` has been renamed to `UserDefinedLogicalNode`
// but the doc comments have not been updated.
#[allow(clippy::wrong_self_convention)]
fn from_template(
/// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
fn with_exprs_and_inputs(
&self,
exprs: &[Expr],
inputs: &[LogicalPlan],
) -> Arc<dyn UserDefinedLogicalNode>;
exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Arc<dyn UserDefinedLogicalNode>>;

/// Returns the necessary input columns for this node required to compute
/// the columns in the output schema
Expand Down Expand Up @@ -312,12 +316,12 @@ impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
self.fmt_for_explain(f)
}

fn from_template(
fn with_exprs_and_inputs(
&self,
exprs: &[Expr],
inputs: &[LogicalPlan],
) -> Arc<dyn UserDefinedLogicalNode> {
Arc::new(self.from_template(exprs, inputs))
exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
Ok(Arc::new(self.from_template(&exprs, &inputs)))
}

fn necessary_children_exprs(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ impl LogicalPlan {
let expr = node.expressions();
let inputs: Vec<_> = node.inputs().into_iter().cloned().collect();
Ok(LogicalPlan::Extension(Extension {
node: node.from_template(&expr, &inputs),
node: node.with_exprs_and_inputs(expr, inputs)?,
}))
}
LogicalPlan::Union(Union { inputs, schema }) => {
Expand Down Expand Up @@ -923,7 +923,7 @@ impl LogicalPlan {
definition: definition.clone(),
}))),
LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension {
node: e.node.from_template(&expr, &inputs),
node: e.node.with_exprs_and_inputs(expr, inputs)?,
})),
LogicalPlan::Union(Union { schema, .. }) => {
let input_schema = inputs[0].schema();
Expand Down
28 changes: 12 additions & 16 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ where
.map_data(|new_inputs| {
let exprs = node.expressions();
Ok(Extension {
node: node.from_template(&exprs, &new_inputs),
node: node.with_exprs_and_inputs(exprs, new_inputs)?,
})
})
}
Expand Down Expand Up @@ -658,22 +658,18 @@ impl LogicalPlan {
LogicalPlan::Extension(Extension { node }) => {
// would be nice to avoid this copy -- maybe can
// update extension to just observer Exprs
node.expressions()
let exprs = node
.expressions()
.into_iter()
.map_until_stop_and_collect(f)?
.update_data(|exprs| {
LogicalPlan::Extension(Extension {
node: UserDefinedLogicalNode::from_template(
node.as_ref(),
exprs.as_slice(),
node.inputs()
.into_iter()
.cloned()
.collect::<Vec<_>>()
.as_slice(),
),
})
})
.map_until_stop_and_collect(f)?;
let plan = LogicalPlan::Extension(Extension {
node: UserDefinedLogicalNode::with_exprs_and_inputs(
node.as_ref(),
exprs.data,
node.inputs().into_iter().cloned().collect::<Vec<_>>(),
)?,
});
Transformed::new(plan, exprs.transformed, exprs.tnr)
}
LogicalPlan::TableScan(TableScan {
table_name,
Expand Down
5 changes: 3 additions & 2 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,8 @@ pub async fn from_substrait_rel(
);
};
let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?;
let plan = plan.from_template(&plan.expressions(), &[input_plan]);
let plan =
plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
}
Some(RelType::ExtensionMulti(extension)) => {
Expand All @@ -567,7 +568,7 @@ pub async fn from_substrait_rel(
let input_plan = from_substrait_rel(ctx, input, extensions).await?;
inputs.push(input_plan);
}
let plan = plan.from_template(&plan.expressions(), &inputs);
let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
}
Some(RelType::Exchange(exchange)) => {
Expand Down
14 changes: 7 additions & 7 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,16 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan {
)
}

fn from_template(
fn with_exprs_and_inputs(
&self,
_: &[Expr],
inputs: &[LogicalPlan],
) -> Arc<dyn UserDefinedLogicalNode> {
Arc::new(Self {
_: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
Ok(Arc::new(Self {
validation_bytes: self.validation_bytes.clone(),
inputs: inputs.to_vec(),
inputs,
empty_schema: Arc::new(DFSchema::empty()),
})
}))
}

fn dyn_hash(&self, _: &mut dyn std::hash::Hasher) {
Expand Down

0 comments on commit 77ba1d8

Please sign in to comment.