diff --git a/datafusion/core/src/logical_plan/expr.rs b/datafusion/core/src/logical_plan/expr.rs index 3ffc1894e5549..6d90c78f171af 100644 --- a/datafusion/core/src/logical_plan/expr.rs +++ b/datafusion/core/src/logical_plan/expr.rs @@ -26,6 +26,7 @@ use crate::sql::utils::find_columns_referenced_by_expr; use arrow::datatypes::DataType; pub use datafusion_common::{Column, ExprSchema}; pub use datafusion_expr::expr_fn::*; +use datafusion_expr::logical_plan::Aggregate; use datafusion_expr::BuiltinScalarFunction; pub use datafusion_expr::Expr; use datafusion_expr::StateTypeFunction; @@ -136,35 +137,63 @@ pub fn create_udaf( ) } +/// Find all columns referenced from an aggregate query +fn agg_cols(agg: &Aggregate) -> Result> { + Ok(agg + .aggr_expr + .iter() + .chain(&agg.group_expr) + .flat_map(find_columns_referenced_by_expr) + .collect()) +} + +fn exprlist_to_fields_aggregate( + exprs: &[Expr], + plan: &LogicalPlan, + agg: &Aggregate, +) -> Result> { + let agg_cols = agg_cols(agg)?; + let mut fields = vec![]; + for expr in exprs { + match expr { + Expr::Column(c) if agg_cols.iter().any(|x| x == c) => { + // resolve against schema of input to aggregate + fields.push(expr.to_field(agg.input.schema())?); + } + _ => fields.push(expr.to_field(plan.schema())?), + } + } + Ok(fields) +} + /// Create field meta-data from an expression, for use in a result set schema pub fn exprlist_to_fields<'a>( expr: impl IntoIterator, plan: &LogicalPlan, ) -> Result> { - match plan { + let exprs: Vec = expr.into_iter().cloned().collect(); + // when dealing with aggregate plans we cannot simply look in the aggregate output schema + // because it will contain columns representing complex expressions (such a column named + // `#GROUPING(person.state)` so in order to resolve `person.state` in this case we need to + // look at the input to the aggregate instead. + let fields = match plan { LogicalPlan::Aggregate(agg) => { - let group_expr: Vec = agg - .group_expr - .iter() - .flat_map(find_columns_referenced_by_expr) - .collect(); - let exprs: Vec = expr.into_iter().cloned().collect(); - let mut fields = vec![]; - for expr in &exprs { - match expr { - Expr::Column(c) if group_expr.iter().any(|x| x == c) => { - // resolve against schema of input to aggregate - fields.push(expr.to_field(agg.input.schema())?); - } - _ => fields.push(expr.to_field(plan.schema())?), - } - } - Ok(fields) - } - _ => { - let input_schema = &plan.schema(); - expr.into_iter().map(|e| e.to_field(input_schema)).collect() + Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) } + LogicalPlan::Window(window) => match window.input.as_ref() { + LogicalPlan::Aggregate(agg) => { + Some(exprlist_to_fields_aggregate(&exprs, plan, agg)) + } + _ => None, + }, + _ => None, + }; + if let Some(fields) = fields { + fields + } else { + // look for exact match in plan's output schema + let input_schema = &plan.schema(); + exprs.iter().map(|e| e.to_field(input_schema)).collect() } } diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index af8329018f672..fa15e83794a77 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -4685,6 +4685,38 @@ mod tests { quick_test(sql, expected); } + #[tokio::test] + async fn aggregate_with_rollup_with_grouping() { + let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), COUNT(*) \ + FROM person GROUP BY id, ROLLUP (state, age)"; + let expected = "Projection: #person.id, #person.state, #person.age, #GROUPING(person.state), #GROUPING(person.age), #GROUPING(person.state) + #GROUPING(person.age), #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#person.id, ROLLUP (#person.state, #person.age)]], aggr=[[GROUPING(#person.state), GROUPING(#person.age), COUNT(UInt8(1))]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[tokio::test] + async fn rank_partition_grouping() { + let sql = "select + sum(age) as total_sum, + state, + last_name, + grouping(state) + grouping(last_name) as x, + rank() over ( + partition by grouping(state) + grouping(last_name), + case when grouping(last_name) = 0 then state end + order by sum(age) desc + ) as the_rank + from + person + group by rollup(state, last_name)"; + let expected = "Projection: #SUM(person.age) AS total_sum, #person.state, #person.last_name, #GROUPING(person.state) + #GROUPING(person.last_name) AS x, #RANK() PARTITION BY [#GROUPING(person.state) + #GROUPING(person.last_name), CASE WHEN #GROUPING(person.last_name) = Int64(0) THEN #person.state END] ORDER BY [#SUM(person.age) DESC NULLS FIRST] AS the_rank\ + \n WindowAggr: windowExpr=[[RANK() PARTITION BY [#GROUPING(person.state) + #GROUPING(person.last_name), CASE WHEN #GROUPING(person.last_name) = Int64(0) THEN #person.state END] ORDER BY [#SUM(person.age) DESC NULLS FIRST]]]\ + \n Aggregate: groupBy=[[ROLLUP (#person.state, #person.last_name)]], aggr=[[SUM(#person.age), GROUPING(#person.state), GROUPING(#person.last_name)]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + #[tokio::test] async fn aggregate_with_cube() { let sql = diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 14cd4661580af..eacb3f74a8644 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -86,6 +86,8 @@ pub enum AggregateFunction { ApproxPercentileContWithWeight, /// ApproxMedian ApproxMedian, + /// Grouping + Grouping, } impl fmt::Display for AggregateFunction { @@ -121,6 +123,7 @@ impl FromStr for AggregateFunction { AggregateFunction::ApproxPercentileContWithWeight } "approx_median" => AggregateFunction::ApproxMedian, + "grouping" => AggregateFunction::Grouping, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", @@ -173,6 +176,7 @@ pub fn return_type( Ok(coerced_data_types[0].clone()) } AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()), + AggregateFunction::Grouping => Ok(DataType::Int32), } } @@ -326,6 +330,7 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), } } @@ -335,6 +340,7 @@ pub fn signature(fun: &AggregateFunction) -> Signature { match fun { AggregateFunction::Count | AggregateFunction::ApproxDistinct + | AggregateFunction::Grouping | AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), AggregateFunction::Min | AggregateFunction::Max => { let valid = STRINGS diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 784cac81b1fbd..6d5dfc75633cd 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -82,6 +82,11 @@ pub fn create_aggregate_expr( name, return_type, )), + (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )), (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new( coerced_phy_exprs[0].clone(), name, diff --git a/datafusion/physical-expr/src/aggregate/grouping.rs b/datafusion/physical-expr/src/aggregate/grouping.rs new file mode 100644 index 0000000000000..4c704b2138f45 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/grouping.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::Accumulator; + +use crate::expressions::format_state_name; + +/// GROUPING aggregate expression +/// Returns the amount of non-null values of the given expression. +#[derive(Debug)] +pub struct Grouping { + name: String, + data_type: DataType, + nullable: bool, + expr: Arc, +} + +impl Grouping { + /// Create a new GROUPING aggregate function. + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + expr, + data_type, + nullable: true, + } + } +} + +impl AggregateExpr for Grouping { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + self.data_type.clone(), + self.nullable, + )) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + &format_state_name(&self.name, "grouping"), + self.data_type.clone(), + true, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn create_accumulator(&self) -> Result> { + Err(DataFusionError::NotImplemented( + "physical plan is not yet implemented for GROUPING aggregate function" + .to_owned(), + )) + } + + fn name(&self) -> &str { + &self.name + } +} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 0db35d109c2dd..1cbd4aeea008f 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -36,6 +36,7 @@ pub(crate) mod correlation; pub(crate) mod count; pub(crate) mod count_distinct; pub(crate) mod covariance; +pub(crate) mod grouping; #[macro_use] pub(crate) mod min_max; pub mod build_in; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index d081720b856d4..22e80e20ec9fc 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -50,6 +50,7 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; +pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; pub use crate::aggregate::stats::StatsType; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index df0d51e36cc9a..a4b1863615ca0 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -211,6 +211,7 @@ enum AggregateFunction { APPROX_PERCENTILE_CONT = 14; APPROX_MEDIAN=15; APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; + GROUPING = 17; } message AggregateExprNode { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index e2f2d62166268..1ecb04bda9bed 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -496,6 +496,7 @@ impl From for AggregateFunction { Self::ApproxPercentileContWithWeight } protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, + protobuf::AggregateFunction::Grouping => Self::Grouping, } } } diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index fa2251cd0f03d..5970a3c30a5c2 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -356,6 +356,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { Self::ApproxPercentileContWithWeight } AggregateFunction::ApproxMedian => Self::ApproxMedian, + AggregateFunction::Grouping => Self::Grouping, } } } @@ -541,6 +542,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { AggregateFunction::ApproxMedian => { protobuf::AggregateFunction::ApproxMedian } + AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, }; let aggregate_expr = protobuf::AggregateExprNode {