Skip to content

Commit ef3d3dd

Browse files
committed
Handle table scan filters that reference dropped columns (datafusion-contrib#59)
* use full table schema when analyzing predicates * uncomment tests * comment * fmt
1 parent f5ec880 commit ef3d3dd

File tree

1 file changed

+69
-17
lines changed

1 file changed

+69
-17
lines changed

src/rewrite/normal_form.rs

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ impl SpjNormalForm {
334334
/// Stores information on filters from a Select-Project-Join plan.
335335
#[derive(Debug, Clone)]
336336
struct Predicate {
337+
/// Full table schema, including all possible columns.
337338
schema: DFSchema,
338339
/// List of column equivalence classes.
339340
eq_classes: Vec<ColumnEquivalenceClass>,
@@ -350,10 +351,14 @@ impl Predicate {
350351
let mut schema = DFSchema::empty();
351352
plan.apply(|plan| {
352353
if let LogicalPlan::TableScan(scan) = plan {
354+
let new_schema = DFSchema::try_from_qualified_schema(
355+
scan.table_name.clone(),
356+
scan.source.schema().as_ref(),
357+
)?;
353358
schema = if schema.fields().is_empty() {
354-
(*scan.projected_schema).clone()
359+
new_schema
355360
} else {
356-
schema.join(&scan.projected_schema)?
361+
schema.join(&new_schema)?
357362
}
358363
}
359364

@@ -371,7 +376,13 @@ impl Predicate {
371376
// Collect all referenced columns
372377
plan.apply(|plan| {
373378
if let LogicalPlan::TableScan(scan) = plan {
374-
for (i, (table_ref, field)) in scan.projected_schema.iter().enumerate() {
379+
for (i, (table_ref, field)) in DFSchema::try_from_qualified_schema(
380+
scan.table_name.clone(),
381+
scan.source.schema().as_ref(),
382+
)?
383+
.iter()
384+
.enumerate()
385+
{
375386
let column = Column::new(table_ref.cloned(), field.name());
376387
let data_type = field.data_type();
377388
new.eq_classes
@@ -948,17 +959,47 @@ fn get_table_scan_columns(scan: &TableScan) -> Result<Vec<Column>> {
948959
#[cfg(test)]
949960
mod test {
950961
use arrow::compute::concat_batches;
951-
use datafusion::{datasource::provider_as_source, prelude::SessionContext};
962+
use datafusion::{
963+
datasource::provider_as_source,
964+
prelude::{SessionConfig, SessionContext},
965+
};
952966
use datafusion_common::{DataFusionError, Result};
953967
use datafusion_sql::TableReference;
968+
use tempfile::tempdir;
954969

955970
use super::SpjNormalForm;
956971

957972
async fn setup() -> Result<SessionContext> {
958-
let ctx = SessionContext::new();
973+
let ctx = SessionContext::new_with_config(
974+
SessionConfig::new()
975+
.set_bool("datafusion.execution.parquet.pushdown_filters", true)
976+
.set_bool("datafusion.explain.logical_plan_only", true),
977+
);
978+
979+
let t1_path = tempdir()?;
980+
981+
// Create external table to exercise parquet filter pushdown.
982+
// This will put the filters directly inside the `TableScan` node.
983+
// This is important because `TableScan` can have filters on
984+
// columns not in its own output.
985+
ctx.sql(&format!(
986+
"
987+
CREATE EXTERNAL TABLE t1 (
988+
column1 VARCHAR,
989+
column2 BIGINT,
990+
column3 CHAR
991+
)
992+
STORED AS PARQUET
993+
LOCATION '{}'",
994+
t1_path.path().to_string_lossy()
995+
))
996+
.await
997+
.map_err(|e| e.context("setup `t1` table"))?
998+
.collect()
999+
.await?;
9591000

9601001
ctx.sql(
961-
"CREATE TABLE t1 AS VALUES
1002+
"INSERT INTO t1 VALUES
9621003
('2021', 3, 'A'),
9631004
('2022', 4, 'B'),
9641005
('2023', 5, 'C')",
@@ -980,8 +1021,7 @@ mod test {
9801021
o_orderdate DATE,
9811022
p_name VARCHAR,
9821023
p_partkey INT
983-
)
984-
",
1024+
)",
9851025
)
9861026
.await
9871027
.map_err(|e| e.context("parse `example` table ddl"))?
@@ -1014,6 +1054,15 @@ mod test {
10141054
let query_plan = context.sql(case.query).await?.into_optimized_plan()?;
10151055
let query_normal_form = SpjNormalForm::new(&query_plan)?;
10161056

1057+
for plan in [&base_plan, &query_plan] {
1058+
context
1059+
.execute_logical_plan(plan.clone())
1060+
.await?
1061+
.explain(false, false)?
1062+
.show()
1063+
.await?;
1064+
}
1065+
10171066
let table_ref = TableReference::bare("mv");
10181067
let rewritten = query_normal_form
10191068
.rewrite_from(
@@ -1025,16 +1074,14 @@ mod test {
10251074
"expected rewrite to succeed".to_string(),
10261075
))?;
10271076

1028-
assert_eq!(rewritten.schema().as_ref(), query_plan.schema().as_ref());
1077+
context
1078+
.execute_logical_plan(rewritten.clone())
1079+
.await?
1080+
.explain(false, false)?
1081+
.show()
1082+
.await?;
10291083

1030-
for plan in [&base_plan, &query_plan, &rewritten] {
1031-
context
1032-
.execute_logical_plan(plan.clone())
1033-
.await?
1034-
.explain(false, false)?
1035-
.show()
1036-
.await?;
1037-
}
1084+
assert_eq!(rewritten.schema().as_ref(), query_plan.schema().as_ref());
10381085

10391086
let expected = concat_batches(
10401087
&query_plan.schema().as_ref().clone().into(),
@@ -1133,6 +1180,11 @@ mod test {
11331180
l_quantity*l_extendedprice > 100
11341181
",
11351182
},
1183+
TestCase {
1184+
name: "naked table scan with pushed down filters",
1185+
base: "SELECT column1 FROM t1 WHERE column2 <= 3",
1186+
query: "SELECT FROM t1 WHERE column2 <= 3",
1187+
},
11361188
];
11371189

11381190
for case in cases {

0 commit comments

Comments
 (0)