diff --git a/Cargo.lock b/Cargo.lock index a2939f4257127..734153d72c2b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2377,6 +2377,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-expr-common", + "datafusion-functions", "datafusion-functions-aggregate", "datafusion-functions-window", "datafusion-functions-window-common", diff --git a/datafusion/expr/src/grouping.rs b/datafusion/expr/src/grouping.rs new file mode 100644 index 0000000000000..2873d6f526014 --- /dev/null +++ b/datafusion/expr/src/grouping.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{Array, Int32Array}, + datatypes::DataType, +}; +use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_expr_common::{ + accumulator::Accumulator, + signature::{Signature, Volatility}, +}; +use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs; + +use crate::{ + expr::{AggregateFunction, ScalarFunction}, + utils::grouping_set_to_exprlist, + Aggregate, AggregateUDF, AggregateUDFImpl, Expr, +}; + +// To avoid adding datafusion-functions-aggregate dependency, implement a DummyGroupingUDAF here +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DummyGroupingUDAF { + signature: Signature, +} + +impl Default for DummyGroupingUDAF { + fn default() -> Self { + Self::new() + } +} + +impl DummyGroupingUDAF { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } + + pub fn from_scalar_function( + func: &ScalarFunction, + agg: &Aggregate, + ) -> Result { + if func.args.len() != 1 && func.args.len() != 2 { + return internal_err!("Grouping function must have one or two arguments"); + } + let grouping_expr = grouping_set_to_exprlist(&agg.group_expr)?; + let args = if func.args.len() == 1 { + grouping_expr.iter().map(|e| (*e).clone()).collect() + } else if let Expr::Literal(ScalarValue::List(list), _) = &func.args[1] { + if list.len() != 1 { + return internal_err!("The second argument of grouping function must be a list with exactly one element"); + } + + let grouping_expr = grouping_expr.into_iter().rev().collect::>(); + let values = list + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec(); + values + .iter() + .map(|i: &i32| grouping_expr[*i as usize].clone()) + .collect() + } else { + return internal_err!( + "The second argument of grouping function must be a list" + ); + }; + Ok(AggregateFunction::new_udf( + Arc::new(AggregateUDF::from(Self::new())), + args, + false, + None, + vec![], + None, + )) + } +} + +impl AggregateUDFImpl for DummyGroupingUDAF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "grouping" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + todo!() + } +} diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 1c9734a89bd37..bf1eb84fdd7f1 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -51,6 +51,7 @@ pub mod expr_fn; pub mod expr_rewriter; pub mod expr_schema; pub mod function; +pub mod grouping; pub mod select_expr; pub mod groups_accumulator { pub use datafusion_expr_common::groups_accumulator::*; diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 4d1da1dad5949..c970e501d9d3e 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -35,25 +35,25 @@ make_udaf_expr_and_func!( Grouping, grouping, expression, - "Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.", + "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn).", grouping_udaf ); #[user_doc( doc_section(label = "General Functions"), - description = "Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set.", + description = "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn).", syntax_example = "grouping(expression)", sql_example = r#"```sql > SELECT column_name, GROUPING(column_name) AS group_column FROM table_name GROUP BY GROUPING SETS ((column_name), ()); -+-------------+-------------+ ++-------------+--------------+ | column_name | group_column | -+-------------+-------------+ -| value1 | 0 | -| value2 | 0 | -| NULL | 1 | -+-------------+-------------+ ++-------------+--------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+--------------+ ```"#, argument( name = "expression", diff --git a/datafusion/functions/src/core/grouping.rs b/datafusion/functions/src/core/grouping.rs new file mode 100644 index 0000000000000..97c7b5c024246 --- /dev/null +++ b/datafusion/functions/src/core/grouping.rs @@ -0,0 +1,440 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, Int32Array, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, +}; +use arrow::compute::cast; +use arrow::datatypes::{DataType, Field, Int32Type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_macros::user_doc; +use std::any::Any; +use std::sync::Arc; + +use crate::utils::make_scalar_function; + +macro_rules! grouping_id { + ($grouping_id:expr, $indices:expr, $type:ty, $array_type:ty) => {{ + let grouping_id = match $grouping_id.as_any().downcast_ref::<$array_type>() { + Some(array) => array, + None => { + return exec_err!( + "grouping function requires {} grouping_id array", + stringify!($type) + ) + } + }; + grouping_id + .iter() + .zip($indices.iter()) + .map(|(grouping_id, indices)| { + grouping_id.map(|grouping_id| { + let mut result = 0 as $type; + match indices { + Some(indices) => { + for index in indices.as_primitive::().iter() { + if let Some(index) = index { + let bit = (grouping_id >> index) & 1; + result = (result << 1) | bit; + } + } + } + None => { + result = grouping_id; + } + } + result as i32 + }) + }) + .collect() + }}; +} + +#[user_doc( + doc_section(label = "Other Functions"), + description = "Developer API: Returns the level of grouping, equals to (((grouping_id >> array[0]) & 1) << (n-1)) + (((grouping_id >> array[1]) & 1) << (n-2)) + ... + (((grouping_id >> array[n-1]) & 1) << 0). Returns grouping_id if indices is not provided.", + syntax_example = "grouping(grouping_id[, indices])", + sql_example = r#"```sql +> SELECT grouping(__grouping_id, make_array(0)) FROM table GROUP BY GROUPING SETS ((a), (b)); ++----------------+ +| grouping | ++----------------+ +| 1 | +| 0 | ++----------------+ +```"#, + argument( + name = "grouping_id", + description = "The internal grouping ID column (UInt8/16/32/64)" + ), + argument( + name = "indices", + description = "The indices of the column in the grouping set (Int32)" + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct GroupingFunc { + signature: Signature, +} + +impl Default for GroupingFunc { + fn default() -> Self { + GroupingFunc::new() + } +} + +impl GroupingFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for GroupingFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "grouping" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(grouping_inner, vec![])(&args.args) + } + + fn short_circuits(&self) -> bool { + false + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 && arg_types.len() != 1 { + return exec_err!( + "grouping function requires 1 or 2 arguments, got {}", + arg_types.len() + ); + } + + if !arg_types[0].is_unsigned_integer() { + return exec_err!( + "grouping function requires unsigned integer for first argument, got {}", + arg_types[0] + ); + } + + if arg_types.len() == 1 { + return Ok(vec![arg_types[0].clone()]); + } + + let DataType::List(field) = &arg_types[1] else { + return exec_err!( + "grouping function requires list for second argument, got {}", + arg_types[1] + ); + }; + + if !field.data_type().is_integer() { + return exec_err!( + "grouping function requires list of integers for second argument, got {}", + arg_types[1] + ); + } + + Ok(vec![ + arg_types[0].clone(), + DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false))), + ]) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn grouping_inner(args: &[ArrayRef]) -> Result { + if args.len() != 2 && args.len() != 1 { + return exec_err!( + "grouping function requires 1 or 2 arguments, got {}", + args.len() + ); + } + + if args.len() == 1 { + return cast(&args[0], &DataType::Int32).map_err(|e| e.into()); + } + + let grouping_id = &args[0]; + let indices = &args[1]; + let indices = indices.as_list::(); + + let result: Int32Array = match grouping_id.data_type() { + DataType::UInt8 => grouping_id!(grouping_id, indices, u8, UInt8Array), + DataType::UInt16 => grouping_id!(grouping_id, indices, u16, UInt16Array), + DataType::UInt32 => grouping_id!(grouping_id, indices, u32, UInt32Array), + DataType::UInt64 => grouping_id!(grouping_id, indices, u64, UInt64Array), + _ => { + return exec_err!( + "grouping function requires UInt8/16/32/64 for grouping_id, got {}", + grouping_id.data_type() + ) + } + }; + + Ok(Arc::new(result)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{ + array::{ + Int32Array, ListArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::Int32Type, + }; + use datafusion_common::{config::ConfigOptions, Result, ScalarValue}; + + #[test] + fn test_grouping_uint8() -> Result<()> { + let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices.clone()), + ))), + ]; + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + config_options: Arc::new(ConfigOptions::default()), + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); + Ok(()) + } + + #[test] + fn test_grouping_uint16() -> Result<()> { + let grouping_id = UInt16Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), + ]; + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + config_options: Arc::new(ConfigOptions::default()), + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); + Ok(()) + } + + #[test] + fn test_grouping_uint32() -> Result<()> { + let grouping_id = UInt32Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), + ]; + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + number_rows: 4, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + config_options: Arc::new(ConfigOptions::default()), + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); + Ok(()) + } + + #[test] + fn test_grouping_uint64() -> Result<()> { + let grouping_id = UInt64Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0), Some(1)])]; + + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), + ]; + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + number_rows: 4, + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + config_options: Arc::new(ConfigOptions::default()), + })?; + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array result"), + }; + + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.values().to_vec(), vec![2, 1, 3, 0]); + Ok(()) + } + + #[test] + fn test_grouping_with_invalid_args() -> Result<()> { + let grouping_id = UInt8Array::from(vec![Some(1), Some(2), Some(3), Some(4)]); + let indices = vec![Some(vec![Some(0)])]; + + // Test with too many arguments + let args = vec![ + ColumnarValue::Array(Arc::new(grouping_id)), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices.clone()), + ))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), + ]; + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + number_rows: 4, + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + config_options: Arc::new(ConfigOptions::default()), + }); + assert!(result.is_err()); + + // Test with invalid array type + let args = vec![ + ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1)]))), + ColumnarValue::Scalar(ScalarValue::List(Arc::new( + ListArray::from_iter_primitive::(indices), + ))), + ]; + + let func = GroupingFunc::new(); + let result = func.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields: arg_fields_owned + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + number_rows: 1, + return_field: Arc::new(Field::new("f", DataType::Int32, true)), + config_options: Arc::new(ConfigOptions::default()), + }); + assert!(result.is_err()); + Ok(()) + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index db080cd628478..0a1e0b6b07176 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -27,6 +27,7 @@ pub mod expr_ext; pub mod getfield; pub mod greatest; mod greatest_least_utils; +pub mod grouping; pub mod least; pub mod named_struct; pub mod nullif; @@ -55,6 +56,7 @@ make_udf_function!(least::LeastFunc, least); make_udf_function!(union_extract::UnionExtractFun, union_extract); make_udf_function!(union_tag::UnionTagFunc, union_tag); make_udf_function!(version::VersionFunc, version); +make_udf_function!(grouping::GroupingFunc, grouping); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index f10510e0973c3..c824cd20f8987 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -46,6 +46,7 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } +datafusion-functions = { workspace = true } datafusion-physical-expr = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index fa7ff1b8b19d6..46177884a26d1 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -18,25 +18,21 @@ //! Analyzed rule to replace TableScan references //! such as DataFrames and Views and inlines the LogicalPlan. -use std::cmp::Ordering; use std::collections::HashMap; use std::sync::Arc; use crate::analyzer::AnalyzerRule; -use arrow::datatypes::DataType; +use arrow::array::ListArray; +use arrow::datatypes::Int32Type; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{ - internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, -}; +use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{AggregateFunction, Alias}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::utils::grouping_set_to_exprlist; -use datafusion_expr::{ - bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, - Expr, Projection, -}; +use datafusion_expr::{Aggregate, Expr, Projection}; +use datafusion_functions::core::grouping; use itertools::Itertools; /// Replaces grouping aggregation function with value derived from internal grouping id @@ -193,18 +189,6 @@ fn grouping_function_on_id( } let group_by_expr_count = group_by_expr.len(); - let literal = |value: usize| { - if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8), None) - } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16), None) - } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32), None) - } else { - Expr::Literal(ScalarValue::from(value as u64), None) - } - }; - let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); // The grouping call is exactly our internal grouping id if args.len() == group_by_expr_count @@ -214,35 +198,23 @@ fn grouping_function_on_id( .enumerate() .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) { - return Ok(cast(grouping_id_column, DataType::Int32)); + return Ok(grouping().call(vec![grouping_id_column])); } - args.iter() - .rev() - .enumerate() - .map(|(arg_idx, expr)| { - group_by_expr.get(expr).map(|group_by_idx| { - let group_by_bit = - bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx)); - match group_by_idx.cmp(&arg_idx) { - Ordering::Less => { - bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx)) - } - Ordering::Greater => { - bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx)) - } - Ordering::Equal => group_by_bit, - } - }) - }) - .collect::>>() - .and_then(|bit_exprs| { - bit_exprs - .into_iter() - .reduce(bitwise_or) - .map(|expr| cast(expr, DataType::Int32)) - }) - .ok_or_else(|| { - internal_datafusion_err!("Grouping sets should contains at least one element") + let args = args + .iter() + .flat_map(|expr| { + group_by_expr + .get(expr) + .map(|group_by_idx| Some(*group_by_idx as i32)) }) + .collect::>(); + + let indices = Expr::Literal( + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(args)], + ))), + None, + ); + Ok(grouping().call(vec![grouping_id_column, indices])) } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 8b3791017a8af..8e8a104f53ade 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -27,8 +27,11 @@ use datafusion_common::{ Column, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, - LogicalPlanBuilder, Projection, SortExpr, Unnest, Window, + expr::{self}, + grouping::DummyGroupingUDAF, + utils::grouping_set_to_exprlist, + Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Projection, SortExpr, Unnest, + Window, }; use indexmap::IndexSet; @@ -195,6 +198,17 @@ pub(crate) fn unproject_agg_exprs( agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { + // replace grouping function + let expr = expr + .transform(|sub_expr| match sub_expr { + Expr::ScalarFunction(grouping) if grouping.name() == "grouping" => { + Ok(Transformed::yes(Expr::AggregateFunction( + DummyGroupingUDAF::from_scalar_function(&grouping, agg)?, + ))) + } + _ => Ok(Transformed::no(sub_expr)), + }) + .map(|e| e.data)?; expr.transform(|sub_expr| { if let Expr::Column(c) = sub_expr { if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 7aa982dcf3dd9..248d00ed7b3eb 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -15,19 +15,22 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::array::ListArray; +use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use datafusion_common::{ - assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, + assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::test::function_stub::{ count_udaf, max_udaf, min_udaf, sum, sum_udaf, }; use datafusion_expr::{ - cast, col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan, - LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + cast, col, lit, table_scan, wildcard, Aggregate, EmptyRelation, Expr, Extension, + LogicalPlan, LogicalPlanBuilder, Union, UserDefinedLogicalNode, + UserDefinedLogicalNodeCore, }; +use datafusion_functions::core::grouping; use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; use datafusion_functions_nested::make_array::make_array_udf; @@ -2605,6 +2608,62 @@ fn test_not_ilike_filter_with_escape() { ); } +#[test] +fn test_grouping() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + Field::new("c3", DataType::Int32, false), + ]); + let table_scan = table_scan(Some("test"), &schema, Some(vec![0, 1, 2]))?.build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![Expr::GroupingSet( + datafusion_expr::GroupingSet::GroupingSets(vec![ + vec![col("c1"), col("c2")], + vec![col("c1")], + vec![col("c2")], + vec![], + ]), + )], + vec![sum(col("c3"))], + )? + .build()?; + + let group1 = + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(vec![Some(1)])], + ))); + let group2 = + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(vec![Some(0)])], + ))); + let group3 = + ScalarValue::List(Arc::new(ListArray::from_iter_primitive::( + vec![Some(vec![Some(1), Some(0)])], + ))); + let project = LogicalPlanBuilder::from(plan) + .project(vec![ + grouping() + .call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group1)]) + .alias("grouping(test.c1)"), + grouping() + .call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group2)]) + .alias("grouping(test.c2)"), + grouping() + .call(vec![col(Aggregate::INTERNAL_GROUPING_ID), lit(group3)]) + .alias("grouping(test.c1,test.c2)"), + ])? + .build()?; + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&project)?; + assert_snapshot!( + sql, + @r#"SELECT grouping("test"."c1") AS "grouping(test.c1)", grouping("test"."c2") AS "grouping(test.c2)", grouping("test"."c1", "test"."c2") AS "grouping(test.c1,test.c2)" FROM "test" GROUP BY GROUPING SETS (("test"."c1", "test"."c2"), ("test"."c1"), ("test"."c2"), ())"# + ); + Ok(()) +} + #[test] fn test_struct_expr() { let statement = generate_round_trip_statement( diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs index c3599a2635ffa..d3950240d6e26 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -45,7 +45,14 @@ pub fn to_substrait_rel( plan: &LogicalPlan, ) -> datafusion::common::Result> { match plan { - LogicalPlan::Projection(plan) => producer.handle_projection(plan), + LogicalPlan::Projection(plan) => { + if producer.has_grouping_set(plan) { + let plan = producer.unproject_grouping_set(plan)?; + producer.handle_projection(&plan) + } else { + producer.handle_projection(plan) + } + } LogicalPlan::Filter(plan) => producer.handle_filter(plan), LogicalPlan::Window(plan) => producer.handle_window(plan), LogicalPlan::Aggregate(plan) => producer.handle_aggregate(plan), diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index 56edfac5769cf..a66bf617ba715 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::extensions::Extensions; use crate::logical_plan::producer::{ from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, @@ -24,10 +26,11 @@ use crate::logical_plan::producer::{ from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, }; -use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue}; +use datafusion::common::{internal_err, substrait_err, Column, DFSchemaRef, ScalarValue}; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::SessionState; use datafusion::logical_expr::expr::{Alias, InList, InSubquery, WindowFunction}; +use datafusion::logical_expr::grouping::DummyGroupingUDAF; use datafusion::logical_expr::{ expr, Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, @@ -346,6 +349,60 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> datafusion::common::Result { from_in_subquery(self, in_subquery, schema) } + + fn has_grouping_set(&self, plan: &Projection) -> bool { + for expr in plan.expr.iter() { + let Expr::Alias(Alias { expr, .. }) = expr else { + continue; + }; + let Expr::ScalarFunction(expr::ScalarFunction { func, .. }) = expr.as_ref() + else { + continue; + }; + if func.name() == "grouping" { + return true; + } + } + false + } + + fn unproject_grouping_set( + &self, + plan: &Projection, + ) -> datafusion::common::Result { + let input = plan.input.as_ref(); + let LogicalPlan::Aggregate(agg) = input else { + return internal_err!( + "Projecting grouping set is not supported for non-aggregate input" + ); + }; + + let mut exprs = vec![]; + let mut agg_expr = agg.aggr_expr.clone(); + + for expr in plan.expr.iter() { + if let Expr::Alias(Alias { expr, name, .. }) = expr { + if let Expr::ScalarFunction(f @ expr::ScalarFunction { func, .. }) = + expr.as_ref() + { + if func.name() == "grouping" { + exprs.push(Expr::Column(Column::from_name(name))); + agg_expr.push( + Expr::AggregateFunction( + DummyGroupingUDAF::from_scalar_function(f, agg)?, + ) + .alias(name), + ); + continue; + } + }; + }; + exprs.push(expr.clone()); + } + let agg = + Aggregate::try_new(Arc::clone(&agg.input), agg.group_expr.clone(), agg_expr)?; + Projection::try_new(exprs, Arc::new(LogicalPlan::Aggregate(agg))) + } } pub struct DefaultSubstraitProducer<'a> { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 39e4984ab9f79..6e5922bf7a194 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -299,6 +299,14 @@ async fn aggregate_grouping_sets() -> Result<()> { .await } +#[tokio::test] +async fn aggregate_grouping_sets_with_grouping() -> Result<()> { + roundtrip( + "SELECT a, c, grouping(a) as g1, grouping(c) as g2, grouping(a, c) as g3, avg(b) FROM data GROUP BY GROUPING SETS ((a, c), (a), ())", + ) + .await +} + #[tokio::test] async fn aggregate_grouping_rollup() -> Result<()> { let plan = generate_plan_from_sql( diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 205962031b1d0..57f441d8cec45 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -267,7 +267,7 @@ first_value(expression [ORDER BY expression]) ### `grouping` -Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set. +Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn). ```sql grouping(expression) @@ -283,13 +283,13 @@ grouping(expression) > SELECT column_name, GROUPING(column_name) AS group_column FROM table_name GROUP BY GROUPING SETS ((column_name), ()); -+-------------+-------------+ ++-------------+--------------+ | column_name | group_column | -+-------------+-------------+ -| value1 | 0 | -| value2 | 0 | -| NULL | 1 | -+-------------+-------------+ ++-------------+--------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+--------------+ ``` ### `last_value`