diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 797867e7d0c6f..46bb31556c17e 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -138,17 +138,25 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( let inner_exprs = inner_p .expr .iter() - .map(|f| { - if let Expr::Alias(alias) = f { + .enumerate() + .map(|(i, f)| match f { + Expr::Alias(alias) => { let a = Expr::Column(alias.name.clone().into()); map.insert(a.clone(), f.clone()); a - } else { - // inner expr may have different type to outer expr: e.g. a + 1 is a column of - // string in outer, but a expr of math in inner - map.insert(Expr::Column(f.to_string().into()), f.clone()); + } + Expr::Column(_) => { + map.insert( + Expr::Column(inner_p.schema.field(i).name().into()), + f.clone(), + ); f.clone() } + _ => { + let a = Expr::Column(inner_p.schema.field(i).name().into()); + map.insert(a.clone(), f.clone()); + a + } }) .collect::>(); @@ -159,8 +167,10 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( } } - // inner expr may have different type to outer expr: e.g. a + 1 is a column of - // string in outer, but a expr of math in inner + // Compare outer collects Expr::to_string with inner collected transformed values + // alias -> alias column + // column -> remain + // others, extract schema field name let outer_collects = collects.iter().map(Expr::to_string).collect::>(); let inner_collects = inner_exprs .iter() @@ -236,11 +246,14 @@ pub(super) fn subquery_alias_inner_query_and_columns( return (plan, vec![]); }; - let expr = outer_alias.expr.clone(); + // inner projection schema fields store the projection name which is used in outer + // projection expr + let inner_expr_string = match inner_expr { + Expr::Column(_) => inner_expr.to_string(), + _ => inner_projection.schema.field(i).name().clone(), + }; - // inner expr may have different type to outer expr: e.g. a + 1 is a column of - // string in outer, but a expr of math in inner - if expr.to_string() != inner_expr.to_string() { + if outer_alias.expr.to_string() != inner_expr_string { return (plan, vec![]); }; diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 09f1be094840c..faa6a1e18a39e 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -346,6 +346,18 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(UnparserDefaultDialect {}), }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT CAST((CAST(j1_id as BIGINT) + 1) as int) * 10 FROM j1 LIMIT 1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT (CAST((CAST(j1.j1_id AS BIGINT) + 1) AS INTEGER) * 10) FROM j1 LIMIT 1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT CAST(j1_id as BIGINT) + 1 FROM j1 ORDER BY j1_id LIMIT 1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + } ]; for query in tests {