diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 184b9cd1a31c..a1b10ff5b828 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -69,6 +69,8 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { Expr::Literal(_) | Expr::Alias(_, _) | Expr::OuterReferenceColumn(_, _) + | Expr::HiddenColumn(_, _) + | Expr::HiddenExpr(_, _) | Expr::ScalarVariable(_, _) | Expr::Not(_) | Expr::IsNotNull(_) diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index c41cc438c898..6b132dc0d79c 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -86,7 +86,11 @@ pub enum AggregateMode { #[derive(Clone, Debug, Default)] pub struct PhysicalGroupBy { /// Distinct (Physical Expr, Alias) in the grouping set - expr: Vec<(Arc, String)>, + grouping_set_expr: Vec<(Arc, String)>, + /// Hidden grouping set expr in the grouping set + hidden_grouping_set_expr: Vec<(Arc, String)>, + /// Distinct result expr for the grouping set, used to generate output schema + result_expr: Vec<(Arc, String)>, /// Corresponding NULL expressions for expr null_expr: Vec<(Arc, String)>, /// Null mask for each group in this grouping set. Each group is @@ -99,12 +103,16 @@ pub struct PhysicalGroupBy { impl PhysicalGroupBy { /// Create a new `PhysicalGroupBy` pub fn new( - expr: Vec<(Arc, String)>, + grouping_set_expr: Vec<(Arc, String)>, + hidden_grouping_set_expr: Vec<(Arc, String)>, + result_expr: Vec<(Arc, String)>, null_expr: Vec<(Arc, String)>, groups: Vec>, ) -> Self { Self { - expr, + grouping_set_expr, + hidden_grouping_set_expr, + result_expr, null_expr, groups, } @@ -115,7 +123,9 @@ impl PhysicalGroupBy { pub fn new_single(expr: Vec<(Arc, String)>) -> Self { let num_exprs = expr.len(); Self { - expr, + grouping_set_expr: expr.clone(), + hidden_grouping_set_expr: vec![], + result_expr: expr, null_expr: vec![], groups: vec![vec![false; num_exprs]], } @@ -128,7 +138,12 @@ impl PhysicalGroupBy { /// Returns the group expressions pub fn expr(&self) -> &[(Arc, String)] { - &self.expr + &self.grouping_set_expr + } + + /// Returns the group result expressions + pub fn result_expr(&self) -> &[(Arc, String)] { + &self.result_expr } /// Returns the null expressions @@ -136,6 +151,11 @@ impl PhysicalGroupBy { &self.null_expr } + /// Returns the hidden grouping set expressions + pub fn hidden_grouping_set_expr(&self) -> &[(Arc, String)] { + &self.hidden_grouping_set_expr + } + /// Returns the group null masks pub fn groups(&self) -> &[Vec] { &self.groups @@ -143,7 +163,7 @@ impl PhysicalGroupBy { /// Returns true if this `PhysicalGroupBy` has no group expressions pub fn is_empty(&self) -> bool { - self.expr.is_empty() + self.grouping_set_expr.is_empty() } } @@ -196,7 +216,7 @@ impl AggregateExec { ) -> Result { let schema = create_schema( &input.schema(), - &group_by.expr, + group_by.result_expr(), &aggr_expr, group_by.contains_null(), mode, @@ -205,7 +225,7 @@ impl AggregateExec { let schema = Arc::new(schema); let mut alias_map: HashMap> = HashMap::new(); - for (expression, name) in group_by.expr.iter() { + for (expression, name) in group_by.result_expr().iter() { if let Some(column) = expression.as_any().downcast_ref::() { let new_col_idx = schema.index_of(name)?; // When the column name is the same, but index does not equal, treat it as Alias @@ -243,7 +263,7 @@ impl AggregateExec { // Update column indices. Since the group by columns come first in the output schema, their // indices are simply 0..self.group_expr(len). self.group_by - .expr() + .result_expr() .iter() .enumerate() .map(|(index, (_col, name))| { @@ -275,7 +295,7 @@ impl AggregateExec { let batch_size = context.session_config().batch_size(); let input = self.input.execute(partition, Arc::clone(&context))?; let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - if self.group_by.expr.is_empty() { + if self.group_by.result_expr().is_empty() { Ok(StreamType::AggregateStream(AggregateStream::new( self.mode, self.schema.clone(), @@ -418,7 +438,7 @@ impl ExecutionPlan for AggregateExec { write!(f, "AggregateExec: mode={:?}", self.mode)?; let g: Vec = if self.group_by.groups.len() == 1 { self.group_by - .expr + .grouping_set_expr .iter() .map(|(e, alias)| { let e = e.to_string(); @@ -447,7 +467,8 @@ impl ExecutionPlan for AggregateExec { e } } else { - let (e, alias) = &self.group_by.expr[idx]; + let (e, alias) = + &self.group_by.grouping_set_expr[idx]; let e = e.to_string(); if &e != alias { format!("{e} as {alias}") @@ -484,7 +505,7 @@ impl ExecutionPlan for AggregateExec { // - aggregations somtimes also preserve invariants such as min, max... match self.mode { AggregateMode::Final | AggregateMode::FinalPartitioned - if self.group_by.expr.is_empty() => + if self.group_by.result_expr().is_empty() => { Statistics { num_rows: Some(1), @@ -671,8 +692,8 @@ fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, ) -> Result>> { - let exprs: Vec = group_by - .expr + let exprs_value: Vec = group_by + .grouping_set_expr .iter() .map(|(expr, _)| { let value = expr.evaluate(batch)?; @@ -680,7 +701,7 @@ fn evaluate_group_by( }) .collect::>>()?; - let null_exprs: Vec = group_by + let null_exprs_value: Vec = group_by .null_expr .iter() .map(|(expr, _)| { @@ -689,23 +710,61 @@ fn evaluate_group_by( }) .collect::>>()?; - Ok(group_by - .groups - .iter() - .map(|group| { - group - .iter() - .enumerate() - .map(|(idx, is_null)| { - if *is_null { - null_exprs[idx].clone() - } else { - exprs[idx].clone() - } - }) - .collect() - }) - .collect()) + if !group_by.hidden_grouping_set_expr().is_empty() { + let hidden_exprs_value: Vec = group_by + .hidden_grouping_set_expr + .iter() + .map(|(expr, _)| { + let value = expr.evaluate(batch)?; + Ok(value.into_array(batch.num_rows())) + }) + .collect::>>()?; + + let chunk_size = hidden_exprs_value.len() / group_by.groups.len(); + let hidden_expr_value_chunks = + hidden_exprs_value.chunks(chunk_size).collect::>(); + + Ok(group_by + .groups + .iter() + .enumerate() + .map(|(groud_id, group)| { + let mut group_data = group + .iter() + .enumerate() + .map(|(idx, is_null)| { + if *is_null { + null_exprs_value[idx].clone() + } else { + exprs_value[idx].clone() + } + }) + .collect::>(); + for data in hidden_expr_value_chunks[groud_id] { + group_data.push(data.clone()); + } + group_data + }) + .collect()) + } else { + Ok(group_by + .groups + .iter() + .map(|group| { + group + .iter() + .enumerate() + .map(|(idx, is_null)| { + if *is_null { + null_exprs_value[idx].clone() + } else { + exprs_value[idx].clone() + } + }) + .collect::>() + }) + .collect()) + } } #[cfg(test)] @@ -775,7 +834,12 @@ mod tests { let input_schema = input.schema(); let grouping_set = PhysicalGroupBy { - expr: vec![ + grouping_set_expr: vec![ + (col("a", &input_schema)?, "a".to_string()), + (col("b", &input_schema)?, "b".to_string()), + ], + hidden_grouping_set_expr: vec![], + result_expr: vec![ (col("a", &input_schema)?, "a".to_string()), (col("b", &input_schema)?, "b".to_string()), ], @@ -890,9 +954,11 @@ mod tests { let input_schema = input.schema(); let grouping_set = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], + grouping_set_expr: vec![(col("a", &input_schema)?, "a".to_string())], + hidden_grouping_set_expr: vec![], null_expr: vec![], groups: vec![vec![false]], + result_expr: vec![(col("a", &input_schema)?, "a".to_string())], }; let aggregates: Vec> = vec![Arc::new(Avg::new( @@ -929,7 +995,7 @@ mod tests { let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); let final_group: Vec<(Arc, String)> = grouping_set - .expr + .result_expr() .iter() .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) .collect::>()?; @@ -1119,9 +1185,11 @@ mod tests { let groups_none = PhysicalGroupBy::default(); let groups_some = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], + grouping_set_expr: vec![(col("a", &input_schema)?, "a".to_string())], + hidden_grouping_set_expr: vec![], null_expr: vec![], groups: vec![vec![false]], + result_expr: vec![(col("a", &input_schema)?, "a".to_string())], }; // something that allocates within the aggregator diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index 612b707cc19e..ea3c97481d8b 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -134,7 +134,7 @@ impl GroupedHashAggregateStream { ) -> Result { let timer = baseline_metrics.elapsed_compute().timer(); - let mut start_idx = group_by.expr.len(); + let mut start_idx = group_by.result_expr().len(); let mut row_aggr_expr = vec![]; let mut row_agg_indices = vec![]; let mut row_aggregate_expressions = vec![]; @@ -175,7 +175,8 @@ impl GroupedHashAggregateStream { let row_aggr_schema = aggr_state_schema(&row_aggr_expr)?; - let group_schema = group_schema(&schema, group_by.expr.len()); + let group_schema = group_schema(&schema, group_by.result_expr().len()); + let row_converter = RowConverter::new( group_schema .fields() diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 51653450a699..0bce37947df2 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -68,11 +68,10 @@ use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessar use datafusion_expr::{logical_plan, StringifiedPlan}; use datafusion_expr::{WindowFrame, WindowFrameBound}; use datafusion_optimizer::utils::unalias; -use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::expressions::{HiddenColumn, Literal}; use datafusion_sql::utils::window_expr_common_partition_keys; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; -use itertools::Itertools; use log::{debug, trace}; use std::collections::HashMap; use std::fmt::Write; @@ -334,6 +333,8 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(format!("{expr} SIMILAR TO {pattern}{escape}")) } } + Expr::HiddenColumn(_dt, c) => Ok(format!("#{}", c)), + Expr::HiddenExpr(expr, _) => Ok(create_physical_name(expr, false)?), Expr::Sort { .. } => Err(DataFusionError::Internal( "Create physical name does not support sort expression".to_string(), )), @@ -700,7 +701,7 @@ impl DefaultPhysicalPlanner { final_group .iter() .enumerate() - .map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone())) + .map(|(i, expr)| (expr.clone(), groups.result_expr()[i].1.clone())) .collect() ); @@ -1271,20 +1272,12 @@ impl DefaultPhysicalPlanner { session_state, ) } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr( - exprs, - input_dfschema, - input_schema, - session_state, - ), - Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { - create_rollup_physical_expr( - exprs, - input_dfschema, - input_schema, - session_state, - ) - } + Expr::GroupingSet(GroupingSet::Cube(_exprs)) => Err(DataFusionError::Internal( + "Unsupported logical plan: GroupingSet::Cube should be replaced to GroupingSet::GroupingSets".to_string(), + )), + Expr::GroupingSet(GroupingSet::Rollup(_exprs)) => Err(DataFusionError::Internal( + "Unsupported logical plan: GroupingSet::Rollup should be replaced to GroupingSet::GroupingSets".to_string(), + )), expr => Ok(PhysicalGroupBy::new_single(vec![tuple_err(( self.create_physical_expr( expr, @@ -1333,145 +1326,81 @@ fn merge_grouping_set_physical_expr( session_state: &SessionState, ) -> Result { let num_groups = grouping_sets.len(); - let mut all_exprs: Vec = vec![]; - let mut grouping_set_expr: Vec<(Arc, String)> = vec![]; - let mut null_exprs: Vec<(Arc, String)> = vec![]; + let mut all_normal_exprs: Vec = vec![]; + let mut all_hidden_result_exprs: Vec = vec![]; + + let mut grouping_set_phy_expr: Vec<(Arc, String)> = vec![]; + let mut hidden_grouping_set_phy_expr: Vec<(Arc, String)> = vec![]; + let mut null_phy_exprs: Vec<(Arc, String)> = vec![]; + + let mut hidden_grouping_set_result_phy_expr: Vec<(Arc, String)> = + vec![]; + let mut grouping_set_result_phy_expr: Vec<(Arc, String)> = vec![]; for expr in grouping_sets.iter().flatten() { - if !all_exprs.contains(expr) { - all_exprs.push(expr.clone()); + if let Expr::HiddenExpr(first, second) = expr { + if let Expr::HiddenColumn(dt, _) = second.as_ref() { + hidden_grouping_set_phy_expr.push(get_physical_expr_pair( + first, + input_dfschema, + input_schema, + session_state, + )?); + + if !all_hidden_result_exprs.contains(second) { + all_hidden_result_exprs.push(*second.clone()); + let hidden_column_name = second.display_name()?; + // The second element in the hidden expr should be converted to a physic HiddenColumn + hidden_grouping_set_result_phy_expr.push(( + Arc::new(HiddenColumn::new(&hidden_column_name, dt)), + hidden_column_name, + )); + } + } else { + return Err(DataFusionError::Internal( + "The second part of the Expr::HiddenExpr should be a Expr::HiddenColumn" + .to_string(), + )); + } + } else if !all_normal_exprs.contains(expr) { + all_normal_exprs.push(expr.clone()); - grouping_set_expr.push(get_physical_expr_pair( + let phy_expr = get_physical_expr_pair( expr, input_dfschema, input_schema, session_state, - )?); - - null_exprs.push(get_null_physical_expr_pair( + )?; + grouping_set_phy_expr.push(phy_expr.clone()); + null_phy_exprs.push(get_null_physical_expr_pair( expr, input_dfschema, input_schema, session_state, )?); + grouping_set_result_phy_expr.push(phy_expr); } } + grouping_set_result_phy_expr.append(&mut hidden_grouping_set_result_phy_expr); let mut merged_sets: Vec> = Vec::with_capacity(num_groups); - for expr_group in grouping_sets.iter() { - let group: Vec = all_exprs + let group: Vec = all_normal_exprs .iter() .map(|expr| !expr_group.contains(expr)) .collect(); - merged_sets.push(group) } Ok(PhysicalGroupBy::new( - grouping_set_expr, - null_exprs, + grouping_set_phy_expr, + hidden_grouping_set_phy_expr, + grouping_set_result_phy_expr, + null_phy_exprs, merged_sets, )) } -/// Expand and align a CUBE expression. This is a special case of GROUPING SETS -/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS) -fn create_cube_physical_expr( - exprs: &[Expr], - input_dfschema: &DFSchema, - input_schema: &Schema, - session_state: &SessionState, -) -> Result { - let num_of_exprs = exprs.len(); - let num_groups = num_of_exprs * num_of_exprs; - - let mut null_exprs: Vec<(Arc, String)> = - Vec::with_capacity(num_of_exprs); - let mut all_exprs: Vec<(Arc, String)> = - Vec::with_capacity(num_of_exprs); - - for expr in exprs { - null_exprs.push(get_null_physical_expr_pair( - expr, - input_dfschema, - input_schema, - session_state, - )?); - - all_exprs.push(get_physical_expr_pair( - expr, - input_dfschema, - input_schema, - session_state, - )?) - } - - let mut groups: Vec> = Vec::with_capacity(num_groups); - - groups.push(vec![false; num_of_exprs]); - - for null_count in 1..=num_of_exprs { - for null_idx in (0..num_of_exprs).combinations(null_count) { - let mut next_group: Vec = vec![false; num_of_exprs]; - null_idx.into_iter().for_each(|i| next_group[i] = true); - groups.push(next_group); - } - } - - Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) -} - -/// Expand and align a ROLLUP expression. This is a special case of GROUPING SETS -/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS) -fn create_rollup_physical_expr( - exprs: &[Expr], - input_dfschema: &DFSchema, - input_schema: &Schema, - session_state: &SessionState, -) -> Result { - let num_of_exprs = exprs.len(); - - let mut null_exprs: Vec<(Arc, String)> = - Vec::with_capacity(num_of_exprs); - let mut all_exprs: Vec<(Arc, String)> = - Vec::with_capacity(num_of_exprs); - - let mut groups: Vec> = Vec::with_capacity(num_of_exprs + 1); - - for expr in exprs { - null_exprs.push(get_null_physical_expr_pair( - expr, - input_dfschema, - input_schema, - session_state, - )?); - - all_exprs.push(get_physical_expr_pair( - expr, - input_dfschema, - input_schema, - session_state, - )?) - } - - for total in 0..=num_of_exprs { - let mut group: Vec = Vec::with_capacity(num_of_exprs); - - for index in 0..num_of_exprs { - if index < total { - group.push(false); - } else { - group.push(true); - } - } - - groups.push(group) - } - - Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) -} - /// For a given logical expr, get a properly typed NULL ScalarValue physical expression fn get_null_physical_expr_pair( expr: &Expr, @@ -1929,60 +1858,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_create_cube_expr() -> Result<()> { - let logical_plan = test_csv_scan().await?.build()?; - - let plan = plan(&logical_plan).await?; - - let exprs = vec![col("c1"), col("c2"), col("c3")]; - - let physical_input_schema = plan.schema(); - let physical_input_schema = physical_input_schema.as_ref(); - let logical_input_schema = logical_plan.schema(); - let session_state = make_session_state(); - - let cube = create_cube_physical_expr( - &exprs, - logical_input_schema, - physical_input_schema, - &session_state, - ); - - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; - - assert_eq!(format!("{cube:?}"), expected); - - Ok(()) - } - - #[tokio::test] - async fn test_create_rollup_expr() -> Result<()> { - let logical_plan = test_csv_scan().await?.build()?; - - let plan = plan(&logical_plan).await?; - - let exprs = vec![col("c1"), col("c2"), col("c3")]; - - let physical_input_schema = plan.schema(); - let physical_input_schema = physical_input_schema.as_ref(); - let logical_input_schema = logical_plan.schema(); - let session_state = make_session_state(); - - let rollup = create_rollup_physical_expr( - &exprs, - logical_input_schema, - physical_input_schema, - &session_state, - ); - - let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; - - assert_eq!(format!("{rollup:?}"), expected); - - Ok(()) - } - #[tokio::test] async fn test_create_not() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); diff --git a/datafusion/core/tests/sql/group_by.rs b/datafusion/core/tests/sql/group_by.rs index a92eaf0f4d31..dcb6002cac6d 100644 --- a/datafusion/core/tests/sql/group_by.rs +++ b/datafusion/core/tests/sql/group_by.rs @@ -110,6 +110,28 @@ async fn csv_query_group_by_boolean() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_group_by_boolean2() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = + "SELECT COUNT(*), c3 FROM aggregate_simple GROUP BY c3 ORDER BY COUNT(*) DESC"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+-----------------+-------+", + "| COUNT(UInt8(1)) | c3 |", + "+-----------------+-------+", + "| 9 | true |", + "| 6 | false |", + "+-----------------+-------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_two_columns() -> Result<()> { let ctx = SessionContext::new(); @@ -905,3 +927,541 @@ async fn csv_query_group_by_order_by_avg_group_by_substr() -> Result<()> { assert_batches_sorted_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn group_by_with_dup_group_set() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c1),())"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12) [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1, UInt32(1) AS #grouping_set_id), (aggregate_test_100.c1, UInt32(2) AS #grouping_set_id), (UInt32(3) AS #grouping_set_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, #grouping_set_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c12] [c1:Utf8, c12:Float64]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+", + "| c1 | AVG(aggregate_test_100.c12) |", + "+----+-----------------------------+", + "| | 0.5089725099127211 |", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", + "| a | 0.48754517466109415 |", + "| b | 0.41040709263815384 |", + "| c | 0.6600456536439784 |", + "| d | 0.48855379387549824 |", + "| e | 0.48600669271341534 |", + "+----+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_with_grouping_id_func() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, c2, c3, GROUPING_ID(c1, c2, c3), avg(c12) FROM \ + (select c1, c2, '0' as c3, c12 from aggregate_test_100) + GROUP BY GROUPING SETS((c1, c2), (c1, c3), (c3),())"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: aggregate_test_100.c1, aggregate_test_100.c2, c3, #grouping_id AS GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2,c3), AVG(aggregate_test_100.c12) [c1:Utf8, c2:UInt32, c3:Utf8, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2,c3):UInt32;N, AVG(aggregate_test_100.c12):Float64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1, aggregate_test_100.c2, UInt32(1) AS #grouping_id), (aggregate_test_100.c1, c3, UInt32(2) AS #grouping_id), (c3, UInt32(6) AS #grouping_id), (UInt32(7) AS #grouping_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, c2:UInt32, c3:Utf8, #grouping_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " Projection: aggregate_test_100.c1, aggregate_test_100.c2, Utf8(\"0\") AS c3, aggregate_test_100.c12 [c1:Utf8, c2:UInt32, c3:Utf8, c12:Float64]", + " TableScan: aggregate_test_100 projection=[c1, c2, c12] [c1:Utf8, c2:UInt32, c12:Float64]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+----+-------------------------------------------------------------+-----------------------------+", + "| c1 | c2 | c3 | GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2,c3) | AVG(aggregate_test_100.c12) |", + "+----+----+----+-------------------------------------------------------------+-----------------------------+", + "| | | | 7 | 0.5089725099127211 |", + "| | | 0 | 6 | 0.5089725099127211 |", + "| a | | 0 | 2 | 0.48754517466109415 |", + "| a | 1 | | 1 | 0.4693685626367209 |", + "| a | 2 | | 1 | 0.5945188963859894 |", + "| a | 3 | | 1 | 0.5996111195922015 |", + "| a | 4 | | 1 | 0.3653038379118398 |", + "| a | 5 | | 1 | 0.3497223654469457 |", + "| b | | 0 | 2 | 0.41040709263815384 |", + "| b | 1 | | 1 | 0.16148594845154118 |", + "| b | 2 | | 1 | 0.5857678873564655 |", + "| b | 3 | | 1 | 0.42804338065410286 |", + "| b | 4 | | 1 | 0.33400957036260354 |", + "| b | 5 | | 1 | 0.4888141504446429 |", + "| c | | 0 | 2 | 0.6600456536439784 |", + "| c | 1 | | 1 | 0.6430620563927849 |", + "| c | 2 | | 1 | 0.7736013221256991 |", + "| c | 3 | | 1 | 0.421733279717472 |", + "| c | 4 | | 1 | 0.6827805579021969 |", + "| c | 5 | | 1 | 0.7277229477969185 |", + "| d | | 0 | 2 | 0.48855379387549824 |", + "| d | 1 | | 1 | 0.49931809179640024 |", + "| d | 2 | | 1 | 0.5181987328311988 |", + "| d | 3 | | 1 | 0.586369575965718 |", + "| d | 4 | | 1 | 0.49575895804943215 |", + "| d | 5 | | 1 | 0.2488799233225611 |", + "| e | | 0 | 2 | 0.48600669271341534 |", + "| e | 1 | | 1 | 0.780297346359783 |", + "| e | 2 | | 1 | 0.660795726704708 |", + "| e | 3 | | 1 | 0.5165824734324667 |", + "| e | 4 | | 1 | 0.2720288398836001 |", + "| e | 5 | | 1 | 0.29536905073188496 |", + "+----+----+----+-------------------------------------------------------------+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_with_multi_grouping_funcs() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, c2, GROUPING(C1), GROUPING(C2), GROUPING_ID(c1, c2), avg(c12) FROM aggregate_test_100 GROUP BY CUBE(c1, c2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: aggregate_test_100.c1, aggregate_test_100.c2, CAST((#grouping_id AS #grouping_id AS #grouping_id AS #grouping_id >> UInt32(1)) & UInt32(1) AS UInt8) AS GROUPING(aggregate_test_100.c1), CAST(#grouping_id AS #grouping_id AS #grouping_id AS #grouping_id & UInt32(1) AS UInt8) AS GROUPING(aggregate_test_100.c2), #grouping_id AS #grouping_id AS #grouping_id AS #grouping_id AS GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2), AVG(aggregate_test_100.c12) [c1:Utf8, c2:UInt32, GROUPING(aggregate_test_100.c1):UInt8;N, GROUPING(aggregate_test_100.c2):UInt8;N, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2):UInt32;N, AVG(aggregate_test_100.c12):Float64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((UInt32(3) AS #grouping_id), (aggregate_test_100.c1, UInt32(1) AS #grouping_id), (aggregate_test_100.c2, UInt32(2) AS #grouping_id), (aggregate_test_100.c1, aggregate_test_100.c2, UInt32(0) AS #grouping_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, c2:UInt32, #grouping_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c2, c12] [c1:Utf8, c2:UInt32, c12:Float64]", + ]; + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+---------------------------------+---------------------------------+----------------------------------------------------------+-----------------------------+", + "| c1 | c2 | GROUPING(aggregate_test_100.c1) | GROUPING(aggregate_test_100.c2) | GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) | AVG(aggregate_test_100.c12) |", + "+----+----+---------------------------------+---------------------------------+----------------------------------------------------------+-----------------------------+", + "| | | 1 | 1 | 3 | 0.5089725099127211 |", + "| | 1 | 1 | 0 | 2 | 0.5108939802619781 |", + "| | 2 | 1 | 0 | 2 | 0.6545641966127662 |", + "| | 3 | 1 | 0 | 2 | 0.5245329062820169 |", + "| | 4 | 1 | 0 | 2 | 0.40234192123489837 |", + "| | 5 | 1 | 0 | 2 | 0.4312272637333415 |", + "| a | | 0 | 1 | 1 | 0.48754517466109415 |", + "| a | 1 | 0 | 0 | 0 | 0.4693685626367209 |", + "| a | 2 | 0 | 0 | 0 | 0.5945188963859894 |", + "| a | 3 | 0 | 0 | 0 | 0.5996111195922015 |", + "| a | 4 | 0 | 0 | 0 | 0.3653038379118398 |", + "| a | 5 | 0 | 0 | 0 | 0.3497223654469457 |", + "| b | | 0 | 1 | 1 | 0.41040709263815384 |", + "| b | 1 | 0 | 0 | 0 | 0.16148594845154118 |", + "| b | 2 | 0 | 0 | 0 | 0.5857678873564655 |", + "| b | 3 | 0 | 0 | 0 | 0.42804338065410286 |", + "| b | 4 | 0 | 0 | 0 | 0.33400957036260354 |", + "| b | 5 | 0 | 0 | 0 | 0.4888141504446429 |", + "| c | | 0 | 1 | 1 | 0.6600456536439784 |", + "| c | 1 | 0 | 0 | 0 | 0.6430620563927849 |", + "| c | 2 | 0 | 0 | 0 | 0.7736013221256991 |", + "| c | 3 | 0 | 0 | 0 | 0.421733279717472 |", + "| c | 4 | 0 | 0 | 0 | 0.6827805579021969 |", + "| c | 5 | 0 | 0 | 0 | 0.7277229477969185 |", + "| d | | 0 | 1 | 1 | 0.48855379387549824 |", + "| d | 1 | 0 | 0 | 0 | 0.49931809179640024 |", + "| d | 2 | 0 | 0 | 0 | 0.5181987328311988 |", + "| d | 3 | 0 | 0 | 0 | 0.586369575965718 |", + "| d | 4 | 0 | 0 | 0 | 0.49575895804943215 |", + "| d | 5 | 0 | 0 | 0 | 0.2488799233225611 |", + "| e | | 0 | 1 | 1 | 0.48600669271341534 |", + "| e | 1 | 0 | 0 | 0 | 0.780297346359783 |", + "| e | 2 | 0 | 0 | 0 | 0.660795726704708 |", + "| e | 3 | 0 | 0 | 0 | 0.5165824734324667 |", + "| e | 4 | 0 | 0 | 0 | 0.2720288398836001 |", + "| e | 5 | 0 | 0 | 0 | 0.29536905073188496 |", + "+----+----+---------------------------------+---------------------------------+----------------------------------------------------------+-----------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn group_by_with_dup_group_set_and_grouping_func() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING(C1) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c1),())"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12), CAST(#grouping_id & UInt32(1) AS UInt8) AS GROUPING(aggregate_test_100.c1) [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING(aggregate_test_100.c1):UInt8;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1, UInt32(0) AS #grouping_id, UInt32(1) AS #grouping_set_id), (aggregate_test_100.c1, UInt32(0) AS #grouping_id, UInt32(2) AS #grouping_set_id), (UInt32(1) AS #grouping_id, UInt32(3) AS #grouping_set_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, #grouping_id:UInt32, #grouping_set_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c12] [c1:Utf8, c12:Float64]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+---------------------------------+", + "| c1 | AVG(aggregate_test_100.c12) | GROUPING(aggregate_test_100.c1) |", + "+----+-----------------------------+---------------------------------+", + "| | 0.5089725099127211 | 1 |", + "| a | 0.48754517466109415 | 0 |", + "| a | 0.48754517466109415 | 0 |", + "| b | 0.41040709263815384 | 0 |", + "| b | 0.41040709263815384 | 0 |", + "| c | 0.6600456536439784 | 0 |", + "| c | 0.6600456536439784 | 0 |", + "| d | 0.48855379387549824 | 0 |", + "| d | 0.48855379387549824 | 0 |", + "| e | 0.48600669271341534 | 0 |", + "| e | 0.48600669271341534 | 0 |", + "+----+-----------------------------+---------------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_with_grouping_func_and_having() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING(C1) FROM aggregate_test_100 \ + GROUP BY GROUPING SETS((c1),(c1),()) HAVING GROUPING(C1) = 1 and avg(c12) > 0"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12), CAST(#grouping_id & UInt32(1) AS UInt8) AS GROUPING(aggregate_test_100.c1) [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING(aggregate_test_100.c1):UInt8;N]", + " Filter: (#grouping_id & UInt32(1)) = UInt32(1) AND AVG(aggregate_test_100.c12) > Float64(0) [c1:Utf8, #grouping_id:UInt32, #grouping_set_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1, UInt32(0) AS #grouping_id, UInt32(1) AS #grouping_set_id), (aggregate_test_100.c1, UInt32(0) AS #grouping_id, UInt32(2) AS #grouping_set_id), (UInt32(1) AS #grouping_id, UInt32(3) AS #grouping_set_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, #grouping_id:UInt32, #grouping_set_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c12] [c1:Utf8, c12:Float64]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+---------------------------------+", + "| c1 | AVG(aggregate_test_100.c12) | GROUPING(aggregate_test_100.c1) |", + "+----+-----------------------------+---------------------------------+", + "| | 0.5089725099127211 | 1 |", + "+----+-----------------------------+---------------------------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_with_grouping_func_as_expr() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING(C1) + GROUPING(C2) as grouping_lvl FROM aggregate_test_100 \ + GROUP BY GROUPING SETS((c1),(c1),(c1, c2))\ + ORDER BY CASE WHEN grouping_lvl = 0 THEN 0 ELSE 1 END"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Sort: CASE WHEN grouping_lvl = UInt8(0) THEN Int64(0) ELSE Int64(1) END AS CASE WHEN grouping_lvl = Int64(0) THEN Int64(0) ELSE Int64(1) END ASC NULLS LAST [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, grouping_lvl:UInt8;N]", + " Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12), CAST((#grouping_id >> UInt32(1)) & UInt32(1) AS UInt8) + CAST(#grouping_id & UInt32(1) AS UInt8) AS grouping_lvl [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, grouping_lvl:UInt8;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1, UInt32(1) AS #grouping_id, UInt32(1) AS #grouping_set_id), (aggregate_test_100.c1, UInt32(1) AS #grouping_id, UInt32(2) AS #grouping_set_id), (aggregate_test_100.c1, aggregate_test_100.c2, UInt32(0) AS #grouping_id, UInt32(3) AS #grouping_set_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, c2:UInt32, #grouping_id:UInt32, #grouping_set_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c2, c12] [c1:Utf8, c2:UInt32, c12:Float64]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+--------------+", + "| c1 | AVG(aggregate_test_100.c12) | grouping_lvl |", + "+----+-----------------------------+--------------+", + "| a | 0.3497223654469457 | 0 |", + "| a | 0.3653038379118398 | 0 |", + "| a | 0.4693685626367209 | 0 |", + "| a | 0.48754517466109415 | 1 |", + "| a | 0.48754517466109415 | 1 |", + "| a | 0.5945188963859894 | 0 |", + "| a | 0.5996111195922015 | 0 |", + "| b | 0.16148594845154118 | 0 |", + "| b | 0.33400957036260354 | 0 |", + "| b | 0.41040709263815384 | 1 |", + "| b | 0.41040709263815384 | 1 |", + "| b | 0.42804338065410286 | 0 |", + "| b | 0.4888141504446429 | 0 |", + "| b | 0.5857678873564655 | 0 |", + "| c | 0.421733279717472 | 0 |", + "| c | 0.6430620563927849 | 0 |", + "| c | 0.6600456536439784 | 1 |", + "| c | 0.6600456536439784 | 1 |", + "| c | 0.6827805579021969 | 0 |", + "| c | 0.7277229477969185 | 0 |", + "| c | 0.7736013221256991 | 0 |", + "| d | 0.2488799233225611 | 0 |", + "| d | 0.48855379387549824 | 1 |", + "| d | 0.48855379387549824 | 1 |", + "| d | 0.49575895804943215 | 0 |", + "| d | 0.49931809179640024 | 0 |", + "| d | 0.5181987328311988 | 0 |", + "| d | 0.586369575965718 | 0 |", + "| e | 0.2720288398836001 | 0 |", + "| e | 0.29536905073188496 | 0 |", + "| e | 0.48600669271341534 | 1 |", + "| e | 0.48600669271341534 | 1 |", + "| e | 0.5165824734324667 | 0 |", + "| e | 0.660795726704708 | 0 |", + "| e | 0.780297346359783 | 0 |", + "+----+-----------------------------+--------------+", + ]; + assert_batches_sorted_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_with_grouping_func_and_order_by() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING_ID(C1, c2) FROM aggregate_test_100 \ + GROUP BY CUBE(c1,c2) ORDER BY GROUPING_ID(C1, c2) DESC, avg(c12) ASC"; + + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Sort: GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) DESC NULLS FIRST, AVG(aggregate_test_100.c12) ASC NULLS LAST [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2):UInt32;N]", + " Projection: aggregate_test_100.c1, AVG(aggregate_test_100.c12), #grouping_id AS GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) [c1:Utf8, AVG(aggregate_test_100.c12):Float64;N, GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2):UInt32;N]", + " Aggregate: groupBy=[[GROUPING SETS ((UInt32(3) AS #grouping_id), (aggregate_test_100.c1, UInt32(1) AS #grouping_id), (aggregate_test_100.c2, UInt32(2) AS #grouping_id), (aggregate_test_100.c1, aggregate_test_100.c2, UInt32(0) AS #grouping_id))]], aggr=[[AVG(aggregate_test_100.c12)]] [c1:Utf8, c2:UInt32, #grouping_id:UInt32, AVG(aggregate_test_100.c12):Float64;N]", + " TableScan: aggregate_test_100 projection=[c1, c2, c12] [c1:Utf8, c2:UInt32, c12:Float64]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+-----------------------------+----------------------------------------------------------+", + "| c1 | AVG(aggregate_test_100.c12) | GROUPING_ID(aggregate_test_100.c1,aggregate_test_100.c2) |", + "+----+-----------------------------+----------------------------------------------------------+", + "| | 0.5089725099127211 | 3 |", + "| | 0.40234192123489837 | 2 |", + "| | 0.4312272637333415 | 2 |", + "| | 0.5108939802619781 | 2 |", + "| | 0.5245329062820169 | 2 |", + "| | 0.6545641966127662 | 2 |", + "| b | 0.41040709263815384 | 1 |", + "| e | 0.48600669271341534 | 1 |", + "| a | 0.48754517466109415 | 1 |", + "| d | 0.48855379387549824 | 1 |", + "| c | 0.6600456536439784 | 1 |", + "| b | 0.16148594845154118 | 0 |", + "| d | 0.2488799233225611 | 0 |", + "| e | 0.2720288398836001 | 0 |", + "| e | 0.29536905073188496 | 0 |", + "| b | 0.33400957036260354 | 0 |", + "| a | 0.3497223654469457 | 0 |", + "| a | 0.3653038379118398 | 0 |", + "| c | 0.421733279717472 | 0 |", + "| b | 0.42804338065410286 | 0 |", + "| a | 0.4693685626367209 | 0 |", + "| b | 0.4888141504446429 | 0 |", + "| d | 0.49575895804943215 | 0 |", + "| d | 0.49931809179640024 | 0 |", + "| e | 0.5165824734324667 | 0 |", + "| d | 0.5181987328311988 | 0 |", + "| b | 0.5857678873564655 | 0 |", + "| d | 0.586369575965718 | 0 |", + "| a | 0.5945188963859894 | 0 |", + "| a | 0.5996111195922015 | 0 |", + "| c | 0.6430620563927849 | 0 |", + "| e | 0.660795726704708 | 0 |", + "| c | 0.6827805579021969 | 0 |", + "| c | 0.7277229477969185 | 0 |", + "| c | 0.7736013221256991 | 0 |", + "| e | 0.780297346359783 | 0 |", + "+----+-----------------------------+----------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn group_by_rollup_with_count_wildcard_and_order_by() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, c2, c3, COUNT(*) \ + FROM aggregate_test_100 \ + WHERE c1 IN ('a', 'b', NULL) \ + GROUP BY c1, ROLLUP (c2, c3) \ + ORDER BY c1, c2, c3"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let plan = dataframe.clone().into_optimized_plan()?; + let expected = vec![ + "Sort: aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST [c1:Utf8, c2:UInt32, c3:Int8, COUNT(UInt8(1)):Int64;N]", + " Aggregate: groupBy=[[GROUPING SETS ((aggregate_test_100.c1), (aggregate_test_100.c1, aggregate_test_100.c2), (aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3))]], aggr=[[COUNT(UInt8(1))]] [c1:Utf8, c2:UInt32, c3:Int8, COUNT(UInt8(1)):Int64;N]", + " Filter: aggregate_test_100.c1 = Utf8(NULL) OR aggregate_test_100.c1 = Utf8(\"b\") OR aggregate_test_100.c1 = Utf8(\"a\") [c1:Utf8, c2:UInt32, c3:Int8]", + " TableScan: aggregate_test_100 projection=[c1, c2, c3], partial_filters=[aggregate_test_100.c1 = Utf8(NULL) OR aggregate_test_100.c1 = Utf8(\"b\") OR aggregate_test_100.c1 = Utf8(\"a\")] [c1:Utf8, c2:UInt32, c3:Int8]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+------+-----------------+", + "| c1 | c2 | c3 | COUNT(UInt8(1)) |", + "+----+----+------+-----------------+", + "| a | 1 | -85 | 1 |", + "| a | 1 | -56 | 1 |", + "| a | 1 | -25 | 1 |", + "| a | 1 | -5 | 1 |", + "| a | 1 | 83 | 1 |", + "| a | 1 | | 5 |", + "| a | 2 | -48 | 1 |", + "| a | 2 | -43 | 1 |", + "| a | 2 | 45 | 1 |", + "| a | 2 | | 3 |", + "| a | 3 | -72 | 1 |", + "| a | 3 | -12 | 1 |", + "| a | 3 | 13 | 2 |", + "| a | 3 | 14 | 1 |", + "| a | 3 | 17 | 1 |", + "| a | 3 | | 6 |", + "| a | 4 | -101 | 1 |", + "| a | 4 | -54 | 1 |", + "| a | 4 | -38 | 1 |", + "| a | 4 | 65 | 1 |", + "| a | 4 | | 4 |", + "| a | 5 | -101 | 1 |", + "| a | 5 | -31 | 1 |", + "| a | 5 | 36 | 1 |", + "| a | 5 | | 3 |", + "| a | | | 21 |", + "| b | 1 | 12 | 1 |", + "| b | 1 | 29 | 1 |", + "| b | 1 | 54 | 1 |", + "| b | 1 | | 3 |", + "| b | 2 | -60 | 1 |", + "| b | 2 | 31 | 1 |", + "| b | 2 | 63 | 1 |", + "| b | 2 | 68 | 1 |", + "| b | 2 | | 4 |", + "| b | 3 | -101 | 1 |", + "| b | 3 | 17 | 1 |", + "| b | 3 | | 2 |", + "| b | 4 | -117 | 1 |", + "| b | 4 | -111 | 1 |", + "| b | 4 | -59 | 1 |", + "| b | 4 | 17 | 1 |", + "| b | 4 | 47 | 1 |", + "| b | 4 | | 5 |", + "| b | 5 | -82 | 1 |", + "| b | 5 | -44 | 1 |", + "| b | 5 | -5 | 1 |", + "| b | 5 | 62 | 1 |", + "| b | 5 | 68 | 1 |", + "| b | 5 | | 5 |", + "| b | | | 19 |", + "+----+----+------+-----------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn invalid_grouping_func() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING(c3) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c2),())"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let err = dataframe.into_optimized_plan().err().unwrap(); + assert_eq!( + "Plan(\"Column of GROUPING(aggregate_test_100.c3) can't be found in GROUP BY columns [aggregate_test_100.c1, aggregate_test_100.c2]\")", + &format!("{err:?}") + ); + + Ok(()) +} + +#[tokio::test] +async fn invalid_grouping_id_func() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c1, avg(c12), GROUPING_ID(c1) FROM aggregate_test_100 GROUP BY GROUPING SETS((c1),(c2),())"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let err = dataframe.into_optimized_plan().err().unwrap(); + assert_eq!( + "Plan(\"Columns of GROUPING_ID([aggregate_test_100.c1]) does not match GROUP BY columns [aggregate_test_100.c1, aggregate_test_100.c2]\")", + &format!("{err:?}") + ); + + Ok(()) +} + +#[tokio::test] +async fn invalid_grouping_id_func2() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + + // The column ordering of the GROUPING_ID() matters + let sql = "SELECT c1, avg(c12), GROUPING_ID(c2, c1) FROM aggregate_test_100 GROUP BY CUBE(c1,c2)"; + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(sql).await.expect(&msg); + let err = dataframe.into_optimized_plan().err().unwrap(); + assert_eq!( + "Plan(\"Columns of GROUPING_ID([aggregate_test_100.c2, aggregate_test_100.c1]) does not match GROUP BY columns [aggregate_test_100.c1, aggregate_test_100.c2]\")", + &format!("{err:?}") + ); + + Ok(()) +} diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index b7fb7d47d297..ad5db1c0a14d 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -63,12 +63,19 @@ pub enum AggregateFunction { ApproxMedian, /// Grouping Grouping, + /// GroupingID + GroupingId, } impl fmt::Display for AggregateFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // uppercase of the debug. - write!(f, "{}", format!("{self:?}").to_uppercase()) + match self { + AggregateFunction::GroupingId => { + write!(f, "GROUPING_ID") + } + _ => write!(f, "{}", format!("{self:?}").to_uppercase()), + } } } @@ -101,6 +108,7 @@ impl FromStr for AggregateFunction { } "approx_median" => AggregateFunction::ApproxMedian, "grouping" => AggregateFunction::Grouping, + "grouping_id" => AggregateFunction::GroupingId, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {name}" @@ -157,7 +165,8 @@ pub fn return_type( AggregateFunction::ApproxMedian | AggregateFunction::Median => { Ok(coerced_data_types[0].clone()) } - AggregateFunction::Grouping => Ok(DataType::Int32), + AggregateFunction::Grouping => Ok(DataType::UInt8), + AggregateFunction::GroupingId => Ok(DataType::UInt32), } } @@ -220,5 +229,9 @@ pub fn signature(fun: &AggregateFunction) -> Signature { .collect(), Volatility::Immutable, ), + AggregateFunction::GroupingId => Signature { + type_signature: TypeSignature::Arbitrary, + volatility: Volatility::Immutable, + }, } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 2806683ab87b..4e36b4d67c11 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -21,7 +21,7 @@ use crate::aggregate_function; use crate::built_in_function; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; -use crate::utils::{expr_to_columns, find_out_reference_exprs}; +use crate::utils::{expr_to_columns, find_hidden_columns, find_out_reference_exprs}; use crate::window_frame; use crate::window_function; use crate::AggregateUDF; @@ -223,6 +223,10 @@ pub enum Expr { /// A place holder which hold a reference to a qualified field /// in the outer query, used for correlated sub queries. OuterReferenceColumn(DataType, Column), + /// A hidden column used by the system internally + HiddenColumn(DataType, String), + /// A hidden expr pair used by the system internally, evaluated to a HiddenColumn + HiddenExpr(Box, Box), } /// Binary expression @@ -505,14 +509,18 @@ impl GroupingSet { /// Return all distinct exprs in the grouping set. For `CUBE` and `ROLLUP` this /// is just the underlying list of exprs. For `GROUPING SET` we need to deduplicate /// the exprs in the underlying sets. - pub fn distinct_expr(&self) -> Vec { + pub fn distinct_expr(&self, include_hidden: bool) -> Vec { match self { GroupingSet::Rollup(exprs) => exprs.clone(), GroupingSet::Cube(exprs) => exprs.clone(), GroupingSet::GroupingSets(groups) => { let mut exprs: Vec = vec![]; for exp in groups.iter().flatten() { - if !exprs.contains(exp) { + if let Expr::HiddenExpr(_, second) = exp { + if include_hidden && !exprs.contains(second) { + exprs.push(*second.clone()); + } + } else if !exprs.contains(exp) { exprs.push(exp.clone()); } } @@ -520,6 +528,48 @@ impl GroupingSet { } } } + + pub fn contains_duplicate_grouping(&self) -> bool { + match self { + GroupingSet::Rollup(_) => false, + GroupingSet::Cube(_) => false, + GroupingSet::GroupingSets(groups) => { + let exclude_hidden = groups + .clone() + .into_iter() + .map(|group| { + group + .into_iter() + .filter(|e| !matches!(e, Expr::HiddenExpr(_, _))) + .collect::>() + }) + .collect::>(); + let exclude_hidden_len = exclude_hidden.len(); + let distinct_set = exclude_hidden.into_iter().collect::>(); + exclude_hidden_len != distinct_set.len() + } + } + } + + pub fn contains_hidden_expr(&self) -> bool { + match self { + GroupingSet::Rollup(_) => false, + GroupingSet::Cube(_) => false, + GroupingSet::GroupingSets(groups) => groups + .iter() + .flatten() + .any(|e| matches!(e, Expr::HiddenExpr(_, _))), + } + } + + /// Return the input exprs len in the grouping set + pub fn input_expr_len(&self) -> usize { + match self { + GroupingSet::Rollup(exprs) => exprs.len(), + GroupingSet::Cube(exprs) => exprs.len(), + GroupingSet::GroupingSets(groups) => groups.len(), + } + } } /// Fixed seed for the hashing so that Ords are consistent across runs @@ -600,6 +650,8 @@ impl Expr { Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", Expr::Wildcard => "Wildcard", + Expr::HiddenColumn(..) => "HiddenColumn", + Expr::HiddenExpr(..) => "HiddenExpr", } } @@ -794,6 +846,11 @@ impl Expr { pub fn contains_outer(&self) -> bool { !find_out_reference_exprs(self).is_empty() } + + /// Return true when the expression contains hidden columns. + pub fn contains_hidden_columns(&self) -> bool { + !find_hidden_columns(self).is_empty() + } } impl Not for Expr { @@ -1081,6 +1138,8 @@ impl fmt::Debug for Expr { } }, Expr::Placeholder { id, .. } => write!(f, "{id}"), + Expr::HiddenColumn(_, c) => write!(f, "#{}", c), + Expr::HiddenExpr(first, _) => write!(f, "{}", first), } } } @@ -1364,6 +1423,8 @@ fn create_name(e: &Expr) -> Result { "Create name does not support qualified wildcard".to_string(), )), Expr::Placeholder { id, .. } => Ok((*id).to_string()), + Expr::HiddenColumn(_, c) => Ok(format!("#{}", c)), + Expr::HiddenExpr(first, _) => Ok(format!("#{}", first)), } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index fafda79a6f61..d9b5e25155cc 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -66,6 +66,8 @@ impl ExprSchemable for Expr { Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), + Expr::HiddenColumn(ty, _) => Ok(ty.clone()), + Expr::HiddenExpr(_, second) => second.get_type(schema), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.get_datatype()), Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), @@ -212,6 +214,7 @@ impl ExprSchemable for Expr { | Expr::IsNotUnknown(_) | Expr::Exists { .. } | Expr::Placeholder { .. } => Ok(true), + Expr::HiddenColumn(_, _) | Expr::HiddenExpr(_, _) => Ok(false), Expr::InSubquery { expr, .. } => expr.nullable(input_schema), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).is_nullable()) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ea37ab603f1f..9a4fcf25e3cf 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -21,8 +21,8 @@ use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::plan; use crate::utils::{ - enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, from_plan, - grouping_set_expr_count, grouping_set_to_exprlist, inspect_expr_pre, + distinct_group_exprs, exprlist_to_fields, find_out_reference_exprs, from_plan, + grouping_set_expr_count, inspect_expr_pre, }; use crate::{ build_join_schema, Expr, ExprSchemable, TableProviderFilterPushDown, TableSource, @@ -1667,25 +1667,39 @@ pub struct Aggregate { } impl Aggregate { - /// Create a new aggregate operator. + /// Create a new aggregate operator, the group_expr might contain multiple [Expr::GroupingSet] expressions pub fn try_new( input: Arc, group_expr: Vec, aggr_expr: Vec, ) -> Result { - let group_expr = enumerate_grouping_sets(group_expr)?; - let grouping_expr: Vec = grouping_set_to_exprlist(group_expr.as_slice())?; - let all_expr = grouping_expr.iter().chain(aggr_expr.iter()); + if group_expr.is_empty() && aggr_expr.is_empty() { + return Err(DataFusionError::Plan( + "Aggregate requires at least one grouping or aggregate expression" + .to_string(), + )); + } + let distinct_grouping_expr: Vec = + distinct_group_exprs(group_expr.as_slice(), true); + + let all_expr = distinct_grouping_expr.iter().chain(aggr_expr.iter()); validate_unique_names("Aggregations", all_expr.clone())?; - let schema = DFSchema::new_with_metadata( + let schema = Arc::new(DFSchema::new_with_metadata( exprlist_to_fields(all_expr, &input)?, input.schema().metadata().clone(), - )?; - Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema)) + )?); + + Ok(Self { + input, + group_expr, + aggr_expr, + schema, + }) } /// Create a new aggregate operator using the provided schema to avoid the overhead of - /// building the schema again when the schema is already known. + /// building the schema again when the schema is already known, + /// The group_expr can not contain multiple [Expr::GroupingSet] expressions. /// /// This method should only be called when you are absolutely sure that the schema being /// provided is correct for the aggregate. If in doubt, call [try_new](Self::try_new) instead. diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 19909cf2fbf4..accd2f08e1ae 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -56,6 +56,8 @@ pub enum TypeSignature { Any(usize), /// One of a list of signatures OneOf(Vec), + /// Arbitrary number of arguments of arbitrary types + Arbitrary, } ///The Signature of a function defines its supported input types as well as its volatility. diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 61a5c91fec09..37fa07aa1e62 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -60,6 +60,7 @@ impl TreeNode for Expr { Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) + | Expr::HiddenColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Exists { .. } @@ -67,6 +68,9 @@ impl TreeNode for Expr { | Expr::Wildcard | Expr::QualifiedWildcard { .. } | Expr::Placeholder { .. } => vec![], + Expr::HiddenExpr(first, _) => { + vec![first.as_ref().clone()] + } Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { vec![left.as_ref().clone(), right.as_ref().clone()] } @@ -148,6 +152,8 @@ impl TreeNode for Expr { } Expr::Column(_) => self, Expr::OuterReferenceColumn(_, _) => self, + Expr::HiddenColumn(_, _) => self, + Expr::HiddenExpr(_, _) => self, Expr::Exists { .. } => self, Expr::InSubquery { expr, diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 3ad197afb64a..97b9eb87affa 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -217,6 +217,7 @@ pub fn coerce_types( } AggregateFunction::Median => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), + AggregateFunction::GroupingId => Ok(input_types.to_vec()), } } @@ -263,6 +264,7 @@ fn check_arg_count( ))); } } + TypeSignature::Arbitrary => return Ok(()), _ => { return Err(DataFusionError::Internal(format!( "Aggregate functions do not support this {signature:?}" diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index a038fdcc92d0..1bdec8a4891c 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -94,6 +94,9 @@ fn get_valid_types( .filter_map(|t| get_valid_types(t, current_types).ok()) .flatten() .collect::>(), + TypeSignature::Arbitrary => { + vec![current_types.to_vec()] + } }; Ok(valid_types) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index bfcadd25ea9d..ca7c30d326ca 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -53,6 +53,16 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result Ok(()) } +/// Check whether the group_expr contains [Expr::GroupingSet]. +pub fn contains_grouping_set(group_expr: &[Expr]) -> bool { + group_expr.iter().any(|e| matches!(e, Expr::GroupingSet(_))) +} + +/// Check whether the group_expr contains [Expr::GroupingSet] without any hidden expr. +pub fn contains_grouping_set_without_hidden_expr(group_expr: &[Expr]) -> bool { + group_expr.iter().any(|e| matches!(e, Expr::GroupingSet(grouping_set) if !grouping_set.contains_hidden_expr())) +} + /// Count the number of distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { @@ -63,7 +73,7 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { .to_string(), )); } - Ok(grouping_set.distinct_expr().len()) + Ok(grouping_set.distinct_expr(true).len()) } else { Ok(group_expr.len()) } @@ -118,6 +128,17 @@ fn check_grouping_set_size_limit(size: usize) -> Result<()> { Ok(()) } +/// check the number of distinct expressions contained in the grouping_set when using group id +fn check_grouping_set_distinct_expression_size_limit(size: usize) -> Result<()> { + // we use u32 to represent the grouping id + let max_expression_set_size = 32; + if size > max_expression_set_size { + return Err(DataFusionError::Plan(format!("The number of distinct group_expression in grouping_set exceeds the maximum limit {} when using group id, found {}", max_expression_set_size, size))); + } + + Ok(()) +} + /// check the number of grouping_set contained in the grouping sets fn check_grouping_sets_size_limit(size: usize) -> Result<()> { let max_grouping_sets_size = 4096; @@ -191,14 +212,10 @@ fn cross_join_grouping_sets( /// (person.id, person.age, person.state),\ /// (person.id, person.age, person.state, person.birth_date)\ /// ) -pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { - let has_grouping_set = group_expr - .iter() - .any(|expr| matches!(expr, Expr::GroupingSet(_))); - if !has_grouping_set || group_expr.len() == 1 { - return Ok(group_expr); +pub fn enumerate_grouping_sets(group_expr: &[Expr]) -> Result> { + if !contains_grouping_set(group_expr) { + return Ok(group_expr.to_vec()); } - // only process mix grouping sets let partial_sets = group_expr .iter() .map(|expr| { @@ -245,22 +262,94 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { ))]) } -/// Find all distinct exprs in a list of group by expressions. If the -/// first element is a `GroupingSet` expression then it must be the only expr. -pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { - if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { - if group_expr.len() > 1 { - return Err(DataFusionError::Plan( - "Invalid group by expressions, GroupingSet must be the only expression" - .to_string(), - )); +/// Generate the grouping ids for each group in this grouping set. +/// Each group id represents the level of grouping which combines the GROUPING() function +/// for several columns into one by assigning each column a bit. +/// +/// For example, we have the Group By columns (person.id, person.age, person.salary), +/// the the Grouping Set (person.id, person.age) will be represented as '001', the selected +/// column is set to '0' and the unselected is set to '1' +pub fn generate_grouping_ids(grouping_set: &GroupingSet) -> Result> { + match grouping_set { + GroupingSet::Rollup(_) => Ok(vec![]), + GroupingSet::Cube(_) => Ok(vec![]), + GroupingSet::GroupingSets(groups) => { + let distinct_exprs = grouping_set.distinct_expr(false); + check_grouping_set_distinct_expression_size_limit(distinct_exprs.len())?; + Ok(groups + .iter() + .map(|group| { + let mut mask = 0u32; + distinct_exprs.iter().for_each(|expr| { + mask = (mask << 1) + (if !group.contains(expr) { 1 } else { 0 }) + }); + mask + }) + .collect::>()) } - Ok(grouping_set.distinct_expr()) - } else { - Ok(group_expr.to_vec()) } } +/// Add hidden grouping set expression to each group in the grouping_set +pub fn add_hidden_grouping_set_expr( + grouping_set: &mut GroupingSet, + hidden_grouping_expr: F, +) -> Result<()> +where + F: Fn(usize) -> Expr, +{ + if let GroupingSet::GroupingSets(groups) = grouping_set { + groups + .iter_mut() + .enumerate() + .for_each(|(idx, expr)| expr.push(hidden_grouping_expr(idx))); + } + Ok(()) +} + +/// Find all distinct exprs in a list of group by expressions. +pub fn distinct_group_exprs(group_expr: &[Expr], include_hidden: bool) -> Vec { + let mut dedup_expr = Vec::new(); + let mut dedup_set = HashSet::new(); + let mut dedup_hidden_expr = Vec::new(); + let mut dedup_hidden_set = HashSet::new(); + group_expr.iter().for_each(|expr| match expr { + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => exprs.iter().for_each(|e| { + if !dedup_set.contains(e) { + dedup_expr.push(e.clone()); + dedup_set.insert(e.clone()); + } + }), + GroupingSet::Cube(exprs) => exprs.iter().for_each(|e| { + if !dedup_set.contains(e) { + dedup_expr.push(e.clone()); + dedup_set.insert(e.clone()); + } + }), + GroupingSet::GroupingSets(groups) => groups.iter().flatten().for_each(|e| { + if let Expr::HiddenExpr(_, second) = e { + if include_hidden && !dedup_hidden_set.contains(second.as_ref()) { + dedup_hidden_expr.push(*second.clone()); + dedup_hidden_set.insert(*second.clone()); + } + } else if !dedup_set.contains(e) { + dedup_expr.push(e.clone()); + dedup_set.insert(e.clone()); + } + }), + }, + _ => { + if !dedup_set.contains(expr) { + dedup_expr.push(expr.clone()); + dedup_set.insert(expr.clone()); + } + } + }); + dedup_expr.append(&mut dedup_hidden_expr); + dedup_expr +} + /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { @@ -310,7 +399,9 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::QualifiedWildcard { .. } | Expr::GetIndexedField { .. } | Expr::Placeholder { .. } - | Expr::OuterReferenceColumn { .. } => {} + | Expr::OuterReferenceColumn { .. } + | Expr::HiddenColumn { .. } + | Expr::HiddenExpr { .. } => {} } Ok(()) }) @@ -566,10 +657,18 @@ pub fn find_out_reference_exprs(expr: &Expr) -> Vec { }) } +/// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence +/// (depth first), with duplicates omitted. +pub fn find_hidden_columns(expr: &Expr) -> Vec { + find_exprs_in_expr(expr, &|nested_expr| { + matches!(nested_expr, Expr::HiddenColumn { .. }) + }) +} + /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that /// pass the provided test. The returned `Expr`'s are deduplicated and returned /// in order of appearance (depth first). -fn find_exprs_in_exprs(exprs: &[Expr], test_fn: &F) -> Vec +pub fn find_exprs_in_exprs(exprs: &[Expr], test_fn: &F) -> Vec where F: Fn(&Expr) -> bool, { @@ -587,7 +686,7 @@ where /// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the /// provided test. The returned `Expr`'s are deduplicated and returned in order /// of appearance (depth first). -fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec +pub fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec where F: Fn(&Expr) -> bool, { @@ -1425,22 +1524,25 @@ mod tests { let grouping_set = grouping_set(vec![multi_cols]); // 1. col - let sets = enumerate_grouping_sets(vec![simple_col.clone()])?; + let sets = enumerate_grouping_sets(&vec![simple_col.clone()])?; let result = format!("{sets:?}"); assert_eq!("[simple_col]", &result); // 2. cube - let sets = enumerate_grouping_sets(vec![cube.clone()])?; + let sets = enumerate_grouping_sets(&vec![cube.clone()])?; let result = format!("{sets:?}"); - assert_eq!("[CUBE (col1, col2, col3)]", &result); + assert_eq!("[GROUPING SETS ((), (col1), (col2), (col1, col2), (col3), (col1, col3), (col2, col3), (col1, col2, col3))]", &result); // 3. rollup - let sets = enumerate_grouping_sets(vec![rollup.clone()])?; + let sets = enumerate_grouping_sets(&vec![rollup.clone()])?; let result = format!("{sets:?}"); - assert_eq!("[ROLLUP (col1, col2, col3)]", &result); + assert_eq!( + "[GROUPING SETS ((), (col1), (col1, col2), (col1, col2, col3))]", + &result + ); // 4. col + cube - let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?; + let sets = enumerate_grouping_sets(&vec![simple_col.clone(), cube.clone()])?; let result = format!("{sets:?}"); assert_eq!( "[GROUPING SETS (\ @@ -1456,7 +1558,7 @@ mod tests { ); // 5. col + rollup - let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?; + let sets = enumerate_grouping_sets(&vec![simple_col.clone(), rollup.clone()])?; let result = format!("{sets:?}"); assert_eq!( "[GROUPING SETS (\ @@ -1469,7 +1571,7 @@ mod tests { // 6. col + grouping_set let sets = - enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?; + enumerate_grouping_sets(&vec![simple_col.clone(), grouping_set.clone()])?; let result = format!("{sets:?}"); assert_eq!( "[GROUPING SETS (\ @@ -1478,11 +1580,9 @@ mod tests { ); // 7. col + grouping_set + rollup - let sets = enumerate_grouping_sets(vec![ - simple_col.clone(), - grouping_set, - rollup.clone(), - ])?; + let sets = enumerate_grouping_sets( + vec![simple_col.clone(), grouping_set, rollup.clone()].as_slice(), + )?; let result = format!("{sets:?}"); assert_eq!( "[GROUPING SETS (\ @@ -1494,7 +1594,7 @@ mod tests { ); // 8. col + cube + rollup - let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?; + let sets = enumerate_grouping_sets(&vec![simple_col, cube, rollup])?; let result = format!("{sets:?}"); assert_eq!( "[GROUPING SETS (\ @@ -1535,4 +1635,41 @@ mod tests { Ok(()) } + + #[test] + fn test_generate_grouping_ids() -> Result<()> { + // 001 + let multi_cols1 = vec![col("col1"), col("col2")]; + // 010 + let multi_cols2 = vec![col("col1"), col("col3")]; + // 100 + let multi_cols3 = vec![col("col2"), col("col3")]; + // 000 + let multi_cols4 = vec![col("col1"), col("col2"), col("col3")]; + // 011 + let multi_cols5 = vec![col("col1")]; + // 101 + let multi_cols6 = vec![col("col2")]; + // 110 + let multi_cols7 = vec![col("col3")]; + // 011 + let multi_cols8 = vec![col("col1"), col("col1"), col("col1")]; + + let grouping_set = GroupingSet::GroupingSets(vec![ + multi_cols1, + multi_cols2, + multi_cols3, + multi_cols4, + multi_cols5, + multi_cols6, + multi_cols7, + multi_cols8, + ]); + + let grouping_id = generate_grouping_ids(&grouping_set)?; + let grouping_id_result = format!("{grouping_id:?}"); + assert_eq!("[1, 2, 4, 0, 3, 5, 6, 3]", &grouping_id_result); + + Ok(()) + } } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index aef46926f517..46a4c1ad6038 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -17,9 +17,11 @@ mod count_wildcard_rule; mod inline_table_scan; +mod resolve_grouping_analytics; use crate::analyzer::count_wildcard_rule::CountWildcardRule; use crate::analyzer::inline_table_scan::InlineTableScan; +use crate::analyzer::resolve_grouping_analytics::ResolveGroupingAnalytics; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; @@ -64,6 +66,7 @@ impl Analyzer { /// Create a new analyzer using the recommended list of rules pub fn new() -> Self { let rules: Vec> = vec![ + Arc::new(ResolveGroupingAnalytics::new()), Arc::new(CountWildcardRule::new()), Arc::new(InlineTableScan::new()), ]; diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs new file mode 100644 index 000000000000..a6c8c2476a55 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_analytics.rs @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::analyzer::AnalyzerRule; +use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::utils::{ + add_hidden_grouping_set_expr, contains_grouping_set_without_hidden_expr, + distinct_group_exprs, enumerate_grouping_sets, generate_grouping_ids, +}; +use datafusion_expr::{ + aggregate_function, bitwise_and, bitwise_shift_right, cast, lit, Projection, +}; +use datafusion_expr::{Aggregate, Expr, LogicalPlan}; +use std::sync::Arc; + +use datafusion_common::{DataFusionError, Result}; + +pub struct ResolveGroupingAnalytics; + +impl ResolveGroupingAnalytics { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +// Internal column used to represent the grouping_id, used by the grouping functions. +// It is "spark_grouping_id" in Spark +const INTERNAL_GROUPING_ID: &str = "grouping_id"; +// Internal column used to represent different grouping sets when there are duplicated grouping sets +const INTERNAL_GROUPING_SET_ID: &str = "grouping_set_id"; + +impl AnalyzerRule for ResolveGroupingAnalytics { + fn analyze( + &self, + plan: LogicalPlan, + _config: &ConfigOptions, + ) -> datafusion_common::Result { + plan.transform_down(&|plan| match plan { + LogicalPlan::Aggregate(Aggregate { + input, + aggr_expr, + group_expr, + .. + }) if contains_grouping_set_without_hidden_expr(&group_expr) => { + let mut expanded_grouping = enumerate_grouping_sets(&group_expr)?; + let mut new_project_exec = vec![]; + if let [Expr::GroupingSet(ref mut grouping_set)] = expanded_grouping.as_mut_slice() { + let new_agg_expr = if contains_grouping_funcs_as_agg_expr(&aggr_expr) { + let gid_column = Expr::HiddenColumn( + DataType::UInt32, + INTERNAL_GROUPING_ID.to_string(), + ); + let hidden_name = gid_column.display_name()?; + let grouping_ids = generate_grouping_ids(grouping_set)?; + let hidden_grouping_expr = |group_set_idx: usize| Expr::HiddenExpr(Box::new(lit(grouping_ids[group_set_idx]) + .alias(hidden_name.clone())), Box::new(gid_column.clone())); + add_hidden_grouping_set_expr(grouping_set, hidden_grouping_expr)?; + + let distinct_group_by = distinct_group_exprs(&group_expr, false); + let mut new_agg_expr = vec![]; + aggr_expr.into_iter().try_for_each(|expr| { + let new_expr = replace_grouping_func( + expr.clone(), + &distinct_group_by, + gid_column.clone(), + )?; + // The grouping func is rewrited to a normal expr, not the AggregateFunction anymore, remove it from the aggr_expr + if new_expr.ne(&expr) { + new_project_exec.push(new_expr); + } else { + new_agg_expr.push(new_expr); + } + Ok::<(), DataFusionError>(()) + })?; + new_agg_expr + } else { + aggr_expr + }; + if grouping_set.contains_duplicate_grouping() { + let grouping_set_id_column = Expr::HiddenColumn( + DataType::UInt32, + INTERNAL_GROUPING_SET_ID.to_string(), + ); + let hidden_name = grouping_set_id_column.display_name()?; + let hidden_grouping_expr = |group_set_idx: usize| Expr::HiddenExpr(Box::new(lit((group_set_idx + 1) as u32) + .alias(hidden_name.clone())), Box::new(grouping_set_id_column.clone())); + add_hidden_grouping_set_expr(grouping_set, hidden_grouping_expr)?; + } + + let aggregate = Aggregate::try_new( + input, + vec![Expr::GroupingSet(grouping_set.clone())], + new_agg_expr, + )?; + let agg_schema = aggregate.schema.clone(); + let new_agg = LogicalPlan::Aggregate(aggregate); + if !new_project_exec.is_empty() { + let mut expr: Vec = agg_schema + .fields() + .iter() + .map(|field| field.qualified_column()) + .map(Expr::Column) + .collect(); + expr.append(&mut new_project_exec); + Ok(Transformed::Yes(LogicalPlan::Projection(Projection::try_new(expr, Arc::new(new_agg))?))) + } else { + Ok(Transformed::Yes(new_agg)) + } + } else { + Err(DataFusionError::Plan( + "Invalid group by expressions, GroupingSet must be the only expression" + .to_string(), + )) + } + } + _ => Ok(Transformed::No(plan)), + }) + } + fn name(&self) -> &str { + "resolve_grouping_analytics" + } +} + +fn contains_grouping_funcs_as_agg_expr(aggr_expr: &[Expr]) -> bool { + aggr_expr.iter().any(|expr| { + matches!( + expr, + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Grouping, + .. + }) | Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::GroupingId, + .. + }) + ) + }) +} + +fn replace_grouping_func( + expr: Expr, + group_by_exprs: &[Expr], + gid_column: Expr, +) -> Result { + expr.transform(&|expr| { + let display_name = expr.display_name()?; + match expr { + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Grouping, + args, + .. + }) => { + let grouping_col = &args[0]; + match group_by_exprs.iter().position(|e| e == grouping_col) { + Some(idx) => Ok(Transformed::Yes(cast( + bitwise_and( + bitwise_shift_right( + gid_column.clone(), + lit((group_by_exprs.len() - 1 - idx) as u32), + ), + lit(1u32), + ), + DataType::UInt8, + ).alias(display_name))), + None => Err(DataFusionError::Plan(format!( + "Column of GROUPING({:?}) can't be found in GROUP BY columns {:?}", + grouping_col, group_by_exprs + ))), + } + } + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::GroupingId, + args, + .. + }) => { + if group_by_exprs.is_empty() + || (group_by_exprs.len() == args.len() + && group_by_exprs.iter().zip(args.iter()).all(|(g, a)| g == a)) + { + Ok(Transformed::Yes(gid_column.clone().alias(display_name))) + } else { + Err(DataFusionError::Plan(format!( + "Columns of GROUPING_ID({:?}) does not match GROUP BY columns {:?}", + args, group_by_exprs + ))) + } + } + _ => Ok(Transformed::No(expr)), + } + }) +} diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 2a78551ea131..b434da5f21ea 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -649,7 +649,9 @@ impl OptimizerRule for PushDownFilter { let mut push_predicates = vec![]; for expr in predicates { let cols = expr.to_columns()?; - if cols.iter().all(|c| group_expr_columns.contains(c)) { + if !expr.contains_hidden_columns() + && cols.iter().all(|c| group_expr_columns.contains(c)) + { push_predicates.push(expr); } else { keep_predicates.push(expr); diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 767077aa0c02..f21900497a2a 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -329,6 +329,8 @@ impl OptimizerRule for PushDownProjection { let new_proj = plan.with_new_inputs(&[filter.input.as_ref().clone()])?; child_plan.with_new_inputs(&[new_proj])? + } else if filter.predicate.contains_hidden_columns() { + return Ok(None); } else { let mut required_columns = HashSet::new(); exprlist_to_columns(&projection.expr, &mut required_columns)?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e904f895e12a..0569a283c118 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -256,6 +256,8 @@ impl<'a> ConstEvaluator<'a> { | Expr::ScalarVariable(_, _) | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) + | Expr::HiddenColumn(_, _) + | Expr::HiddenExpr(_, _) | Expr::Exists { .. } | Expr::InSubquery { .. } | Expr::ScalarSubquery(_) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index cee31b5b3352..e0cda718262d 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -20,6 +20,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::utils::contains_grouping_set; use datafusion_expr::{ col, expr::AggregateFunction, @@ -82,11 +83,6 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { } } -/// Check if the first expr is [Expr::GroupingSet]. -fn contains_grouping_set(expr: &[Expr]) -> bool { - matches!(expr.first(), Some(Expr::GroupingSet(_))) -} - impl OptimizerRule for SingleDistinctToGroupBy { fn try_optimize( &self, diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index b3dbef7dfdf5..f0ab7baf3374 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -62,11 +62,6 @@ pub fn create_aggregate_expr( input_phy_exprs[0].clone(), name, )), - (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( - input_phy_exprs[0].clone(), - name, - return_type, - )), (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new( input_phy_exprs[0].clone(), name, @@ -250,6 +245,16 @@ pub fn create_aggregate_expr( "MEDIAN(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::Grouping, _) => { + return Err(DataFusionError::Plan( + "GROUPING() aggregations are not evaluable, should be converted by the Analyzer".to_string(), + )); + } + (AggregateFunction::GroupingId, _) => { + return Err(DataFusionError::Plan( + "GROUPING_ID() aggregations are not evaluable, should be converted by the Analyzer".to_string(), + )); + } }) } diff --git a/datafusion/physical-expr/src/aggregate/grouping.rs b/datafusion/physical-expr/src/aggregate/grouping.rs deleted file mode 100644 index 9ddd17c035e8..000000000000 --- a/datafusion/physical-expr/src/aggregate/grouping.rs +++ /dev/null @@ -1,93 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::sync::Arc; - -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; - -use crate::expressions::format_state_name; - -/// GROUPING aggregate expression -/// Returns the amount of non-null values of the given expression. -#[derive(Debug)] -pub struct Grouping { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, -} - -impl Grouping { - /// Create a new GROUPING aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for Grouping { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "grouping"), - self.data_type.clone(), - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "physical plan is not yet implemented for GROUPING aggregate function" - .to_owned(), - )) - } - - fn name(&self) -> &str { - &self.name - } -} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index c42a5c03b306..9a25bc3d7a97 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -35,7 +35,6 @@ pub(crate) mod correlation; pub(crate) mod count; pub(crate) mod count_distinct; pub(crate) mod covariance; -pub(crate) mod grouping; pub(crate) mod median; #[macro_use] pub(crate) mod min_max; diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index eb2be5ef217c..43c06f22b565 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -202,6 +202,77 @@ impl PartialEq for UnKnownColumn { } } +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct HiddenColumn { + name: String, + data_type: DataType, +} + +impl HiddenColumn { + /// Create a new hidden column + pub fn new(name: &str, data_type: &DataType) -> Self { + Self { + name: name.to_owned(), + data_type: data_type.clone(), + } + } + + /// Get the column name + pub fn name(&self) -> &str { + &self.name + } +} + +impl std::fmt::Display for HiddenColumn { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +impl PhysicalExpr for HiddenColumn { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn std::any::Any { + self + } + + /// Get the data type of this expression, given the schema of the input + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.data_type.clone()) + } + + /// Decide whehter this expression is nullable, given the schema of the input + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + /// Evaluate the expression + fn evaluate(&self, _batch: &RecordBatch) -> Result { + Err(DataFusionError::Plan( + "HiddenColumn::evaluate() should not be called".to_owned(), + )) + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } +} + +impl PartialEq for HiddenColumn { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } +} + /// Create a column expression pub fn col(name: &str, schema: &Schema) -> Result> { Ok(Arc::new(Column::new_with_schema(name, schema)?)) diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 63fb7b7d37ad..2ac11c23cc52 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -52,7 +52,6 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; @@ -77,7 +76,7 @@ pub use case::{case, CaseExpr}; pub use cast::{ cast, cast_column, cast_with_options, CastExpr, DEFAULT_DATAFUSION_CAST_OPTIONS, }; -pub use column::{col, Column, UnKnownColumn}; +pub use column::{col, Column, HiddenColumn, UnKnownColumn}; pub use datetime::DateTimeIntervalExpr; pub use get_indexed_field::GetIndexedFieldExpr; pub use in_list::{in_list, InListExpr}; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 1fbd73b3ba01..d63a6a897b4a 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -26,7 +26,9 @@ use crate::{ PhysicalExpr, }; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + unqualified_field_not_found, DFSchema, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::expr::Cast; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetIndexedField, Like, Operator, TryCast, @@ -474,6 +476,22 @@ pub fn create_physical_expr( expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, + Expr::HiddenColumn(_, _) => { + let hidden_col_name = e.display_name()?; + let col_idx = + input_dfschema.index_of_column_by_name(None, &hidden_col_name)?; + if let Some(idx) = col_idx { + Ok(Arc::new(Column::new(&hidden_col_name, idx))) + } else { + Err(unqualified_field_not_found( + &hidden_col_name, + input_dfschema, + )) + } + } + Expr::HiddenExpr(expr, _) => { + create_physical_expr(expr, input_dfschema, input_schema, execution_props) + } other => Err(DataFusionError::NotImplemented(format!( "Physical plan does not support logical expression {other:?}" ))), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 76ec5b001708..ee1c37d1b5d6 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -550,6 +550,7 @@ enum AggregateFunction { APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; GROUPING = 17; MEDIAN = 18; + GROUPING_ID = 19; } message AggregateExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 406d9ee27aa5..df190c78adb1 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -404,6 +404,7 @@ impl serde::Serialize for AggregateFunction { Self::ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Self::Grouping => "GROUPING", Self::Median => "MEDIAN", + Self::GroupingId => "GROUPING_ID", }; serializer.serialize_str(variant) } @@ -434,6 +435,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "APPROX_PERCENTILE_CONT_WITH_WEIGHT", "GROUPING", "MEDIAN", + "GROUPING_ID", ]; struct GeneratedVisitor; @@ -495,6 +497,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => Ok(AggregateFunction::ApproxPercentileContWithWeight), "GROUPING" => Ok(AggregateFunction::Grouping), "MEDIAN" => Ok(AggregateFunction::Median), + "GROUPING_ID" => Ok(AggregateFunction::GroupingId), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e5b4534f60a0..5d17c859fced 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2311,6 +2311,7 @@ pub enum AggregateFunction { ApproxPercentileContWithWeight = 16, Grouping = 17, Median = 18, + GroupingId = 19, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2340,6 +2341,7 @@ impl AggregateFunction { } AggregateFunction::Grouping => "GROUPING", AggregateFunction::Median => "MEDIAN", + AggregateFunction::GroupingId => "GROUPING_ID", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2366,6 +2368,7 @@ impl AggregateFunction { } "GROUPING" => Some(Self::Grouping), "MEDIAN" => Some(Self::Median), + "GROUPING_ID" => Some(Self::GroupingId), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 845cb60d17a9..ab0c2915c687 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -502,6 +502,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::Median => Self::Median, + protobuf::AggregateFunction::GroupingId => Self::GroupingId, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e8570cf3c7e4..42cfab2371d9 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -371,6 +371,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::Grouping => Self::Grouping, AggregateFunction::Median => Self::Median, + AggregateFunction::GroupingId => Self::GroupingId, } } } @@ -644,6 +645,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::GroupingId => { + protobuf::AggregateFunction::GroupingId + } }; let aggregate_expr = protobuf::AggregateExprNode { @@ -885,10 +889,12 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } - | Expr::OuterReferenceColumn { .. } => { + | Expr::OuterReferenceColumn { .. } + | Expr::HiddenColumn { .. } + | Expr::HiddenExpr { .. } => { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 - return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); + return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery { .. } | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported | Exp:HiddenColumn not supported | Exp:HiddenExpr not supported".to_string())); } Expr::GetIndexedField(GetIndexedField { key, expr }) => Self { expr_type: Some(ExprType::GetIndexedField(Box::new( diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 92986b0b39b2..9f0494122e9e 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -449,7 +449,13 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(Arc::new(AggregateExec::try_new( agg_mode, - PhysicalGroupBy::new(group_expr, null_expr, groups), + PhysicalGroupBy::new( + group_expr.clone(), + vec![], + group_expr, + null_expr, + groups, + ), physical_aggr_expr, input, Arc::new((&input_schema).try_into()?), diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 91cef6d4712e..8c0cd0e6a53f 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -362,6 +362,8 @@ where ))), Expr::Column { .. } | Expr::OuterReferenceColumn(_, _) + | Expr::HiddenColumn(_, _) + | Expr::HiddenExpr(_, _) | Expr::Literal(_) | Expr::ScalarVariable(_, _) | Expr::Exists { .. } diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index 3242989f574d..6e12c4faa09e 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -2856,8 +2856,8 @@ fn aggregate_with_rollup() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)"; let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: person"; + \n Aggregate: groupBy=[[person.id, ROLLUP (person.state, person.age)]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person"; quick_test(sql, expected); } @@ -2866,8 +2866,8 @@ fn aggregate_with_rollup_with_grouping() { let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), COUNT(*) \ FROM person GROUP BY id, ROLLUP (state, age)"; let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), COUNT(UInt8(1))]]\ - \n TableScan: person"; + \n Aggregate: groupBy=[[person.id, ROLLUP (person.state, person.age)]], aggr=[[GROUPING(person.state), GROUPING(person.age), COUNT(UInt8(1))]]\ + \n TableScan: person"; quick_test(sql, expected); } @@ -2898,8 +2898,8 @@ fn aggregate_with_cube() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)"; let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: person"; + \n Aggregate: groupBy=[[person.id, CUBE (person.state, person.age)]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person"; quick_test(sql, expected); } @@ -2915,8 +2915,8 @@ fn round_decimal() { fn aggregate_with_grouping_sets() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[COUNT(UInt8(1))]]\ - \n TableScan: person"; + \n Aggregate: groupBy=[[person.id, GROUPING SETS ((person.state), (person.state, person.age), (person.id, person.state))]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person"; quick_test(sql, expected); } @@ -3902,7 +3902,7 @@ fn test_multi_grouping_sets() { GROUPING SETS ((person.age,person.salary),(person.age))"; let expected = "Projection: person.id, person.age\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.age, person.salary), (person.id, person.age))]], aggr=[[]]\ + \n Aggregate: groupBy=[[person.id, GROUPING SETS ((person.age, person.salary), (person.age))]], aggr=[[]]\ \n TableScan: person"; quick_test(sql, expected); @@ -3914,13 +3914,7 @@ fn test_multi_grouping_sets() { ROLLUP(person.state, person.birth_date)"; let expected = "Projection: person.id, person.age\ - \n Aggregate: groupBy=[[GROUPING SETS (\ - (person.id, person.age, person.salary), \ - (person.id, person.age, person.salary, person.state), \ - (person.id, person.age, person.salary, person.state, person.birth_date), \ - (person.id, person.age), \ - (person.id, person.age, person.state), \ - (person.id, person.age, person.state, person.birth_date))]], aggr=[[]]\ + \n Aggregate: groupBy=[[person.id, GROUPING SETS ((person.age, person.salary), (person.age)), ROLLUP (person.state, person.birth_date)]], aggr=[[]]\ \n TableScan: person"; quick_test(sql, expected); }