Skip to content

Commit 44bbb0e

Browse files
authored
Simplifications (#7907)
1 parent eee790f commit 44bbb0e

File tree

2 files changed

+45
-56
lines changed

2 files changed

+45
-56
lines changed

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,23 @@ impl PhysicalGroupBy {
218218
pub fn is_single(&self) -> bool {
219219
self.null_expr.is_empty()
220220
}
221+
222+
/// Calculate GROUP BY expressions according to input schema.
223+
pub fn input_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
224+
self.expr
225+
.iter()
226+
.map(|(expr, _alias)| expr.clone())
227+
.collect()
228+
}
229+
230+
/// Return grouping expressions as they occur in the output schema.
231+
fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
232+
self.expr
233+
.iter()
234+
.enumerate()
235+
.map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _)
236+
.collect()
237+
}
221238
}
222239

223240
impl PartialEq for PhysicalGroupBy {
@@ -319,11 +336,7 @@ fn get_working_mode(
319336
// Since direction of the ordering is not important for GROUP BY columns,
320337
// we convert PhysicalSortExpr to PhysicalExpr in the existing ordering.
321338
let ordering_exprs = convert_to_expr(output_ordering);
322-
let groupby_exprs = group_by
323-
.expr
324-
.iter()
325-
.map(|(item, _)| item.clone())
326-
.collect::<Vec<_>>();
339+
let groupby_exprs = group_by.input_exprs();
327340
// Find where each expression of the GROUP BY clause occurs in the existing
328341
// ordering (if it occurs):
329342
let mut ordered_indices =
@@ -363,7 +376,7 @@ fn calc_aggregation_ordering(
363376
) -> Option<AggregationOrdering> {
364377
get_working_mode(input, group_by).map(|(mode, order_indices)| {
365378
let existing_ordering = input.output_ordering().unwrap_or(&[]);
366-
let out_group_expr = output_group_expr_helper(group_by);
379+
let out_group_expr = group_by.output_exprs();
367380
// Calculate output ordering information for the operator:
368381
let out_ordering = order_indices
369382
.iter()
@@ -381,18 +394,6 @@ fn calc_aggregation_ordering(
381394
})
382395
}
383396

384-
/// This function returns grouping expressions as they occur in the output schema.
385-
fn output_group_expr_helper(group_by: &PhysicalGroupBy) -> Vec<Arc<dyn PhysicalExpr>> {
386-
// Update column indices. Since the group by columns come first in the output schema, their
387-
// indices are simply 0..self.group_expr(len).
388-
group_by
389-
.expr()
390-
.iter()
391-
.enumerate()
392-
.map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _)
393-
.collect()
394-
}
395-
396397
/// This function returns the ordering requirement of the first non-reversible
397398
/// order-sensitive aggregate function such as ARRAY_AGG. This requirement serves
398399
/// as the initial requirement while calculating the finest requirement among all
@@ -591,11 +592,7 @@ fn group_by_contains_all_requirements(
591592
group_by: &PhysicalGroupBy,
592593
requirement: &LexOrdering,
593594
) -> bool {
594-
let physical_exprs = group_by
595-
.expr()
596-
.iter()
597-
.map(|(expr, _alias)| expr.clone())
598-
.collect::<Vec<_>>();
595+
let physical_exprs = group_by.input_exprs();
599596
// When we have multiple groups (grouping set)
600597
// since group by may be calculated on the subset of the group_by.expr()
601598
// it is not guaranteed to have all of the requirements among group by expressions.
@@ -735,7 +732,7 @@ impl AggregateExec {
735732

736733
/// Grouping expressions as they occur in the output schema
737734
pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
738-
output_group_expr_helper(&self.group_by)
735+
self.group_by.output_exprs()
739736
}
740737

741738
/// Aggregate expressions
@@ -894,28 +891,24 @@ impl ExecutionPlan for AggregateExec {
894891

895892
/// Get the output partitioning of this plan
896893
fn output_partitioning(&self) -> Partitioning {
897-
match &self.mode {
898-
AggregateMode::Partial | AggregateMode::Single => {
899-
// Partial and Single Aggregation will not change the output partitioning but need to respect the Alias
900-
let input_partition = self.input.output_partitioning();
901-
match input_partition {
902-
Partitioning::Hash(exprs, part) => {
903-
let normalized_exprs = exprs
904-
.into_iter()
905-
.map(|expr| {
906-
normalize_out_expr_with_columns_map(
907-
expr,
908-
&self.columns_map,
909-
)
910-
})
911-
.collect::<Vec<_>>();
912-
Partitioning::Hash(normalized_exprs, part)
913-
}
914-
_ => input_partition,
915-
}
894+
let input_partition = self.input.output_partitioning();
895+
if self.mode.is_first_stage() {
896+
// First stage Aggregation will not change the output partitioning but need to respect the Alias
897+
let input_partition = self.input.output_partitioning();
898+
if let Partitioning::Hash(exprs, part) = input_partition {
899+
let normalized_exprs = exprs
900+
.into_iter()
901+
.map(|expr| {
902+
normalize_out_expr_with_columns_map(expr, &self.columns_map)
903+
})
904+
.collect::<Vec<_>>();
905+
Partitioning::Hash(normalized_exprs, part)
906+
} else {
907+
input_partition
916908
}
909+
} else {
917910
// Final Aggregation's output partitioning is the same as its real input
918-
_ => self.input.output_partitioning(),
911+
input_partition
919912
}
920913
}
921914

datafusion/physical-plan/src/projection.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -224,18 +224,14 @@ impl ExecutionPlan for ProjectionExec {
224224
fn output_partitioning(&self) -> Partitioning {
225225
// Output partition need to respect the alias
226226
let input_partition = self.input.output_partitioning();
227-
match input_partition {
228-
Partitioning::Hash(exprs, part) => {
229-
let normalized_exprs = exprs
230-
.into_iter()
231-
.map(|expr| {
232-
normalize_out_expr_with_columns_map(expr, &self.columns_map)
233-
})
234-
.collect::<Vec<_>>();
235-
236-
Partitioning::Hash(normalized_exprs, part)
237-
}
238-
_ => input_partition,
227+
if let Partitioning::Hash(exprs, part) = input_partition {
228+
let normalized_exprs = exprs
229+
.into_iter()
230+
.map(|expr| normalize_out_expr_with_columns_map(expr, &self.columns_map))
231+
.collect::<Vec<_>>();
232+
Partitioning::Hash(normalized_exprs, part)
233+
} else {
234+
input_partition
239235
}
240236
}
241237

0 commit comments

Comments
 (0)