From 9a2d8cf825cb2e611179327853a3b39882f5d0f0 Mon Sep 17 00:00:00 2001 From: bgjackma Date: Fri, 20 Sep 2024 22:43:02 -0700 Subject: [PATCH 1/2] Implement GROUPING aggregate function (following Postgres behavior.) --- .../expr-common/src/groups_accumulator.rs | 18 ++ .../functions-aggregate/src/grouping.rs | 199 +++++++++++++++++- .../physical-plan/src/aggregates/mod.rs | 19 +- .../physical-plan/src/aggregates/row_hash.rs | 35 ++- .../src/aggregates/topk_stream.rs | 4 +- .../sqllogictest/test_files/grouping.slt | 198 +++++++++++++++++ 6 files changed, 452 insertions(+), 21 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/grouping.slt diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 8e81c51d8460f..938e83f54abab 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -235,6 +235,24 @@ pub trait GroupsAccumulator: Send { false } + /// Update this accumulator's groupings. Used for aggregates that + /// report data about the grouping strategy e.g. GROUPING. + /// + /// * `group_indices`: Indices of groups in the current grouping set + /// + /// * `group_mask`: Mask for the current grouping set (true means null/aggregated) + /// + /// * `total_num_groups`: the number of groups (the largest + /// group_index is thus `total_num_groups - 1`). + fn update_groupings( + &mut self, + _group_indices: &[usize], + _group_mask: &[bool], + _total_num_groups: usize, + ) -> Result<()> { + Ok(()) + } + /// Amount of memory used to store the state of this accumulator, /// in bytes. /// diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 6fb7c3800f4ed..b3f92e3909368 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -19,20 +19,34 @@ use std::any::Any; use std::fmt; +use std::sync::Arc; +use arrow::array::ArrayRef; +use arrow::array::AsArray; +use arrow::array::BooleanArray; +use arrow::array::UInt32Array; use arrow::datatypes::DataType; use arrow::datatypes::Field; +use arrow::datatypes::UInt32Type; +use datafusion_common::internal_datafusion_err; +use datafusion_common::internal_err; +use datafusion_common::plan_err; use datafusion_common::{not_impl_err, Result}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::utils::format_state_name; +use datafusion_expr::EmitTo; +use datafusion_expr::GroupsAccumulator; use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::PhysicalExpr; make_udaf_expr_and_func!( Grouping, grouping, expression, - "Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.", + "Returns a bitmap where bit i is 1 if this row is aggregated across the ith argument to GROUPING and 0 otherwise.", grouping_udaf ); @@ -59,9 +73,55 @@ impl Grouping { /// Create a new GROUPING aggregate function. pub fn new() -> Self { Self { - signature: Signature::any(1, Volatility::Immutable), + signature: Signature::variadic_any(Volatility::Immutable), } } + + /// Create an accumulator for GROUPING(grouping_args) in a GROUP BY over group_exprs + /// A special creation function is necessary because GROUPING has unusual input requirements. + pub fn create_grouping_accumulator( + &self, + grouping_args: &[Arc], + group_exprs: &[(Arc, String)], + ) -> Result> { + if grouping_args.len() > 32 { + return plan_err!( + "GROUPING is supported for up to 32 columns. Consider another \ + GROUPING statement if you need to aggregate over more columns." + ); + } + // The PhysicalExprs of grouping_exprs must be Column PhysicalExpr. Because if + // the group by PhysicalExpr in SQL is non-Column PhysicalExpr, then there is + // a ProjectionExec before AggregateExec to convert the non-column PhysicalExpr + // to Column PhysicalExpr. + let column_index = + |expr: &Arc| match expr.as_any().downcast_ref::() { + Some(column) => Ok(column.index()), + None => internal_err!("Grouping doesn't support expr: {}", expr), + }; + let group_by_columns: Result> = + group_exprs.iter().map(|(e, _)| column_index(e)).collect(); + let group_by_columns = group_by_columns?; + + let arg_columns: Result> = + grouping_args.iter().map(column_index).collect(); + let expr_indices: Result> = arg_columns? + .iter() + .map(|arg| { + group_by_columns + .iter() + .position(|gb| arg == gb) + .ok_or_else(|| { + internal_datafusion_err!("Invalid grouping set indices.") + }) + }) + .collect(); + + Ok(Box::new(GroupingAccumulator { + grouping_ids: vec![], + expr_indices: expr_indices?, + })) + } } impl AggregateUDFImpl for Grouping { @@ -78,20 +138,145 @@ impl AggregateUDFImpl for Grouping { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Int32) + Ok(DataType::UInt32) } fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![Field::new( format_state_name(args.name, "grouping"), - DataType::Int32, + DataType::UInt32, true, )]) } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - not_impl_err!( - "physical plan is not yet implemented for GROUPING aggregate function" - ) + not_impl_err!("The GROUPING function requires a GROUP BY context.") + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + false + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + // Use `create_grouping_accumulator` instead. + not_impl_err!("GROUPING is not supported when invoked this way.") + } +} + +struct GroupingAccumulator { + // Grouping ID value for each group + grouping_ids: Vec, + // Indices of GROUPING arguments as they appear in the GROUPING SET + expr_indices: Vec, +} + +impl GroupingAccumulator { + fn mask_to_id(&self, mask: &[bool]) -> Result { + let mut id: u32 = 0; + // rightmost entry is the LSB + for (i, &idx) in self.expr_indices.iter().rev().enumerate() { + match mask.get(idx) { + Some(true) => id |= 1 << i, + Some(false) => {} + None => { + return internal_err!( + "Index out of bounds while calculating GROUPING id." + ) + } + } + } + Ok(id) + } +} + +impl GroupsAccumulator for GroupingAccumulator { + fn update_batch( + &mut self, + _values: &[ArrayRef], + _group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + _total_num_groups: usize, + ) -> Result<()> { + // No-op since GROUPING doesn't care about values + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to merge_batch"); + self.grouping_ids.resize(total_num_groups, 0); + let other_ids = values[0].as_primitive::(); + accumulate(group_indices, other_ids, None, |group_index, group_id| { + self.grouping_ids[group_index] |= group_id; + }); + Ok(()) + } + + fn update_groupings( + &mut self, + group_indices: &[usize], + group_mask: &[bool], + total_num_groups: usize, + ) -> Result<()> { + self.grouping_ids.resize(total_num_groups, 0); + let group_id = self.mask_to_id(group_mask)?; + for &group_idx in group_indices { + self.grouping_ids[group_idx] = group_id; + } + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let values = emit_to.take_needed(&mut self.grouping_ids); + let values = UInt32Array::new(values.into(), None); + Ok(Arc::new(values)) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn size(&self) -> usize { + self.grouping_ids.capacity() * std::mem::size_of::() + } +} + +#[cfg(test)] +mod tests { + use crate::grouping::GroupingAccumulator; + + #[test] + fn test_group_ids() { + let grouping = GroupingAccumulator { + grouping_ids: vec![], + expr_indices: vec![0, 1, 3, 2], + }; + let cases = vec![ + (0b0000, vec![false, false, false, false]), + (0b1000, vec![true, false, false, false]), + (0b0100, vec![false, true, false, false]), + (0b1010, vec![true, false, false, true]), + (0b1001, vec![true, false, true, false]), + ]; + for (expected, input) in cases { + assert_eq!(expected, grouping.mask_to_id(&input).unwrap()); + } + } + #[test] + fn test_bad_index() { + let grouping = GroupingAccumulator { + grouping_ids: vec![], + expr_indices: vec![5], + }; + let res = grouping.mask_to_id(&vec![false]); + assert_eq!(res.is_err(), true) } } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index c3bc7b042e655..f2f68c9a17e17 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -238,6 +238,13 @@ impl PartialEq for PhysicalGroupBy { } } +pub(crate) struct PhysicalGroupingSet { + /// Exprs/columns over which the grouping set is aggregated + values: Vec, + /// True if the corresponding value is null in this grouping set + mask: Vec, +} + enum StreamType { AggregateStream(AggregateStream), GroupedHash(GroupedHashAggregateStream), @@ -1140,13 +1147,13 @@ fn evaluate_optional( /// - `batch`: the `RecordBatch` to evaluate against /// /// Returns: A Vec of Vecs of Array of results -/// The outer Vec appears to be for grouping sets +/// The outer Vec contains the grouping sets defined by `group_by.groups` /// The inner Vec contains the results per expression /// The inner-inner Array contains the results per row pub(crate) fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, -) -> Result>> { +) -> Result> { let exprs: Vec = group_by .expr .iter() @@ -1169,7 +1176,7 @@ pub(crate) fn evaluate_group_by( .groups .iter() .map(|group| { - group + let v = group .iter() .enumerate() .map(|(idx, is_null)| { @@ -1179,7 +1186,11 @@ pub(crate) fn evaluate_group_by( Arc::clone(&exprs[idx]) } }) - .collect() + .collect(); + PhysicalGroupingSet { + values: v, + mask: group.clone(), + } }) .collect()) } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 60efc77112167..d03c148eceb87 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -45,6 +45,7 @@ use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate::grouping::Grouping; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; @@ -477,7 +478,7 @@ impl GroupedHashAggregateStream { // Instantiate the accumulators let accumulators: Vec<_> = aggregate_exprs .iter() - .map(create_group_accumulator) + .map(|agg_expr| create_group_accumulator(agg_expr, &agg_group_by)) .collect::>()?; let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); @@ -578,7 +579,13 @@ impl GroupedHashAggregateStream { /// [`GroupsAccumulatorAdapter`] if not. pub(crate) fn create_group_accumulator( agg_expr: &AggregateFunctionExpr, + group_by: &PhysicalGroupBy, ) -> Result> { + // GROUPING is a special fxn that exposes info about group organization + if let Some(grouping) = agg_expr.fun().inner().as_any().downcast_ref::() { + let args = agg_expr.all_expressions().args; + return grouping.create_grouping_accumulator(&args, &group_by.expr); + } if agg_expr.groups_accumulator_supported() { agg_expr.create_groups_accumulator() } else { @@ -740,7 +747,7 @@ impl GroupedHashAggregateStream { /// Perform group-by aggregation for the given [`RecordBatch`]. fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> { // Evaluate the grouping expressions - let group_by_values = if self.spill_state.is_stream_merging { + let grouping_sets = if self.spill_state.is_stream_merging { evaluate_group_by(&self.spill_state.merging_group_by, &batch)? } else { evaluate_group_by(&self.group_by, &batch)? @@ -761,18 +768,18 @@ impl GroupedHashAggregateStream { evaluate_optional(&self.filter_expressions, &batch)? }; - for group_values in &group_by_values { + for grouping_set in grouping_sets.iter() { // calculate the group indices for each input row let starting_num_groups = self.group_values.len(); self.group_values - .intern(group_values, &mut self.current_group_indices)?; + .intern(&grouping_set.values, &mut self.current_group_indices)?; let group_indices = &self.current_group_indices; // Update ordering information if necessary let total_num_groups = self.group_values.len(); if total_num_groups > starting_num_groups { self.group_ordering.new_groups( - group_values, + &grouping_set.values, group_indices, total_num_groups, )?; @@ -802,6 +809,12 @@ impl GroupedHashAggregateStream { opt_filter, total_num_groups, )?; + // Update aggregates that care about which exprs are masked + acc.update_groupings( + group_indices, + &grouping_set.mask, + total_num_groups, + )?; } _ => { // if aggregation is over intermediate states, @@ -870,6 +883,7 @@ impl GroupedHashAggregateStream { | AggregateMode::SinglePartitioned => output.push(acc.evaluate(emit_to)?), } } + debug!("Output: {:?}", output); // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is // over the target memory size after emission, we can emit again rather than returning Err. @@ -1052,9 +1066,14 @@ impl GroupedHashAggregateStream { let input_values = evaluate_many(&self.aggregate_arguments, &batch)?; let filter_values = evaluate_optional(&self.filter_expressions, &batch)?; - let mut output = group_values.first().cloned().ok_or_else(|| { - internal_datafusion_err!("group_values expected to have at least one element") - })?; + let mut output = group_values + .first() + .map(|gs| gs.values.clone()) + .ok_or_else(|| { + internal_datafusion_err!( + "group_values expected to have at least one element" + ) + })?; let iter = self .accumulators diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index 075d8c5f28833..a2d13607a5a75 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -135,11 +135,11 @@ impl Stream for GroupedTopKAggregateStream { "Exactly 1 group value required" ); assert_eq!( - group_by_values[0].len(), + group_by_values[0].values.len(), 1, "Exactly 1 group value required" ); - let group_by_values = Arc::clone(&group_by_values[0][0]); + let group_by_values = Arc::clone(&group_by_values[0].values[0]); let input_values = evaluate_many( &self.aggregate_arguments, batches.first().unwrap(), diff --git a/datafusion/sqllogictest/test_files/grouping.slt b/datafusion/sqllogictest/test_files/grouping.slt new file mode 100644 index 0000000000000..7b85ed3b2e40d --- /dev/null +++ b/datafusion/sqllogictest/test_files/grouping.slt @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE test (c1 VARCHAR,c2 VARCHAR,c3 INT) as values +('a','A',1), ('b','B',2) + +# grouping_with_grouping_sets +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + grouping sets ( + (c1, c2), + (c1), + (c2), + () + ) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL A 1 0 2 1 +NULL B 1 0 2 1 +NULL NULL 1 1 3 3 + +# grouping_with_cube +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + cube(c1, c2) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL A 1 0 2 1 +NULL B 1 0 2 1 +NULL NULL 1 1 3 3 + +# grouping_with_rollup +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + rollup(c1, c2) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL NULL 1 1 3 3 + +query TTIIIIIIII +select + c1, + c2, + c3, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3, + grouping(c1, c2, c3) as g4, + grouping(c2, c3, c1) as g5, + grouping(c3, c2, c1) as g6 +from + test +group by + rollup(c1, c2, c3) +order by + c1, c2, g0, g1, g2, g3, g4, g5, g6; +---- +a A 1 0 0 0 0 0 0 0 +a A NULL 0 0 0 0 1 2 4 +a NULL NULL 0 1 1 2 3 6 6 +b B 2 0 0 0 0 0 0 0 +b B NULL 0 0 0 0 1 2 4 +b NULL NULL 0 1 1 2 3 6 6 +NULL NULL NULL 1 1 3 3 7 7 7 + +# grouping_with_add +query TTI +select + c1, + c2, + grouping(c1)+grouping(c2) as g0 +from + test +group by + rollup(c1, c2) +order by + c1, c2, g0; +---- +a A 0 +a NULL 1 +b B 0 +b NULL 1 +NULL NULL 2 + +#grouping_with_windown_function +query TTIII +select + c1, + c2, + count(c1) as cnt, + grouping(c1)+ grouping(c2) as g0, + rank() over ( + partition by grouping(c1)+grouping(c2), + case when grouping(c2) = 0 then c1 end + order by + count(c1) desc + ) as rank_within_parent +from + test +group by + rollup(c1, c2) +order by + c1, + c2, + cnt, + g0 desc, + rank_within_parent; +---- +a A 1 0 1 +a NULL 1 1 1 +b B 1 0 1 +b NULL 1 1 1 +NULL NULL 2 2 1 + +# grouping_with_non_columns +query TIIIII +select + c1, + c3 + 1 as c3_add_one, + grouping(c1) as g0, + grouping(c3 + 1) as g1, + grouping(c1, c3 + 1) as g2, + grouping(c3 + 1, c1) as g3 +from + test +group by + grouping sets ( + (c1, c3 + 1), + (c3 + 1), + (c1) + ) +order by + c1, c3_add_one, g0, g1, g2, g3; +---- +a 2 0 0 0 0 +a NULL 0 1 1 2 +b 3 0 0 0 0 +b NULL 0 1 1 2 +NULL 2 1 0 2 1 +NULL 3 1 0 2 1 From 9c0ed8afa6dedd7800d842b0da42128f10fecfad Mon Sep 17 00:00:00 2001 From: bgjackma Date: Wed, 25 Sep 2024 11:35:34 -0700 Subject: [PATCH 2/2] Satisfy Clippy --- datafusion/functions-aggregate/src/grouping.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index b3f92e3909368..37cd2e7b9fd99 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -276,7 +276,7 @@ mod tests { grouping_ids: vec![], expr_indices: vec![5], }; - let res = grouping.mask_to_id(&vec![false]); - assert_eq!(res.is_err(), true) + let res = grouping.mask_to_id(&[false]); + assert!(res.is_err()) } }