diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index a120ab427e1d..2e897c9bc5b2 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -22,12 +22,14 @@ use crate::expressions::Column; use crate::utils::collect_columns; use crate::PhysicalExpr; +use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use indexmap::IndexMap; use itertools::Itertools; @@ -47,6 +49,15 @@ pub struct ProjectionExpr { pub alias: String, } +impl PartialEq for ProjectionExpr { + fn eq(&self, other: &Self) -> bool { + let ProjectionExpr { expr, alias } = self; + expr.eq(&other.expr) && *alias == other.alias + } +} + +impl Eq for ProjectionExpr {} + impl std::fmt::Display for ProjectionExpr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if self.expr.to_string() == self.alias { @@ -99,7 +110,7 @@ impl From for (Arc, String) { /// This struct encapsulates multiple `ProjectionExpr` instances, /// representing a complete projection operation and provides /// methods to manipulate and analyze the projection as a whole. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ProjectionExprs { exprs: Vec, } @@ -192,7 +203,7 @@ impl ProjectionExprs { /// assert_eq!(projection_with_dups.as_ref()[1].alias, "a"); // duplicate /// assert_eq!(projection_with_dups.as_ref()[2].alias, "b"); /// ``` - pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Self { + pub fn from_indices(indices: &[usize], schema: &Schema) -> Self { let projection_exprs = indices.iter().map(|&i| { let field = schema.field(i); ProjectionExpr { @@ -396,6 +407,22 @@ impl ProjectionExprs { )) } + /// Create a new [`Projector`] from this projection and an input schema. + /// + /// A [`Projector`] can be used to apply this projection to record batches. + /// + /// # Errors + /// This function returns an error if the output schema cannot be constructed from the input schema + /// with the given projection expressions. + /// For example, if an expression only works with integer columns but the input schema has a string column at that index. + pub fn make_projector(&self, input_schema: &Schema) -> Result { + let output_schema = Arc::new(self.project_schema(input_schema)?); + Ok(Projector { + projection: self.clone(), + output_schema, + }) + } + /// Project statistics according to this projection. /// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1, /// if the input statistics has column statistics for columns `a`, `b`, and `c`, the output statistics would have column statistics for columns `x` and `y`. @@ -444,6 +471,57 @@ impl<'a> IntoIterator for &'a ProjectionExprs { } } +/// Applies a projection to record batches. +/// +/// A [`Projector`] uses a set of projection expressions to transform +/// and a pre-computed output schema to project record batches accordingly. +/// +/// The main reason to use a `Projector` is to avoid repeatedly computing +/// the output schema for each batch, which can be costly if the projection +/// expressions are complex. +#[derive(Clone, Debug)] +pub struct Projector { + projection: ProjectionExprs, + output_schema: SchemaRef, +} + +impl Projector { + /// Project a record batch according to this projector's expressions. + /// + /// # Errors + /// This function returns an error if any expression evaluation fails + /// or if the output schema of the resulting record batch does not match + /// the pre-computed output schema of the projector. + pub fn project_batch(&self, batch: &RecordBatch) -> Result { + let arrays = evaluate_expressions_to_arrays( + self.projection.exprs.iter().map(|p| &p.expr), + batch, + )?; + + if arrays.is_empty() { + let options = + RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options( + Arc::clone(&self.output_schema), + arrays, + &options, + ) + .map_err(Into::into) + } else { + RecordBatch::try_new(Arc::clone(&self.output_schema), arrays) + .map_err(Into::into) + } + } + + pub fn output_schema(&self) -> &SchemaRef { + &self.output_schema + } + + pub fn projection(&self) -> &ProjectionExprs { + &self.projection + } +} + impl IntoIterator for ProjectionExprs { type Item = ProjectionExpr; type IntoIter = std::vec::IntoIter; @@ -545,7 +623,13 @@ pub fn update_expr( }) .data()?; - Ok((state == RewriteState::RewrittenValid).then_some(new_expr)) + match state { + RewriteState::RewrittenInvalid => Ok(None), + // Both Unchanged and RewrittenValid are valid: + // - Unchanged means no columns to rewrite (e.g., literals) + // - RewrittenValid means columns were successfully rewritten + RewriteState::Unchanged | RewriteState::RewrittenValid => Ok(Some(new_expr)), + } } /// Stores target expressions, along with their indices, that associate with a @@ -2009,6 +2093,94 @@ pub(crate) mod tests { ); } + #[test] + fn test_merge_empty_projection_with_literal() -> Result<()> { + // This test reproduces the issue from roundtrip_empty_projection test + // Query like: SELECT 1 FROM table + // where the file scan needs no columns (empty projection) + // but we project a literal on top + + // Empty base projection (no columns needed from file) + let base_projection = ProjectionExprs::new(vec![]); + + // Top projection with a literal expression: SELECT 1 + let top_projection = ProjectionExprs::new(vec![ProjectionExpr { + expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), + alias: "Int64(1)".to_string(), + }]); + + // This should succeed - literals don't reference columns so they should + // pass through unchanged when merged with an empty projection + let merged = base_projection.try_merge(&top_projection)?; + assert_snapshot!(format!("{merged}"), @"Projection[1 AS Int64(1)]"); + + Ok(()) + } + + #[test] + fn test_update_expr_with_literal() -> Result<()> { + // Test that update_expr correctly handles expressions without column references + let literal_expr: Arc = + Arc::new(Literal::new(ScalarValue::Int64(Some(42)))); + let empty_projection: Vec = vec![]; + + // Updating a literal with an empty projection should return the literal unchanged + let result = update_expr(&literal_expr, &empty_projection, true)?; + assert!(result.is_some(), "Literal expression should be valid"); + + let result_expr = result.unwrap(); + assert_eq!( + result_expr + .as_any() + .downcast_ref::() + .unwrap() + .value(), + &ScalarValue::Int64(Some(42)) + ); + + Ok(()) + } + + #[test] + fn test_update_expr_with_complex_literal_expr() -> Result<()> { + // Test update_expr with an expression containing both literals and a column + // This tests the case where we have: literal + column + let expr: Arc = Arc::new(BinaryExpr::new( + Arc::new(Literal::new(ScalarValue::Int64(Some(10)))), + Operator::Plus, + Arc::new(Column::new("x", 0)), + )); + + // Base projection that maps column 0 to a different expression + let base_projection = vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 5)), + alias: "x".to_string(), + }]; + + // The expression should be updated: 10 + x@0 becomes 10 + a@5 + let result = update_expr(&expr, &base_projection, true)?; + assert!(result.is_some(), "Expression should be valid"); + + let result_expr = result.unwrap(); + let binary = result_expr + .as_any() + .downcast_ref::() + .expect("Should be a BinaryExpr"); + + // Left side should still be the literal + assert!(binary.left().as_any().downcast_ref::().is_some()); + + // Right side should be updated to reference column at index 5 + let right_col = binary + .right() + .as_any() + .downcast_ref::() + .expect("Right should be a Column"); + assert_eq!(right_col.index(), 5); + + Ok(()) + } + #[test] fn test_project_schema_simple_columns() -> Result<()> { // Input schema: [col0: Int64, col1: Utf8, col2: Float32] diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 8bc2bcd6f2e9..abd8daa3fd7e 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -40,7 +40,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; -use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, @@ -48,6 +48,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{internal_err, JoinSide, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::projection::Projector; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; @@ -57,7 +58,6 @@ pub use datafusion_physical_expr::projection::{ update_expr, ProjectionExpr, ProjectionExprs, }; -use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -67,10 +67,9 @@ use log::trace; /// output row for each input row. #[derive(Debug, Clone)] pub struct ProjectionExec { - /// The projection expressions stored as tuples of (expression, output column name) - projection: ProjectionExprs, - /// The schema once the projection has been applied to the input - schema: SchemaRef, + /// A projector specialized to apply the projection to the input schema from the child node + /// and produce [`RecordBatch`]es with the output schema of this node. + projector: Projector, /// The input plan input: Arc, /// Execution metrics @@ -138,16 +137,17 @@ impl ProjectionExec { // convert argument to Vec let expr_vec = expr.into_iter().map(Into::into).collect::>(); let projection = ProjectionExprs::new(expr_vec); - - let schema = Arc::new(projection.project_schema(&input_schema)?); + let projector = projection.make_projector(&input_schema)?; // Construct a map from the input expressions to the output expression of the Projection let projection_mapping = projection.projection_mapping(&input_schema)?; - let cache = - Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?; + let cache = Self::compute_properties( + &input, + &projection_mapping, + Arc::clone(projector.output_schema()), + )?; Ok(Self { - projection, - schema, + projector, input, metrics: ExecutionPlanMetricsSet::new(), cache, @@ -156,7 +156,7 @@ impl ProjectionExec { /// The projection expressions stored as tuples of (expression, output column name) pub fn expr(&self) -> &[ProjectionExpr] { - self.projection.as_ref() + self.projector.projection().as_ref() } /// The input plan @@ -196,7 +196,8 @@ impl DisplayAs for ProjectionExec { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let expr: Vec = self - .projection + .projector + .projection() .as_ref() .iter() .map(|proj_expr| { @@ -247,10 +248,15 @@ impl ExecutionPlan for ProjectionExec { } fn benefits_from_input_partitioning(&self) -> Vec { - let all_simple_exprs = self.projection.iter().all(|proj_expr| { - proj_expr.expr.as_any().is::() - || proj_expr.expr.as_any().is::() - }); + let all_simple_exprs = + self.projector + .projection() + .as_ref() + .iter() + .all(|proj_expr| { + proj_expr.expr.as_any().is::() + || proj_expr.expr.as_any().is::() + }); // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename, // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. vec![!all_simple_exprs] @@ -264,8 +270,11 @@ impl ExecutionPlan for ProjectionExec { self: Arc, mut children: Vec>, ) -> Result> { - ProjectionExec::try_new(self.projection.clone(), children.swap_remove(0)) - .map(|p| Arc::new(p) as _) + ProjectionExec::try_new( + self.projector.projection().clone(), + children.swap_remove(0), + ) + .map(|p| Arc::new(p) as _) } fn execute( @@ -275,11 +284,10 @@ impl ExecutionPlan for ProjectionExec { ) -> Result { trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); Ok(Box::pin(ProjectionStream::new( - Arc::clone(&self.schema), - self.projection.expr_iter().collect(), + self.projector.clone(), self.input.execute(partition, context)?, BaselineMetrics::new(&self.metrics, partition), - ))) + )?)) } fn metrics(&self) -> Option { @@ -292,7 +300,8 @@ impl ExecutionPlan for ProjectionExec { fn partition_statistics(&self, partition: Option) -> Result { let input_stats = self.input.partition_statistics(partition)?; - self.projection + self.projector + .projection() .project_statistics(input_stats, &self.input.schema()) } @@ -342,39 +351,27 @@ impl ExecutionPlan for ProjectionExec { impl ProjectionStream { /// Create a new projection stream fn new( - schema: SchemaRef, - expr: Vec>, + projector: Projector, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, - ) -> Self { - Self { - schema, - expr, + ) -> Result { + Ok(Self { + projector, input, baseline_metrics, - } + }) } fn batch_project(&self, batch: &RecordBatch) -> Result { // Records time on drop let _timer = self.baseline_metrics.elapsed_compute().timer(); - let arrays = evaluate_expressions_to_arrays(&self.expr, batch)?; - - if arrays.is_empty() { - let options = - RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - RecordBatch::try_new_with_options(Arc::clone(&self.schema), arrays, &options) - .map_err(Into::into) - } else { - RecordBatch::try_new(Arc::clone(&self.schema), arrays).map_err(Into::into) - } + self.projector.project_batch(batch) } } /// Projection iterator struct ProjectionStream { - schema: SchemaRef, - expr: Vec>, + projector: Projector, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, } @@ -403,7 +400,7 @@ impl Stream for ProjectionStream { impl RecordBatchStream for ProjectionStream { /// Get the schema fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) + Arc::clone(self.projector.output_schema()) } }