Skip to content

Commit e161cd6

Browse files
authored
fix NamedStructField should be rewritten in OperatorToFunction in subquery regression (change ApplyFunctionRewrites to use TreeNode API (#10032)
* fix NamedStructField should be rewritten in OperatorToFunction in subquery * Use TreeNode rewriter
1 parent 2def10f commit e161cd6

File tree

3 files changed

+133
-65
lines changed

3 files changed

+133
-65
lines changed

datafusion/optimizer/src/analyzer/function_rewrite.rs

Lines changed: 34 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
2020
use super::AnalyzerRule;
2121
use datafusion_common::config::ConfigOptions;
22-
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
22+
use datafusion_common::tree_node::{Transformed, TreeNode};
2323
use datafusion_common::{DFSchema, Result};
24-
use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite};
24+
25+
use crate::utils::NamePreserver;
26+
use datafusion_expr::expr_rewriter::FunctionRewrite;
2527
use datafusion_expr::utils::merge_schema;
26-
use datafusion_expr::{Expr, LogicalPlan};
28+
use datafusion_expr::LogicalPlan;
2729
use std::sync::Arc;
2830

2931
/// Analyzer rule that invokes [`FunctionRewrite`]s on expressions
@@ -37,86 +39,53 @@ impl ApplyFunctionRewrites {
3739
pub fn new(function_rewrites: Vec<Arc<dyn FunctionRewrite + Send + Sync>>) -> Self {
3840
Self { function_rewrites }
3941
}
40-
}
41-
42-
impl AnalyzerRule for ApplyFunctionRewrites {
43-
fn name(&self) -> &str {
44-
"apply_function_rewrites"
45-
}
46-
47-
fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result<LogicalPlan> {
48-
self.analyze_internal(&plan, options)
49-
}
50-
}
5142

52-
impl ApplyFunctionRewrites {
53-
fn analyze_internal(
43+
/// Rewrite a single plan, and all its expressions using the provided rewriters
44+
fn rewrite_plan(
5445
&self,
55-
plan: &LogicalPlan,
46+
plan: LogicalPlan,
5647
options: &ConfigOptions,
57-
) -> Result<LogicalPlan> {
58-
// optimize child plans first
59-
let new_inputs = plan
60-
.inputs()
61-
.iter()
62-
.map(|p| self.analyze_internal(p, options))
63-
.collect::<Result<Vec<_>>>()?;
64-
48+
) -> Result<Transformed<LogicalPlan>> {
6549
// get schema representing all available input fields. This is used for data type
6650
// resolution only, so order does not matter here
67-
let mut schema = merge_schema(new_inputs.iter().collect());
51+
let mut schema = merge_schema(plan.inputs());
6852

69-
if let LogicalPlan::TableScan(ts) = plan {
53+
if let LogicalPlan::TableScan(ts) = &plan {
7054
let source_schema = DFSchema::try_from_qualified_schema(
7155
ts.table_name.clone(),
7256
&ts.source.schema(),
7357
)?;
7458
schema.merge(&source_schema);
7559
}
7660

77-
let mut expr_rewrite = OperatorToFunctionRewriter {
78-
function_rewrites: &self.function_rewrites,
79-
options,
80-
schema: &schema,
81-
};
61+
let name_preserver = NamePreserver::new(&plan);
62+
63+
plan.map_expressions(|expr| {
64+
let original_name = name_preserver.save(&expr)?;
8265

83-
let new_expr = plan
84-
.expressions()
85-
.into_iter()
86-
.map(|expr| {
87-
// ensure names don't change:
88-
// https://github.com/apache/arrow-datafusion/issues/3555
89-
rewrite_preserving_name(expr, &mut expr_rewrite)
90-
})
91-
.collect::<Result<Vec<_>>>()?;
66+
// recursively transform the expression, applying the rewrites at each step
67+
let result = expr.transform_up(&|expr| {
68+
let mut result = Transformed::no(expr);
69+
for rewriter in self.function_rewrites.iter() {
70+
result = result.transform_data(|expr| {
71+
rewriter.rewrite(expr, &schema, options)
72+
})?;
73+
}
74+
Ok(result)
75+
})?;
9276

93-
plan.with_new_exprs(new_expr, new_inputs)
77+
result.map_data(|expr| original_name.restore(expr))
78+
})
9479
}
9580
}
96-
struct OperatorToFunctionRewriter<'a> {
97-
function_rewrites: &'a [Arc<dyn FunctionRewrite + Send + Sync>],
98-
options: &'a ConfigOptions,
99-
schema: &'a DFSchema,
100-
}
101-
102-
impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> {
103-
type Node = Expr;
10481

105-
fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
106-
// apply transforms one by one
107-
let mut transformed = false;
108-
for rewriter in self.function_rewrites.iter() {
109-
let result = rewriter.rewrite(expr, self.schema, self.options)?;
110-
if result.transformed {
111-
transformed = true;
112-
}
113-
expr = result.data
114-
}
82+
impl AnalyzerRule for ApplyFunctionRewrites {
83+
fn name(&self) -> &str {
84+
"apply_function_rewrites"
85+
}
11586

116-
Ok(if transformed {
117-
Transformed::yes(expr)
118-
} else {
119-
Transformed::no(expr)
120-
})
87+
fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result<LogicalPlan> {
88+
plan.transform_up_with_subqueries(&|plan| self.rewrite_plan(plan, options))
89+
.map(|res| res.data)
12190
}
12291
}

datafusion/optimizer/src/utils.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,47 @@ pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
288288
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
289289
expr_utils::merge_schema(inputs)
290290
}
291+
292+
/// Handles ensuring the name of rewritten expressions is not changed.
293+
///
294+
/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
295+
/// expression should be preserved: `3 as "1 + 2"`
296+
///
297+
/// See <https://github.com/apache/arrow-datafusion/issues/3555> for details
298+
pub struct NamePreserver {
299+
use_alias: bool,
300+
}
301+
302+
/// If the name of an expression is remembered, it will be preserved when
303+
/// rewriting the expression
304+
pub struct SavedName(Option<String>);
305+
306+
impl NamePreserver {
307+
/// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
308+
pub fn new(plan: &LogicalPlan) -> Self {
309+
Self {
310+
use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)),
311+
}
312+
}
313+
314+
pub fn save(&self, expr: &Expr) -> Result<SavedName> {
315+
let original_name = if self.use_alias {
316+
Some(expr.name_for_alias()?)
317+
} else {
318+
None
319+
};
320+
321+
Ok(SavedName(original_name))
322+
}
323+
}
324+
325+
impl SavedName {
326+
/// Ensures the name of the rewritten expression is preserved
327+
pub fn restore(self, expr: Expr) -> Result<Expr> {
328+
let Self(original_name) = self;
329+
match original_name {
330+
Some(name) => expr.alias_if_changed(name),
331+
None => Ok(expr),
332+
}
333+
}
334+
}

datafusion/sqllogictest/test_files/subquery.slt

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,3 +1060,58 @@ logical_plan
10601060
Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1)
10611061
--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a
10621062
----TableScan: t projection=[a]
1063+
1064+
###
1065+
## Ensure that operators are rewritten in subqueries
1066+
###
1067+
1068+
statement ok
1069+
create table foo(x int) as values (1);
1070+
1071+
# Show input data
1072+
query ?
1073+
select struct(1, 'b')
1074+
----
1075+
{c0: 1, c1: b}
1076+
1077+
1078+
query T
1079+
select (select struct(1, 'b')['c1']);
1080+
----
1081+
b
1082+
1083+
query T
1084+
select 'foo' || (select struct(1, 'b')['c1']);
1085+
----
1086+
foob
1087+
1088+
query I
1089+
SELECT * FROM (VALUES (1), (2))
1090+
WHERE column1 IN (SELECT struct(1, 'b')['c0']);
1091+
----
1092+
1
1093+
1094+
# also add an expression so the subquery is the output expr
1095+
query I
1096+
SELECT * FROM (VALUES (1), (2))
1097+
WHERE 1+2 = 3 AND column1 IN (SELECT struct(1, 'b')['c0']);
1098+
----
1099+
1
1100+
1101+
1102+
query I
1103+
SELECT * FROM foo
1104+
WHERE EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1);
1105+
----
1106+
1
1107+
1108+
# also add an expression so the subquery is the output expr
1109+
query I
1110+
SELECT * FROM foo
1111+
WHERE 1+2 = 3 AND EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1);
1112+
----
1113+
1
1114+
1115+
1116+
statement ok
1117+
drop table foo;

0 commit comments

Comments
 (0)