Skip to content

Commit 12d4a8c

Browse files
committed
Implement TreeNode::map_children in place
1 parent bd355a3 commit 12d4a8c

File tree

5 files changed

+265
-18
lines changed

5 files changed

+265
-18
lines changed

datafusion/common/src/tree_node.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,15 @@ impl<T> Transformed<T> {
534534
TreeNodeRecursion::Jump | TreeNodeRecursion::Stop => Ok(self),
535535
}
536536
}
537+
538+
/// Discards the data of this [`Transformed`] object transforming it into Transformed<()>
539+
pub fn discard_data(self) -> Transformed<()> {
540+
Transformed {
541+
data: (),
542+
transformed: self.transformed,
543+
tnr: self.tnr,
544+
}
545+
}
537546
}
538547

539548
/// Transformation helper to process a sequence of iterable tree nodes that are siblings.

datafusion/expr/src/logical_plan/ddl.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,24 @@ impl DdlStatement {
110110
}
111111
}
112112

113+
/// Return a mutable reference to the input `LogicalPlan`, if any
114+
pub fn input_mut(&mut self) -> Option<&mut Arc<LogicalPlan>> {
115+
match self {
116+
DdlStatement::CreateMemoryTable(CreateMemoryTable { input, .. }) => {
117+
Some(input)
118+
}
119+
DdlStatement::CreateExternalTable(_) => None,
120+
DdlStatement::CreateView(CreateView { input, .. }) => Some(input),
121+
DdlStatement::CreateCatalogSchema(_) => None,
122+
DdlStatement::CreateCatalog(_) => None,
123+
DdlStatement::DropTable(_) => None,
124+
DdlStatement::DropView(_) => None,
125+
DdlStatement::DropCatalogSchema(_) => None,
126+
DdlStatement::CreateFunction(_) => None,
127+
DdlStatement::DropFunction(_) => None,
128+
}
129+
}
130+
113131
/// Return a `format`able structure with the a human readable
114132
/// description of this LogicalPlan node per node, not including
115133
/// children.

datafusion/expr/src/logical_plan/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub mod display;
2121
pub mod dml;
2222
mod extension;
2323
mod plan;
24+
mod rewrite;
2425
mod statement;
2526

2627
pub use builder::{
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Methods for rewriting logical plans
19+
20+
use crate::{
21+
Aggregate, CrossJoin, Distinct, DistinctOn, EmptyRelation, Filter, Join, Limit,
22+
LogicalPlan, Prepare, Projection, RecursiveQuery, Repartition, Sort, Subquery,
23+
SubqueryAlias, Union, Unnest, UserDefinedLogicalNode, Window,
24+
};
25+
use datafusion_common::tree_node::{Transformed, TreeNodeIterator};
26+
use datafusion_common::{DFSchema, DFSchemaRef, Result};
27+
use std::sync::{Arc, OnceLock};
28+
29+
/// A temporary node that is left in place while rewriting the children of a
30+
/// [`LogicalPlan`]. This is necessary to ensure that the `LogicalPlan` is
31+
/// always in a valid state (from the Rust perspective)
32+
static PLACEHOLDER: OnceLock<Arc<LogicalPlan>> = OnceLock::new();
33+
34+
/// its inputs, so this code would not be needed. However, for now we try and
35+
/// unwrap the `Arc` which avoids `clone`ing in most cases.
36+
///
37+
/// On error, node be left with a placeholder logical plan
38+
fn rewrite_arc<F>(
39+
node: &mut Arc<LogicalPlan>,
40+
mut f: F,
41+
) -> datafusion_common::Result<Transformed<&mut Arc<LogicalPlan>>>
42+
where
43+
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
44+
{
45+
// We need to leave a valid node in the Arc, while we rewrite the existing
46+
// one, so use a single global static placeholder node
47+
let mut new_node = PLACEHOLDER
48+
.get_or_init(|| {
49+
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
50+
produce_one_row: false,
51+
schema: DFSchemaRef::new(DFSchema::empty()),
52+
}))
53+
})
54+
.clone();
55+
56+
// take the old value out of the Arc
57+
std::mem::swap(node, &mut new_node);
58+
59+
// try to update existing node, if it isn't shared with others
60+
let new_node = Arc::try_unwrap(new_node)
61+
// if None is returned, there is another reference to this
62+
// LogicalPlan, so we must clone instead
63+
.unwrap_or_else(|node| node.as_ref().clone());
64+
65+
// apply the actual transform
66+
let result = f(new_node)?;
67+
68+
// put the new value back into the Arc
69+
let mut new_node = Arc::new(result.data);
70+
std::mem::swap(node, &mut new_node);
71+
72+
// return the `node` back
73+
Ok(Transformed::new(node, result.transformed, result.tnr))
74+
}
75+
76+
/// Rewrite the arc and discard the contents of Transformed
77+
fn rewrite_arc_no_data<F>(
78+
node: &mut Arc<LogicalPlan>,
79+
f: F,
80+
) -> datafusion_common::Result<Transformed<()>>
81+
where
82+
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
83+
{
84+
rewrite_arc(node, f).map(|res| res.discard_data())
85+
}
86+
87+
/// Rewrites all inputs for an Extension node "in place"
88+
/// (it currently has to copy values because there are no APIs for in place modification)
89+
///
90+
/// Should be removed when we have an API for in place modifications of the
91+
/// extension to avoid these copies
92+
fn rewrite_extension_inputs<F>(
93+
node: &mut Arc<dyn UserDefinedLogicalNode>,
94+
f: F,
95+
) -> datafusion_common::Result<Transformed<()>>
96+
where
97+
F: FnMut(LogicalPlan) -> Result<Transformed<LogicalPlan>>,
98+
{
99+
let Transformed {
100+
data: new_inputs,
101+
transformed,
102+
tnr,
103+
} = node
104+
.inputs()
105+
.into_iter()
106+
.cloned()
107+
.map_until_stop_and_collect(f)?;
108+
109+
let exprs = node.expressions();
110+
let mut new_node = node.from_template(&exprs, &new_inputs);
111+
std::mem::swap(node, &mut new_node);
112+
Ok(Transformed {
113+
data: (),
114+
transformed,
115+
tnr,
116+
})
117+
}
118+
119+
impl LogicalPlan {
120+
/// Applies `f` to each child (input) of this plan node, rewriting them *in place.*
121+
///
122+
/// Note that this function returns `Transformed<()>` because it does not
123+
/// consume `self`, but instead modifies it in place. However, `F` transforms
124+
/// the children by ownership
125+
///
126+
/// # Notes
127+
///
128+
/// Inputs include ONLY direct children, not embedded subquery
129+
/// `LogicalPlan`s, for example such as are in [`Expr::Exists`].
130+
///
131+
/// [`Expr::Exists`]: crate::expr::Expr::Exists
132+
pub(crate) fn rewrite_children<F>(&mut self, mut f: F) -> Result<Transformed<()>>
133+
where
134+
F: FnMut(Self) -> Result<Transformed<Self>>,
135+
{
136+
let children_result = match self {
137+
LogicalPlan::Projection(Projection { input, .. }) => {
138+
rewrite_arc_no_data(input, &mut f)
139+
}
140+
LogicalPlan::Filter(Filter { input, .. }) => {
141+
rewrite_arc_no_data(input, &mut f)
142+
}
143+
LogicalPlan::Repartition(Repartition { input, .. }) => {
144+
rewrite_arc_no_data(input, &mut f)
145+
}
146+
LogicalPlan::Window(Window { input, .. }) => {
147+
rewrite_arc_no_data(input, &mut f)
148+
}
149+
LogicalPlan::Aggregate(Aggregate { input, .. }) => {
150+
rewrite_arc_no_data(input, &mut f)
151+
}
152+
LogicalPlan::Sort(Sort { input, .. }) => rewrite_arc_no_data(input, &mut f),
153+
LogicalPlan::Join(Join { left, right, .. }) => {
154+
let results = [left, right]
155+
.into_iter()
156+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
157+
Ok(results.discard_data())
158+
}
159+
LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
160+
let results = [left, right]
161+
.into_iter()
162+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
163+
Ok(results.discard_data())
164+
}
165+
LogicalPlan::Limit(Limit { input, .. }) => rewrite_arc_no_data(input, &mut f),
166+
LogicalPlan::Subquery(Subquery { subquery, .. }) => {
167+
rewrite_arc_no_data(subquery, &mut f)
168+
}
169+
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => {
170+
rewrite_arc_no_data(input, &mut f)
171+
}
172+
LogicalPlan::Extension(extension) => {
173+
rewrite_extension_inputs(&mut extension.node, &mut f)
174+
}
175+
LogicalPlan::Union(Union { inputs, .. }) => {
176+
let results = inputs
177+
.iter_mut()
178+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
179+
Ok(results.discard_data())
180+
}
181+
LogicalPlan::Distinct(
182+
Distinct::All(input) | Distinct::On(DistinctOn { input, .. }),
183+
) => rewrite_arc_no_data(input, &mut f),
184+
LogicalPlan::Explain(explain) => {
185+
rewrite_arc_no_data(&mut explain.plan, &mut f)
186+
}
187+
LogicalPlan::Analyze(analyze) => {
188+
rewrite_arc_no_data(&mut analyze.input, &mut f)
189+
}
190+
LogicalPlan::Dml(write) => rewrite_arc_no_data(&mut write.input, &mut f),
191+
LogicalPlan::Copy(copy) => rewrite_arc_no_data(&mut copy.input, &mut f),
192+
LogicalPlan::Ddl(ddl) => {
193+
if let Some(input) = ddl.input_mut() {
194+
rewrite_arc_no_data(input, &mut f)
195+
} else {
196+
Ok(Transformed::no(()))
197+
}
198+
}
199+
LogicalPlan::Unnest(Unnest { input, .. }) => {
200+
rewrite_arc_no_data(input, &mut f)
201+
}
202+
LogicalPlan::Prepare(Prepare { input, .. }) => {
203+
rewrite_arc_no_data(input, &mut f)
204+
}
205+
LogicalPlan::RecursiveQuery(RecursiveQuery {
206+
static_term,
207+
recursive_term,
208+
..
209+
}) => {
210+
let results = [static_term, recursive_term]
211+
.into_iter()
212+
.map_until_stop_and_collect(|input| rewrite_arc(input, &mut f))?;
213+
Ok(results.discard_data())
214+
}
215+
// plans without inputs
216+
LogicalPlan::TableScan { .. }
217+
| LogicalPlan::Statement { .. }
218+
| LogicalPlan::EmptyRelation { .. }
219+
| LogicalPlan::Values { .. }
220+
| LogicalPlan::DescribeTable(_) => Ok(Transformed::no(())),
221+
}?;
222+
223+
// after visiting the actual children we we need to visit any subqueries
224+
// that are inside the expressions
225+
// TODO use pattern introduced in https://github.com/apache/arrow-datafusion/pull/9913
226+
Ok(children_result)
227+
}
228+
}

datafusion/expr/src/tree_node/plan.rs

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,14 @@ impl TreeNode for LogicalPlan {
3232
self.inputs().into_iter().apply_until_stop(f)
3333
}
3434

35-
fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
36-
self,
37-
f: F,
38-
) -> Result<Transformed<Self>> {
39-
let new_children = self
40-
.inputs()
41-
.into_iter()
42-
.cloned()
43-
.map_until_stop_and_collect(f)?;
44-
// Propagate up `new_children.transformed` and `new_children.tnr`
45-
// along with the node containing transformed children.
46-
if new_children.transformed {
47-
new_children.map_data(|new_children| {
48-
self.with_new_exprs(self.expressions(), new_children)
49-
})
50-
} else {
51-
Ok(new_children.update_data(|_| self))
52-
}
35+
fn map_children<F>(mut self, f: F) -> Result<Transformed<Self>>
36+
where
37+
F: FnMut(Self) -> Result<Transformed<Self>>,
38+
{
39+
// Apply the rewrite *in place* for each child to avoid cloning
40+
let result = self.rewrite_children(f)?;
41+
42+
// return ourself
43+
Ok(result.update_data(|_| self))
5344
}
5445
}

0 commit comments

Comments
 (0)