Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions datafusion/sql/src/unparser/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ impl SelectBuilder {
self.top = value;
self
}
pub fn get_projection(&self) -> Vec<ast::SelectItem> {
self.projection.clone()
}
pub fn projection(&mut self, value: Vec<ast::SelectItem>) -> &mut Self {
self.projection = value;
self
Expand Down Expand Up @@ -193,9 +196,9 @@ impl SelectBuilder {
(None, Some(new_selection)) => {
self.selection = Some(new_selection);
}
(_, None) => ()
(_, None) => (),
}

self
}
pub fn group_by(&mut self, value: ast::GroupByExpr) -> &mut Self {
Expand Down Expand Up @@ -299,7 +302,9 @@ impl TableWithJoinsBuilder {
self.relation = Some(value);
self
}

pub fn get_joins(&self) -> Vec<ast::Join> {
self.joins.clone()
}
pub fn joins(&mut self, value: Vec<ast::Join>) -> &mut Self {
self.joins = value;
self
Expand Down Expand Up @@ -352,6 +357,25 @@ impl RelationBuilder {
pub fn has_relation(&self) -> bool {
self.relation.is_some()
}
pub fn get_name(&self) -> Option<String> {
match self.relation {
Some(TableFactorBuilder::Table(ref value)) => {
value.name.as_ref().map(|a| a.to_string())
}
_ => None,
}
}
pub fn get_alias(&self) -> Option<String> {
match self.relation {
Some(TableFactorBuilder::Table(ref value)) => {
value.alias.as_ref().map(|a| a.name.to_string())
}
Some(TableFactorBuilder::Derived(ref value)) => {
value.alias.as_ref().map(|a| a.name.to_string())
}
_ => None,
}
}
pub fn table(&mut self, value: TableRelationBuilder) -> &mut Self {
self.relation = Some(TableFactorBuilder::Table(value));
self
Expand Down
61 changes: 60 additions & 1 deletion datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use datafusion_expr::{
expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan,
LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan,
};
use sqlparser::ast::{self, Ident, SetExpr};
use sqlparser::ast::{self, display_separated, Ident, SetExpr};
use std::{sync::Arc, vec};

use super::{
Expand Down Expand Up @@ -162,10 +162,69 @@ impl Unparser<'_> {
)]);
}

// Construct a list of all the identifiers present in query sources
let mut all_idents = Vec::new();
if let Some(source_alias) = relation_builder.get_alias() {
all_idents.push(source_alias);
} else if let Some(source_name) = relation_builder.get_name() {
all_idents.push(source_name);
}

let mut twj = select_builder.pop_from().unwrap();
twj.get_joins()
.iter()
.for_each(|join| match &join.relation {
ast::TableFactor::Table { alias, name, .. } => {
if let Some(alias) = alias {
all_idents.push(alias.name.to_string());
} else {
all_idents.push(name.to_string());
}
}
ast::TableFactor::Derived { alias, .. } => {
if let Some(alias) = alias {
all_idents.push(alias.name.to_string());
}
}
_ => {}
});

twj.relation(relation_builder);
select_builder.push_from(twj);

// Ensure that the projection contains references to sources that actually exist
let mut projection = select_builder.get_projection();
projection
.iter_mut()
.for_each(|select_item| match select_item {
ast::SelectItem::UnnamedExpr(ast::Expr::CompoundIdentifier(idents)) => {
if idents.len() > 1 {
let ident_source = display_separated(
&idents
.clone()
.into_iter()
.take(idents.len() - 1)
.collect::<Vec<Ident>>(),
".",
)
.to_string();
// If the identifier is not present in the list of all identifiers, it refers to a table that does not exist
if !all_idents.contains(&ident_source) {
let Some(last) = idents.last() else {
unreachable!(
"CompoundIdentifier must have a last element"
);
};
// Reset the identifiers to only the last element, which is the column name
*idents = vec![last.clone()];
}
}
}
_ => {}
});

select_builder.projection(projection);

Ok(SetExpr::Select(Box::new(select_builder.build()?)))
}

Expand Down
52 changes: 52 additions & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,58 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "select j1_id from (select ta.j1_id from j1 ta)",
expected:
// This seems like desirable behavior, but is actually hiding an underlying issue
// The re-written identifier is `ta`.`j1_id`, because `reconstuct_select_statement` runs before the derived projection
// and for some reason, the derived table alias is pre-set to `ta` for the top-level projection
"SELECT `j1_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `derived_projection`",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select j1_id from (select ta.j1_id from j1 ta)",
expected:
"SELECT j1_id FROM (SELECT ta.j1_id FROM j1 AS ta)",
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "select j1_id from (select ta.j1_id from j1 ta) AS tbl1",
expected:
"SELECT tbl1.j1_id FROM (SELECT ta.j1_id FROM j1 AS ta) AS tbl1",
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta) AS tbl1, (select ta.j1_id as j2_id from j1 ta) as tbl2",
expected:
"SELECT tbl1.j1_id, tbl2.j2_id FROM (SELECT ta.j1_id FROM j1 AS ta) AS tbl1 JOIN (SELECT ta.j1_id AS j2_id FROM j1 AS ta) AS tbl2 ON true",
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta) AS tbl1, (select ta.j1_id as j2_id from j1 ta) as tbl2",
expected:
"SELECT `tbl1`.`j1_id`, `tbl2`.`j2_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `tbl1` JOIN (SELECT `ta`.`j1_id` AS `j2_id` FROM `j1` AS `ta`) AS `tbl2` ON true",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta), (select ta.j1_id as j2_id from j1 ta)",
expected:
"SELECT `j1_id`, `j2_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `derived_projection` JOIN (SELECT `ta`.`j1_id` AS `j2_id` FROM `j1` AS `ta`) AS `derived_projection` ON true",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta), (select ta.j1_id AS j2_id from j1 ta)",
expected:
"SELECT j1_id, j2_id FROM (SELECT ta.j1_id FROM j1 AS ta) JOIN (SELECT ta.j1_id AS j2_id FROM j1 AS ta) ON true",
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id order by j1_id",
expected: r#"SELECT j1.j1_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#,
Expand Down