diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 1eef1b718ba6f..4c6a28bb18117 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -58,6 +58,8 @@ datafusion-functions = { workspace = true, default-features = true } datafusion-functions-aggregate = { workspace = true } datafusion-functions-nested = { workspace = true } datafusion-functions-window = { workspace = true } +datafusion-optimizer = { workspace = true } +datafusion-common = { workspace = true } env_logger = { workspace = true } paste = "^1.0" rstest = { workspace = true } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 0b4ecc95beef6..570b7a604e7b6 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -19,12 +19,14 @@ use std::{cmp::Ordering, sync::Arc, vec}; use datafusion_common::{ internal_err, - tree_node::{Transformed, TransformedResult, TreeNode}, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion}, Column, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, - LogicalPlanBuilder, Projection, SortExpr, Unnest, Window, + expr::{self, Placeholder}, + utils::grouping_set_to_exprlist, + Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Projection, SortExpr, Unnest, + Window, }; use sqlparser::ast; @@ -168,6 +170,20 @@ pub(crate) fn unproject_agg_exprs( agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { + // If the current expression is an Alias over the internal grouping id column, + // we need to return a placeholder expression that represents the inverse + // of the replacement done in the [ResolveGroupingFunction] analyzer rule. + // + // [ResolveGroupingFunction]: datafusion_optimizer::resolve_grouping_function::ResolveGroupingFunction + if let Expr::Alias(alias) = &expr { + if find_grouping_id_col(&expr).is_some() { + return Ok(Expr::Placeholder(Placeholder::new( + alias.name.clone(), + None, + ))); + } + } + expr.transform(|sub_expr| { if let Expr::Column(c) = sub_expr { if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { @@ -243,6 +259,21 @@ fn find_window_expr<'a>( .find(|expr| expr.schema_name().to_string() == column_name) } +/// Recursively searches for a Column expression with the name of the internal grouping id column: `__grouping_id`. +fn find_grouping_id_col(expr: &Expr) -> Option<&Expr> { + let mut grouping_id_col: Option<&Expr> = None; + expr.apply(|sub_expr| { + if let Expr::Column(c) = sub_expr { + if c.name == Aggregate::INTERNAL_GROUPING_ID { + grouping_id_col = Some(sub_expr); + } + } + Ok(TreeNodeRecursion::Continue) + }) + .ok()?; + grouping_id_col +} + /// Transforms a Column expression into the actual expression from aggregation or projection if found. /// This is required because if an ORDER BY expression is present in an Aggregate or Select, it is replaced /// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We need to transform it back to diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 567cf775c77c9..cef4b25d9e289 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use std::vec; use arrow_schema::*; +use datafusion_common::config::ConfigOptions; use datafusion_common::{DFSchema, Result, TableReference}; use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_udaf}; use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder}; @@ -26,6 +27,8 @@ use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; use datafusion_functions_nested::make_array::make_array_udf; use datafusion_functions_window::rank::rank_udwf; +use datafusion_optimizer::analyzer::resolve_grouping_function::ResolveGroupingFunction; +use datafusion_optimizer::AnalyzerRule; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect, @@ -796,6 +799,38 @@ where assert_eq!(roundtrip_statement.to_string(), expect); } +fn sql_round_trip_with_analyzer( + dialect: D, + query: &str, + expect: &str, + analyzer: &dyn AnalyzerRule, +) where + D: Dialect, +{ + let statement = Parser::new(&dialect) + .try_with_sql(query) + .unwrap() + .parse_statement() + .unwrap(); + + let context = MockContextProvider { + state: MockSessionState::default() + .with_aggregate_function(sum_udaf()) + .with_aggregate_function(max_udaf()) + .with_aggregate_function(grouping_udaf()) + .with_window_function(rank_udwf()) + .with_scalar_function(Arc::new(unicode::substr().as_ref().clone())) + .with_scalar_function(make_array_udf()), + }; + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + + let plan = analyzer.analyze(plan, &ConfigOptions::default()).unwrap(); + + let roundtrip_statement = plan_to_sql(&plan).unwrap(); + assert_eq!(roundtrip_statement.to_string(), expect); +} + #[test] fn test_table_scan_alias() -> Result<()> { let schema = Schema::new(vec![ @@ -1276,6 +1311,16 @@ GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(), ); } +#[test] +fn test_grouping_aggregate_function_to_sql() { + sql_round_trip_with_analyzer( + GenericDialect {}, + r#"SELECT id, first_name, grouping(id) FROM person GROUP BY ROLLUP(id, first_name)"#, + r#"SELECT id, first_name, "grouping(person.id)" FROM (SELECT person.id, person.first_name, grouping(person.id) FROM person GROUP BY ROLLUP (person.id, person.first_name))"#, + &ResolveGroupingFunction, + ); +} + #[test] fn test_unnest_to_sql() { sql_round_trip(