Skip to content

Commit e151863

Browse files
peaseephillipleblanc
authored andcommitted
fix: Ensure only tables or aliases that exist are projected (#52)
fix: More dangling references (#54) * fix: More dangling references * test: Add tests for remove_dangling_identifiers UPSTREAM NOTE: This PR was attempted to be upstreamed in apache#13405 - but it was not accepted due to the complexity it brought. Phillip needs to figure out what a good solution that solves our problem and can be upstreamed is.
1 parent 26058ac commit e151863

File tree

4 files changed

+231
-3
lines changed

4 files changed

+231
-3
lines changed

datafusion/sql/src/unparser/ast.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ impl QueryBuilder {
4646
pub fn take_body(&mut self) -> Option<Box<ast::SetExpr>> {
4747
self.body.take()
4848
}
49+
pub fn get_order_by(&self) -> Vec<ast::OrderByExpr> {
50+
self.order_by.clone()
51+
}
4952
pub fn order_by(&mut self, value: Vec<ast::OrderByExpr>) -> &mut Self {
5053
self.order_by = value;
5154
self
@@ -150,6 +153,9 @@ impl SelectBuilder {
150153
self.top = value;
151154
self
152155
}
156+
pub fn get_projection(&self) -> Vec<ast::SelectItem> {
157+
self.projection.clone()
158+
}
153159
pub fn projection(&mut self, value: Vec<ast::SelectItem>) -> &mut Self {
154160
self.projection = value;
155161
self
@@ -217,6 +223,9 @@ impl SelectBuilder {
217223
self.sort_by = value;
218224
self
219225
}
226+
pub fn get_sort_by(&self) -> Vec<ast::Expr> {
227+
self.sort_by.clone()
228+
}
220229
pub fn having(&mut self, value: Option<ast::Expr>) -> &mut Self {
221230
self.having = value;
222231
self
@@ -304,7 +313,9 @@ impl TableWithJoinsBuilder {
304313
self.relation = Some(value);
305314
self
306315
}
307-
316+
pub fn get_joins(&self) -> Vec<ast::Join> {
317+
self.joins.clone()
318+
}
308319
pub fn joins(&mut self, value: Vec<ast::Join>) -> &mut Self {
309320
self.joins = value;
310321
self
@@ -358,6 +369,25 @@ impl RelationBuilder {
358369
pub fn has_relation(&self) -> bool {
359370
self.relation.is_some()
360371
}
372+
pub fn get_name(&self) -> Option<String> {
373+
match self.relation {
374+
Some(TableFactorBuilder::Table(ref value)) => {
375+
value.name.as_ref().map(|a| a.to_string())
376+
}
377+
_ => None,
378+
}
379+
}
380+
pub fn get_alias(&self) -> Option<String> {
381+
match self.relation {
382+
Some(TableFactorBuilder::Table(ref value)) => {
383+
value.alias.as_ref().map(|a| a.name.to_string())
384+
}
385+
Some(TableFactorBuilder::Derived(ref value)) => {
386+
value.alias.as_ref().map(|a| a.name.to_string())
387+
}
388+
_ => None,
389+
}
390+
}
361391
pub fn table(&mut self, value: TableRelationBuilder) -> &mut Self {
362392
self.relation = Some(TableFactorBuilder::Table(value));
363393
self

datafusion/sql/src/unparser/plan.rs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use super::{
2222
},
2323
rewrite::{
2424
inject_column_aliases_into_subquery, normalize_union_schema,
25-
rewrite_plan_for_sort_on_non_projected_fields,
25+
remove_dangling_identifiers, rewrite_plan_for_sort_on_non_projected_fields,
2626
subquery_alias_inner_query_and_columns, TableAliasRewriter,
2727
},
2828
utils::{
@@ -209,10 +209,69 @@ impl Unparser<'_> {
209209
)]);
210210
}
211211

212+
// Construct a list of all the identifiers present in query sources
213+
let mut all_idents = Vec::new();
214+
if let Some(source_alias) = relation_builder.get_alias() {
215+
all_idents.push(source_alias);
216+
} else if let Some(source_name) = relation_builder.get_name() {
217+
all_idents.push(source_name);
218+
}
219+
212220
let mut twj = select_builder.pop_from().unwrap();
221+
twj.get_joins()
222+
.iter()
223+
.for_each(|join| match &join.relation {
224+
ast::TableFactor::Table { alias, name, .. } => {
225+
if let Some(alias) = alias {
226+
all_idents.push(alias.name.to_string());
227+
} else {
228+
all_idents.push(name.to_string());
229+
}
230+
}
231+
ast::TableFactor::Derived { alias, .. } => {
232+
if let Some(alias) = alias {
233+
all_idents.push(alias.name.to_string());
234+
}
235+
}
236+
_ => {}
237+
});
238+
213239
twj.relation(relation_builder);
214240
select_builder.push_from(twj);
215241

242+
// Ensure that the projection contains references to sources that actually exist
243+
let mut projection = select_builder.get_projection();
244+
projection
245+
.iter_mut()
246+
.for_each(|select_item| match select_item {
247+
ast::SelectItem::UnnamedExpr(ast::Expr::CompoundIdentifier(idents)) => {
248+
remove_dangling_identifiers(idents, &all_idents);
249+
}
250+
_ => {}
251+
});
252+
253+
// Check the order by as well
254+
if let Some(query) = query.as_mut() {
255+
let mut order_by = query.get_order_by();
256+
order_by.iter_mut().for_each(|sort_item| {
257+
if let ast::Expr::CompoundIdentifier(idents) = &mut sort_item.expr {
258+
remove_dangling_identifiers(idents, &all_idents);
259+
}
260+
});
261+
262+
query.order_by(order_by);
263+
}
264+
265+
// Order by could be a sort in the select builder
266+
let mut sort = select_builder.get_sort_by();
267+
sort.iter_mut().for_each(|sort_item| {
268+
if let ast::Expr::CompoundIdentifier(idents) = sort_item {
269+
remove_dangling_identifiers(idents, &all_idents);
270+
}
271+
});
272+
273+
select_builder.projection(projection);
274+
216275
Ok(SetExpr::Select(Box::new(select_builder.build()?)))
217276
}
218277

datafusion/sql/src/unparser/rewrite.rs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use datafusion_common::{
2525
};
2626
use datafusion_expr::expr::{Alias, UNNEST_COLUMN_PREFIX};
2727
use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr};
28-
use sqlparser::ast::Ident;
28+
use sqlparser::ast::{display_separated, Ident};
2929

3030
/// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions.
3131
///
@@ -417,3 +417,68 @@ impl TreeNodeRewriter for TableAliasRewriter<'_> {
417417
}
418418
}
419419
}
420+
421+
/// Takes an input list of identifiers and a list of identifiers that are available from relations or joins.
422+
/// Removes any table identifiers that are not present in the list of available identifiers, retains original column names.
423+
pub fn remove_dangling_identifiers(
424+
idents: &mut Vec<Ident>,
425+
available_idents: &Vec<String>,
426+
) -> () {
427+
if idents.len() > 1 {
428+
let ident_source = display_separated(
429+
&idents
430+
.clone()
431+
.into_iter()
432+
.take(idents.len() - 1)
433+
.collect::<Vec<Ident>>(),
434+
".",
435+
)
436+
.to_string();
437+
// If the identifier is not present in the list of all identifiers, it refers to a table that does not exist
438+
if !available_idents.contains(&ident_source) {
439+
let Some(last) = idents.last() else {
440+
unreachable!("CompoundIdentifier must have a last element");
441+
};
442+
// Reset the identifiers to only the last element, which is the column name
443+
*idents = vec![last.clone()];
444+
}
445+
}
446+
}
447+
448+
#[cfg(test)]
449+
mod test {
450+
use super::*;
451+
452+
#[test]
453+
fn test_remove_dangling_identifiers() {
454+
let tests = vec![
455+
(vec![], vec![Ident::new("column1".to_string())]),
456+
(
457+
vec!["table1.table2".to_string()],
458+
vec![
459+
Ident::new("table1".to_string()),
460+
Ident::new("table2".to_string()),
461+
Ident::new("column1".to_string()),
462+
],
463+
),
464+
(
465+
vec!["table1".to_string()],
466+
vec![Ident::new("column1".to_string())],
467+
),
468+
];
469+
470+
for test in tests {
471+
let test_in = test.0;
472+
let test_out = test.1;
473+
474+
let mut idents = vec![
475+
Ident::new("table1".to_string()),
476+
Ident::new("table2".to_string()),
477+
Ident::new("column1".to_string()),
478+
];
479+
480+
remove_dangling_identifiers(&mut idents, &test_in);
481+
assert_eq!(idents, test_out);
482+
}
483+
}
484+
}

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,80 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
366366
parser_dialect: Box::new(GenericDialect {}),
367367
unparser_dialect: Box::new(UnparserDefaultDialect {}),
368368
},
369+
TestStatementWithDialect {
370+
sql: "select j1_id from (select ta.j1_id from j1 ta)",
371+
expected:
372+
// This seems like desirable behavior, but is actually hiding an underlying issue
373+
// The re-written identifier is `ta`.`j1_id`, because `reconstuct_select_statement` runs before the derived projection
374+
// and for some reason, the derived table alias is pre-set to `ta` for the top-level projection
375+
"SELECT `j1_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `derived_projection`",
376+
parser_dialect: Box::new(MySqlDialect {}),
377+
unparser_dialect: Box::new(UnparserMySqlDialect {}),
378+
},
379+
TestStatementWithDialect {
380+
sql: "select j1_id from (select ta.j1_id from j1 ta)",
381+
expected:
382+
"SELECT j1_id FROM (SELECT ta.j1_id FROM j1 AS ta)",
383+
parser_dialect: Box::new(GenericDialect {}),
384+
unparser_dialect: Box::new(UnparserDefaultDialect {}),
385+
},
386+
TestStatementWithDialect {
387+
sql: "select j1_id from (select ta.j1_id from j1 ta) order by j1_id",
388+
expected:
389+
"SELECT j1_id FROM (SELECT ta.j1_id FROM j1 AS ta) ORDER BY j1_id ASC NULLS LAST",
390+
parser_dialect: Box::new(GenericDialect {}),
391+
unparser_dialect: Box::new(UnparserDefaultDialect {}),
392+
},
393+
// TODO: remove dangling identifiers from group by, filter, etc
394+
// TestStatementWithDialect {
395+
// sql: "select j1_id from (select ta.j1_id from j1 ta) where j1_id = 1",
396+
// expected:
397+
// "SELECT j1_id FROM (SELECT ta.j1_id FROM j1 AS ta) WHERE (ta.j1_id = 1)",
398+
// parser_dialect: Box::new(GenericDialect {}),
399+
// unparser_dialect: Box::new(UnparserDefaultDialect {}),
400+
// },
401+
TestStatementWithDialect {
402+
sql: "select j1_id from (select ta.j1_id from j1 ta) order by j1_id",
403+
expected:
404+
"SELECT `j1_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `derived_projection` ORDER BY `j1_id` ASC",
405+
parser_dialect: Box::new(MySqlDialect {}),
406+
unparser_dialect: Box::new(UnparserMySqlDialect {}),
407+
},
408+
TestStatementWithDialect {
409+
sql: "select j1_id from (select ta.j1_id from j1 ta) AS tbl1",
410+
expected:
411+
"SELECT tbl1.j1_id FROM (SELECT ta.j1_id FROM j1 AS ta) AS tbl1",
412+
parser_dialect: Box::new(GenericDialect {}),
413+
unparser_dialect: Box::new(UnparserDefaultDialect {}),
414+
},
415+
TestStatementWithDialect {
416+
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta) AS tbl1, (select ta.j1_id as j2_id from j1 ta) as tbl2",
417+
expected:
418+
"SELECT tbl1.j1_id, tbl2.j2_id FROM (SELECT ta.j1_id FROM j1 AS ta) AS tbl1 JOIN (SELECT ta.j1_id AS j2_id FROM j1 AS ta) AS tbl2 ON true",
419+
parser_dialect: Box::new(GenericDialect {}),
420+
unparser_dialect: Box::new(UnparserDefaultDialect {}),
421+
},
422+
TestStatementWithDialect {
423+
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta) AS tbl1, (select ta.j1_id as j2_id from j1 ta) as tbl2",
424+
expected:
425+
"SELECT `tbl1`.`j1_id`, `tbl2`.`j2_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `tbl1` JOIN (SELECT `ta`.`j1_id` AS `j2_id` FROM `j1` AS `ta`) AS `tbl2` ON true",
426+
parser_dialect: Box::new(MySqlDialect {}),
427+
unparser_dialect: Box::new(UnparserMySqlDialect {}),
428+
},
429+
TestStatementWithDialect {
430+
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta), (select ta.j1_id as j2_id from j1 ta)",
431+
expected:
432+
"SELECT `j1_id`, `j2_id` FROM (SELECT `ta`.`j1_id` FROM `j1` AS `ta`) AS `derived_projection` JOIN (SELECT `ta`.`j1_id` AS `j2_id` FROM `j1` AS `ta`) AS `derived_projection` ON true",
433+
parser_dialect: Box::new(MySqlDialect {}),
434+
unparser_dialect: Box::new(UnparserMySqlDialect {}),
435+
},
436+
TestStatementWithDialect {
437+
sql: "select j1_id, j2_id from (select ta.j1_id from j1 ta), (select ta.j1_id AS j2_id from j1 ta)",
438+
expected:
439+
"SELECT j1_id, j2_id FROM (SELECT ta.j1_id FROM j1 AS ta) JOIN (SELECT ta.j1_id AS j2_id FROM j1 AS ta) ON true",
440+
parser_dialect: Box::new(GenericDialect {}),
441+
unparser_dialect: Box::new(UnparserDefaultDialect {}),
442+
},
369443
TestStatementWithDialect {
370444
sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id order by j1_id",
371445
expected: r#"SELECT j1.j1_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#,

0 commit comments

Comments
 (0)