19
19
20
20
use super :: AnalyzerRule ;
21
21
use datafusion_common:: config:: ConfigOptions ;
22
- use datafusion_common:: tree_node:: { Transformed , TreeNodeRewriter } ;
22
+ use datafusion_common:: tree_node:: { Transformed , TreeNode } ;
23
23
use datafusion_common:: { DFSchema , Result } ;
24
- use datafusion_expr:: expr_rewriter:: { rewrite_preserving_name, FunctionRewrite } ;
24
+
25
+ use crate :: utils:: NamePreserver ;
26
+ use datafusion_expr:: expr_rewriter:: FunctionRewrite ;
25
27
use datafusion_expr:: utils:: merge_schema;
26
- use datafusion_expr:: { Expr , LogicalPlan } ;
28
+ use datafusion_expr:: LogicalPlan ;
27
29
use std:: sync:: Arc ;
28
30
29
31
/// Analyzer rule that invokes [`FunctionRewrite`]s on expressions
@@ -37,86 +39,53 @@ impl ApplyFunctionRewrites {
37
39
pub fn new ( function_rewrites : Vec < Arc < dyn FunctionRewrite + Send + Sync > > ) -> Self {
38
40
Self { function_rewrites }
39
41
}
40
- }
41
-
42
- impl AnalyzerRule for ApplyFunctionRewrites {
43
- fn name ( & self ) -> & str {
44
- "apply_function_rewrites"
45
- }
46
-
47
- fn analyze ( & self , plan : LogicalPlan , options : & ConfigOptions ) -> Result < LogicalPlan > {
48
- self . analyze_internal ( & plan, options)
49
- }
50
- }
51
42
52
- impl ApplyFunctionRewrites {
53
- fn analyze_internal (
43
+ /// Rewrite a single plan, and all its expressions using the provided rewriters
44
+ fn rewrite_plan (
54
45
& self ,
55
- plan : & LogicalPlan ,
46
+ plan : LogicalPlan ,
56
47
options : & ConfigOptions ,
57
- ) -> Result < LogicalPlan > {
58
- // optimize child plans first
59
- let new_inputs = plan
60
- . inputs ( )
61
- . iter ( )
62
- . map ( |p| self . analyze_internal ( p, options) )
63
- . collect :: < Result < Vec < _ > > > ( ) ?;
64
-
48
+ ) -> Result < Transformed < LogicalPlan > > {
65
49
// get schema representing all available input fields. This is used for data type
66
50
// resolution only, so order does not matter here
67
- let mut schema = merge_schema ( new_inputs . iter ( ) . collect ( ) ) ;
51
+ let mut schema = merge_schema ( plan . inputs ( ) ) ;
68
52
69
- if let LogicalPlan :: TableScan ( ts) = plan {
53
+ if let LogicalPlan :: TableScan ( ts) = & plan {
70
54
let source_schema = DFSchema :: try_from_qualified_schema (
71
55
ts. table_name . clone ( ) ,
72
56
& ts. source . schema ( ) ,
73
57
) ?;
74
58
schema. merge ( & source_schema) ;
75
59
}
76
60
77
- let mut expr_rewrite = OperatorToFunctionRewriter {
78
- function_rewrites : & self . function_rewrites ,
79
- options,
80
- schema : & schema,
81
- } ;
61
+ let name_preserver = NamePreserver :: new ( & plan) ;
62
+
63
+ plan. map_expressions ( |expr| {
64
+ let original_name = name_preserver. save ( & expr) ?;
82
65
83
- let new_expr = plan
84
- . expressions ( )
85
- . into_iter ( )
86
- . map ( |expr| {
87
- // ensure names don't change:
88
- // https://github.com/apache/arrow-datafusion/issues/3555
89
- rewrite_preserving_name ( expr, & mut expr_rewrite)
90
- } )
91
- . collect :: < Result < Vec < _ > > > ( ) ?;
66
+ // recursively transform the expression, applying the rewrites at each step
67
+ let result = expr. transform_up ( & |expr| {
68
+ let mut result = Transformed :: no ( expr) ;
69
+ for rewriter in self . function_rewrites . iter ( ) {
70
+ result = result. transform_data ( |expr| {
71
+ rewriter. rewrite ( expr, & schema, options)
72
+ } ) ?;
73
+ }
74
+ Ok ( result)
75
+ } ) ?;
92
76
93
- plan. with_new_exprs ( new_expr, new_inputs)
77
+ result. map_data ( |expr| original_name. restore ( expr) )
78
+ } )
94
79
}
95
80
}
96
- struct OperatorToFunctionRewriter < ' a > {
97
- function_rewrites : & ' a [ Arc < dyn FunctionRewrite + Send + Sync > ] ,
98
- options : & ' a ConfigOptions ,
99
- schema : & ' a DFSchema ,
100
- }
101
-
102
- impl < ' a > TreeNodeRewriter for OperatorToFunctionRewriter < ' a > {
103
- type Node = Expr ;
104
81
105
- fn f_up ( & mut self , mut expr : Expr ) -> Result < Transformed < Expr > > {
106
- // apply transforms one by one
107
- let mut transformed = false ;
108
- for rewriter in self . function_rewrites . iter ( ) {
109
- let result = rewriter. rewrite ( expr, self . schema , self . options ) ?;
110
- if result. transformed {
111
- transformed = true ;
112
- }
113
- expr = result. data
114
- }
82
+ impl AnalyzerRule for ApplyFunctionRewrites {
83
+ fn name ( & self ) -> & str {
84
+ "apply_function_rewrites"
85
+ }
115
86
116
- Ok ( if transformed {
117
- Transformed :: yes ( expr)
118
- } else {
119
- Transformed :: no ( expr)
120
- } )
87
+ fn analyze ( & self , plan : LogicalPlan , options : & ConfigOptions ) -> Result < LogicalPlan > {
88
+ plan. transform_up_with_subqueries ( & |plan| self . rewrite_plan ( plan, options) )
89
+ . map ( |res| res. data )
121
90
}
122
91
}
0 commit comments