diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index 8dee79ad61b23..55db0d854204d 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -39,7 +39,7 @@ mod tests { use crate::test::object_store::local_unpartitioned_file; use arrow::array::{ ArrayRef, AsArray, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, - StringViewArray, StructArray, + StringViewArray, StructArray, TimestampNanosecondArray, }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; @@ -960,6 +960,72 @@ mod tests { assert_eq!(read, 2, "Expected 2 rows to match the predicate"); } + #[tokio::test] + async fn evolved_schema_column_type_filter_timestamp_units() { + // The table and filter have a common data type + // The table schema is in milliseconds, but the file schema is in nanoseconds + let c1: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![ + Some(1_000_000_000), // 1970-01-01T00:00:01Z + Some(2_000_000_000), // 1970-01-01T00:00:02Z + Some(3_000_000_000), // 1970-01-01T00:00:03Z + Some(4_000_000_000), // 1970-01-01T00:00:04Z + ])); + let batch = create_batch(vec![("c1", c1.clone())]); + let table_schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Timestamp(TimeUnit::Millisecond, Some("UTC".into())), + false, + )])); + // One row should match, 2 pruned via page index, 1 pruned via filter pushdown + let filter = col("c1").eq(lit(ScalarValue::TimestampMillisecond( + Some(1_000), + Some("UTC".into()), + ))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .with_page_index_predicate() // produces pages with 2 rows each (2 pages total for our data) + .with_table_schema(table_schema.clone()) + .round_trip(vec![batch.clone()]) + .await; + // There should be no predicate evaluation errors and we keep 1 row + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 1, "Expected 1 rows to match the predicate"); + assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 0); + assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 2); + assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 1); + // If we filter with a value that is completely out of the range of the data + // we prune at the row group level. + let filter = col("c1").eq(lit(ScalarValue::TimestampMillisecond( + Some(5_000), + Some("UTC".into()), + ))); + let rt = RoundTrip::new() + .with_predicate(filter) + .with_pushdown_predicate() + .with_table_schema(table_schema) + .round_trip(vec![batch]) + .await; + // There should be no predicate evaluation errors and we keep 0 rows + let metrics = rt.parquet_exec.metrics().unwrap(); + assert_eq!(get_value(&metrics, "predicate_evaluation_errors"), 0); + let read = rt + .batches + .unwrap() + .iter() + .map(|b| b.num_rows()) + .sum::(); + assert_eq!(read, 0, "Expected 0 rows to match the predicate"); + assert_eq!(get_value(&metrics, "row_groups_pruned_statistics"), 1); + } + #[tokio::test] async fn evolved_schema_disjoint_schema_filter() { let c1: ArrayRef = diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 285044803d73c..69ea7a4b7896a 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -37,6 +37,7 @@ use datafusion_common::pruning::{ }; use datafusion_common::{exec_err, Result}; use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr::PhysicalExprSchemaRewriter; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::pruning::PruningPredicate; use datafusion_physical_plan::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder}; @@ -117,7 +118,6 @@ impl FileOpener for ParquetOpener { let projected_schema = SchemaRef::from(self.logical_file_schema.project(&self.projection)?); - let schema_adapter_factory = Arc::clone(&self.schema_adapter_factory); let schema_adapter = self .schema_adapter_factory .create(projected_schema, Arc::clone(&self.logical_file_schema)); @@ -159,7 +159,7 @@ impl FileOpener for ParquetOpener { if let Some(pruning_predicate) = pruning_predicate { // The partition column schema is the schema of the table - the schema of the file let mut pruning = Box::new(PartitionPruningStatistics::try_new( - vec![file.partition_values], + vec![file.partition_values.clone()], partition_fields.clone(), )?) as Box; @@ -248,10 +248,27 @@ impl FileOpener for ParquetOpener { } } + // Adapt the predicate to the physical file schema. + // This evaluates missing columns and inserts any necessary casts. + let predicate = predicate + .map(|p| { + PhysicalExprSchemaRewriter::new( + &physical_file_schema, + &logical_file_schema, + ) + .with_partition_columns( + partition_fields.to_vec(), + file.partition_values, + ) + .rewrite(p) + .map_err(ArrowError::from) + }) + .transpose()?; + // Build predicates for this specific file let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( predicate.as_ref(), - &logical_file_schema, + &physical_file_schema, &predicate_creation_errors, ); @@ -288,11 +305,9 @@ impl FileOpener for ParquetOpener { let row_filter = row_filter::build_row_filter( &predicate, &physical_file_schema, - &logical_file_schema, builder.metadata(), reorder_predicates, &file_metrics, - &schema_adapter_factory, ); match row_filter { @@ -879,4 +894,115 @@ mod test { assert_eq!(num_batches, 0); assert_eq!(num_rows, 0); } + + #[tokio::test] + async fn test_prune_on_partition_value_and_data_value() { + let store = Arc::new(InMemory::new()) as Arc; + + // Note: number 3 is missing! + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(4)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: true, // note that this is true! + reorder_filters: true, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: false, // note that this is false! + coerce_int96: None, + } + }; + + let make_meta = || FileMeta { + object_meta: ObjectMeta { + location: Path::from("part=1/file.parquet"), + last_modified: Utc::now(), + size: u64::try_from(data_size).unwrap(), + e_tag: None, + version: None, + }, + range: None, + extensions: None, + metadata_size_hint: None, + }; + + // Filter should match the partition value and data value + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should match the partition value but not the data value + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); + + // Filter should not match the partition value but match the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 1); + + // Filter should not match the partition value or the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(make_meta(), file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } } diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index db455fed61606..5626f83186e31 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -67,6 +67,7 @@ use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; +use itertools::Itertools; use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::arrow::ProjectionMask; use parquet::file::metadata::ParquetMetaData; @@ -74,9 +75,8 @@ use parquet::file::metadata::ParquetMetaData; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::Result; -use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::reassign_predicate_columns; +use datafusion_physical_expr::utils::{collect_columns, reassign_predicate_columns}; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use datafusion_physical_plan::metrics; @@ -106,8 +106,6 @@ pub(crate) struct DatafusionArrowPredicate { rows_matched: metrics::Count, /// how long was spent evaluating this predicate time: metrics::Time, - /// used to perform type coercion while filtering rows - schema_mapper: Arc, } impl DatafusionArrowPredicate { @@ -132,7 +130,6 @@ impl DatafusionArrowPredicate { rows_pruned, rows_matched, time, - schema_mapper: candidate.schema_mapper, }) } } @@ -143,8 +140,6 @@ impl ArrowPredicate for DatafusionArrowPredicate { } fn evaluate(&mut self, batch: RecordBatch) -> ArrowResult { - let batch = self.schema_mapper.map_batch(batch)?; - // scoped timer updates on drop let mut timer = self.time.timer(); @@ -187,9 +182,6 @@ pub(crate) struct FilterCandidate { /// required to pass thorugh a `SchemaMapper` to the table schema /// upon which we then evaluate the filter expression. projection: Vec, - /// A `SchemaMapper` used to map batches read from the file schema to - /// the filter's projection of the table schema. - schema_mapper: Arc, /// The projected table schema that this filter references filter_schema: SchemaRef, } @@ -230,26 +222,11 @@ struct FilterCandidateBuilder { /// columns in the file schema that are not in the table schema or columns that /// are in the table schema that are not in the file schema. file_schema: SchemaRef, - /// The schema of the table (merged schema) -- columns may be in different - /// order than in the file and have columns that are not in the file schema - table_schema: SchemaRef, - /// A `SchemaAdapterFactory` used to map the file schema to the table schema. - schema_adapter_factory: Arc, } impl FilterCandidateBuilder { - pub fn new( - expr: Arc, - file_schema: Arc, - table_schema: Arc, - schema_adapter_factory: Arc, - ) -> Self { - Self { - expr, - file_schema, - table_schema, - schema_adapter_factory, - } + pub fn new(expr: Arc, file_schema: Arc) -> Self { + Self { expr, file_schema } } /// Attempt to build a `FilterCandidate` from the expression @@ -261,20 +238,21 @@ impl FilterCandidateBuilder { /// * `Err(e)` if an error occurs while building the candidate pub fn build(self, metadata: &ParquetMetaData) -> Result> { let Some(required_indices_into_table_schema) = - pushdown_columns(&self.expr, &self.table_schema)? + pushdown_columns(&self.expr, &self.file_schema)? else { return Ok(None); }; let projected_table_schema = Arc::new( - self.table_schema + self.file_schema .project(&required_indices_into_table_schema)?, ); - let (schema_mapper, projection_into_file_schema) = self - .schema_adapter_factory - .create(Arc::clone(&projected_table_schema), self.table_schema) - .map_schema(&self.file_schema)?; + let projection_into_file_schema = collect_columns(&self.expr) + .iter() + .map(|c| c.index()) + .sorted_unstable() + .collect_vec(); let required_bytes = size_of_columns(&projection_into_file_schema, metadata)?; let can_use_index = columns_sorted(&projection_into_file_schema, metadata)?; @@ -284,7 +262,6 @@ impl FilterCandidateBuilder { required_bytes, can_use_index, projection: projection_into_file_schema, - schema_mapper: Arc::clone(&schema_mapper), filter_schema: Arc::clone(&projected_table_schema), })) } @@ -426,11 +403,9 @@ fn columns_sorted(_columns: &[usize], _metadata: &ParquetMetaData) -> Result, physical_file_schema: &SchemaRef, - logical_file_schema: &SchemaRef, metadata: &ParquetMetaData, reorder_predicates: bool, file_metrics: &ParquetFileMetrics, - schema_adapter_factory: &Arc, ) -> Result> { let rows_pruned = &file_metrics.pushdown_rows_pruned; let rows_matched = &file_metrics.pushdown_rows_matched; @@ -447,8 +422,6 @@ pub fn build_row_filter( FilterCandidateBuilder::new( Arc::clone(expr), Arc::clone(physical_file_schema), - Arc::clone(logical_file_schema), - Arc::clone(schema_adapter_factory), ) .build(metadata) }) @@ -492,13 +465,9 @@ mod test { use super::*; use datafusion_common::ScalarValue; - use arrow::datatypes::{Field, TimeUnit::Nanosecond}; - use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion_expr::{col, Expr}; use datafusion_physical_expr::planner::logical2physical; - use datafusion_physical_plan::metrics::{Count, Time}; - use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use parquet::arrow::parquet_to_arrow_schema; use parquet::file::reader::{FileReader, SerializedFileReader}; @@ -520,111 +489,15 @@ mod test { let expr = col("int64_list").is_not_null(); let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); let table_schema = Arc::new(table_schema.clone()); - let candidate = FilterCandidateBuilder::new( - expr, - table_schema.clone(), - table_schema, - schema_adapter_factory, - ) - .build(metadata) - .expect("building candidate"); + let candidate = FilterCandidateBuilder::new(expr, table_schema.clone()) + .build(metadata) + .expect("building candidate"); assert!(candidate.is_none()); } - #[test] - fn test_filter_type_coercion() { - let testdata = datafusion_common::test_util::parquet_test_data(); - let file = std::fs::File::open(format!("{testdata}/alltypes_plain.parquet")) - .expect("opening file"); - - let parquet_reader_builder = - ParquetRecordBatchReaderBuilder::try_new(file).expect("creating reader"); - let metadata = parquet_reader_builder.metadata().clone(); - let file_schema = parquet_reader_builder.schema().clone(); - - // This is the schema we would like to coerce to, - // which is different from the physical schema of the file. - let table_schema = Schema::new(vec![Field::new( - "timestamp_col", - DataType::Timestamp(Nanosecond, Some(Arc::from("UTC"))), - false, - )]); - - // Test all should fail - let expr = col("timestamp_col").lt(Expr::Literal( - ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), - None, - )); - let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); - let table_schema = Arc::new(table_schema.clone()); - let candidate = FilterCandidateBuilder::new( - expr, - file_schema.clone(), - table_schema.clone(), - schema_adapter_factory, - ) - .build(&metadata) - .expect("building candidate") - .expect("candidate expected"); - - let mut row_filter = DatafusionArrowPredicate::try_new( - candidate, - &metadata, - Count::new(), - Count::new(), - Time::new(), - ) - .expect("creating filter predicate"); - - let mut parquet_reader = parquet_reader_builder - .with_projection(row_filter.projection().clone()) - .build() - .expect("building reader"); - - // Parquet file is small, we only need 1 record batch - let first_rb = parquet_reader - .next() - .expect("expected record batch") - .expect("expected error free record batch"); - - let filtered = row_filter.evaluate(first_rb.clone()); - assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![false; 8]))); - - // Test all should pass - let expr = col("timestamp_col").gt(Expr::Literal( - ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), - None, - )); - let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); - let candidate = FilterCandidateBuilder::new( - expr, - file_schema, - table_schema, - schema_adapter_factory, - ) - .build(&metadata) - .expect("building candidate") - .expect("candidate expected"); - - let mut row_filter = DatafusionArrowPredicate::try_new( - candidate, - &metadata, - Count::new(), - Count::new(), - Time::new(), - ) - .expect("creating filter predicate"); - - let filtered = row_filter.evaluate(first_rb); - assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![true; 8]))); - } - #[test] fn nested_data_structures_prevent_pushdown() { let table_schema = Arc::new(get_lists_table_schema()); diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index be60e26cc2d2a..3bdb9d84d8278 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -38,6 +38,7 @@ mod partitioning; mod physical_expr; pub mod planner; mod scalar_function; +pub mod schema_rewriter; pub mod statistics; pub mod utils; pub mod window; @@ -68,6 +69,7 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; +pub use schema_rewriter::PhysicalExprSchemaRewriter; pub use utils::{conjunction, conjunction_opt, split_conjunction}; // For backwards compatibility diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs new file mode 100644 index 0000000000000..b8759ea16d6e8 --- /dev/null +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -0,0 +1,466 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expression schema rewriting utilities + +use std::sync::Arc; + +use arrow::compute::can_cast_types; +use arrow::datatypes::{FieldRef, Schema}; +use datafusion_common::{ + exec_err, + tree_node::{Transformed, TransformedResult, TreeNode}, + Result, ScalarValue, +}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +use crate::expressions::{self, CastExpr, Column}; + +/// Builder for rewriting physical expressions to match different schemas. +/// +/// # Example +/// +/// ```rust +/// use datafusion_physical_expr::schema_rewriter::PhysicalExprSchemaRewriter; +/// use arrow::datatypes::Schema; +/// +/// # fn example( +/// # predicate: std::sync::Arc, +/// # physical_file_schema: &Schema, +/// # logical_file_schema: &Schema, +/// # ) -> datafusion_common::Result<()> { +/// let rewriter = PhysicalExprSchemaRewriter::new(physical_file_schema, logical_file_schema); +/// let adapted_predicate = rewriter.rewrite(predicate)?; +/// # Ok(()) +/// # } +/// ``` +pub struct PhysicalExprSchemaRewriter<'a> { + physical_file_schema: &'a Schema, + logical_file_schema: &'a Schema, + partition_fields: Vec, + partition_values: Vec, +} + +impl<'a> PhysicalExprSchemaRewriter<'a> { + /// Create a new schema rewriter with the given schemas + pub fn new( + physical_file_schema: &'a Schema, + logical_file_schema: &'a Schema, + ) -> Self { + Self { + physical_file_schema, + logical_file_schema, + partition_fields: Vec::new(), + partition_values: Vec::new(), + } + } + + /// Add partition columns and their corresponding values + /// + /// When a column reference matches a partition field, it will be replaced + /// with the corresponding literal value from partition_values. + pub fn with_partition_columns( + mut self, + partition_fields: Vec, + partition_values: Vec, + ) -> Self { + self.partition_fields = partition_fields; + self.partition_values = partition_values; + self + } + + /// Rewrite the given physical expression to match the target schema + /// + /// This method applies the following transformations: + /// 1. Replaces partition column references with literal values + /// 2. Handles missing columns by inserting null literals + /// 3. Casts columns when logical and physical schemas have different types + pub fn rewrite(&self, expr: Arc) -> Result> { + expr.transform(|expr| self.rewrite_expr(expr)).data() + } + + fn rewrite_expr( + &self, + expr: Arc, + ) -> Result>> { + if let Some(column) = expr.as_any().downcast_ref::() { + return self.rewrite_column(Arc::clone(&expr), column); + } + + Ok(Transformed::no(expr)) + } + + fn rewrite_column( + &self, + expr: Arc, + column: &Column, + ) -> Result>> { + // Get the logical field for this column + let logical_field = match self.logical_file_schema.field_with_name(column.name()) + { + Ok(field) => field, + Err(e) => { + // If the column is a partition field, we can use the partition value + if let Some(partition_value) = self.get_partition_value(column.name()) { + return Ok(Transformed::yes(expressions::lit(partition_value))); + } + // If the column is not found in the logical schema and is not a partition value, return an error + // This should probably never be hit unless something upstream broke, but nontheless it's better + // for us to return a handleable error than to panic / do something unexpected. + return Err(e.into()); + } + }; + + // Check if the column exists in the physical schema + let physical_column_index = + match self.physical_file_schema.index_of(column.name()) { + Ok(index) => index, + Err(_) => { + if !logical_field.is_nullable() { + return exec_err!( + "Non-nullable column '{}' is missing from the physical schema", + column.name() + ); + } + // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do. + // TODO: do we need to sync this with what the `SchemaAdapter` actually does? + // While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else! + // See https://github.com/apache/datafusion/issues/16527 + let null_value = + ScalarValue::Null.cast_to(logical_field.data_type())?; + return Ok(Transformed::yes(expressions::lit(null_value))); + } + }; + let physical_field = self.physical_file_schema.field(physical_column_index); + + let column = match ( + column.index() == physical_column_index, + logical_field.data_type() == physical_field.data_type(), + ) { + // If the column index matches and the data types match, we can use the column as is + (true, true) => return Ok(Transformed::no(expr)), + // If the indexes or data types do not match, we need to create a new column expression + (true, _) => column.clone(), + (false, _) => { + Column::new_with_schema(logical_field.name(), self.physical_file_schema)? + } + }; + + if logical_field.data_type() == physical_field.data_type() { + // If the data types match, we can use the column as is + return Ok(Transformed::yes(Arc::new(column))); + } + + // We need to cast the column to the logical data type + // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` + // since that's much cheaper to evalaute. + // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 + if !can_cast_types(physical_field.data_type(), logical_field.data_type()) { + return exec_err!( + "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", + column.name(), + physical_field.data_type(), + logical_field.data_type() + ); + } + + let cast_expr = Arc::new(CastExpr::new( + Arc::new(column), + logical_field.data_type().clone(), + None, + )); + + Ok(Transformed::yes(cast_expr)) + } + + fn get_partition_value(&self, column_name: &str) -> Option { + self.partition_fields + .iter() + .zip(self.partition_values.iter()) + .find(|(field, _)| field.name() == column_name) + .map(|(_, value)| value.clone()) + } +} + +#[cfg(test)] +mod tests { + use crate::expressions::{col, lit}; + + use super::*; + use arrow::{ + array::{RecordBatch, RecordBatchOptions}, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + use datafusion_common::{record_batch, ScalarValue}; + use datafusion_expr::Operator; + use itertools::Itertools; + use std::sync::Arc; + + fn create_test_schema() -> (Schema, Schema) { + let physical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Different type + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), // Missing from physical + ]); + + (physical_schema, logical_schema) + } + + #[test] + fn test_rewrite_column_with_type_cast() { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("a", 0)); + + let result = rewriter.rewrite(column_expr).unwrap(); + + // Should be wrapped in a cast expression + assert!(result.as_any().downcast_ref::().is_some()); + } + + #[test] + fn test_rewrite_mulit_column_expr_with_type_cast() { + let (physical_schema, logical_schema) = create_test_schema(); + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter + let column_a = Arc::new(Column::new("a", 0)) as Arc; + let column_c = Arc::new(Column::new("c", 2)) as Arc; + let expr = expressions::BinaryExpr::new( + Arc::clone(&column_a), + Operator::Plus, + Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))), + ); + let expr = expressions::BinaryExpr::new( + Arc::new(expr), + Operator::Or, + Arc::new(expressions::BinaryExpr::new( + Arc::clone(&column_c), + Operator::Gt, + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))), + )), + ); + + let result = rewriter.rewrite(Arc::new(expr)).unwrap(); + println!("Rewritten expression: {result}"); + + let expected = expressions::BinaryExpr::new( + Arc::new(CastExpr::new( + Arc::new(Column::new("a", 0)), + DataType::Int64, + None, + )), + Operator::Plus, + Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))), + ); + let expected = Arc::new(expressions::BinaryExpr::new( + Arc::new(expected), + Operator::Or, + Arc::new(expressions::BinaryExpr::new( + lit(ScalarValue::Null), + Operator::Gt, + Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.0)))), + )), + )) as Arc; + + assert_eq!( + result.to_string(), + expected.to_string(), + "The rewritten expression did not match the expected output" + ); + } + + #[test] + fn test_rewrite_missing_column() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("c", 2)); + + let result = rewriter.rewrite(column_expr)?; + + // Should be replaced with a literal null + if let Some(literal) = result.as_any().downcast_ref::() { + assert_eq!(*literal.value(), ScalarValue::Float64(None)); + } else { + panic!("Expected literal expression"); + } + + Ok(()) + } + + #[test] + fn test_rewrite_partition_column() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let partition_fields = + vec![Arc::new(Field::new("partition_col", DataType::Utf8, false))]; + let partition_values = vec![ScalarValue::Utf8(Some("test_value".to_string()))]; + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema) + .with_partition_columns(partition_fields, partition_values); + + let column_expr = Arc::new(Column::new("partition_col", 0)); + let result = rewriter.rewrite(column_expr)?; + + // Should be replaced with the partition value + if let Some(literal) = result.as_any().downcast_ref::() { + assert_eq!( + *literal.value(), + ScalarValue::Utf8(Some("test_value".to_string())) + ); + } else { + panic!("Expected literal expression"); + } + + Ok(()) + } + + #[test] + fn test_rewrite_no_change_needed() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("b", 1)) as Arc; + + let result = rewriter.rewrite(Arc::clone(&column_expr))?; + + // Should be the same expression (no transformation needed) + // We compare the underlying pointer through the trait object + assert!(std::ptr::eq( + column_expr.as_ref() as *const dyn PhysicalExpr, + result.as_ref() as *const dyn PhysicalExpr + )); + + Ok(()) + } + + #[test] + fn test_non_nullable_missing_column_error() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), // Non-nullable missing column + ]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("b", 1)); + + let result = rewriter.rewrite(column_expr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Non-nullable column 'b' is missing")); + } + + /// Roughly stolen from ProjectionExec + fn batch_project( + expr: Vec>, + batch: &RecordBatch, + schema: SchemaRef, + ) -> Result { + let arrays = expr + .iter() + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect::>>()?; + + if arrays.is_empty() { + let options = + RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options(Arc::clone(&schema), arrays, &options) + .map_err(Into::into) + } else { + RecordBatch::try_new(Arc::clone(&schema), arrays).map_err(Into::into) + } + } + + /// Example showing how we can use the `PhysicalExprSchemaRewriter` to adapt RecordBatches during a scan + /// to apply projections, type conversions and handling of missing columns all at once. + #[test] + fn test_adapt_batches() { + let physical_batch = record_batch!( + ("a", Int32, vec![Some(1), None, Some(3)]), + ("extra", Utf8, vec![Some("x"), Some("y"), None]) + ) + .unwrap(); + + let physical_schema = physical_batch.schema(); + + let logical_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), // Different type + Field::new("b", DataType::Utf8, true), // Missing from physical + ])); + + let projection = vec![ + col("b", &logical_schema).unwrap(), + col("a", &logical_schema).unwrap(), + ]; + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + let adapted_projection = projection + .into_iter() + .map(|expr| rewriter.rewrite(expr).unwrap()) + .collect_vec(); + + let adapted_schema = Arc::new(Schema::new( + adapted_projection + .iter() + .map(|expr| expr.return_field(&physical_schema).unwrap()) + .collect_vec(), + )); + + let res = batch_project( + adapted_projection, + &physical_batch, + Arc::clone(&adapted_schema), + ) + .unwrap(); + + assert_eq!(res.num_columns(), 2); + assert_eq!(res.column(0).data_type(), &DataType::Utf8); + assert_eq!(res.column(1).data_type(), &DataType::Int64); + assert_eq!( + res.column(0) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect_vec(), + vec![None, None, None] + ); + assert_eq!( + res.column(1) + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .collect_vec(), + vec![Some(1), None, Some(3)] + ); + } +}