From bdc88aeaf24039750e51ca27d1ec735e51354b4c Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Tue, 1 Oct 2024 21:50:47 +0200 Subject: [PATCH 1/2] Implement grouping function using grouping id This patch adds a Analyzer rule to transform the grouping aggreation function into computation ontop of the grouping id that is used internally for grouping sets. --- .../functions-aggregate/src/grouping.rs | 2 +- datafusion/optimizer/src/analyzer/mod.rs | 3 + .../src/analyzer/resolve_grouping_function.rs | 242 ++++++++++++++++++ .../sqllogictest/test_files/explain.slt | 1 + .../sqllogictest/test_files/grouping.slt | 214 ++++++++++++++++ 5 files changed, 461 insertions(+), 1 deletion(-) create mode 100644 datafusion/optimizer/src/analyzer/resolve_grouping_function.rs create mode 100644 datafusion/sqllogictest/test_files/grouping.slt diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 09e9b90b2e6d..558d3055f1bf 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -63,7 +63,7 @@ impl Grouping { /// Create a new GROUPING aggregate function. pub fn new() -> Self { Self { - signature: Signature::any(1, Volatility::Immutable), + signature: Signature::variadic_any(Volatility::Immutable), } } } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 4cd891664e7f..a9fd4900b2f4 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -34,6 +34,7 @@ use datafusion_expr::{Expr, LogicalPlan}; use crate::analyzer::count_wildcard_rule::CountWildcardRule; use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule; use crate::analyzer::inline_table_scan::InlineTableScan; +use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction; use crate::analyzer::subquery::check_subquery_expr; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; @@ -44,6 +45,7 @@ pub mod count_wildcard_rule; pub mod expand_wildcard_rule; pub mod function_rewrite; pub mod inline_table_scan; +pub mod resolve_grouping_function; pub mod subquery; pub mod type_coercion; @@ -96,6 +98,7 @@ impl Analyzer { // Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule]. Arc::new(ExpandWildcardRule::new()), // [Expr::Wildcard] should be expanded before [TypeCoercion] + Arc::new(ResolveGroupingFunction::new()), Arc::new(TypeCoercion::new()), Arc::new(CountWildcardRule::new()), ]; diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs new file mode 100644 index 000000000000..24560ca481ce --- /dev/null +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -0,0 +1,242 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzed rule to replace TableScan references +//! such as DataFrames and Views and inlines the LogicalPlan. + +use std::cmp::Ordering; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::analyzer::AnalyzerRule; + +use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{Column, DFSchemaRef, DataFusionError, Result, ScalarValue}; +use datafusion_expr::expr::{AggregateFunction, Alias}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::grouping_set_to_exprlist; +use datafusion_expr::{ + bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, + Expr, Projection, +}; +use itertools::Itertools; + +/// Replaces grouping aggregation function with value derived from internal grouping id +#[derive(Default, Debug)] +pub struct ResolveGroupingFunction; + +impl ResolveGroupingFunction { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for ResolveGroupingFunction { + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + plan.transform_up(analyze_internal).data() + } + + fn name(&self) -> &str { + "resolve_grouping_function" + } +} + +fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result> { + Ok(grouping_set_to_exprlist(group_expr)? + .into_iter() + .rev() + .enumerate() + .map(|(idx, v)| (v, idx)) + .collect::>()) +} + +fn replace_grouping_exprs( + input: Arc, + schema: DFSchemaRef, + group_expr: Vec, + aggr_expr: Vec, +) -> Result { + // Create HashMap from Expr to index in the grouping_id bitmap + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?; + let columns = schema.columns(); + let mut new_agg_expr = Vec::new(); + let mut projection_exprs = Vec::new(); + let grouping_id_len = if is_grouping_set { 1 } else { 0 }; + let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len; + projection_exprs.extend( + columns + .iter() + .take(group_expr_len) + .map(|column| Expr::Column(column.clone())), + ); + for (expr, column) in aggr_expr + .into_iter() + .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) + { + match expr { + Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { + let grouping_expr = grouping_function_on_id( + function, + &group_expr_to_bitmap_index, + is_grouping_set, + )?; + projection_exprs.push(Expr::Alias(Alias::new( + grouping_expr, + column.relation, + column.name, + ))); + } + _ => { + projection_exprs.push(Expr::Column(column)); + new_agg_expr.push(expr); + } + } + } + // Recreate aggregate without grouping functions + let new_aggregate = + LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); + // Create projection with grouping functions calculations + let projection = LogicalPlan::Projection(Projection::try_new( + projection_exprs, + new_aggregate.into(), + )?); + Ok(projection) +} + +fn analyze_internal(plan: LogicalPlan) -> Result> { + // rewrite any subqueries in the plan first + let transformed_plan = + plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?; + + let transformed_plan = transformed_plan.transform_data(|plan| match plan { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + .. + }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( + replace_grouping_exprs(input, schema, group_expr, aggr_expr)?, + )), + _ => Ok(Transformed::no(plan)), + })?; + + Ok(transformed_plan) +} + +fn is_grouping_function(expr: &Expr) -> bool { + // TODO: Do something better than name here should grouping be a built + // in expression? + matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping") +} + +fn contains_grouping_function(exprs: &[Expr]) -> bool { + exprs.iter().any(is_grouping_function) +} + +fn validate_args( + function: &AggregateFunction, + group_by_expr: &HashMap<&Expr, usize>, +) -> Result<()> { + let expr_not_in_group_by = function + .args + .iter() + .find(|expr| !group_by_expr.contains_key(expr)); + if let Some(expr) = expr_not_in_group_by { + Err(DataFusionError::Plan(format!( + "Argument {} to grouping function is not in grouping columns {}", + expr, + group_by_expr.keys().map(|e| e.to_string()).join(", ") + ))) + } else { + Ok(()) + } +} + +fn grouping_function_on_id( + function: &AggregateFunction, + group_by_expr: &HashMap<&Expr, usize>, + is_grouping_set: bool, +) -> Result { + validate_args(function, group_by_expr)?; + let args = &function.args; + + // Postgres allows grouping function for group by without grouping sets, the result is then + // always 0 + if !is_grouping_set { + return Ok(Expr::Literal(ScalarValue::from(0u32))); + } + + let group_by_expr_count = group_by_expr.len(); + let literal = |value: usize| { + if group_by_expr_count < 8 { + Expr::Literal(ScalarValue::from(value as u8)) + } else if group_by_expr_count < 16 { + Expr::Literal(ScalarValue::from(value as u16)) + } else if group_by_expr_count < 32 { + Expr::Literal(ScalarValue::from(value as u32)) + } else { + Expr::Literal(ScalarValue::from(value as u64)) + } + }; + + let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); + // The grouping call is exactly our internal grouping id + if args.len() == group_by_expr_count + && args + .iter() + .rev() + .enumerate() + .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) + { + return Ok(cast(grouping_id_column, DataType::UInt32)); + } + + args.iter() + .rev() + .enumerate() + .map(|(arg_idx, expr)| { + group_by_expr.get(expr).map(|group_by_idx| { + let group_by_bit = + bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx)); + match group_by_idx.cmp(&arg_idx) { + Ordering::Less => { + bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx)) + } + Ordering::Greater => { + bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx)) + } + Ordering::Equal => group_by_bit, + } + }) + }) + .collect::>>() + .and_then(|bit_exprs| { + bit_exprs + .into_iter() + .reduce(bitwise_or) + .map(|expr| cast(expr, DataType::UInt32)) + }) + .ok_or_else(|| { + DataFusionError::Internal( + "Grouping sets should contains at least one element".to_string(), + ) + }) +} diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 6dc92bae828b..b1962ffcc116 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -176,6 +176,7 @@ initial_logical_plan 02)--TableScan: simple_explain_test logical_plan after inline_table_scan SAME TEXT AS ABOVE logical_plan after expand_wildcard_rule SAME TEXT AS ABOVE +logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE diff --git a/datafusion/sqllogictest/test_files/grouping.slt b/datafusion/sqllogictest/test_files/grouping.slt new file mode 100644 index 000000000000..78fe40379574 --- /dev/null +++ b/datafusion/sqllogictest/test_files/grouping.slt @@ -0,0 +1,214 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE test (c1 VARCHAR,c2 VARCHAR,c3 INT) as values +('a','A',1), ('b','B',2) + +# grouping_with_grouping_sets +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + grouping sets ( + (c1, c2), + (c1), + (c2), + () + ) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL A 1 0 2 1 +NULL B 1 0 2 1 +NULL NULL 1 1 3 3 + +# grouping_with_cube +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + cube(c1, c2) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL A 1 0 2 1 +NULL B 1 0 2 1 +NULL NULL 1 1 3 3 + +# grouping_with_rollup +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + rollup(c1, c2) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL NULL 1 1 3 3 + +query TTIIIIIIII +select + c1, + c2, + c3, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3, + grouping(c1, c2, c3) as g4, + grouping(c2, c3, c1) as g5, + grouping(c3, c2, c1) as g6 +from + test +group by + rollup(c1, c2, c3) +order by + c1, c2, g0, g1, g2, g3, g4, g5, g6; +---- +a A 1 0 0 0 0 0 0 0 +a A NULL 0 0 0 0 1 2 4 +a NULL NULL 0 1 1 2 3 6 6 +b B 2 0 0 0 0 0 0 0 +b B NULL 0 0 0 0 1 2 4 +b NULL NULL 0 1 1 2 3 6 6 +NULL NULL NULL 1 1 3 3 7 7 7 + +# grouping_with_add +query TTI +select + c1, + c2, + grouping(c1)+grouping(c2) as g0 +from + test +group by + rollup(c1, c2) +order by + c1, c2, g0; +---- +a A 0 +a NULL 1 +b B 0 +b NULL 1 +NULL NULL 2 + +#grouping_with_windown_function +query TTIII +select + c1, + c2, + count(c1) as cnt, + grouping(c1)+ grouping(c2) as g0, + rank() over ( + partition by grouping(c1)+grouping(c2), + case when grouping(c2) = 0 then c1 end + order by + count(c1) desc + ) as rank_within_parent +from + test +group by + rollup(c1, c2) +order by + c1, + c2, + cnt, + g0 desc, + rank_within_parent; +---- +a A 1 0 1 +a NULL 1 1 1 +b B 1 0 1 +b NULL 1 1 1 +NULL NULL 2 2 1 + +# grouping_with_non_columns +query TIIIII +select + c1, + c3 + 1 as c3_add_one, + grouping(c1) as g0, + grouping(c3 + 1) as g1, + grouping(c1, c3 + 1) as g2, + grouping(c3 + 1, c1) as g3 +from + test +group by + grouping sets ( + (c1, c3 + 1), + (c3 + 1), + (c1) + ) +order by + c1, c3_add_one, g0, g1, g2, g3; +---- +a 2 0 0 0 0 +a NULL 0 1 1 2 +b 3 0 0 0 0 +b NULL 0 1 1 2 +NULL 2 1 0 2 1 +NULL 3 1 0 2 1 + +# postgres allows grounping function for group by without grouping sets/rollup/cube +query TI +select c1, grouping(c1) from test group by c1 order by c1; +---- +a 0 +b 0 + +statement error c2.*not in grouping columns +select c1, grouping(c2) from test group by c1; + +statement error c2.*not in grouping columns +select c1, grouping(c1, c2) from test group by CUBE(c1); + +statement error zero arguments +select c1, grouping() from test group by CUBE(c1); From 8acd5cb0f13620c85d86d9ebdb8475a3905f7714 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Tue, 15 Oct 2024 19:42:52 +0200 Subject: [PATCH 2/2] PR comments --- .../src/analyzer/resolve_grouping_function.rs | 23 +++++++++++-------- .../sqllogictest/test_files/grouping.slt | 2 +- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 24560ca481ce..16ebb8cd3972 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -27,7 +27,9 @@ use crate::analyzer::AnalyzerRule; use arrow::datatypes::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{Column, DFSchemaRef, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, +}; use datafusion_expr::expr::{AggregateFunction, Alias}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::utils::grouping_set_to_exprlist; @@ -57,6 +59,10 @@ impl AnalyzerRule for ResolveGroupingFunction { } } +/// Create a map from grouping expr to index in the internal grouping id. +/// +/// For more details on how the grouping id bitmap works the documentation for +/// [[Aggregate::INTERNAL_GROUPING_ID]] fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result> { Ok(grouping_set_to_exprlist(group_expr)? .into_iter() @@ -151,6 +157,7 @@ fn contains_grouping_function(exprs: &[Expr]) -> bool { exprs.iter().any(is_grouping_function) } +/// Validate that the arguments to the grouping function are in the group by clause. fn validate_args( function: &AggregateFunction, group_by_expr: &HashMap<&Expr, usize>, @@ -160,11 +167,11 @@ fn validate_args( .iter() .find(|expr| !group_by_expr.contains_key(expr)); if let Some(expr) = expr_not_in_group_by { - Err(DataFusionError::Plan(format!( + plan_err!( "Argument {} to grouping function is not in grouping columns {}", expr, group_by_expr.keys().map(|e| e.to_string()).join(", ") - ))) + ) } else { Ok(()) } @@ -181,7 +188,7 @@ fn grouping_function_on_id( // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 if !is_grouping_set { - return Ok(Expr::Literal(ScalarValue::from(0u32))); + return Ok(Expr::Literal(ScalarValue::from(0i32))); } let group_by_expr_count = group_by_expr.len(); @@ -206,7 +213,7 @@ fn grouping_function_on_id( .enumerate() .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) { - return Ok(cast(grouping_id_column, DataType::UInt32)); + return Ok(cast(grouping_id_column, DataType::Int32)); } args.iter() @@ -232,11 +239,9 @@ fn grouping_function_on_id( bit_exprs .into_iter() .reduce(bitwise_or) - .map(|expr| cast(expr, DataType::UInt32)) + .map(|expr| cast(expr, DataType::Int32)) }) .ok_or_else(|| { - DataFusionError::Internal( - "Grouping sets should contains at least one element".to_string(), - ) + internal_datafusion_err!("Grouping sets should contains at least one element") }) } diff --git a/datafusion/sqllogictest/test_files/grouping.slt b/datafusion/sqllogictest/test_files/grouping.slt index 78fe40379574..64d040d012f9 100644 --- a/datafusion/sqllogictest/test_files/grouping.slt +++ b/datafusion/sqllogictest/test_files/grouping.slt @@ -197,7 +197,7 @@ b NULL 0 1 1 2 NULL 2 1 0 2 1 NULL 3 1 0 2 1 -# postgres allows grounping function for group by without grouping sets/rollup/cube +# postgres allows grouping function for GROUP BY without GROUPING SETS/ROLLUP/CUBE query TI select c1, grouping(c1) from test group by c1 order by c1; ----