Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 175 additions & 3 deletions datafusion/physical-expr/src/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ use crate::expressions::Column;
use crate::utils::collect_columns;
use crate::PhysicalExpr;

use arrow::array::{RecordBatch, RecordBatchOptions};
use arrow::datatypes::{Field, Schema, SchemaRef};
use datafusion_common::stats::{ColumnStatistics, Precision};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result};

use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
use indexmap::IndexMap;
use itertools::Itertools;

Expand All @@ -47,6 +49,15 @@ pub struct ProjectionExpr {
pub alias: String,
}

impl PartialEq for ProjectionExpr {
fn eq(&self, other: &Self) -> bool {
let ProjectionExpr { expr, alias } = self;
expr.eq(&other.expr) && *alias == other.alias
}
}

impl Eq for ProjectionExpr {}

impl std::fmt::Display for ProjectionExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.expr.to_string() == self.alias {
Expand Down Expand Up @@ -99,7 +110,7 @@ impl From<ProjectionExpr> for (Arc<dyn PhysicalExpr>, String) {
/// This struct encapsulates multiple `ProjectionExpr` instances,
/// representing a complete projection operation and provides
/// methods to manipulate and analyze the projection as a whole.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProjectionExprs {
exprs: Vec<ProjectionExpr>,
}
Expand Down Expand Up @@ -192,7 +203,7 @@ impl ProjectionExprs {
/// assert_eq!(projection_with_dups.as_ref()[1].alias, "a"); // duplicate
/// assert_eq!(projection_with_dups.as_ref()[2].alias, "b");
/// ```
pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Self {
pub fn from_indices(indices: &[usize], schema: &Schema) -> Self {
let projection_exprs = indices.iter().map(|&i| {
let field = schema.field(i);
ProjectionExpr {
Expand Down Expand Up @@ -396,6 +407,22 @@ impl ProjectionExprs {
))
}

/// Create a new [`Projector`] from this projection and an input schema.
///
/// A [`Projector`] can be used to apply this projection to record batches.
///
/// # Errors
/// This function returns an error if the output schema cannot be constructed from the input schema
/// with the given projection expressions.
/// For example, if an expression only works with integer columns but the input schema has a string column at that index.
pub fn make_projector(&self, input_schema: &Schema) -> Result<Projector> {
let output_schema = Arc::new(self.project_schema(input_schema)?);
Ok(Projector {
projection: self.clone(),
output_schema,
})
}

/// Project statistics according to this projection.
/// For example, for a projection `SELECT a AS x, b + 1 AS y`, where `a` is at index 0 and `b` is at index 1,
/// if the input statistics has column statistics for columns `a`, `b`, and `c`, the output statistics would have column statistics for columns `x` and `y`.
Expand Down Expand Up @@ -444,6 +471,57 @@ impl<'a> IntoIterator for &'a ProjectionExprs {
}
}

/// Applies a projection to record batches.
///
/// A [`Projector`] uses a set of projection expressions to transform
/// and a pre-computed output schema to project record batches accordingly.
///
/// The main reason to use a `Projector` is to avoid repeatedly computing
/// the output schema for each batch, which can be costly if the projection
/// expressions are complex.
#[derive(Clone, Debug)]
pub struct Projector {
projection: ProjectionExprs,
output_schema: SchemaRef,
}

impl Projector {
/// Project a record batch according to this projector's expressions.
///
/// # Errors
/// This function returns an error if any expression evaluation fails
/// or if the output schema of the resulting record batch does not match
/// the pre-computed output schema of the projector.
pub fn project_batch(&self, batch: &RecordBatch) -> Result<RecordBatch> {
let arrays = evaluate_expressions_to_arrays(
self.projection.exprs.iter().map(|p| &p.expr),
batch,
)?;

if arrays.is_empty() {
let options =
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
RecordBatch::try_new_with_options(
Arc::clone(&self.output_schema),
arrays,
&options,
)
.map_err(Into::into)
} else {
RecordBatch::try_new(Arc::clone(&self.output_schema), arrays)
.map_err(Into::into)
}
}

pub fn output_schema(&self) -> &SchemaRef {
&self.output_schema
}

pub fn projection(&self) -> &ProjectionExprs {
&self.projection
}
}

impl IntoIterator for ProjectionExprs {
type Item = ProjectionExpr;
type IntoIter = std::vec::IntoIter<ProjectionExpr>;
Expand Down Expand Up @@ -545,7 +623,13 @@ pub fn update_expr(
})
.data()?;

Ok((state == RewriteState::RewrittenValid).then_some(new_expr))
match state {
RewriteState::RewrittenInvalid => Ok(None),
// Both Unchanged and RewrittenValid are valid:
// - Unchanged means no columns to rewrite (e.g., literals)
// - RewrittenValid means columns were successfully rewritten
RewriteState::Unchanged | RewriteState::RewrittenValid => Ok(Some(new_expr)),
}
}

/// Stores target expressions, along with their indices, that associate with a
Expand Down Expand Up @@ -2009,6 +2093,94 @@ pub(crate) mod tests {
);
}

#[test]
fn test_merge_empty_projection_with_literal() -> Result<()> {
// This test reproduces the issue from roundtrip_empty_projection test
// Query like: SELECT 1 FROM table
// where the file scan needs no columns (empty projection)
// but we project a literal on top

// Empty base projection (no columns needed from file)
let base_projection = ProjectionExprs::new(vec![]);

// Top projection with a literal expression: SELECT 1
let top_projection = ProjectionExprs::new(vec![ProjectionExpr {
expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
alias: "Int64(1)".to_string(),
}]);

// This should succeed - literals don't reference columns so they should
// pass through unchanged when merged with an empty projection
let merged = base_projection.try_merge(&top_projection)?;
assert_snapshot!(format!("{merged}"), @"Projection[1 AS Int64(1)]");

Ok(())
}

#[test]
fn test_update_expr_with_literal() -> Result<()> {
// Test that update_expr correctly handles expressions without column references
let literal_expr: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new(ScalarValue::Int64(Some(42))));
let empty_projection: Vec<ProjectionExpr> = vec![];

// Updating a literal with an empty projection should return the literal unchanged
let result = update_expr(&literal_expr, &empty_projection, true)?;
assert!(result.is_some(), "Literal expression should be valid");

let result_expr = result.unwrap();
assert_eq!(
result_expr
.as_any()
.downcast_ref::<Literal>()
.unwrap()
.value(),
&ScalarValue::Int64(Some(42))
);

Ok(())
}

#[test]
fn test_update_expr_with_complex_literal_expr() -> Result<()> {
// Test update_expr with an expression containing both literals and a column
// This tests the case where we have: literal + column
let expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
Arc::new(Literal::new(ScalarValue::Int64(Some(10)))),
Operator::Plus,
Arc::new(Column::new("x", 0)),
));

// Base projection that maps column 0 to a different expression
let base_projection = vec![ProjectionExpr {
expr: Arc::new(Column::new("a", 5)),
alias: "x".to_string(),
}];

// The expression should be updated: 10 + x@0 becomes 10 + a@5
let result = update_expr(&expr, &base_projection, true)?;
assert!(result.is_some(), "Expression should be valid");

let result_expr = result.unwrap();
let binary = result_expr
.as_any()
.downcast_ref::<BinaryExpr>()
.expect("Should be a BinaryExpr");

// Left side should still be the literal
assert!(binary.left().as_any().downcast_ref::<Literal>().is_some());

// Right side should be updated to reference column at index 5
let right_col = binary
.right()
.as_any()
.downcast_ref::<Column>()
.expect("Right should be a Column");
assert_eq!(right_col.index(), 5);

Ok(())
}

#[test]
fn test_project_schema_simple_columns() -> Result<()> {
// Input schema: [col0: Int64, col1: Utf8, col2: Float32]
Expand Down
Loading