diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index b14fbdff236f5..20844d80332f4 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -377,8 +377,17 @@ impl Unparser<'_> { }; if self.dialect.unnest_as_table_factor() && unnest_input_type.is_some() { if let LogicalPlan::Unnest(unnest) = &p.input.as_ref() { - return self - .unnest_to_table_factor_sql(unnest, query, select, relation); + if let Some(unnest_relation) = + self.try_unnest_to_table_factor_sql(unnest)? + { + relation.unnest(unnest_relation); + return self.select_to_sql_recursively( + p.input.as_ref(), + query, + select, + relation, + ); + } } } @@ -854,25 +863,34 @@ impl Unparser<'_> { None } - fn unnest_to_table_factor_sql( + fn try_unnest_to_table_factor_sql( &self, unnest: &Unnest, - query: &mut Option, - select: &mut SelectBuilder, - relation: &mut RelationBuilder, - ) -> Result<()> { + ) -> Result> { let mut unnest_relation = UnnestRelationBuilder::default(); - let LogicalPlan::Projection(p) = unnest.input.as_ref() else { - return internal_err!("Unnest input is not a Projection: {unnest:?}"); + let LogicalPlan::Projection(projection) = unnest.input.as_ref() else { + return Ok(None); }; - let exprs = p + + if !matches!(projection.input.as_ref(), LogicalPlan::EmptyRelation(_)) { + // It may be possible that UNNEST is used as a source for the query. + // However, at this point, we don't yet know if it is just a single expression + // from another source or if it's from UNNEST. + // + // Unnest(Projection(EmptyRelation)) denotes a case with `UNNEST([...])`, + // which is normally safe to unnest as a table factor. + // However, in the future, more comprehensive checks can be added here. + return Ok(None); + }; + + let exprs = projection .expr .iter() .map(|e| self.expr_to_sql(e)) .collect::>>()?; unnest_relation.array_exprs(exprs); - relation.unnest(unnest_relation); - self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) + + Ok(Some(unnest_relation)) } fn is_scan_with_pushdown(scan: &TableScan) -> bool { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 8599de864687b..25dbf51bf721c 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -633,6 +633,18 @@ fn roundtrip_statement_with_dialect() -> Result<()> { parser_dialect: Box::new(GenericDialect {}), unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), }, + TestStatementWithDialect { + sql: "SELECT unnest([1, 2, 3, 4]) from unnest([1, 2, 3]);", + expected: r#"SELECT UNNEST([1, 2, 3, 4]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3),Int64(4))) FROM UNNEST([1, 2, 3])"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, + TestStatementWithDialect { + sql: "SELECT unnest([1, 2, 3, 4]) from unnest([1, 2, 3]);", + expected: r#"SELECT UNNEST([1, 2, 3, 4]) AS UNNEST(make_array(Int64(1),Int64(2),Int64(3),Int64(4))) FROM UNNEST([1, 2, 3])"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(CustomDialectBuilder::default().with_unnest_as_table_factor(true).build()), + }, TestStatementWithDialect { sql: "SELECT * FROM unnest_table u, UNNEST(u.array_col)", expected: r#"SELECT u.array_col, u.struct_col, "UNNEST(outer_ref(u.array_col))" FROM unnest_table AS u CROSS JOIN LATERAL (SELECT UNNEST(u.array_col) AS "UNNEST(outer_ref(u.array_col))")"#,