diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 482c08fb444e..21396821c723 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -374,6 +374,24 @@ fn optimize_plan( )?; from_plan(plan, &plan.expressions(), &[child]) } + // at a distinct, all columns are required + LogicalPlan::Distinct(distinct) => { + let new_required_columns = distinct + .input + .schema() + .fields() + .iter() + .map(|f| f.qualified_column()) + .collect(); + let child = optimize_plan( + _optimizer, + distinct.input.as_ref(), + &new_required_columns, + has_projection, + _config, + )?; + from_plan(plan, &[], &[child]) + } // all other nodes: Add any additional columns used by // expressions in this node to the list of required columns LogicalPlan::Limit(_) @@ -392,7 +410,6 @@ fn optimize_plan( | LogicalPlan::DropView(_) | LogicalPlan::SetVariable(_) | LogicalPlan::CrossJoin(_) - | LogicalPlan::Distinct(_) | LogicalPlan::Extension { .. } | LogicalPlan::Prepare(_) => { let expr = plan.expressions(); @@ -1009,6 +1026,25 @@ mod tests { Ok(()) } + #[test] + fn pushdown_through_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .distinct()? + .project(vec![col("a")])? + .build()?; + + let expected = "Projection: test.a\ + \n Distinct:\ + \n TableScan: test projection=[a, b]"; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { let optimized_plan = optimize(plan).expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}");