Skip to content

Commit 5e55154

Browse files
peaseephillipleblanc
authored andcommitted
fix: Ensure only tables or aliases that exist are projected (#52)
1 parent 09495f6 commit 5e55154

File tree

3 files changed

+137
-2
lines changed

3 files changed

+137
-2
lines changed

datafusion/sql/src/unparser/ast.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ impl SelectBuilder {
155155
self.top = value;
156156
self
157157
}
158+
pub fn get_projection(&self) -> Vec<ast::SelectItem> {
159+
self.projection.clone()
160+
}
158161
pub fn projection(&mut self, value: Vec<ast::SelectItem>) -> &mut Self {
159162
self.projection = value;
160163
self
@@ -307,7 +310,9 @@ impl TableWithJoinsBuilder {
307310
self.relation = Some(value);
308311
self
309312
}
310-
313+
pub fn get_joins(&self) -> Vec<ast::Join> {
314+
self.joins.clone()
315+
}
311316
pub fn joins(&mut self, value: Vec<ast::Join>) -> &mut Self {
312317
self.joins = value;
313318
self
@@ -360,6 +365,25 @@ impl RelationBuilder {
360365
pub fn has_relation(&self) -> bool {
361366
self.relation.is_some()
362367
}
368+
pub fn get_name(&self) -> Option<String> {
369+
match self.relation {
370+
Some(TableFactorBuilder::Table(ref value)) => {
371+
value.name.as_ref().map(|a| a.to_string())
372+
}
373+
_ => None,
374+
}
375+
}
376+
pub fn get_alias(&self) -> Option<String> {
377+
match self.relation {
378+
Some(TableFactorBuilder::Table(ref value)) => {
379+
value.alias.as_ref().map(|a| a.name.to_string())
380+
}
381+
Some(TableFactorBuilder::Derived(ref value)) => {
382+
value.alias.as_ref().map(|a| a.name.to_string())
383+
}
384+
_ => None,
385+
}
386+
}
363387
pub fn table(&mut self, value: TableRelationBuilder) -> &mut Self {
364388
self.relation = Some(TableFactorBuilder::Table(value));
365389
self

datafusion/sql/src/unparser/plan.rs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use datafusion_expr::{
4242
expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan,
4343
LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan,
4444
};
45-
use sqlparser::ast::{self, Ident, SetExpr};
45+
use sqlparser::ast::{self, display_separated, Ident, SetExpr};
4646
use std::sync::Arc;
4747

4848
/// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`]
@@ -159,10 +159,69 @@ impl Unparser<'_> {
159159
)]);
160160
}
161161

162+
// Construct a list of all the identifiers present in query sources
163+
let mut all_idents = Vec::new();
164+
if let Some(source_alias) = relation_builder.get_alias() {
165+
all_idents.push(source_alias);
166+
} else if let Some(source_name) = relation_builder.get_name() {
167+
all_idents.push(source_name);
168+
}
169+
162170
let mut twj = select_builder.pop_from().unwrap();
171+
twj.get_joins()
172+
.iter()
173+
.for_each(|join| match &join.relation {
174+
ast::TableFactor::Table { alias, name, .. } => {
175+
if let Some(alias) = alias {
176+
all_idents.push(alias.name.to_string());
177+
} else {
178+
all_idents.push(name.to_string());
179+
}
180+
}
181+
ast::TableFactor::Derived { alias, .. } => {
182+
if let Some(alias) = alias {
183+
all_idents.push(alias.name.to_string());
184+
}
185+
}
186+
_ => {}
187+
});
188+
163189
twj.relation(relation_builder);
164190
select_builder.push_from(twj);
165191

192+
// Ensure that the projection contains references to sources that actually exist
193+
let mut projection = select_builder.get_projection();
194+
projection
195+
.iter_mut()
196+
.for_each(|select_item| match select_item {
197+
ast::SelectItem::UnnamedExpr(ast::Expr::CompoundIdentifier(idents)) => {
198+
if idents.len() > 1 {
199+
let ident_source = display_separated(
200+
&idents
201+
.clone()
202+
.into_iter()
203+
.take(idents.len() - 1)
204+
.collect::<Vec<Ident>>(),
205+
".",
206+
)
207+
.to_string();
208+
// If the identifier is not present in the list of all identifiers, it refers to a table that does not exist
209+
if !all_idents.contains(&ident_source) {
210+
let Some(last) = idents.last() else {
211+
unreachable!(
212+
"CompoundIdentifier must have a last element"
213+
);
214+
};
215+
// Reset the identifiers to only the last element, which is the column name
216+
*idents = vec![last.clone()];
217+
}
218+
}
219+
}
220+
_ => {}
221+
});
222+
223+
select_builder.projection(projection);
224+
166225
Ok(SetExpr::Select(Box::new(select_builder.build()?)))
167226
}
168227

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,58 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
337337
parser_dialect: Box::new(GenericDialect {}),
338338
unparser_dialect: Box::new(UnparserDefaultDialect {}),
339339
},
340+
TestStatementWithDialect {
341+
sql: "select j1_id from (select ta.j1_id from j1 ta)",
342+
expected:
343+
// This seems like desirable behavior, but is actually hiding an underlying issue
344+
// The re-written identifier is `ta`.`j1_id`, because `reconstuct_select_statement` runs before the derived projection
345+
// and for some reason, the derived table alias is pre-set to `ta` for the top-level projection
346+
"SELECT `j1_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `derived_projection`",
347+
parser_dialect: Box::new(MySqlDialect {}),
348+
unparser_dialect: Box::new(UnparserMySqlDialect {}),
349+
},
350+
TestStatementWithDialect {
351+
sql: "select j1_id from (select ta.j1_id from j1 ta)",
352+
expected:
353+
"SELECT j1_id FROM (SELECT ta.j1_id FROM j1 AS ta)",
354+
parser_dialect: Box::new(GenericDialect {}),
355+
unparser_dialect: Box::new(UnparserDefaultDialect {}),
356+
},
357+
TestStatementWithDialect {
358+
sql: "select j1_id from (select ta.j1_id from j1 ta) AS tbl1",
359+
expected:
360+
"SELECT tbl1.j1_id FROM (SELECT ta.j1_id FROM j1 AS ta) AS tbl1",
361+
parser_dialect: Box::new(GenericDialect {}),
362+
unparser_dialect: Box::new(UnparserDefaultDialect {}),
363+
},
364+
TestStatementWithDialect {
365+
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta) AS tbl1, (select ta.j1_id as j2_id from j1 ta) as tbl2",
366+
expected:
367+
"SELECT tbl1.j1_id, tbl2.j2_id FROM (SELECT ta.j1_id FROM j1 AS ta) AS tbl1 JOIN (SELECT ta.j1_id AS j2_id FROM j1 AS ta) AS tbl2 ON true",
368+
parser_dialect: Box::new(GenericDialect {}),
369+
unparser_dialect: Box::new(UnparserDefaultDialect {}),
370+
},
371+
TestStatementWithDialect {
372+
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta) AS tbl1, (select ta.j1_id as j2_id from j1 ta) as tbl2",
373+
expected:
374+
"SELECT `tbl1`.`j1_id`, `tbl2`.`j2_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `tbl1` JOIN (SELECT `ta`.`j1_id` AS `j2_id` FROM `j1` AS `ta`) AS `tbl2` ON true",
375+
parser_dialect: Box::new(MySqlDialect {}),
376+
unparser_dialect: Box::new(UnparserMySqlDialect {}),
377+
},
378+
TestStatementWithDialect {
379+
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta), (select ta.j1_id as j2_id from j1 ta)",
380+
expected:
381+
"SELECT `j1_id`, `j2_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `derived_projection` JOIN (SELECT `ta`.`j1_id` AS `j2_id` FROM `j1` AS `ta`) AS `derived_projection` ON true",
382+
parser_dialect: Box::new(MySqlDialect {}),
383+
unparser_dialect: Box::new(UnparserMySqlDialect {}),
384+
},
385+
TestStatementWithDialect {
386+
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta), (select ta.j1_id AS j2_id from j1 ta)",
387+
expected:
388+
"SELECT j1_id, j2_id FROM (SELECT ta.j1_id FROM j1 AS ta) JOIN (SELECT ta.j1_id AS j2_id FROM j1 AS ta) ON true",
389+
parser_dialect: Box::new(GenericDialect {}),
390+
unparser_dialect: Box::new(UnparserDefaultDialect {}),
391+
},
340392
TestStatementWithDialect {
341393
sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id order by j1_id",
342394
expected: r#"SELECT j1.j1_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#,

0 commit comments

Comments
 (0)