From 5912c74bd291a155c8ee617ac23e48dbe6edaad5 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 26 Jun 2025 13:10:23 -0500 Subject: [PATCH 01/21] wip --- Cargo.lock | 1 + datafusion-examples/Cargo.toml | 1 + datafusion/datasource-parquet/src/opener.rs | 25 ++++++++--- .../datasource-parquet/src/row_filter.rs | 2 + datafusion/datasource-parquet/src/source.rs | 14 ++++++ .../physical-expr/src/schema_rewriter.rs | 44 ++++++++++++++++++- 6 files changed, 81 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b61a08a470ddf..ddea7e14b29fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2195,6 +2195,7 @@ dependencies = [ "nix", "object_store", "prost", + "serde_json", "tempfile", "test-utils", "tokio", diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index b31708a5c1cc7..66d15f16d4187 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -70,6 +70,7 @@ log = { workspace = true } mimalloc = { version = "0.1", default-features = false } object_store = { workspace = true, features = ["aws", "http"] } prost = { workspace = true } +serde_json = { workspace = true } tempfile = { workspace = true } test-utils = { path = "../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index b39ec3929f978..f48ebb71f747a 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -33,6 +33,7 @@ use arrow::datatypes::{FieldRef, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr::schema_rewriter::PhysicalExprSchemaRewriteHook; use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; use datafusion_physical_expr::PhysicalExprSchemaRewriter; use datafusion_physical_expr_common::physical_expr::{ @@ -92,6 +93,8 @@ pub(super) struct ParquetOpener { pub coerce_int96: Option, /// Optional parquet FileDecryptionProperties pub file_decryption_properties: Option>, + /// Rewrite expressions in the context of the file schema + pub predicate_rewrite_hook: Option>, } impl FileOpener for ParquetOpener { @@ -132,6 +135,8 @@ impl FileOpener for ParquetOpener { let predicate_creation_errors = MetricBuilder::new(&self.metrics) .global_counter("num_predicate_creation_errors"); + let predicate_rewrite_hook = self.predicate_rewrite_hook.clone(); + let mut enable_page_index = self.enable_page_index; let file_decryption_properties = self.file_decryption_properties.clone(); @@ -237,17 +242,20 @@ impl FileOpener for ParquetOpener { // This evaluates missing columns and inserts any necessary casts. let predicate = predicate .map(|p| { - PhysicalExprSchemaRewriter::new( + let mut rewriter = PhysicalExprSchemaRewriter::new( &physical_file_schema, &logical_file_schema, ) .with_partition_columns( partition_fields.to_vec(), file.partition_values, - ) - .rewrite(p) - .map_err(ArrowError::from) - .map(|p| { + ); + if let Some(predicate_rewrite_hook) = predicate_rewrite_hook.as_ref() + { + rewriter = rewriter + .with_rewrite_hook(Arc::clone(predicate_rewrite_hook)); + }; + rewriter.rewrite(p).map_err(ArrowError::from).map(|p| { // After rewriting to the file schema, further simplifications may be possible. // For example, if `'a' = col_that_is_missing` becomes `'a' = NULL` that can then be simplified to `FALSE` // and we can avoid doing any more work on the file (bloom filters, loading the page index, etc.). @@ -259,6 +267,8 @@ impl FileOpener for ParquetOpener { .transpose()? .transpose()?; + println!("predicate: {predicate:?}"); + // Build predicates for this specific file let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( predicate.as_ref(), @@ -631,6 +641,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, + predicate_rewrite_hook: None, } }; @@ -716,6 +727,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, + predicate_rewrite_hook: None, } }; @@ -817,6 +829,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, + predicate_rewrite_hook: None, } }; let make_meta = || FileMeta { @@ -928,6 +941,7 @@ mod test { enable_row_group_stats_pruning: false, // note that this is false! coerce_int96: None, file_decryption_properties: None, + predicate_rewrite_hook: None, } }; @@ -1040,6 +1054,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, + predicate_rewrite_hook: None, } }; diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 5626f83186e31..816d15fcbc007 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -430,6 +430,8 @@ pub fn build_row_filter( .flatten() .collect(); + println!("Filter candidates: {}", candidates.len()); + // no candidates if candidates.is_empty() { return Ok(None); diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 8ca36e7cd3216..d7dc8b9f3cf13 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -39,6 +39,7 @@ use datafusion_common::{DataFusionError, Statistics}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_physical_expr::conjunction; +use datafusion_physical_expr::schema_rewriter::PhysicalExprSchemaRewriteHook; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::filter_pushdown::{ @@ -278,6 +279,7 @@ pub struct ParquetSource { /// Optional hint for the size of the parquet metadata pub(crate) metadata_size_hint: Option, pub(crate) projected_statistics: Option, + pub(crate) predicate_rewrite_hook: Option>, } impl ParquetSource { @@ -316,6 +318,17 @@ impl ParquetSource { conf } + /// Register a predicate rewrite hook to transform predicates in the context of each file's physical file schema. + /// This can be used to optimize predicates to take advantage of shredded variant columns or pre-computed expressions + /// that vary on a per-file basis. + pub fn with_predicate_rewrite_hook( + mut self, + predicate_rewrite_hook: Arc, + ) -> Self { + self.predicate_rewrite_hook = Some(predicate_rewrite_hook); + self + } + /// Options passed to the parquet reader for this scan pub fn table_parquet_options(&self) -> &TableParquetOptions { &self.table_parquet_options @@ -509,6 +522,7 @@ impl FileSource for ParquetSource { schema_adapter_factory, coerce_int96, file_decryption_properties, + predicate_rewrite_hook: self.predicate_rewrite_hook.clone(), }) } diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index b8759ea16d6e8..70873e78031d6 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::compute::can_cast_types; use arrow::datatypes::{FieldRef, Schema}; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{ exec_err, tree_node::{Transformed, TransformedResult, TreeNode}, @@ -29,6 +30,16 @@ use datafusion_common::{ use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use crate::expressions::{self, CastExpr, Column}; +pub trait PhysicalExprSchemaRewriteHook: Send + Sync + std::fmt::Debug { + /// Rewrite a physical expression to match the target schema + /// + /// This method should return a transformed expression that matches the target schema. + fn rewrite( + &self, + expr: Arc, + physical_file_schema: &Schema, + ) -> Result>>; +} /// Builder for rewriting physical expressions to match different schemas. /// @@ -53,6 +64,7 @@ pub struct PhysicalExprSchemaRewriter<'a> { logical_file_schema: &'a Schema, partition_fields: Vec, partition_values: Vec, + rewrite_hook: Option>, } impl<'a> PhysicalExprSchemaRewriter<'a> { @@ -66,6 +78,7 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { logical_file_schema, partition_fields: Vec::new(), partition_values: Vec::new(), + rewrite_hook: None, } } @@ -83,6 +96,15 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { self } + /// Add a hook to intercept expression rewrites + pub fn with_rewrite_hook( + mut self, + hook: Arc, + ) -> Self { + self.rewrite_hook = Some(hook); + self + } + /// Rewrite the given physical expression to match the target schema /// /// This method applies the following transformations: @@ -90,7 +112,27 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { /// 2. Handles missing columns by inserting null literals /// 3. Casts columns when logical and physical schemas have different types pub fn rewrite(&self, expr: Arc) -> Result> { - expr.transform(|expr| self.rewrite_expr(expr)).data() + println!("Top level expression: {expr}"); + expr.transform(|expr| { + println!("Rewriting expression: {expr}"); + let transformed = if let Some(rewriter) = self.rewrite_hook.as_ref() { + // If a rewrite hook is provided, apply it first + let transformed = + rewriter.rewrite(expr.clone(), &self.physical_file_schema)?; + Ok(transformed) + // if transformed.tnr == TreeNodeRecursion::Stop { + // // If the hook indicates no further recursion, return the transformed expression + // return Ok(transformed); + // } else { + // transformed.transform_parent(|expr| self.rewrite_expr(expr)) + // } + } else { + // Otherwise, rewrite the expression directly + self.rewrite_expr(expr.clone()) + }?; + Ok(transformed) + }) + .data() } fn rewrite_expr( From 99f57cc00c4a218c8ff8281b28a6d1c5e9d4bcc7 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 26 Jun 2025 13:10:32 -0500 Subject: [PATCH 02/21] wip --- .../examples/variant_shredding.rs | 470 ++++++++++++++++++ 1 file changed, 470 insertions(+) create mode 100644 datafusion-examples/examples/variant_shredding.rs diff --git a/datafusion-examples/examples/variant_shredding.rs b/datafusion-examples/examples/variant_shredding.rs new file mode 100644 index 0000000000000..ddca86af1031c --- /dev/null +++ b/datafusion-examples/examples/variant_shredding.rs @@ -0,0 +1,470 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + as_string_array, ArrayRef, Int32Array, RecordBatch, StringArray, StructArray, +}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::Fields; +use async_trait::async_trait; + +use datafusion::assert_batches_eq; +use datafusion::catalog::memory::DataSourceExec; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter +}; +use datafusion::common::{assert_contains, DFSchema, Result}; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::parquet::source; +use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::functions::core::getfield::GetFieldFunc; +use datafusion::logical_expr::utils::conjunction; +use datafusion::logical_expr::{ + ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableProviderFilterPushDown, TableType, Volatility +}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::file::properties::WriterProperties; +use datafusion::physical_expr::schema_rewriter::PhysicalExprSchemaRewriteHook; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::{lit, SessionConfig}; +use datafusion::scalar::ScalarValue; +use futures::StreamExt; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, PutPayload}; + +// Example showing how to implement custom filter rewriting for struct fields. +// +// In this example, we have a table with a struct column like: +// struct_col: {"a": 1, "b": "foo"} +// +// Our custom TableProvider will use a FilterExpressionRewriter to rewrite +// expressions like `struct_col['a'] = 10` to use a flattened column name +// `_struct_col.a` if it exists in the file schema. +#[tokio::main] +async fn main() -> Result<()> { + println!("=== Creating example data with structs and flattened fields ==="); + + // Create sample data with both struct columns and flattened fields + let (table_schema, batch) = create_sample_data(); + + let store = InMemory::new(); + let buf = { + let mut buf = vec![]; + + let props = WriterProperties::builder() + .set_max_row_group_size(1) + .build(); + + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)) + .expect("creating writer"); + + writer.write(&batch).expect("Writing batch"); + writer.close().unwrap(); + buf + }; + let path = Path::from("example.parquet"); + let payload = PutPayload::from_bytes(buf.into()); + store.put(&path, payload).await?; + + // Create a custom table provider that rewrites struct field access + let table_provider = Arc::new(ExampleTableProvider::new(table_schema)); + + // Set up query execution + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + + // Register our table + ctx.register_table("structs", table_provider)?; + ctx.register_udf(ScalarUDF::new_from_impl(JsonGetStr::default())); + + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::new(store), + ); + + // println!("\n=== Showing all data ==="); + // let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?; + // arrow::util::pretty::print_batches(&batches)?; + + println!("\n=== Running query with struct field access and filter < 30 ==="); + let query = "SELECT count(*) FROM structs WHERE json_get_str('name', user_info) = 'Bob'"; + println!("Query: {query}"); + + let batches = ctx + .sql(query) + .await? + .collect() + .await?; + + #[rustfmt::skip] + let expected = [ + "+-------------------------+", + "| structs.user_info[name] |", + "+-------------------------+", + "| Bob |", + "| Dave |", + "+-------------------------+", + ]; + arrow::util::pretty::print_batches(&batches)?; + println!("batches: {batches:?}"); + assert_batches_eq!(expected, &batches); + + // println!("\n=== Running explain analyze to confirm row group pruning ==="); + + // let batches = ctx + // .sql("EXPLAIN ANALYZE SELECT user_info['name'] FROM structs WHERE user_info['age'] < 30") + // .await? + // .collect() + // .await?; + // let plan = format!("{}", arrow::util::pretty::pretty_format_batches(&batches)?); + // println!("{plan}"); + // assert_contains!(&plan, "row_groups_pruned_statistics=2"); + + Ok(()) +} + +/// Create the example data that has a struct column with `name` as a shredded field and `age` as a non-shredded field. +fn create_sample_data() -> (SchemaRef, RecordBatch) { + // The table schema doesn't have any shredded fields + let struct_fields = Fields::from(vec![ + Field::new("data", DataType::Utf8, false), + ]); + let struct_field = Field::new("user_info", DataType::Struct(struct_fields), false); + let table_schema = Schema::new(vec![struct_field.clone()]); + // The file schema has `name` as a shredded field + let struct_fields = Fields::from(vec![ + Field::new("data", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + ]); + let struct_field = Field::new("user_info", DataType::Struct(struct_fields), false); + let file_schema = Schema::new(vec![struct_field.clone()]); + // Build a RecordBatch with shredded data + let names = StringArray::from(vec!["Alice", "Bob", "Charlie", "Dave"]); + let user_info = StructArray::from(vec![ + ( + Arc::new(Field::new("data", DataType::Utf8, false)), + Arc::new(StringArray::from(vec![ + r#"{"age": 30}"#, + r#"{"age": 25}"#, + r#"{"age": 35}"#, + r#"{"age": 22}"#, + ])) as ArrayRef, + ), + ( + Arc::new(Field::new("name", DataType::Utf8, false)), + Arc::new(names.clone()) as ArrayRef, + ), + ]); + + ( + Arc::new(table_schema), + RecordBatch::try_new( + Arc::new(file_schema), + vec![Arc::new(user_info)], + ) + .unwrap(), + ) +} + +/// Custom TableProvider that uses a StructFieldRewriter +#[derive(Debug)] +struct ExampleTableProvider { + schema: SchemaRef, +} + +impl ExampleTableProvider { + fn new(schema: SchemaRef) -> Self { + Self { schema } + } +} + +#[async_trait] +impl TableProvider for ExampleTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + // Implementers can choose to mark these filters as exact or inexact. + // If marked as exact they cannot have false positives and must always be applied. + // If marked as Inexact they can have false positives and at runtime the rewriter + // can decide to not rewrite / ignore some filters since they will be re-evaluated upstream. + // For the purposes of this example we mark them as Exact to demonstrate the rewriter is working and the filtering is not being re-evaluated upstream. + Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let schema = self.schema.clone(); + let df_schema = DFSchema::try_from(schema.clone())?; + let filter = state.create_physical_expr( + conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)), + &df_schema, + )?; + + let parquet_source = ParquetSource::default() + .with_predicate(filter) + .with_pushdown_filters(true) + // if the rewriter needs a reference to the table schema you can bind self.schema() here + .with_predicate_rewrite_hook(Arc::new(ShreddedVariantRewriter) as _); + + let object_store_url = ObjectStoreUrl::parse("memory://")?; + + let store = state.runtime_env().object_store(object_store_url)?; + + let mut files = vec![]; + let mut listing = store.list(None); + while let Some(file) = listing.next().await { + if let Ok(file) = file { + files.push(file); + } + } + + let file_group = files + .iter() + .map(|file| PartitionedFile::new(file.location.clone(), file.size)) + .collect(); + + let file_scan_config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("memory://")?, + schema, + Arc::new(parquet_source), + ) + .with_projection(projection.cloned()) + .with_limit(limit) + .with_file_group(file_group); + + Ok(Arc::new(DataSourceExec::new(Arc::new( + file_scan_config.build(), + )))) + } +} + +/// Scalar UDF that uses serde_json to access json fields +#[derive(Debug)] +pub struct JsonGetStr { + signature: Signature, + aliases: [String; 1], +} + +impl Default for JsonGetStr { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: ["json_get_str".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGetStr { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert!( + args.args.len() == 2, + "json_get_str requires exactly 2 arguments" + ); + let key = match &args.args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(key))) => key, + _ => { + return Err(datafusion::error::DataFusionError::Execution( + "json_get_str first argument must be a string".to_string(), + )) + } + }; + // We expect a struct array with a field called `data` that contains JSON strings + let struct_array = match &args.args[1] { + ColumnarValue::Array(array) => array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Execution( + "json_get_str second argument must be a struct array".to_string(), + ) + })?, + _ => { + return Err(datafusion::error::DataFusionError::Execution( + "json_get_str second argument must be a struct array".to_string(), + )) + } + }; + // Extract the "data" field from the struct array + let data_array = struct_array.column_by_name("data").ok_or_else(|| { + datafusion::error::DataFusionError::Execution( + "json_get_str second argument must have a 'data' field".to_string(), + ) + })?; + let json_array = match data_array.as_any().downcast_ref::() { + Some(array) => array, + None => { + return Err(datafusion::error::DataFusionError::Execution( + "json_get_str second argument 'data' field must be a StringArray" + .to_string(), + )) + } + }; + let values = json_array + .iter() + .map(|value| { + value + .map(|v| { + let json_value: serde_json::Value = + serde_json::from_str(&v).unwrap_or_default(); + json_value + .get(&key) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .flatten() + }) + .collect::(); + Ok(ColumnarValue::Array(Arc::new(values))) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Rewriter that converts struct field access to flattened column references +#[derive(Debug)] +struct ShreddedVariantRewriter; + +impl PhysicalExprSchemaRewriteHook for ShreddedVariantRewriter { + fn rewrite( + &self, + expr: Arc, + physical_file_schema: &Schema, + ) -> Result>> { + if let Some(func) = expr.as_any().downcast_ref::() { + if + func.name() == "json_get_str" && func.args().len() == 2 { + // Get the key from the first argument + if let Some(literal) = func.args()[0] + .as_any() + .downcast_ref::() + { + if let ScalarValue::Utf8(Some(field_name)) = literal.value() { + // Get the column from the second argument + if let Some(column) = func.args()[1] + .as_any() + .downcast_ref::() + { + let column_name = column.name(); + // Get the physical file schema's field + if let Ok(source_field_index) = + physical_file_schema.index_of(column_name) + { + let source_field = + physical_file_schema.field(source_field_index); + // If it's a struct field check if there is a shredded field with the name `field_name` + if let DataType::Struct(struct_fields) = + source_field.data_type() + { + if let Some((_, shredded_field)) = + struct_fields.find(field_name) + { + if shredded_field.data_type() == &DataType::Utf8 { + // Replace the whole expression with a struct field access on the shredded field + let args = vec![ + Arc::new(expressions::Column::new( + source_field.name(), + source_field_index, + )) + as Arc, + Arc::new(expressions::Literal::new( + ScalarValue::Utf8(Some( + field_name.clone(), + )), + )), + ]; + let return_field = Arc::new( + Field::new( + format!( + "{}.{}", + source_field.name(), + field_name + ), + DataType::Utf8, + true, + ), + ); + let new_expr = Arc::new( + ScalarFunctionExpr::new( + "get_field", + Arc::new(GetFieldFunc::new().into()), + args, + return_field, + ), + ); + return Ok(Transformed { + data: new_expr, + tnr: TreeNodeRecursion::Stop, + transformed: true, + }); + } + } + } + } + } + } + } + } + } + Ok(Transformed::no(expr)) + } +} From 2deb4b35ebfc9faffea73bb5e8180cf781c366ac Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 26 Jun 2025 13:28:28 -0500 Subject: [PATCH 03/21] working --- .../examples/variant_shredding.rs | 232 ++++++------------ .../datasource-parquet/src/row_filter.rs | 2 - .../physical-expr/src/schema_rewriter.rs | 24 +- 3 files changed, 88 insertions(+), 170 deletions(-) diff --git a/datafusion-examples/examples/variant_shredding.rs b/datafusion-examples/examples/variant_shredding.rs index ddca86af1031c..6d045ae9c5bef 100644 --- a/datafusion-examples/examples/variant_shredding.rs +++ b/datafusion-examples/examples/variant_shredding.rs @@ -18,29 +18,23 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ - as_string_array, ArrayRef, Int32Array, RecordBatch, StringArray, StructArray, -}; +use arrow::array::{RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow_schema::Fields; use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter -}; +use datafusion::common::tree_node::{Transformed, TreeNodeRecursion}; use datafusion::common::{assert_contains, DFSchema, Result}; use datafusion::datasource::listing::PartitionedFile; -use datafusion::datasource::physical_plan::parquet::source; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; -use datafusion::functions::core::getfield::GetFieldFunc; use datafusion::logical_expr::utils::conjunction; use datafusion::logical_expr::{ - ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TableProviderFilterPushDown, TableType, Volatility + ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + TableProviderFilterPushDown, TableType, Volatility, }; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; @@ -55,19 +49,19 @@ use object_store::memory::InMemory; use object_store::path::Path; use object_store::{ObjectStore, PutPayload}; -// Example showing how to implement custom filter rewriting for struct fields. +// Example showing how to implement custom filter rewriting for variant shredding. // -// In this example, we have a table with a struct column like: -// struct_col: {"a": 1, "b": "foo"} +// In this example, we have a table with flat columns using underscore prefixes: +// data: "...", _data.name: "..." // // Our custom TableProvider will use a FilterExpressionRewriter to rewrite -// expressions like `struct_col['a'] = 10` to use a flattened column name -// `_struct_col.a` if it exists in the file schema. +// expressions like `json_get_str('name', data)` to use a flattened column name +// `_data.name` if it exists in the file schema. #[tokio::main] async fn main() -> Result<()> { - println!("=== Creating example data with structs and flattened fields ==="); + println!("=== Creating example data with flat columns and underscore prefixes ==="); - // Create sample data with both struct columns and flattened fields + // Create sample data with flat columns using underscore prefixes let (table_schema, batch) = create_sample_data(); let store = InMemory::new(); @@ -75,7 +69,7 @@ async fn main() -> Result<()> { let mut buf = vec![]; let props = WriterProperties::builder() - .set_max_row_group_size(1) + .set_max_row_group_size(2) .build(); let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)) @@ -106,85 +100,68 @@ async fn main() -> Result<()> { Arc::new(store), ); - // println!("\n=== Showing all data ==="); - // let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?; - // arrow::util::pretty::print_batches(&batches)?; + println!("\n=== Showing all data ==="); + let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?; + arrow::util::pretty::print_batches(&batches)?; - println!("\n=== Running query with struct field access and filter < 30 ==="); - let query = "SELECT count(*) FROM structs WHERE json_get_str('name', user_info) = 'Bob'"; + println!("\n=== Running query with flat column access and filter ==="); + let query = "SELECT json_get_str('age', data) as age FROM structs WHERE json_get_str('name', data) = 'Bob'"; println!("Query: {query}"); - let batches = ctx - .sql(query) - .await? - .collect() - .await?; + let batches = ctx.sql(query).await?.collect().await?; #[rustfmt::skip] let expected = [ - "+-------------------------+", - "| structs.user_info[name] |", - "+-------------------------+", - "| Bob |", - "| Dave |", - "+-------------------------+", + "+-----+", + "| age |", + "+-----+", + "| 25 |", + "+-----+", ]; arrow::util::pretty::print_batches(&batches)?; - println!("batches: {batches:?}"); assert_batches_eq!(expected, &batches); - // println!("\n=== Running explain analyze to confirm row group pruning ==="); + println!("\n=== Running explain analyze to confirm row group pruning ==="); - // let batches = ctx - // .sql("EXPLAIN ANALYZE SELECT user_info['name'] FROM structs WHERE user_info['age'] < 30") - // .await? - // .collect() - // .await?; - // let plan = format!("{}", arrow::util::pretty::pretty_format_batches(&batches)?); - // println!("{plan}"); - // assert_contains!(&plan, "row_groups_pruned_statistics=2"); + let batches = ctx + .sql(&format!("EXPLAIN ANALYZE {query}")) + .await? + .collect() + .await?; + let plan = format!("{}", arrow::util::pretty::pretty_format_batches(&batches)?); + println!("{plan}"); + assert_contains!(&plan, "row_groups_pruned_statistics=1"); + assert_contains!(&plan, "pushdown_rows_pruned=1"); Ok(()) } -/// Create the example data that has a struct column with `name` as a shredded field and `age` as a non-shredded field. +/// Create the example data with flat columns using underscore prefixes. +/// The table schema has `data` column, while the file schema has both `data` and `_data.name` as flat columns. fn create_sample_data() -> (SchemaRef, RecordBatch) { - // The table schema doesn't have any shredded fields - let struct_fields = Fields::from(vec![ - Field::new("data", DataType::Utf8, false), - ]); - let struct_field = Field::new("user_info", DataType::Struct(struct_fields), false); - let table_schema = Schema::new(vec![struct_field.clone()]); - // The file schema has `name` as a shredded field - let struct_fields = Fields::from(vec![ + // The table schema only has the main data column + let table_schema = Schema::new(vec![Field::new("data", DataType::Utf8, false)]); + + // The file schema has both the main column and the shredded flat column with underscore prefix + let file_schema = Schema::new(vec![ Field::new("data", DataType::Utf8, false), - Field::new("name", DataType::Utf8, false), + Field::new("_data.name", DataType::Utf8, false), ]); - let struct_field = Field::new("user_info", DataType::Struct(struct_fields), false); - let file_schema = Schema::new(vec![struct_field.clone()]); - // Build a RecordBatch with shredded data - let names = StringArray::from(vec!["Alice", "Bob", "Charlie", "Dave"]); - let user_info = StructArray::from(vec![ - ( - Arc::new(Field::new("data", DataType::Utf8, false)), - Arc::new(StringArray::from(vec![ - r#"{"age": 30}"#, - r#"{"age": 25}"#, - r#"{"age": 35}"#, - r#"{"age": 22}"#, - ])) as ArrayRef, - ), - ( - Arc::new(Field::new("name", DataType::Utf8, false)), - Arc::new(names.clone()) as ArrayRef, - ), + + // Build a RecordBatch with flat columns + let data_array = StringArray::from(vec![ + r#"{"age": 30}"#, + r#"{"age": 25}"#, + r#"{"age": 35}"#, + r#"{"age": 22}"#, ]); + let names_array = StringArray::from(vec!["Alice", "Bob", "Charlie", "Dave"]); ( Arc::new(table_schema), RecordBatch::try_new( Arc::new(file_schema), - vec![Arc::new(user_info)], + vec![Arc::new(data_array), Arc::new(names_array)], ) .unwrap(), ) @@ -326,34 +303,19 @@ impl ScalarUDFImpl for JsonGetStr { )) } }; - // We expect a struct array with a field called `data` that contains JSON strings - let struct_array = match &args.args[1] { + // We expect a string array that contains JSON strings + let json_array = match &args.args[1] { ColumnarValue::Array(array) => array .as_any() - .downcast_ref::() + .downcast_ref::() .ok_or_else(|| { datafusion::error::DataFusionError::Execution( - "json_get_str second argument must be a struct array".to_string(), + "json_get_str second argument must be a string array".to_string(), ) })?, _ => { return Err(datafusion::error::DataFusionError::Execution( - "json_get_str second argument must be a struct array".to_string(), - )) - } - }; - // Extract the "data" field from the struct array - let data_array = struct_array.column_by_name("data").ok_or_else(|| { - datafusion::error::DataFusionError::Execution( - "json_get_str second argument must have a 'data' field".to_string(), - ) - })?; - let json_array = match data_array.as_any().downcast_ref::() { - Some(array) => array, - None => { - return Err(datafusion::error::DataFusionError::Execution( - "json_get_str second argument 'data' field must be a StringArray" - .to_string(), + "json_get_str second argument must be a string array".to_string(), )) } }; @@ -364,10 +326,7 @@ impl ScalarUDFImpl for JsonGetStr { .map(|v| { let json_value: serde_json::Value = serde_json::from_str(&v).unwrap_or_default(); - json_value - .get(&key) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) + json_value.get(&key).map(|v| v.to_string()) }) .flatten() }) @@ -380,7 +339,7 @@ impl ScalarUDFImpl for JsonGetStr { } } -/// Rewriter that converts struct field access to flattened column references +/// Rewriter that converts json_get_str calls to direct flat column references #[derive(Debug)] struct ShreddedVariantRewriter; @@ -391,8 +350,7 @@ impl PhysicalExprSchemaRewriteHook for ShreddedVariantRewriter { physical_file_schema: &Schema, ) -> Result>> { if let Some(func) = expr.as_any().downcast_ref::() { - if - func.name() == "json_get_str" && func.args().len() == 2 { + if func.name() == "json_get_str" && func.args().len() == 2 { // Get the key from the first argument if let Some(literal) = func.args()[0] .as_any() @@ -405,59 +363,29 @@ impl PhysicalExprSchemaRewriteHook for ShreddedVariantRewriter { .downcast_ref::() { let column_name = column.name(); - // Get the physical file schema's field - if let Ok(source_field_index) = - physical_file_schema.index_of(column_name) + // Check if there's a flat column with underscore prefix + let flat_column_name = + format!("_{}.{}", column_name, field_name); + + if let Ok(flat_field_index) = + physical_file_schema.index_of(&flat_column_name) { - let source_field = - physical_file_schema.field(source_field_index); - // If it's a struct field check if there is a shredded field with the name `field_name` - if let DataType::Struct(struct_fields) = - source_field.data_type() - { - if let Some((_, shredded_field)) = - struct_fields.find(field_name) - { - if shredded_field.data_type() == &DataType::Utf8 { - // Replace the whole expression with a struct field access on the shredded field - let args = vec![ - Arc::new(expressions::Column::new( - source_field.name(), - source_field_index, - )) - as Arc, - Arc::new(expressions::Literal::new( - ScalarValue::Utf8(Some( - field_name.clone(), - )), - )), - ]; - let return_field = Arc::new( - Field::new( - format!( - "{}.{}", - source_field.name(), - field_name - ), - DataType::Utf8, - true, - ), - ); - let new_expr = Arc::new( - ScalarFunctionExpr::new( - "get_field", - Arc::new(GetFieldFunc::new().into()), - args, - return_field, - ), - ); - return Ok(Transformed { - data: new_expr, - tnr: TreeNodeRecursion::Stop, - transformed: true, - }); - } - } + let flat_field = + physical_file_schema.field(flat_field_index); + + if flat_field.data_type() == &DataType::Utf8 { + // Replace the whole expression with a direct column reference + let new_expr = Arc::new(expressions::Column::new( + &flat_column_name, + flat_field_index, + )) + as Arc; + + return Ok(Transformed { + data: new_expr, + tnr: TreeNodeRecursion::Stop, + transformed: true, + }); } } } diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 816d15fcbc007..5626f83186e31 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -430,8 +430,6 @@ pub fn build_row_filter( .flatten() .collect(); - println!("Filter candidates: {}", candidates.len()); - // no candidates if candidates.is_empty() { return Ok(None); diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index 70873e78031d6..b9f57fe7afb2e 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -21,7 +21,6 @@ use std::sync::Arc; use arrow::compute::can_cast_types; use arrow::datatypes::{FieldRef, Schema}; -use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{ exec_err, tree_node::{Transformed, TransformedResult, TreeNode}, @@ -114,23 +113,16 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { pub fn rewrite(&self, expr: Arc) -> Result> { println!("Top level expression: {expr}"); expr.transform(|expr| { - println!("Rewriting expression: {expr}"); - let transformed = if let Some(rewriter) = self.rewrite_hook.as_ref() { + if let Some(rewriter) = self.rewrite_hook.as_ref() { // If a rewrite hook is provided, apply it first let transformed = - rewriter.rewrite(expr.clone(), &self.physical_file_schema)?; - Ok(transformed) - // if transformed.tnr == TreeNodeRecursion::Stop { - // // If the hook indicates no further recursion, return the transformed expression - // return Ok(transformed); - // } else { - // transformed.transform_parent(|expr| self.rewrite_expr(expr)) - // } - } else { - // Otherwise, rewrite the expression directly - self.rewrite_expr(expr.clone()) - }?; - Ok(transformed) + rewriter.rewrite(Arc::clone(&expr), &self.physical_file_schema)?; + if transformed.transformed { + // If the hook transformed the expression, return it + return Ok(transformed); + } + } + self.rewrite_expr(Arc::clone(&expr)) }) .data() } From 8c7a6fb49714fbd3313c054bb824cd87f9279dd9 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 26 Jun 2025 13:34:09 -0500 Subject: [PATCH 04/21] Update datafusion/datasource-parquet/src/opener.rs --- datafusion/datasource-parquet/src/opener.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index f48ebb71f747a..765bcd3fea9ca 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -267,8 +267,6 @@ impl FileOpener for ParquetOpener { .transpose()? .transpose()?; - println!("predicate: {predicate:?}"); - // Build predicates for this specific file let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( predicate.as_ref(), From 14772764fba68c4c8dd1c15ab04734b73c17d02a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jun 2025 08:21:33 -0500 Subject: [PATCH 05/21] lint --- datafusion-examples/examples/variant_shredding.rs | 15 ++++++--------- datafusion/physical-expr/src/schema_rewriter.rs | 2 +- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/datafusion-examples/examples/variant_shredding.rs b/datafusion-examples/examples/variant_shredding.rs index 6d045ae9c5bef..aeca332bb269f 100644 --- a/datafusion-examples/examples/variant_shredding.rs +++ b/datafusion-examples/examples/variant_shredding.rs @@ -322,13 +322,11 @@ impl ScalarUDFImpl for JsonGetStr { let values = json_array .iter() .map(|value| { - value - .map(|v| { - let json_value: serde_json::Value = - serde_json::from_str(&v).unwrap_or_default(); - json_value.get(&key).map(|v| v.to_string()) - }) - .flatten() + value.and_then(|v| { + let json_value: serde_json::Value = + serde_json::from_str(v).unwrap_or_default(); + json_value.get(key).map(|v| v.to_string()) + }) }) .collect::(); Ok(ColumnarValue::Array(Arc::new(values))) @@ -364,8 +362,7 @@ impl PhysicalExprSchemaRewriteHook for ShreddedVariantRewriter { { let column_name = column.name(); // Check if there's a flat column with underscore prefix - let flat_column_name = - format!("_{}.{}", column_name, field_name); + let flat_column_name = format!("_{column_name}.{field_name}"); if let Ok(flat_field_index) = physical_file_schema.index_of(&flat_column_name) diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index b9f57fe7afb2e..1b093dbbe7410 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -116,7 +116,7 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { if let Some(rewriter) = self.rewrite_hook.as_ref() { // If a rewrite hook is provided, apply it first let transformed = - rewriter.rewrite(Arc::clone(&expr), &self.physical_file_schema)?; + rewriter.rewrite(Arc::clone(&expr), self.physical_file_schema)?; if transformed.transformed { // If the hook transformed the expression, return it return Ok(transformed); From f4c37bd364eb04d4b5d51686a8d201be572fb461 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 3 Jul 2025 17:47:03 -0500 Subject: [PATCH 06/21] decouple --- .../examples/variant_shredding.rs | 15 +++++++++- datafusion/datasource-parquet/src/opener.rs | 24 +++++++++++---- .../physical-expr/src/schema_rewriter.rs | 29 ++----------------- 3 files changed, 35 insertions(+), 33 deletions(-) diff --git a/datafusion-examples/examples/variant_shredding.rs b/datafusion-examples/examples/variant_shredding.rs index aeca332bb269f..f491200015b44 100644 --- a/datafusion-examples/examples/variant_shredding.rs +++ b/datafusion-examples/examples/variant_shredding.rs @@ -25,7 +25,9 @@ use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{Transformed, TreeNodeRecursion}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion::common::{assert_contains, DFSchema, Result}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; @@ -346,6 +348,17 @@ impl PhysicalExprSchemaRewriteHook for ShreddedVariantRewriter { &self, expr: Arc, physical_file_schema: &Schema, + ) -> Result> { + expr.transform(|expr| self.rewrite_impl(expr, physical_file_schema)) + .data() + } +} + +impl ShreddedVariantRewriter { + fn rewrite_impl( + &self, + expr: Arc, + physical_file_schema: &Schema, ) -> Result>> { if let Some(func) = expr.as_any().downcast_ref::() { if func.name() == "json_get_str" && func.args().len() == 2 { diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 765bcd3fea9ca..9fc4d6cf5169c 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -242,7 +242,7 @@ impl FileOpener for ParquetOpener { // This evaluates missing columns and inserts any necessary casts. let predicate = predicate .map(|p| { - let mut rewriter = PhysicalExprSchemaRewriter::new( + let rewriter = PhysicalExprSchemaRewriter::new( &physical_file_schema, &logical_file_schema, ) @@ -250,11 +250,6 @@ impl FileOpener for ParquetOpener { partition_fields.to_vec(), file.partition_values, ); - if let Some(predicate_rewrite_hook) = predicate_rewrite_hook.as_ref() - { - rewriter = rewriter - .with_rewrite_hook(Arc::clone(predicate_rewrite_hook)); - }; rewriter.rewrite(p).map_err(ArrowError::from).map(|p| { // After rewriting to the file schema, further simplifications may be possible. // For example, if `'a' = col_that_is_missing` becomes `'a' = NULL` that can then be simplified to `FALSE` @@ -267,6 +262,23 @@ impl FileOpener for ParquetOpener { .transpose()? .transpose()?; + // if let (Some(predicate_rewrite_hook), Some(predicate)) = (predicate_rewrite_hook.as_ref(), predicate.as_ref()) + // { + // predicate = predicate_rewrite_hook.rewrite(Arc::clone(&predicate), physical_file_schema)?; + // }; + + let predicate = predicate + .map(|p| { + if let Some(predicate_rewrite_hook) = predicate_rewrite_hook.as_ref() + { + predicate_rewrite_hook + .rewrite(Arc::clone(&p), &physical_file_schema) + } else { + Ok(p) + } + }) + .transpose()?; + // Build predicates for this specific file let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( predicate.as_ref(), diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index 1b093dbbe7410..03de20d253634 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -37,7 +37,7 @@ pub trait PhysicalExprSchemaRewriteHook: Send + Sync + std::fmt::Debug { &self, expr: Arc, physical_file_schema: &Schema, - ) -> Result>>; + ) -> Result>; } /// Builder for rewriting physical expressions to match different schemas. @@ -63,7 +63,6 @@ pub struct PhysicalExprSchemaRewriter<'a> { logical_file_schema: &'a Schema, partition_fields: Vec, partition_values: Vec, - rewrite_hook: Option>, } impl<'a> PhysicalExprSchemaRewriter<'a> { @@ -77,7 +76,6 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { logical_file_schema, partition_fields: Vec::new(), partition_values: Vec::new(), - rewrite_hook: None, } } @@ -95,15 +93,6 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { self } - /// Add a hook to intercept expression rewrites - pub fn with_rewrite_hook( - mut self, - hook: Arc, - ) -> Self { - self.rewrite_hook = Some(hook); - self - } - /// Rewrite the given physical expression to match the target schema /// /// This method applies the following transformations: @@ -111,20 +100,8 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { /// 2. Handles missing columns by inserting null literals /// 3. Casts columns when logical and physical schemas have different types pub fn rewrite(&self, expr: Arc) -> Result> { - println!("Top level expression: {expr}"); - expr.transform(|expr| { - if let Some(rewriter) = self.rewrite_hook.as_ref() { - // If a rewrite hook is provided, apply it first - let transformed = - rewriter.rewrite(Arc::clone(&expr), self.physical_file_schema)?; - if transformed.transformed { - // If the hook transformed the expression, return it - return Ok(transformed); - } - } - self.rewrite_expr(Arc::clone(&expr)) - }) - .data() + expr.transform(|expr| self.rewrite_expr(Arc::clone(&expr))) + .data() } fn rewrite_expr( From b14ec7572031c10d68792a934bd6c268ad8a57ff Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 3 Jul 2025 18:42:29 -0500 Subject: [PATCH 07/21] remove commented out code, flip order --- datafusion/datasource-parquet/src/opener.rs | 33 +++++++++------------ 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 9fc4d6cf5169c..9465e8c361726 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -122,7 +122,7 @@ impl FileOpener for ParquetOpener { let schema_adapter = self .schema_adapter_factory .create(projected_schema, Arc::clone(&self.logical_file_schema)); - let predicate = self.predicate.clone(); + let mut predicate = self.predicate.clone(); let logical_file_schema = Arc::clone(&self.logical_file_schema); let partition_fields = self.partition_fields.clone(); let reorder_predicates = self.reorder_filters; @@ -238,9 +238,21 @@ impl FileOpener for ParquetOpener { } } + predicate = predicate + .map(|p| { + if let Some(predicate_rewrite_hook) = predicate_rewrite_hook.as_ref() + { + predicate_rewrite_hook + .rewrite(Arc::clone(&p), &physical_file_schema) + } else { + Ok(p) + } + }) + .transpose()?; + // Adapt the predicate to the physical file schema. // This evaluates missing columns and inserts any necessary casts. - let predicate = predicate + predicate = predicate .map(|p| { let rewriter = PhysicalExprSchemaRewriter::new( &physical_file_schema, @@ -262,23 +274,6 @@ impl FileOpener for ParquetOpener { .transpose()? .transpose()?; - // if let (Some(predicate_rewrite_hook), Some(predicate)) = (predicate_rewrite_hook.as_ref(), predicate.as_ref()) - // { - // predicate = predicate_rewrite_hook.rewrite(Arc::clone(&predicate), physical_file_schema)?; - // }; - - let predicate = predicate - .map(|p| { - if let Some(predicate_rewrite_hook) = predicate_rewrite_hook.as_ref() - { - predicate_rewrite_hook - .rewrite(Arc::clone(&p), &physical_file_schema) - } else { - Ok(p) - } - }) - .transpose()?; - // Build predicates for this specific file let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( predicate.as_ref(), From 6794ffce4e2d2f0311bdad83c4da982d52d58835 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 3 Jul 2025 21:01:07 -0500 Subject: [PATCH 08/21] handle edge case with rewrite --- .../physical-expr/src/schema_rewriter.rs | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index 03de20d253634..a1836771b03de 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -120,7 +120,7 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { expr: Arc, column: &Column, ) -> Result>> { - // Get the logical field for this column + // Get the logical field for this column if it exists in the logical schema let logical_field = match self.logical_file_schema.field_with_name(column.name()) { Ok(field) => field, @@ -129,10 +129,22 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { if let Some(partition_value) = self.get_partition_value(column.name()) { return Ok(Transformed::yes(expressions::lit(partition_value))); } - // If the column is not found in the logical schema and is not a partition value, return an error - // This should probably never be hit unless something upstream broke, but nontheless it's better - // for us to return a handleable error than to panic / do something unexpected. - return Err(e.into()); + // This can be hit if a custom rewrite injected a reference to a column that doesn't exist in the logical schema. + // For example, a pre-computed column that is kept only in the physical schema. + // If the column exists in the physical schema, we can still use it. + if let Ok(physical_field) = + self.physical_file_schema.field_with_name(column.name()) + { + // If the column exists in the physical schema, we can use it in place of the logical column. + // This is nice to users because if they do a rewrite that results in something like `phyiscal_int32_col = 123u64` + // we'll at least handle the casts for them. + physical_field + } else { + // A completely unknwon column that doesn't exist in either schema! + // This should probably never be hit unless something upstream broke, but nontheless it's better + // for us to return a handleable error than to panic / do something unexpected. + return Err(e.into()); + } } }; From 6ef51fcbc99c5dbd0940375f2c30c1ccc5086ade Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 9 Jul 2025 23:35:11 -0500 Subject: [PATCH 09/21] address pr feedback --- .../examples/variant_shredding.rs | 408 ------------------ .../core/tests/parquet/filter_pushdown.rs | 251 +++++++++++ datafusion/datasource-parquet/src/opener.rs | 4 +- datafusion/datasource-parquet/src/source.rs | 6 +- .../physical-expr/src/schema_rewriter.rs | 2 +- 5 files changed, 257 insertions(+), 414 deletions(-) delete mode 100644 datafusion-examples/examples/variant_shredding.rs diff --git a/datafusion-examples/examples/variant_shredding.rs b/datafusion-examples/examples/variant_shredding.rs deleted file mode 100644 index f491200015b44..0000000000000 --- a/datafusion-examples/examples/variant_shredding.rs +++ /dev/null @@ -1,408 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::any::Any; -use std::sync::Arc; - -use arrow::array::{RecordBatch, StringArray}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use async_trait::async_trait; - -use datafusion::assert_batches_eq; -use datafusion::catalog::memory::DataSourceExec; -use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, -}; -use datafusion::common::{assert_contains, DFSchema, Result}; -use datafusion::datasource::listing::PartitionedFile; -use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; -use datafusion::execution::context::SessionContext; -use datafusion::execution::object_store::ObjectStoreUrl; -use datafusion::logical_expr::utils::conjunction; -use datafusion::logical_expr::{ - ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - TableProviderFilterPushDown, TableType, Volatility, -}; -use datafusion::parquet::arrow::ArrowWriter; -use datafusion::parquet::file::properties::WriterProperties; -use datafusion::physical_expr::schema_rewriter::PhysicalExprSchemaRewriteHook; -use datafusion::physical_expr::PhysicalExpr; -use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::{lit, SessionConfig}; -use datafusion::scalar::ScalarValue; -use futures::StreamExt; -use object_store::memory::InMemory; -use object_store::path::Path; -use object_store::{ObjectStore, PutPayload}; - -// Example showing how to implement custom filter rewriting for variant shredding. -// -// In this example, we have a table with flat columns using underscore prefixes: -// data: "...", _data.name: "..." -// -// Our custom TableProvider will use a FilterExpressionRewriter to rewrite -// expressions like `json_get_str('name', data)` to use a flattened column name -// `_data.name` if it exists in the file schema. -#[tokio::main] -async fn main() -> Result<()> { - println!("=== Creating example data with flat columns and underscore prefixes ==="); - - // Create sample data with flat columns using underscore prefixes - let (table_schema, batch) = create_sample_data(); - - let store = InMemory::new(); - let buf = { - let mut buf = vec![]; - - let props = WriterProperties::builder() - .set_max_row_group_size(2) - .build(); - - let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)) - .expect("creating writer"); - - writer.write(&batch).expect("Writing batch"); - writer.close().unwrap(); - buf - }; - let path = Path::from("example.parquet"); - let payload = PutPayload::from_bytes(buf.into()); - store.put(&path, payload).await?; - - // Create a custom table provider that rewrites struct field access - let table_provider = Arc::new(ExampleTableProvider::new(table_schema)); - - // Set up query execution - let mut cfg = SessionConfig::new(); - cfg.options_mut().execution.parquet.pushdown_filters = true; - let ctx = SessionContext::new_with_config(cfg); - - // Register our table - ctx.register_table("structs", table_provider)?; - ctx.register_udf(ScalarUDF::new_from_impl(JsonGetStr::default())); - - ctx.runtime_env().register_object_store( - ObjectStoreUrl::parse("memory://")?.as_ref(), - Arc::new(store), - ); - - println!("\n=== Showing all data ==="); - let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?; - arrow::util::pretty::print_batches(&batches)?; - - println!("\n=== Running query with flat column access and filter ==="); - let query = "SELECT json_get_str('age', data) as age FROM structs WHERE json_get_str('name', data) = 'Bob'"; - println!("Query: {query}"); - - let batches = ctx.sql(query).await?.collect().await?; - - #[rustfmt::skip] - let expected = [ - "+-----+", - "| age |", - "+-----+", - "| 25 |", - "+-----+", - ]; - arrow::util::pretty::print_batches(&batches)?; - assert_batches_eq!(expected, &batches); - - println!("\n=== Running explain analyze to confirm row group pruning ==="); - - let batches = ctx - .sql(&format!("EXPLAIN ANALYZE {query}")) - .await? - .collect() - .await?; - let plan = format!("{}", arrow::util::pretty::pretty_format_batches(&batches)?); - println!("{plan}"); - assert_contains!(&plan, "row_groups_pruned_statistics=1"); - assert_contains!(&plan, "pushdown_rows_pruned=1"); - - Ok(()) -} - -/// Create the example data with flat columns using underscore prefixes. -/// The table schema has `data` column, while the file schema has both `data` and `_data.name` as flat columns. -fn create_sample_data() -> (SchemaRef, RecordBatch) { - // The table schema only has the main data column - let table_schema = Schema::new(vec![Field::new("data", DataType::Utf8, false)]); - - // The file schema has both the main column and the shredded flat column with underscore prefix - let file_schema = Schema::new(vec![ - Field::new("data", DataType::Utf8, false), - Field::new("_data.name", DataType::Utf8, false), - ]); - - // Build a RecordBatch with flat columns - let data_array = StringArray::from(vec![ - r#"{"age": 30}"#, - r#"{"age": 25}"#, - r#"{"age": 35}"#, - r#"{"age": 22}"#, - ]); - let names_array = StringArray::from(vec!["Alice", "Bob", "Charlie", "Dave"]); - - ( - Arc::new(table_schema), - RecordBatch::try_new( - Arc::new(file_schema), - vec![Arc::new(data_array), Arc::new(names_array)], - ) - .unwrap(), - ) -} - -/// Custom TableProvider that uses a StructFieldRewriter -#[derive(Debug)] -struct ExampleTableProvider { - schema: SchemaRef, -} - -impl ExampleTableProvider { - fn new(schema: SchemaRef) -> Self { - Self { schema } - } -} - -#[async_trait] -impl TableProvider for ExampleTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn table_type(&self) -> TableType { - TableType::Base - } - - fn supports_filters_pushdown( - &self, - filters: &[&Expr], - ) -> Result> { - // Implementers can choose to mark these filters as exact or inexact. - // If marked as exact they cannot have false positives and must always be applied. - // If marked as Inexact they can have false positives and at runtime the rewriter - // can decide to not rewrite / ignore some filters since they will be re-evaluated upstream. - // For the purposes of this example we mark them as Exact to demonstrate the rewriter is working and the filtering is not being re-evaluated upstream. - Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) - } - - async fn scan( - &self, - state: &dyn Session, - projection: Option<&Vec>, - filters: &[Expr], - limit: Option, - ) -> Result> { - let schema = self.schema.clone(); - let df_schema = DFSchema::try_from(schema.clone())?; - let filter = state.create_physical_expr( - conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)), - &df_schema, - )?; - - let parquet_source = ParquetSource::default() - .with_predicate(filter) - .with_pushdown_filters(true) - // if the rewriter needs a reference to the table schema you can bind self.schema() here - .with_predicate_rewrite_hook(Arc::new(ShreddedVariantRewriter) as _); - - let object_store_url = ObjectStoreUrl::parse("memory://")?; - - let store = state.runtime_env().object_store(object_store_url)?; - - let mut files = vec![]; - let mut listing = store.list(None); - while let Some(file) = listing.next().await { - if let Ok(file) = file { - files.push(file); - } - } - - let file_group = files - .iter() - .map(|file| PartitionedFile::new(file.location.clone(), file.size)) - .collect(); - - let file_scan_config = FileScanConfigBuilder::new( - ObjectStoreUrl::parse("memory://")?, - schema, - Arc::new(parquet_source), - ) - .with_projection(projection.cloned()) - .with_limit(limit) - .with_file_group(file_group); - - Ok(Arc::new(DataSourceExec::new(Arc::new( - file_scan_config.build(), - )))) - } -} - -/// Scalar UDF that uses serde_json to access json fields -#[derive(Debug)] -pub struct JsonGetStr { - signature: Signature, - aliases: [String; 1], -} - -impl Default for JsonGetStr { - fn default() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: ["json_get_str".to_string()], - } - } -} - -impl ScalarUDFImpl for JsonGetStr { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - self.aliases[0].as_str() - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Utf8) - } - - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - assert!( - args.args.len() == 2, - "json_get_str requires exactly 2 arguments" - ); - let key = match &args.args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(key))) => key, - _ => { - return Err(datafusion::error::DataFusionError::Execution( - "json_get_str first argument must be a string".to_string(), - )) - } - }; - // We expect a string array that contains JSON strings - let json_array = match &args.args[1] { - ColumnarValue::Array(array) => array - .as_any() - .downcast_ref::() - .ok_or_else(|| { - datafusion::error::DataFusionError::Execution( - "json_get_str second argument must be a string array".to_string(), - ) - })?, - _ => { - return Err(datafusion::error::DataFusionError::Execution( - "json_get_str second argument must be a string array".to_string(), - )) - } - }; - let values = json_array - .iter() - .map(|value| { - value.and_then(|v| { - let json_value: serde_json::Value = - serde_json::from_str(v).unwrap_or_default(); - json_value.get(key).map(|v| v.to_string()) - }) - }) - .collect::(); - Ok(ColumnarValue::Array(Arc::new(values))) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -/// Rewriter that converts json_get_str calls to direct flat column references -#[derive(Debug)] -struct ShreddedVariantRewriter; - -impl PhysicalExprSchemaRewriteHook for ShreddedVariantRewriter { - fn rewrite( - &self, - expr: Arc, - physical_file_schema: &Schema, - ) -> Result> { - expr.transform(|expr| self.rewrite_impl(expr, physical_file_schema)) - .data() - } -} - -impl ShreddedVariantRewriter { - fn rewrite_impl( - &self, - expr: Arc, - physical_file_schema: &Schema, - ) -> Result>> { - if let Some(func) = expr.as_any().downcast_ref::() { - if func.name() == "json_get_str" && func.args().len() == 2 { - // Get the key from the first argument - if let Some(literal) = func.args()[0] - .as_any() - .downcast_ref::() - { - if let ScalarValue::Utf8(Some(field_name)) = literal.value() { - // Get the column from the second argument - if let Some(column) = func.args()[1] - .as_any() - .downcast_ref::() - { - let column_name = column.name(); - // Check if there's a flat column with underscore prefix - let flat_column_name = format!("_{column_name}.{field_name}"); - - if let Ok(flat_field_index) = - physical_file_schema.index_of(&flat_column_name) - { - let flat_field = - physical_file_schema.field(flat_field_index); - - if flat_field.data_type() == &DataType::Utf8 { - // Replace the whole expression with a direct column reference - let new_expr = Arc::new(expressions::Column::new( - &flat_column_name, - flat_field_index, - )) - as Arc; - - return Ok(Transformed { - data: new_expr, - tnr: TreeNodeRecursion::Stop, - transformed: true, - }); - } - } - } - } - } - } - } - Ok(Transformed::no(expr)) - } -} diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index b8d570916c7c5..b9144f960d7bf 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -601,3 +601,254 @@ fn get_value(metrics: &MetricsSet, metric_name: &str) -> usize { } } } + +#[cfg(test)] +mod schema_rewriter_tests { + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::datasource::physical_plan::ParquetSource; + use datafusion::physical_expr::expressions::{col, lit}; + use datafusion::physical_expr::schema_rewriter::{ + PhysicalExprSchemaRewriter, PhysicalSchemaExprRewriter, + }; + use datafusion::physical_expr::PhysicalExpr; + use datafusion::scalar::ScalarValue; + use std::sync::Arc; + + /// Test basic functionality of PhysicalExprSchemaRewriter + #[test] + fn test_schema_rewriter_basic() { + let physical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Type mismatch + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), // Missing in physical + ]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Test column with type cast + let column_a = col("a", &physical_schema).unwrap(); + let result = rewriter.rewrite(column_a).unwrap(); + assert!(result.to_string().contains("CAST")); + + // Test column with no changes needed + let column_b = col("b", &physical_schema).unwrap(); + let result = rewriter.rewrite(column_b).unwrap(); + assert_eq!(result.to_string(), "b@1"); + + // Test missing column (should be replaced with null) + let column_c = + Arc::new(datafusion::physical_expr::expressions::Column::new("c", 2)); + let result = rewriter.rewrite(column_c).unwrap(); + assert!(result.to_string().contains("NULL")); + } + + /// Test edge case: non-nullable missing column should error + #[test] + fn test_schema_rewriter_non_nullable_missing_error() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), // Non-nullable missing + ]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_b = + Arc::new(datafusion::physical_expr::expressions::Column::new("b", 1)); + let result = rewriter.rewrite(column_b); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Non-nullable column 'b' is missing")); + } + + /// Test partition columns functionality + #[test] + fn test_schema_rewriter_partition_columns() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let partition_fields = + vec![Arc::new(Field::new("partition_col", DataType::Utf8, false))]; + let partition_values = vec![ScalarValue::Utf8(Some("test_value".to_string()))]; + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema) + .with_partition_columns(partition_fields, partition_values); + + let partition_column = Arc::new( + datafusion::physical_expr::expressions::Column::new("partition_col", 0), + ); + let result = rewriter.rewrite(partition_column).unwrap(); + assert!(result.to_string().contains("test_value")); + } + + /// Test complex expressions with multiple columns + #[test] + fn test_schema_rewriter_complex_expressions() { + let physical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Type mismatch + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), // Missing + ]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Create a complex expression: a + 10 > 5 + let column_a = col("a", &physical_schema).unwrap(); + let literal_10 = lit(ScalarValue::Int64(Some(10))); + let literal_5 = lit(ScalarValue::Int64(Some(5))); + + let add_expr = Arc::new(datafusion::physical_expr::expressions::BinaryExpr::new( + column_a, + datafusion_expr::Operator::Plus, + literal_10, + )); + + let gt_expr = Arc::new(datafusion::physical_expr::expressions::BinaryExpr::new( + add_expr, + datafusion_expr::Operator::Gt, + literal_5, + )); + + let result = rewriter.rewrite(gt_expr).unwrap(); + let result_str = result.to_string(); + // Should contain cast for column 'a' + assert!(result_str.contains("CAST")); + // Should still have the same logical structure + assert!(result_str.contains("+ 10")); + assert!(result_str.contains("> 5")); + } + + /// Test that invalid casts are properly handled + #[test] + fn test_schema_rewriter_invalid_cast_error() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Binary, false)]); + + // Try to cast binary to struct (should fail) + let logical_schema = Schema::new(vec![Field::new( + "a", + DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()), + false, + )]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_a = col("a", &physical_schema).unwrap(); + let result = rewriter.rewrite(column_a); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Cannot cast")); + } + + /// Test that column indexes are properly handled + #[test] + fn test_schema_rewriter_column_indexes() { + let physical_schema = Schema::new(vec![ + Field::new("x", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Different index and type + Field::new("b", DataType::Utf8, true), + ]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Create column with wrong index (should be corrected) + let column_a_wrong_idx = + Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)); + let result = rewriter.rewrite(column_a_wrong_idx).unwrap(); + let result_str = result.to_string(); + // Should contain cast and correct index + assert!(result_str.contains("CAST")); + assert!(result_str.contains("a@1")); // Should be index 1 in physical schema + } + + /// Mock rewrite hook for testing custom rewriter integration + #[derive(Debug)] + struct TestRewriteHook; + + impl PhysicalSchemaExprRewriter for TestRewriteHook { + fn rewrite( + &self, + expr: Arc, + _physical_file_schema: &Schema, + ) -> datafusion::common::Result> { + // Simple hook that adds 1 to any integer literal + if let Some(literal) = + expr.as_any() + .downcast_ref::() + { + if let ScalarValue::Int32(Some(val)) = literal.value() { + return Ok(lit(ScalarValue::Int32(Some(val + 1)))); + } + } + Ok(expr) + } + } + + /// Test that custom rewrite hooks work with ParquetSource + #[test] + fn test_parquet_source_with_custom_rewrite_hook() { + let hook = Arc::new(TestRewriteHook); + let _parquet_source = + ParquetSource::default().with_predicate_rewrite_hook(hook.clone()); + + // Test that the hook can be configured (we can't easily verify it's stored without accessing private fields) + // This test ensures the API works correctly + assert!(true); // Simple test to verify compilation + } + + /// Test that rewriter handles expressions with columns that exist only in physical schema + #[test] + fn test_schema_rewriter_physical_only_columns() { + let physical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("_computed_col", DataType::Float64, false), // Only in physical + ]); + + let logical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Reference a column that only exists in physical schema + let computed_col = Arc::new(datafusion::physical_expr::expressions::Column::new( + "_computed_col", + 1, + )); + let result = rewriter.rewrite(computed_col); + + // Should succeed and use the physical column + assert!(result.is_ok()); + assert_eq!(result.unwrap().to_string(), "_computed_col@1"); + } + + /// Test that rewriter handles completely unknown columns properly + #[test] + fn test_schema_rewriter_unknown_column_error() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Reference a column that doesn't exist in either schema + let unknown_col = Arc::new(datafusion::physical_expr::expressions::Column::new( + "unknown", 99, + )); + let result = rewriter.rewrite(unknown_col); + + // Should return an error + assert!(result.is_err()); + } +} diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 9465e8c361726..4f4e1c8a824c2 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -33,7 +33,7 @@ use arrow::datatypes::{FieldRef, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_datasource::PartitionedFile; -use datafusion_physical_expr::schema_rewriter::PhysicalExprSchemaRewriteHook; +use datafusion_physical_expr::schema_rewriter::PhysicalSchemaExprRewriter; use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; use datafusion_physical_expr::PhysicalExprSchemaRewriter; use datafusion_physical_expr_common::physical_expr::{ @@ -94,7 +94,7 @@ pub(super) struct ParquetOpener { /// Optional parquet FileDecryptionProperties pub file_decryption_properties: Option>, /// Rewrite expressions in the context of the file schema - pub predicate_rewrite_hook: Option>, + pub predicate_rewrite_hook: Option>, } impl FileOpener for ParquetOpener { diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index d7dc8b9f3cf13..3150da7fb6160 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -39,7 +39,7 @@ use datafusion_common::{DataFusionError, Statistics}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_physical_expr::conjunction; -use datafusion_physical_expr::schema_rewriter::PhysicalExprSchemaRewriteHook; +use datafusion_physical_expr::schema_rewriter::PhysicalSchemaExprRewriter; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::filter_pushdown::{ @@ -279,7 +279,7 @@ pub struct ParquetSource { /// Optional hint for the size of the parquet metadata pub(crate) metadata_size_hint: Option, pub(crate) projected_statistics: Option, - pub(crate) predicate_rewrite_hook: Option>, + pub(crate) predicate_rewrite_hook: Option>, } impl ParquetSource { @@ -323,7 +323,7 @@ impl ParquetSource { /// that vary on a per-file basis. pub fn with_predicate_rewrite_hook( mut self, - predicate_rewrite_hook: Arc, + predicate_rewrite_hook: Arc, ) -> Self { self.predicate_rewrite_hook = Some(predicate_rewrite_hook); self diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index a1836771b03de..e5a08ccbc7ad9 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -29,7 +29,7 @@ use datafusion_common::{ use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use crate::expressions::{self, CastExpr, Column}; -pub trait PhysicalExprSchemaRewriteHook: Send + Sync + std::fmt::Debug { +pub trait PhysicalSchemaExprRewriter: Send + Sync + std::fmt::Debug { /// Rewrite a physical expression to match the target schema /// /// This method should return a transformed expression that matches the target schema. From 3376b53936800d1d67503c8901d30d754818478b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 10 Jul 2025 00:06:06 -0500 Subject: [PATCH 10/21] add missing file --- .../examples/json_shredding.rs | 428 ++++++++++++++++++ .../core/tests/parquet/filter_pushdown.rs | 12 - 2 files changed, 428 insertions(+), 12 deletions(-) create mode 100644 datafusion-examples/examples/json_shredding.rs diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs new file mode 100644 index 0000000000000..9f3af66cbbc1c --- /dev/null +++ b/datafusion-examples/examples/json_shredding.rs @@ -0,0 +1,428 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; + +use datafusion::assert_batches_eq; +use datafusion::catalog::memory::DataSourceExec; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; +use datafusion::common::{assert_contains, DFSchema, Result}; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::logical_expr::utils::conjunction; +use datafusion::logical_expr::{ + ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + TableProviderFilterPushDown, TableType, Volatility, +}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::file::properties::WriterProperties; +use datafusion::physical_expr::schema_rewriter::PhysicalSchemaExprRewriter; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::{lit, SessionConfig}; +use datafusion::scalar::ScalarValue; +use futures::StreamExt; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, PutPayload}; + +// Example showing how to implement custom filter rewriting for JSON shredding. +// +// JSON shredding is a technique for optimizing queries on semi-structured data +// by materializing commonly accessed fields into separate columns for better +// columnar storage performance. +// +// In this example, we have a table with both: +// - Original JSON data: data: '{"age": 30}' +// - Shredded flat columns: _data.name: "Alice" (extracted from JSON) +// +// Our custom TableProvider uses a PhysicalSchemaExprRewriter to rewrite +// expressions like `json_get_str('name', data)` to use the pre-computed +// flat column `_data.name` when available. This allows the query engine to: +// 1. Push down predicates for better filtering +// 2. Avoid expensive JSON parsing at query time +// 3. Leverage columnar storage benefits for the materialized fields +#[tokio::main] +async fn main() -> Result<()> { + println!("=== Creating example data with flat columns and underscore prefixes ==="); + + // Create sample data with flat columns using underscore prefixes + let (table_schema, batch) = create_sample_data(); + + let store = InMemory::new(); + let buf = { + let mut buf = vec![]; + + let props = WriterProperties::builder() + .set_max_row_group_size(2) + .build(); + + let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props)) + .expect("creating writer"); + + writer.write(&batch).expect("Writing batch"); + writer.close().unwrap(); + buf + }; + let path = Path::from("example.parquet"); + let payload = PutPayload::from_bytes(buf.into()); + store.put(&path, payload).await?; + + // Create a custom table provider that rewrites struct field access + let table_provider = Arc::new(ExampleTableProvider::new(table_schema)); + + // Set up query execution + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + + // Register our table + ctx.register_table("structs", table_provider)?; + ctx.register_udf(ScalarUDF::new_from_impl(JsonGetStr::default())); + + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::new(store), + ); + + println!("\n=== Showing all data ==="); + let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?; + arrow::util::pretty::print_batches(&batches)?; + + println!("\n=== Running query with flat column access and filter ==="); + let query = "SELECT json_get_str('age', data) as age FROM structs WHERE json_get_str('name', data) = 'Bob'"; + println!("Query: {query}"); + + let batches = ctx.sql(query).await?.collect().await?; + + #[rustfmt::skip] + let expected = [ + "+-----+", + "| age |", + "+-----+", + "| 25 |", + "+-----+", + ]; + arrow::util::pretty::print_batches(&batches)?; + assert_batches_eq!(expected, &batches); + + println!("\n=== Running explain analyze to confirm row group pruning ==="); + + let batches = ctx + .sql(&format!("EXPLAIN ANALYZE {query}")) + .await? + .collect() + .await?; + let plan = format!("{}", arrow::util::pretty::pretty_format_batches(&batches)?); + println!("{plan}"); + assert_contains!(&plan, "row_groups_pruned_statistics=1"); + assert_contains!(&plan, "pushdown_rows_pruned=1"); + + Ok(()) +} + +/// Create the example data with flat columns using underscore prefixes. +/// +/// This demonstrates the logical data structure: +/// - Table schema: What users see (just the 'data' JSON column) +/// - File schema: What's physically stored (both 'data' and materialized '_data.name') +/// +/// The naming convention uses underscore prefixes to indicate shredded columns: +/// - `data` -> original JSON column +/// - `_data.name` -> materialized field from JSON data.name +fn create_sample_data() -> (SchemaRef, RecordBatch) { + // The table schema only has the main data column - this is what users query against + let table_schema = Schema::new(vec![Field::new("data", DataType::Utf8, false)]); + + // The file schema has both the main column and the shredded flat column with underscore prefix + // This represents the actual physical storage with pre-computed columns + let file_schema = Schema::new(vec![ + Field::new("data", DataType::Utf8, false), // Original JSON data + Field::new("_data.name", DataType::Utf8, false), // Materialized name field + ]); + + let batch = create_sample_record_batch(&file_schema); + + (Arc::new(table_schema), batch) +} + +/// Create the actual RecordBatch with sample data +fn create_sample_record_batch(file_schema: &Schema) -> RecordBatch { + // Build a RecordBatch with flat columns + let data_array = StringArray::from(vec![ + r#"{"age": 30}"#, + r#"{"age": 25}"#, + r#"{"age": 35}"#, + r#"{"age": 22}"#, + ]); + let names_array = StringArray::from(vec!["Alice", "Bob", "Charlie", "Dave"]); + + RecordBatch::try_new( + Arc::new(file_schema.clone()), + vec![Arc::new(data_array), Arc::new(names_array)], + ) + .unwrap() +} + +/// Custom TableProvider that uses a StructFieldRewriter +#[derive(Debug)] +struct ExampleTableProvider { + schema: SchemaRef, +} + +impl ExampleTableProvider { + fn new(schema: SchemaRef) -> Self { + Self { schema } + } +} + +#[async_trait] +impl TableProvider for ExampleTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + // Implementers can choose to mark these filters as exact or inexact. + // If marked as exact they cannot have false positives and must always be applied. + // If marked as Inexact they can have false positives and at runtime the rewriter + // can decide to not rewrite / ignore some filters since they will be re-evaluated upstream. + // For the purposes of this example we mark them as Exact to demonstrate the rewriter is working and the filtering is not being re-evaluated upstream. + Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let schema = self.schema.clone(); + let df_schema = DFSchema::try_from(schema.clone())?; + let filter = state.create_physical_expr( + conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)), + &df_schema, + )?; + + let parquet_source = ParquetSource::default() + .with_predicate(filter) + .with_pushdown_filters(true) + // if the rewriter needs a reference to the table schema you can bind self.schema() here + .with_predicate_rewrite_hook(Arc::new(ShreddedJsonRewriter) as _); + + let object_store_url = ObjectStoreUrl::parse("memory://")?; + + let store = state.runtime_env().object_store(object_store_url)?; + + let mut files = vec![]; + let mut listing = store.list(None); + while let Some(file) = listing.next().await { + if let Ok(file) = file { + files.push(file); + } + } + + let file_group = files + .iter() + .map(|file| PartitionedFile::new(file.location.clone(), file.size)) + .collect(); + + let file_scan_config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("memory://")?, + schema, + Arc::new(parquet_source), + ) + .with_projection(projection.cloned()) + .with_limit(limit) + .with_file_group(file_group); + + Ok(Arc::new(DataSourceExec::new(Arc::new( + file_scan_config.build(), + )))) + } +} + +/// Scalar UDF that uses serde_json to access json fields +#[derive(Debug)] +pub struct JsonGetStr { + signature: Signature, + aliases: [String; 1], +} + +impl Default for JsonGetStr { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: ["json_get_str".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGetStr { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert!( + args.args.len() == 2, + "json_get_str requires exactly 2 arguments" + ); + let key = match &args.args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(key))) => key, + _ => { + return Err(datafusion::error::DataFusionError::Execution( + "json_get_str first argument must be a string".to_string(), + )) + } + }; + // We expect a string array that contains JSON strings + let json_array = match &args.args[1] { + ColumnarValue::Array(array) => array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + datafusion::error::DataFusionError::Execution( + "json_get_str second argument must be a string array".to_string(), + ) + })?, + _ => { + return Err(datafusion::error::DataFusionError::Execution( + "json_get_str second argument must be a string array".to_string(), + )) + } + }; + let values = json_array + .iter() + .map(|value| { + value.and_then(|v| { + let json_value: serde_json::Value = + serde_json::from_str(v).unwrap_or_default(); + json_value.get(key).map(|v| v.to_string()) + }) + }) + .collect::(); + Ok(ColumnarValue::Array(Arc::new(values))) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Rewriter that converts json_get_str calls to direct flat column references +#[derive(Debug)] +struct ShreddedJsonRewriter; + +impl PhysicalSchemaExprRewriter for ShreddedJsonRewriter { + fn rewrite( + &self, + expr: Arc, + physical_file_schema: &Schema, + ) -> Result> { + expr.transform(|expr| self.rewrite_impl(expr, physical_file_schema)) + .data() + } +} + +impl ShreddedJsonRewriter { + fn rewrite_impl( + &self, + expr: Arc, + physical_file_schema: &Schema, + ) -> Result>> { + if let Some(func) = expr.as_any().downcast_ref::() { + if func.name() == "json_get_str" && func.args().len() == 2 { + // Get the key from the first argument + if let Some(literal) = func.args()[0] + .as_any() + .downcast_ref::() + { + if let ScalarValue::Utf8(Some(field_name)) = literal.value() { + // Get the column from the second argument + if let Some(column) = func.args()[1] + .as_any() + .downcast_ref::() + { + let column_name = column.name(); + // Check if there's a flat column with underscore prefix + let flat_column_name = format!("_{column_name}.{field_name}"); + + if let Ok(flat_field_index) = + physical_file_schema.index_of(&flat_column_name) + { + let flat_field = + physical_file_schema.field(flat_field_index); + + if flat_field.data_type() == &DataType::Utf8 { + // Replace the whole expression with a direct column reference + let new_expr = Arc::new(expressions::Column::new( + &flat_column_name, + flat_field_index, + )) + as Arc; + + return Ok(Transformed { + data: new_expr, + tnr: TreeNodeRecursion::Stop, + transformed: true, + }); + } + } + } + } + } + } + } + Ok(Transformed::no(expr)) + } +} diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index b9144f960d7bf..64e8b4fbc1cbf 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -798,18 +798,6 @@ mod schema_rewriter_tests { } } - /// Test that custom rewrite hooks work with ParquetSource - #[test] - fn test_parquet_source_with_custom_rewrite_hook() { - let hook = Arc::new(TestRewriteHook); - let _parquet_source = - ParquetSource::default().with_predicate_rewrite_hook(hook.clone()); - - // Test that the hook can be configured (we can't easily verify it's stored without accessing private fields) - // This test ensures the API works correctly - assert!(true); // Simple test to verify compilation - } - /// Test that rewriter handles expressions with columns that exist only in physical schema #[test] fn test_schema_rewriter_physical_only_columns() { From 438f0b9bdd9d36bc90ab8f869d885582084f5839 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 10 Jul 2025 00:07:21 -0500 Subject: [PATCH 11/21] move tests --- .../core/tests/parquet/filter_pushdown.rs | 239 ------------------ 1 file changed, 239 deletions(-) diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index 64e8b4fbc1cbf..b8d570916c7c5 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -601,242 +601,3 @@ fn get_value(metrics: &MetricsSet, metric_name: &str) -> usize { } } } - -#[cfg(test)] -mod schema_rewriter_tests { - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion::datasource::physical_plan::ParquetSource; - use datafusion::physical_expr::expressions::{col, lit}; - use datafusion::physical_expr::schema_rewriter::{ - PhysicalExprSchemaRewriter, PhysicalSchemaExprRewriter, - }; - use datafusion::physical_expr::PhysicalExpr; - use datafusion::scalar::ScalarValue; - use std::sync::Arc; - - /// Test basic functionality of PhysicalExprSchemaRewriter - #[test] - fn test_schema_rewriter_basic() { - let physical_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, true), - ]); - - let logical_schema = Schema::new(vec![ - Field::new("a", DataType::Int64, false), // Type mismatch - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Float64, true), // Missing in physical - ]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - - // Test column with type cast - let column_a = col("a", &physical_schema).unwrap(); - let result = rewriter.rewrite(column_a).unwrap(); - assert!(result.to_string().contains("CAST")); - - // Test column with no changes needed - let column_b = col("b", &physical_schema).unwrap(); - let result = rewriter.rewrite(column_b).unwrap(); - assert_eq!(result.to_string(), "b@1"); - - // Test missing column (should be replaced with null) - let column_c = - Arc::new(datafusion::physical_expr::expressions::Column::new("c", 2)); - let result = rewriter.rewrite(column_c).unwrap(); - assert!(result.to_string().contains("NULL")); - } - - /// Test edge case: non-nullable missing column should error - #[test] - fn test_schema_rewriter_non_nullable_missing_error() { - let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let logical_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), // Non-nullable missing - ]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - let column_b = - Arc::new(datafusion::physical_expr::expressions::Column::new("b", 1)); - let result = rewriter.rewrite(column_b); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Non-nullable column 'b' is missing")); - } - - /// Test partition columns functionality - #[test] - fn test_schema_rewriter_partition_columns() { - let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let logical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let partition_fields = - vec![Arc::new(Field::new("partition_col", DataType::Utf8, false))]; - let partition_values = vec![ScalarValue::Utf8(Some("test_value".to_string()))]; - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema) - .with_partition_columns(partition_fields, partition_values); - - let partition_column = Arc::new( - datafusion::physical_expr::expressions::Column::new("partition_col", 0), - ); - let result = rewriter.rewrite(partition_column).unwrap(); - assert!(result.to_string().contains("test_value")); - } - - /// Test complex expressions with multiple columns - #[test] - fn test_schema_rewriter_complex_expressions() { - let physical_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, true), - ]); - - let logical_schema = Schema::new(vec![ - Field::new("a", DataType::Int64, false), // Type mismatch - Field::new("b", DataType::Utf8, true), - Field::new("c", DataType::Float64, true), // Missing - ]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - - // Create a complex expression: a + 10 > 5 - let column_a = col("a", &physical_schema).unwrap(); - let literal_10 = lit(ScalarValue::Int64(Some(10))); - let literal_5 = lit(ScalarValue::Int64(Some(5))); - - let add_expr = Arc::new(datafusion::physical_expr::expressions::BinaryExpr::new( - column_a, - datafusion_expr::Operator::Plus, - literal_10, - )); - - let gt_expr = Arc::new(datafusion::physical_expr::expressions::BinaryExpr::new( - add_expr, - datafusion_expr::Operator::Gt, - literal_5, - )); - - let result = rewriter.rewrite(gt_expr).unwrap(); - let result_str = result.to_string(); - // Should contain cast for column 'a' - assert!(result_str.contains("CAST")); - // Should still have the same logical structure - assert!(result_str.contains("+ 10")); - assert!(result_str.contains("> 5")); - } - - /// Test that invalid casts are properly handled - #[test] - fn test_schema_rewriter_invalid_cast_error() { - let physical_schema = Schema::new(vec![Field::new("a", DataType::Binary, false)]); - - // Try to cast binary to struct (should fail) - let logical_schema = Schema::new(vec![Field::new( - "a", - DataType::Struct(vec![Field::new("field", DataType::Int32, true)].into()), - false, - )]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - let column_a = col("a", &physical_schema).unwrap(); - let result = rewriter.rewrite(column_a); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("Cannot cast")); - } - - /// Test that column indexes are properly handled - #[test] - fn test_schema_rewriter_column_indexes() { - let physical_schema = Schema::new(vec![ - Field::new("x", DataType::Int32, false), - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, true), - ]); - - let logical_schema = Schema::new(vec![ - Field::new("a", DataType::Int64, false), // Different index and type - Field::new("b", DataType::Utf8, true), - ]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - - // Create column with wrong index (should be corrected) - let column_a_wrong_idx = - Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)); - let result = rewriter.rewrite(column_a_wrong_idx).unwrap(); - let result_str = result.to_string(); - // Should contain cast and correct index - assert!(result_str.contains("CAST")); - assert!(result_str.contains("a@1")); // Should be index 1 in physical schema - } - - /// Mock rewrite hook for testing custom rewriter integration - #[derive(Debug)] - struct TestRewriteHook; - - impl PhysicalSchemaExprRewriter for TestRewriteHook { - fn rewrite( - &self, - expr: Arc, - _physical_file_schema: &Schema, - ) -> datafusion::common::Result> { - // Simple hook that adds 1 to any integer literal - if let Some(literal) = - expr.as_any() - .downcast_ref::() - { - if let ScalarValue::Int32(Some(val)) = literal.value() { - return Ok(lit(ScalarValue::Int32(Some(val + 1)))); - } - } - Ok(expr) - } - } - - /// Test that rewriter handles expressions with columns that exist only in physical schema - #[test] - fn test_schema_rewriter_physical_only_columns() { - let physical_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("_computed_col", DataType::Float64, false), // Only in physical - ]); - - let logical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - - // Reference a column that only exists in physical schema - let computed_col = Arc::new(datafusion::physical_expr::expressions::Column::new( - "_computed_col", - 1, - )); - let result = rewriter.rewrite(computed_col); - - // Should succeed and use the physical column - assert!(result.is_ok()); - assert_eq!(result.unwrap().to_string(), "_computed_col@1"); - } - - /// Test that rewriter handles completely unknown columns properly - #[test] - fn test_schema_rewriter_unknown_column_error() { - let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let logical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - - // Reference a column that doesn't exist in either schema - let unknown_col = Arc::new(datafusion::physical_expr::expressions::Column::new( - "unknown", 99, - )); - let result = rewriter.rewrite(unknown_col); - - // Should return an error - assert!(result.is_err()); - } -} From 280fab1abeae30086f147001a6234a38c22c5178 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 10 Jul 2025 14:53:34 -0500 Subject: [PATCH 12/21] add example, refactor --- .../examples/default_column_values.rs | 360 ++++++++++++++++++ .../examples/json_shredding.rs | 37 +- datafusion/datasource-parquet/src/opener.rs | 70 ++-- datafusion/datasource-parquet/src/source.rs | 25 +- datafusion/physical-expr/src/lib.rs | 2 +- .../physical-expr/src/schema_rewriter.rs | 201 +++++++--- 6 files changed, 576 insertions(+), 119 deletions(-) create mode 100644 datafusion-examples/examples/default_column_values.rs diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs new file mode 100644 index 0000000000000..a9507c99fa72e --- /dev/null +++ b/datafusion-examples/examples/default_column_values.rs @@ -0,0 +1,360 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; +use async_trait::async_trait; + +use datafusion::assert_batches_eq; +use datafusion::catalog::memory::DataSourceExec; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::common::DFSchema; +use datafusion::common::{Result, ScalarValue}; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::logical_expr::utils::conjunction; +use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::file::properties::WriterProperties; +use datafusion::physical_expr::expressions::{CastExpr, Column, Literal}; +use datafusion::physical_expr::schema_rewriter::{ + DefaultPhysicalExprAdapter, PhysicalExprAdapter, +}; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::{lit, SessionConfig}; +use futures::StreamExt; +use object_store::memory::InMemory; +use object_store::path::Path; +use object_store::{ObjectStore, PutPayload}; + +// Metadata key for storing default values in field metadata +const DEFAULT_VALUE_METADATA_KEY: &str = "example.default_value"; + +// Example showing how to implement custom default value handling for missing columns +// using field metadata and PhysicalExprAdapter. +// +// This example demonstrates how to: +// 1. Store default values in field metadata using a constant key +// 2. Create a custom PhysicalExprAdapter that reads these defaults +// 3. Inject default values for missing columns in filter predicates +// 4. Use the DefaultPhysicalExprAdapter as a fallback for standard schema adaptation +// 5. Wrap string default values in cast expressions for proper type conversion +// +// Important: PhysicalExprAdapter is specifically designed for rewriting filter predicates +// that get pushed down to file scans. For handling missing columns in projections, +// other mechanisms in DataFusion are used (like SchemaAdapter). +// +// The metadata-based approach provides a flexible way to store default values as strings +// and cast them to the appropriate types at query time. + +#[tokio::main] +async fn main() -> Result<()> { + println!("=== Creating example data with missing columns and default values ==="); + + // Create sample data where the logical schema has more columns than the physical schema + let (logical_schema, physical_schema, batch) = create_sample_data_with_defaults(); + + let store = InMemory::new(); + let buf = { + let mut buf = vec![]; + + let props = WriterProperties::builder() + .set_max_row_group_size(2) + .build(); + + let mut writer = + ArrowWriter::try_new(&mut buf, physical_schema.clone(), Some(props)) + .expect("creating writer"); + + writer.write(&batch).expect("Writing batch"); + writer.close().unwrap(); + buf + }; + let path = Path::from("example.parquet"); + let payload = PutPayload::from_bytes(buf.into()); + store.put(&path, payload).await?; + + // Create a custom table provider that handles missing columns with defaults + let table_provider = Arc::new(DefaultValueTableProvider::new(logical_schema)); + + // Set up query execution + let mut cfg = SessionConfig::new(); + cfg.options_mut().execution.parquet.pushdown_filters = true; + let ctx = SessionContext::new_with_config(cfg); + + // Register our table + ctx.register_table("example_table", table_provider)?; + + ctx.runtime_env().register_object_store( + ObjectStoreUrl::parse("memory://")?.as_ref(), + Arc::new(store), + ); + + println!("\n=== Demonstrating default value injection in filter predicates ==="); + let query = "SELECT id, name FROM example_table WHERE status = 'active' ORDER BY id"; + println!("Query: {query}"); + println!("Note: The 'status' column doesn't exist in the physical schema,"); + println!( + "but our adapter injects the default value 'active' for the filter predicate." + ); + + let batches = ctx.sql(query).await?.collect().await?; + + #[rustfmt::skip] + let expected = [ + "+----+-------+", + "| id | name |", + "+----+-------+", + "| 1 | Alice |", + "| 2 | Bob |", + "| 3 | Carol |", + "+----+-------+", + ]; + arrow::util::pretty::print_batches(&batches)?; + assert_batches_eq!(expected, &batches); + + println!("\n=== Key Insight ==="); + println!("This example demonstrates how PhysicalExprAdapter works:"); + println!("1. Physical schema only has 'id' and 'name' columns"); + println!("2. Logical schema has 'id', 'name', 'status', and 'priority' columns with defaults"); + println!("3. Our custom adapter intercepts filter expressions on missing columns"); + println!("4. Default values from metadata are injected as cast expressions"); + println!("5. The DefaultPhysicalExprAdapter handles other schema adaptations"); + println!("\nNote: PhysicalExprAdapter is specifically for filter predicates."); + println!("For projection columns, different mechanisms handle missing columns."); + + Ok(()) +} + +/// Create sample data with a logical schema that has default values in metadata +/// and a physical schema that's missing some columns +fn create_sample_data_with_defaults() -> (SchemaRef, SchemaRef, RecordBatch) { + // Create metadata for default values + let mut status_metadata = HashMap::new(); + status_metadata.insert(DEFAULT_VALUE_METADATA_KEY.to_string(), "active".to_string()); + + let mut priority_metadata = HashMap::new(); + priority_metadata.insert(DEFAULT_VALUE_METADATA_KEY.to_string(), "1".to_string()); + + // The logical schema includes all columns with their default values in metadata + // Note: We make the columns with defaults nullable to allow the default adapter to handle them + let logical_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("status", DataType::Utf8, true).with_metadata(status_metadata), + Field::new("priority", DataType::Int32, true).with_metadata(priority_metadata), + ]); + + // The physical schema only has some columns (simulating missing columns in storage) + let physical_schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + ]); + + // Create sample data for the physical schema + let batch = RecordBatch::try_new( + Arc::new(physical_schema.clone()), + vec![ + Arc::new(arrow::array::Int32Array::from(vec![1, 2, 3])), + Arc::new(arrow::array::StringArray::from(vec![ + "Alice", "Bob", "Carol", + ])), + ], + ) + .unwrap(); + + (Arc::new(logical_schema), Arc::new(physical_schema), batch) +} + +/// Custom TableProvider that uses DefaultValuePhysicalExprAdapter +#[derive(Debug)] +struct DefaultValueTableProvider { + schema: SchemaRef, +} + +impl DefaultValueTableProvider { + fn new(schema: SchemaRef) -> Self { + Self { schema } + } +} + +#[async_trait] +impl TableProvider for DefaultValueTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let schema = self.schema.clone(); + let df_schema = DFSchema::try_from(schema.clone())?; + let filter = state.create_physical_expr( + conjunction(filters.iter().cloned()).unwrap_or_else(|| lit(true)), + &df_schema, + )?; + + let parquet_source = ParquetSource::default() + .with_predicate(filter) + .with_pushdown_filters(true) + .with_expr_adapter(Arc::new(DefaultValuePhysicalExprAdapter { + default_adapter: DefaultPhysicalExprAdapter, + }) as _); + + let object_store_url = ObjectStoreUrl::parse("memory://")?; + let store = state.runtime_env().object_store(object_store_url)?; + + let mut files = vec![]; + let mut listing = store.list(None); + while let Some(file) = listing.next().await { + if let Ok(file) = file { + files.push(file); + } + } + + let file_group = files + .iter() + .map(|file| PartitionedFile::new(file.location.clone(), file.size)) + .collect(); + + let file_scan_config = FileScanConfigBuilder::new( + ObjectStoreUrl::parse("memory://")?, + self.schema.clone(), + Arc::new(parquet_source), + ) + .with_projection(projection.cloned()) + .with_limit(limit) + .with_file_group(file_group); + + Ok(Arc::new(DataSourceExec::new(Arc::new( + file_scan_config.build(), + )))) + } +} + +/// Custom PhysicalExprAdapter that handles missing columns with default values from metadata +/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation +#[derive(Debug)] +struct DefaultValuePhysicalExprAdapter { + default_adapter: DefaultPhysicalExprAdapter, +} + +impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { + fn rewrite_to_file_schema( + &self, + expr: Arc, + logical_file_schema: &Schema, + physical_file_schema: &Schema, + partition_values: &[(FieldRef, ScalarValue)], + ) -> Result> { + // First try our custom default value injection for missing columns + let rewritten = expr.transform(|expr| { + self.inject_default_values(expr, logical_file_schema, physical_file_schema) + })?; + + // Then apply the default adapter as a fallback to handle standard schema differences + // like type casting, partition column handling, etc. + self.default_adapter.rewrite_to_file_schema( + rewritten.data, + logical_file_schema, + physical_file_schema, + partition_values, + ) + } +} + +impl DefaultValuePhysicalExprAdapter { + fn inject_default_values( + &self, + expr: Arc, + logical_file_schema: &Schema, + physical_file_schema: &Schema, + ) -> Result>> { + if let Some(column) = expr.as_any().downcast_ref::() { + let column_name = column.name(); + + // Check if this column exists in the physical schema + if physical_file_schema.index_of(column_name).is_err() { + // Column is missing from physical schema, check if logical schema has a default + if let Ok(logical_field) = + logical_file_schema.field_with_name(column_name) + { + if let Some(default_value_str) = + logical_field.metadata().get(DEFAULT_VALUE_METADATA_KEY) + { + // Create a string literal and wrap it in a cast expression + let default_literal = self.create_default_value_expr( + default_value_str, + logical_field.data_type(), + )?; + return Ok(Transformed::yes(default_literal)); + } + } + } + } + + // No transformation needed + Ok(Transformed::no(expr)) + } + + fn create_default_value_expr( + &self, + value_str: &str, + data_type: &DataType, + ) -> Result> { + // Create a string literal with the default value + let string_literal = + Arc::new(Literal::new(ScalarValue::Utf8(Some(value_str.to_string())))); + + // If the target type is already Utf8, return the string literal directly + if matches!(data_type, DataType::Utf8) { + return Ok(string_literal); + } + + // Otherwise, wrap the string literal in a cast expression + let cast_expr = Arc::new(CastExpr::new(string_literal, data_type.clone(), None)); + + Ok(cast_expr) + } +} diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index 9f3af66cbbc1c..771239e0994a6 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{RecordBatch, StringArray}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use async_trait::async_trait; use datafusion::assert_batches_eq; @@ -40,7 +40,9 @@ use datafusion::logical_expr::{ }; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; -use datafusion::physical_expr::schema_rewriter::PhysicalSchemaExprRewriter; +use datafusion::physical_expr::schema_rewriter::{ + DefaultPhysicalExprAdapter, PhysicalExprAdapter, +}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; use datafusion::physical_plan::ExecutionPlan; @@ -61,7 +63,7 @@ use object_store::{ObjectStore, PutPayload}; // - Original JSON data: data: '{"age": 30}' // - Shredded flat columns: _data.name: "Alice" (extracted from JSON) // -// Our custom TableProvider uses a PhysicalSchemaExprRewriter to rewrite +// Our custom TableProvider uses a PhysicalExprAdapter to rewrite // expressions like `json_get_str('name', data)` to use the pre-computed // flat column `_data.name` when available. This allows the query engine to: // 1. Push down predicates for better filtering @@ -245,7 +247,9 @@ impl TableProvider for ExampleTableProvider { .with_predicate(filter) .with_pushdown_filters(true) // if the rewriter needs a reference to the table schema you can bind self.schema() here - .with_predicate_rewrite_hook(Arc::new(ShreddedJsonRewriter) as _); + .with_expr_adapter(Arc::new(ShreddedJsonRewriter { + default_adapter: DefaultPhysicalExprAdapter, + }) as _); let object_store_url = ObjectStoreUrl::parse("memory://")?; @@ -360,17 +364,32 @@ impl ScalarUDFImpl for JsonGetStr { } /// Rewriter that converts json_get_str calls to direct flat column references +/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation #[derive(Debug)] -struct ShreddedJsonRewriter; +struct ShreddedJsonRewriter { + default_adapter: DefaultPhysicalExprAdapter, +} -impl PhysicalSchemaExprRewriter for ShreddedJsonRewriter { - fn rewrite( +impl PhysicalExprAdapter for ShreddedJsonRewriter { + fn rewrite_to_file_schema( &self, expr: Arc, + logical_file_schema: &Schema, physical_file_schema: &Schema, + partition_values: &[(FieldRef, ScalarValue)], ) -> Result> { - expr.transform(|expr| self.rewrite_impl(expr, physical_file_schema)) - .data() + // First try our custom JSON shredding rewrite + let rewritten = + expr.transform(|expr| self.rewrite_impl(expr, physical_file_schema))?; + + // Then apply the default adapter as a fallback to handle standard schema differences + // like type casting, missing columns, and partition column handling + self.default_adapter.rewrite_to_file_schema( + rewritten, + logical_file_schema, + physical_file_schema, + partition_values, + ) } } diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 4f4e1c8a824c2..5389c966ad7db 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -33,9 +33,8 @@ use arrow::datatypes::{FieldRef, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_datasource::PartitionedFile; -use datafusion_physical_expr::schema_rewriter::PhysicalSchemaExprRewriter; +use datafusion_physical_expr::schema_rewriter::PhysicalExprAdapter; use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; -use datafusion_physical_expr::PhysicalExprSchemaRewriter; use datafusion_physical_expr_common::physical_expr::{ is_dynamic_physical_expr, PhysicalExpr, }; @@ -43,6 +42,7 @@ use datafusion_physical_plan::metrics::{Count, ExecutionPlanMetricsSet, MetricBu use datafusion_pruning::{build_pruning_predicate, FilePruner, PruningPredicate}; use futures::{StreamExt, TryStreamExt}; +use itertools::Itertools; use log::debug; use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; use parquet::arrow::async_reader::AsyncFileReader; @@ -94,7 +94,7 @@ pub(super) struct ParquetOpener { /// Optional parquet FileDecryptionProperties pub file_decryption_properties: Option>, /// Rewrite expressions in the context of the file schema - pub predicate_rewrite_hook: Option>, + pub expr_adapter: Arc, } impl FileOpener for ParquetOpener { @@ -135,7 +135,7 @@ impl FileOpener for ParquetOpener { let predicate_creation_errors = MetricBuilder::new(&self.metrics) .global_counter("num_predicate_creation_errors"); - let predicate_rewrite_hook = self.predicate_rewrite_hook.clone(); + let expr_adapter = self.expr_adapter.clone(); let mut enable_page_index = self.enable_page_index; let file_decryption_properties = self.file_decryption_properties.clone(); @@ -238,38 +238,31 @@ impl FileOpener for ParquetOpener { } } - predicate = predicate - .map(|p| { - if let Some(predicate_rewrite_hook) = predicate_rewrite_hook.as_ref() - { - predicate_rewrite_hook - .rewrite(Arc::clone(&p), &physical_file_schema) - } else { - Ok(p) - } - }) - .transpose()?; - // Adapt the predicate to the physical file schema. // This evaluates missing columns and inserts any necessary casts. predicate = predicate .map(|p| { - let rewriter = PhysicalExprSchemaRewriter::new( - &physical_file_schema, - &logical_file_schema, - ) - .with_partition_columns( - partition_fields.to_vec(), - file.partition_values, - ); - rewriter.rewrite(p).map_err(ArrowError::from).map(|p| { - // After rewriting to the file schema, further simplifications may be possible. - // For example, if `'a' = col_that_is_missing` becomes `'a' = NULL` that can then be simplified to `FALSE` - // and we can avoid doing any more work on the file (bloom filters, loading the page index, etc.). - PhysicalExprSimplifier::new(&physical_file_schema) - .simplify(p) - .map_err(ArrowError::from) - }) + let partition_values = partition_fields + .iter() + .cloned() + .zip(file.partition_values) + .collect_vec(); + expr_adapter + .rewrite_to_file_schema( + p, + &logical_file_schema, + &physical_file_schema, + &partition_values, + ) + .map_err(ArrowError::from) + .map(|p| { + // After rewriting to the file schema, further simplifications may be possible. + // For example, if `'a' = col_that_is_missing` becomes `'a' = NULL` that can then be simplified to `FALSE` + // and we can avoid doing any more work on the file (bloom filters, loading the page index, etc.). + PhysicalExprSimplifier::new(&physical_file_schema) + .simplify(p) + .map_err(ArrowError::from) + }) }) .transpose()? .transpose()?; @@ -540,7 +533,8 @@ mod test { }; use datafusion_expr::{col, lit}; use datafusion_physical_expr::{ - expressions::DynamicFilterPhysicalExpr, planner::logical2physical, PhysicalExpr, + expressions::DynamicFilterPhysicalExpr, planner::logical2physical, + DefaultPhysicalExprAdapter, PhysicalExpr, }; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use futures::{Stream, StreamExt}; @@ -646,7 +640,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, - predicate_rewrite_hook: None, + expr_adapter: Arc::new(DefaultPhysicalExprAdapter), } }; @@ -732,7 +726,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, - predicate_rewrite_hook: None, + expr_adapter: Arc::new(DefaultPhysicalExprAdapter), } }; @@ -834,7 +828,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, - predicate_rewrite_hook: None, + expr_adapter: Arc::new(DefaultPhysicalExprAdapter), } }; let make_meta = || FileMeta { @@ -946,7 +940,7 @@ mod test { enable_row_group_stats_pruning: false, // note that this is false! coerce_int96: None, file_decryption_properties: None, - predicate_rewrite_hook: None, + expr_adapter: Arc::new(DefaultPhysicalExprAdapter), } }; @@ -1059,7 +1053,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, - predicate_rewrite_hook: None, + expr_adapter: Arc::new(DefaultPhysicalExprAdapter), } }; diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 3150da7fb6160..c9e6b9105c630 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -39,7 +39,8 @@ use datafusion_common::{DataFusionError, Statistics}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_physical_expr::conjunction; -use datafusion_physical_expr::schema_rewriter::PhysicalSchemaExprRewriter; +use datafusion_physical_expr::schema_rewriter::PhysicalExprAdapter; +use datafusion_physical_expr::DefaultPhysicalExprAdapter; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::filter_pushdown::{ @@ -279,7 +280,7 @@ pub struct ParquetSource { /// Optional hint for the size of the parquet metadata pub(crate) metadata_size_hint: Option, pub(crate) projected_statistics: Option, - pub(crate) predicate_rewrite_hook: Option>, + pub(crate) expr_adapter: Option>, } impl ParquetSource { @@ -318,14 +319,17 @@ impl ParquetSource { conf } - /// Register a predicate rewrite hook to transform predicates in the context of each file's physical file schema. - /// This can be used to optimize predicates to take advantage of shredded variant columns or pre-computed expressions - /// that vary on a per-file basis. - pub fn with_predicate_rewrite_hook( + /// Register an expression adapter used to adapt filters and projections that are pushed down into the scan + /// from the logical schema to the physical schema of the parquet file. + /// This can include things like: + /// - Column ordering changes + /// - Handling of missing columns + /// - Rewriting expression to use pre-computed values or file format specific optimizations + pub fn with_expr_adapter( mut self, - predicate_rewrite_hook: Arc, + expr_adapter: Arc, ) -> Self { - self.predicate_rewrite_hook = Some(predicate_rewrite_hook); + self.expr_adapter = Some(expr_adapter); self } @@ -522,7 +526,10 @@ impl FileSource for ParquetSource { schema_adapter_factory, coerce_int96, file_decryption_properties, - predicate_rewrite_hook: self.predicate_rewrite_hook.clone(), + expr_adapter: self + .expr_adapter + .clone() + .unwrap_or(Arc::new(DefaultPhysicalExprAdapter)), }) } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 03fc77f156d95..845c358d7e58b 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -70,7 +70,7 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; -pub use schema_rewriter::PhysicalExprSchemaRewriter; +pub use schema_rewriter::DefaultPhysicalExprAdapter; pub use utils::{conjunction, conjunction_opt, split_conjunction}; // For backwards compatibility diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index e5a08ccbc7ad9..4fa8550190bd0 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -29,14 +29,71 @@ use datafusion_common::{ use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use crate::expressions::{self, CastExpr, Column}; -pub trait PhysicalSchemaExprRewriter: Send + Sync + std::fmt::Debug { - /// Rewrite a physical expression to match the target schema + +/// Trait for adapting physical expressions to match a target schema. +/// +/// This is used in file scans to rewrite expressions so that they can be evaluated +/// against the physical schema of the file being scanned. It allows for handling +/// differences between logical and physical schemas, such as type mismatches or missing columns. +/// +/// You can create a custom implemention of this trait to handle specific rewriting logic. +/// For example, to fill in missing columns with default values instead of nulls: +/// +/// ```rust +/// use datafusion_physical_expr::schema_rewriter::PhysicalExprAdapter; +/// use arrow::datatypes::{Schema, Field, DataType, FieldRef}; +/// use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +/// use datafusion_common::{Result, ScalarValue, tree_node::{Transformed, TreeNode}}; +/// use datafusion_physical_expr::expressions::{self, Column}; +/// use std::sync::Arc; +/// +/// #[derive(Debug)] +/// pub struct CustomPhysicalExprAdapter; +/// +/// impl PhysicalExprAdapter for CustomPhysicalExprAdapter { +/// fn rewrite( +/// &self, +/// expr: Arc, +/// logical_file_schema: &Schema, +/// physical_file_schema: &Schema, +/// partition_values: &[(FieldRef, ScalarValue)], +/// ) -> Result> { +/// expr.transform(|expr| { +/// if let Some(column) = expr.as_any().downcast_ref::() { +/// // Check if the column exists in the physical schema +/// if physical_file_schema.index_of(column.name()).is_err() { +/// // If the column is missing, fill it with a default value instead of null +/// // The default value could be stored in the table schema's column metadata for example. +/// let default_value = ScalarValue::Int32(Some(0)); +/// return Ok(Transformed::yes(expressions::lit(default_value))); +/// } +/// } +/// // If the column exists, return it as is +/// Ok(Transformed::no(expr)) +/// }) +/// } +/// } +/// ``` +pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug { + /// Rewrite a physical expression to match the target schema. /// /// This method should return a transformed expression that matches the target schema. - fn rewrite( + /// + /// Arguments: + /// - `expr`: The physical expression to rewrite. + /// - `logical_file_schema`: The logical schema of the table being queried, excluding any partition columns. + /// - `physical_file_schema`: The physical schema of the file being scanned. + /// - `partition_values`: Optional partition values to use for rewriting partition column references. + /// These are handled as if they were columns appended onto the logical file schema. + /// + /// Returns: + /// - `Arc`: The rewritten physical expression that can be evaluated against the physical schema. + fn rewrite_to_file_schema( &self, expr: Arc, + logical_file_schema: &Schema, physical_file_schema: &Schema, + partition_values: &[(FieldRef, ScalarValue)], ) -> Result>; } @@ -45,7 +102,7 @@ pub trait PhysicalSchemaExprRewriter: Send + Sync + std::fmt::Debug { /// # Example /// /// ```rust -/// use datafusion_physical_expr::schema_rewriter::PhysicalExprSchemaRewriter; +/// use datafusion_physical_expr::schema_rewriter::DefaultPhysicalExprAdapter; /// use arrow::datatypes::Schema; /// /// # fn example( @@ -53,57 +110,45 @@ pub trait PhysicalSchemaExprRewriter: Send + Sync + std::fmt::Debug { /// # physical_file_schema: &Schema, /// # logical_file_schema: &Schema, /// # ) -> datafusion_common::Result<()> { -/// let rewriter = PhysicalExprSchemaRewriter::new(physical_file_schema, logical_file_schema); -/// let adapted_predicate = rewriter.rewrite(predicate)?; +/// let rewriter = DefaultPhysicalExprAdapter; +/// let adapted_predicate = rewriter.rewrite(predicate, logical_file_schema, physical_file_schema, &[])?; /// # Ok(()) /// # } /// ``` -pub struct PhysicalExprSchemaRewriter<'a> { - physical_file_schema: &'a Schema, - logical_file_schema: &'a Schema, - partition_fields: Vec, - partition_values: Vec, -} - -impl<'a> PhysicalExprSchemaRewriter<'a> { - /// Create a new schema rewriter with the given schemas - pub fn new( - physical_file_schema: &'a Schema, - logical_file_schema: &'a Schema, - ) -> Self { - Self { - physical_file_schema, - logical_file_schema, - partition_fields: Vec::new(), - partition_values: Vec::new(), - } - } - - /// Add partition columns and their corresponding values - /// - /// When a column reference matches a partition field, it will be replaced - /// with the corresponding literal value from partition_values. - pub fn with_partition_columns( - mut self, - partition_fields: Vec, - partition_values: Vec, - ) -> Self { - self.partition_fields = partition_fields; - self.partition_values = partition_values; - self - } +#[derive(Debug, Clone)] +pub struct DefaultPhysicalExprAdapter; +impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { /// Rewrite the given physical expression to match the target schema /// /// This method applies the following transformations: /// 1. Replaces partition column references with literal values /// 2. Handles missing columns by inserting null literals /// 3. Casts columns when logical and physical schemas have different types - pub fn rewrite(&self, expr: Arc) -> Result> { - expr.transform(|expr| self.rewrite_expr(Arc::clone(&expr))) + fn rewrite_to_file_schema( + &self, + expr: Arc, + logical_file_schema: &Schema, + physical_file_schema: &Schema, + partition_values: &[(FieldRef, ScalarValue)], + ) -> Result> { + let rewriter = DefaultPhysicalExprAdapterRewriter { + logical_file_schema, + physical_file_schema, + partition_fields: partition_values, + }; + expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) .data() } +} + +struct DefaultPhysicalExprAdapterRewriter<'a> { + logical_file_schema: &'a Schema, + physical_file_schema: &'a Schema, + partition_fields: &'a [(FieldRef, ScalarValue)], +} +impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn rewrite_expr( &self, expr: Arc, @@ -213,7 +258,6 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { fn get_partition_value(&self, column_name: &str) -> Option { self.partition_fields .iter() - .zip(self.partition_values.iter()) .find(|(field, _)| field.name() == column_name) .map(|(_, value)| value.clone()) } @@ -252,10 +296,12 @@ mod tests { fn test_rewrite_column_with_type_cast() { let (physical_schema, logical_schema) = create_test_schema(); - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let rewriter = DefaultPhysicalExprAdapter; let column_expr = Arc::new(Column::new("a", 0)); - let result = rewriter.rewrite(column_expr).unwrap(); + let result = rewriter + .rewrite_to_file_schema(column_expr, &logical_schema, &physical_schema, &[]) + .unwrap(); // Should be wrapped in a cast expression assert!(result.as_any().downcast_ref::().is_some()); @@ -264,7 +310,7 @@ mod tests { #[test] fn test_rewrite_mulit_column_expr_with_type_cast() { let (physical_schema, logical_schema) = create_test_schema(); - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let rewriter = DefaultPhysicalExprAdapter; // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter let column_a = Arc::new(Column::new("a", 0)) as Arc; @@ -284,7 +330,14 @@ mod tests { )), ); - let result = rewriter.rewrite(Arc::new(expr)).unwrap(); + let result = rewriter + .rewrite_to_file_schema( + Arc::new(expr), + &logical_schema, + &physical_schema, + &[], + ) + .unwrap(); println!("Rewritten expression: {result}"); let expected = expressions::BinaryExpr::new( @@ -317,10 +370,15 @@ mod tests { fn test_rewrite_missing_column() -> Result<()> { let (physical_schema, logical_schema) = create_test_schema(); - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let rewriter = DefaultPhysicalExprAdapter; let column_expr = Arc::new(Column::new("c", 2)); - let result = rewriter.rewrite(column_expr)?; + let result = rewriter.rewrite_to_file_schema( + column_expr, + &logical_schema, + &physical_schema, + &[], + )?; // Should be replaced with a literal null if let Some(literal) = result.as_any().downcast_ref::() { @@ -336,15 +394,20 @@ mod tests { fn test_rewrite_partition_column() -> Result<()> { let (physical_schema, logical_schema) = create_test_schema(); - let partition_fields = - vec![Arc::new(Field::new("partition_col", DataType::Utf8, false))]; - let partition_values = vec![ScalarValue::Utf8(Some("test_value".to_string()))]; + let partition_field = + Arc::new(Field::new("partition_col", DataType::Utf8, false)); + let partition_value = ScalarValue::Utf8(Some("test_value".to_string())); + let partition_values = vec![(partition_field, partition_value)]; - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema) - .with_partition_columns(partition_fields, partition_values); + let rewriter = DefaultPhysicalExprAdapter; let column_expr = Arc::new(Column::new("partition_col", 0)); - let result = rewriter.rewrite(column_expr)?; + let result = rewriter.rewrite_to_file_schema( + column_expr, + &logical_schema, + &physical_schema, + &partition_values, + )?; // Should be replaced with the partition value if let Some(literal) = result.as_any().downcast_ref::() { @@ -363,10 +426,15 @@ mod tests { fn test_rewrite_no_change_needed() -> Result<()> { let (physical_schema, logical_schema) = create_test_schema(); - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let rewriter = DefaultPhysicalExprAdapter; let column_expr = Arc::new(Column::new("b", 1)) as Arc; - let result = rewriter.rewrite(Arc::clone(&column_expr))?; + let result = rewriter.rewrite_to_file_schema( + Arc::clone(&column_expr), + &logical_schema, + &physical_schema, + &[], + )?; // Should be the same expression (no transformation needed) // We compare the underlying pointer through the trait object @@ -386,10 +454,15 @@ mod tests { Field::new("b", DataType::Utf8, false), // Non-nullable missing column ]); - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let rewriter = DefaultPhysicalExprAdapter; let column_expr = Arc::new(Column::new("b", 1)); - let result = rewriter.rewrite(column_expr); + let result = rewriter.rewrite_to_file_schema( + column_expr, + &logical_schema, + &physical_schema, + &[], + ); assert!(result.is_err()); assert!(result .unwrap_err() @@ -421,7 +494,7 @@ mod tests { } } - /// Example showing how we can use the `PhysicalExprSchemaRewriter` to adapt RecordBatches during a scan + /// Example showing how we can use the `DefaultPhysicalExprAdapter` to adapt RecordBatches during a scan /// to apply projections, type conversions and handling of missing columns all at once. #[test] fn test_adapt_batches() { @@ -443,11 +516,15 @@ mod tests { col("a", &logical_schema).unwrap(), ]; - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let rewriter = DefaultPhysicalExprAdapter; let adapted_projection = projection .into_iter() - .map(|expr| rewriter.rewrite(expr).unwrap()) + .map(|expr| { + rewriter + .rewrite_to_file_schema(expr, &logical_schema, &physical_schema, &[]) + .unwrap() + }) .collect_vec(); let adapted_schema = Arc::new(Schema::new( From ff4505c5469f5eeb5696bd60b6182e39aa1b2018 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:12:36 -0500 Subject: [PATCH 13/21] Fix examples --- datafusion-examples/examples/default_column_values.rs | 6 +++--- datafusion-examples/examples/json_shredding.rs | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs index a9507c99fa72e..2a11a844f6491 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/default_column_values.rs @@ -26,7 +26,7 @@ use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion::common::DFSchema; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::PartitionedFile; @@ -291,12 +291,12 @@ impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { // First try our custom default value injection for missing columns let rewritten = expr.transform(|expr| { self.inject_default_values(expr, logical_file_schema, physical_file_schema) - })?; + }).data()?; // Then apply the default adapter as a fallback to handle standard schema differences // like type casting, partition column handling, etc. self.default_adapter.rewrite_to_file_schema( - rewritten.data, + rewritten, logical_file_schema, physical_file_schema, partition_values, diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index 771239e0994a6..bc042fa3e9d93 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -25,9 +25,7 @@ use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, -}; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion}; use datafusion::common::{assert_contains, DFSchema, Result}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; @@ -380,7 +378,7 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter { ) -> Result> { // First try our custom JSON shredding rewrite let rewritten = - expr.transform(|expr| self.rewrite_impl(expr, physical_file_schema))?; + expr.transform(|expr| self.rewrite_impl(expr, physical_file_schema)).data()?; // Then apply the default adapter as a fallback to handle standard schema differences // like type casting, missing columns, and partition column handling From ca4850373e834d9c7a68e1f32e0b7c2da301298b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:15:54 -0500 Subject: [PATCH 14/21] fmt --- .../examples/default_column_values.rs | 12 +++++++++--- datafusion-examples/examples/json_shredding.rs | 9 ++++++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs index 2a11a844f6491..9be490142f40c 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/default_column_values.rs @@ -289,9 +289,15 @@ impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { partition_values: &[(FieldRef, ScalarValue)], ) -> Result> { // First try our custom default value injection for missing columns - let rewritten = expr.transform(|expr| { - self.inject_default_values(expr, logical_file_schema, physical_file_schema) - }).data()?; + let rewritten = expr + .transform(|expr| { + self.inject_default_values( + expr, + logical_file_schema, + physical_file_schema, + ) + }) + .data()?; // Then apply the default adapter as a fallback to handle standard schema differences // like type casting, partition column handling, etc. diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index bc042fa3e9d93..66a3e7d04e5b7 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -25,7 +25,9 @@ use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion::common::{assert_contains, DFSchema, Result}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; @@ -377,8 +379,9 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter { partition_values: &[(FieldRef, ScalarValue)], ) -> Result> { // First try our custom JSON shredding rewrite - let rewritten = - expr.transform(|expr| self.rewrite_impl(expr, physical_file_schema)).data()?; + let rewritten = expr + .transform(|expr| self.rewrite_impl(expr, physical_file_schema)) + .data()?; // Then apply the default adapter as a fallback to handle standard schema differences // like type casting, missing columns, and partition column handling From 50b9d3c4e52854250dfad423af7996f4b4e083d4 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 10 Jul 2025 20:16:40 -0500 Subject: [PATCH 15/21] move to filescanconfigbuilder --- .../examples/default_column_values.rs | 10 ++++---- .../examples/json_shredding.rs | 12 +++++----- datafusion/datasource-parquet/src/source.rs | 18 +------------- datafusion/datasource/src/file_scan_config.rs | 24 +++++++++++++++++++ 4 files changed, 36 insertions(+), 28 deletions(-) diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs index 9be490142f40c..836be14c50524 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/default_column_values.rs @@ -237,10 +237,7 @@ impl TableProvider for DefaultValueTableProvider { let parquet_source = ParquetSource::default() .with_predicate(filter) - .with_pushdown_filters(true) - .with_expr_adapter(Arc::new(DefaultValuePhysicalExprAdapter { - default_adapter: DefaultPhysicalExprAdapter, - }) as _); + .with_pushdown_filters(true); let object_store_url = ObjectStoreUrl::parse("memory://")?; let store = state.runtime_env().object_store(object_store_url)?; @@ -265,7 +262,10 @@ impl TableProvider for DefaultValueTableProvider { ) .with_projection(projection.cloned()) .with_limit(limit) - .with_file_group(file_group); + .with_file_group(file_group) + .with_expr_adapter(Arc::new(DefaultValuePhysicalExprAdapter { + default_adapter: DefaultPhysicalExprAdapter, + }) as _); Ok(Arc::new(DataSourceExec::new(Arc::new( file_scan_config.build(), diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index 66a3e7d04e5b7..e55693effb098 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -245,11 +245,7 @@ impl TableProvider for ExampleTableProvider { let parquet_source = ParquetSource::default() .with_predicate(filter) - .with_pushdown_filters(true) - // if the rewriter needs a reference to the table schema you can bind self.schema() here - .with_expr_adapter(Arc::new(ShreddedJsonRewriter { - default_adapter: DefaultPhysicalExprAdapter, - }) as _); + .with_pushdown_filters(true); let object_store_url = ObjectStoreUrl::parse("memory://")?; @@ -275,7 +271,11 @@ impl TableProvider for ExampleTableProvider { ) .with_projection(projection.cloned()) .with_limit(limit) - .with_file_group(file_group); + .with_file_group(file_group) + // if the rewriter needs a reference to the table schema you can bind self.schema() here + .with_expr_adapter(Arc::new(ShreddedJsonRewriter { + default_adapter: DefaultPhysicalExprAdapter, + }) as _); Ok(Arc::new(DataSourceExec::new(Arc::new( file_scan_config.build(), diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index c9e6b9105c630..d396b27020cfa 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -39,7 +39,6 @@ use datafusion_common::{DataFusionError, Statistics}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_physical_expr::conjunction; -use datafusion_physical_expr::schema_rewriter::PhysicalExprAdapter; use datafusion_physical_expr::DefaultPhysicalExprAdapter; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -280,7 +279,6 @@ pub struct ParquetSource { /// Optional hint for the size of the parquet metadata pub(crate) metadata_size_hint: Option, pub(crate) projected_statistics: Option, - pub(crate) expr_adapter: Option>, } impl ParquetSource { @@ -319,20 +317,6 @@ impl ParquetSource { conf } - /// Register an expression adapter used to adapt filters and projections that are pushed down into the scan - /// from the logical schema to the physical schema of the parquet file. - /// This can include things like: - /// - Column ordering changes - /// - Handling of missing columns - /// - Rewriting expression to use pre-computed values or file format specific optimizations - pub fn with_expr_adapter( - mut self, - expr_adapter: Arc, - ) -> Self { - self.expr_adapter = Some(expr_adapter); - self - } - /// Options passed to the parquet reader for this scan pub fn table_parquet_options(&self) -> &TableParquetOptions { &self.table_parquet_options @@ -526,7 +510,7 @@ impl FileSource for ParquetSource { schema_adapter_factory, coerce_int96, file_decryption_properties, - expr_adapter: self + expr_adapter: base_config .expr_adapter .clone() .unwrap_or(Arc::new(DefaultPhysicalExprAdapter)), diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index 431b6ab0bcf0d..6c64def579c99 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -53,6 +53,7 @@ use datafusion_execution::{ object_store::ObjectStoreUrl, SendableRecordBatchStream, TaskContext, }; use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::schema_rewriter::PhysicalExprAdapter; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -188,6 +189,9 @@ pub struct FileScanConfig { /// Batch size while creating new batches /// Defaults to [`datafusion_common::config::ExecutionOptions`] batch_size. pub batch_size: Option, + /// Expression adapter used to adapt filters and projections that are pushed down into the scan + /// from the logical schema to the physical schema of the file. + pub expr_adapter: Option>, } /// A builder for [`FileScanConfig`]'s. @@ -265,6 +269,7 @@ pub struct FileScanConfigBuilder { file_compression_type: Option, new_lines_in_values: Option, batch_size: Option, + expr_adapter: Option>, } impl FileScanConfigBuilder { @@ -293,6 +298,7 @@ impl FileScanConfigBuilder { table_partition_cols: vec![], constraints: None, batch_size: None, + expr_adapter: None, } } @@ -401,6 +407,20 @@ impl FileScanConfigBuilder { self } + /// Register an expression adapter used to adapt filters and projections that are pushed down into the scan + /// from the logical schema to the physical schema of the file. + /// This can include things like: + /// - Column ordering changes + /// - Handling of missing columns + /// - Rewriting expression to use pre-computed values or file format specific optimizations + pub fn with_expr_adapter( + mut self, + expr_adapter: Arc, + ) -> Self { + self.expr_adapter = Some(expr_adapter); + self + } + /// Build the final [`FileScanConfig`] with all the configured settings. /// /// This method takes ownership of the builder and returns the constructed `FileScanConfig`. @@ -420,6 +440,7 @@ impl FileScanConfigBuilder { file_compression_type, new_lines_in_values, batch_size, + expr_adapter, } = self; let constraints = constraints.unwrap_or_default(); @@ -446,6 +467,7 @@ impl FileScanConfigBuilder { file_compression_type, new_lines_in_values, batch_size, + expr_adapter, } } } @@ -466,6 +488,7 @@ impl From for FileScanConfigBuilder { table_partition_cols: config.table_partition_cols, constraints: Some(config.constraints), batch_size: config.batch_size, + expr_adapter: config.expr_adapter, } } } @@ -679,6 +702,7 @@ impl FileScanConfig { new_lines_in_values: false, file_source: Arc::clone(&file_source), batch_size: None, + expr_adapter: None, } } From fc7e01032fe975892d05426abc52cf678023fc66 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 10 Jul 2025 21:39:30 -0500 Subject: [PATCH 16/21] fix --- datafusion/datasource-parquet/src/opener.rs | 2 +- datafusion/datasource-parquet/src/source.rs | 2 +- datafusion/physical-expr/src/schema_rewriter.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 5389c966ad7db..fa61e350520ac 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -135,7 +135,7 @@ impl FileOpener for ParquetOpener { let predicate_creation_errors = MetricBuilder::new(&self.metrics) .global_counter("num_predicate_creation_errors"); - let expr_adapter = self.expr_adapter.clone(); + let expr_adapter = Arc::clone(&self.expr_adapter); let mut enable_page_index = self.enable_page_index; let file_decryption_properties = self.file_decryption_properties.clone(); diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index d396b27020cfa..84b1b1c1ca077 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -513,7 +513,7 @@ impl FileSource for ParquetSource { expr_adapter: base_config .expr_adapter .clone() - .unwrap_or(Arc::new(DefaultPhysicalExprAdapter)), + .unwrap_or_else(|| Arc::new(DefaultPhysicalExprAdapter)), }) } diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index 4fa8550190bd0..cc4ad9125b259 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -84,7 +84,7 @@ pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug { /// - `logical_file_schema`: The logical schema of the table being queried, excluding any partition columns. /// - `physical_file_schema`: The physical schema of the file being scanned. /// - `partition_values`: Optional partition values to use for rewriting partition column references. - /// These are handled as if they were columns appended onto the logical file schema. + /// These are handled as if they were columns appended onto the logical file schema. /// /// Returns: /// - `Arc`: The rewritten physical expression that can be evaluated against the physical schema. From 8a5cf4f1fadb3c6f2288e9d44fef04207da936f6 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 15 Jul 2025 09:28:42 -0500 Subject: [PATCH 17/21] use a factory --- datafusion/datasource-parquet/src/opener.rs | 44 ++-- datafusion/datasource-parquet/src/source.rs | 4 +- datafusion/datasource/src/file_scan_config.rs | 8 +- .../physical-expr/src/schema_rewriter.rs | 229 +++++++++++------- 4 files changed, 168 insertions(+), 117 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index fa61e350520ac..e7cb1e061e8a6 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -33,7 +33,7 @@ use arrow::datatypes::{FieldRef, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_datasource::PartitionedFile; -use datafusion_physical_expr::schema_rewriter::PhysicalExprAdapter; +use datafusion_physical_expr::schema_rewriter::PhysicalExprAdapterFactory; use datafusion_physical_expr::simplifier::PhysicalExprSimplifier; use datafusion_physical_expr_common::physical_expr::{ is_dynamic_physical_expr, PhysicalExpr, @@ -94,7 +94,7 @@ pub(super) struct ParquetOpener { /// Optional parquet FileDecryptionProperties pub file_decryption_properties: Option>, /// Rewrite expressions in the context of the file schema - pub expr_adapter: Arc, + pub expr_adapter: Arc, } impl FileOpener for ParquetOpener { @@ -247,22 +247,20 @@ impl FileOpener for ParquetOpener { .cloned() .zip(file.partition_values) .collect_vec(); - expr_adapter - .rewrite_to_file_schema( - p, - &logical_file_schema, - &physical_file_schema, - &partition_values, + let adapter = expr_adapter + .create( + Arc::clone(&logical_file_schema), + Arc::clone(&physical_file_schema), ) - .map_err(ArrowError::from) - .map(|p| { - // After rewriting to the file schema, further simplifications may be possible. - // For example, if `'a' = col_that_is_missing` becomes `'a' = NULL` that can then be simplified to `FALSE` - // and we can avoid doing any more work on the file (bloom filters, loading the page index, etc.). - PhysicalExprSimplifier::new(&physical_file_schema) - .simplify(p) - .map_err(ArrowError::from) - }) + .with_partition_values(partition_values); + adapter.rewrite(p).map_err(ArrowError::from).map(|p| { + // After rewriting to the file schema, further simplifications may be possible. + // For example, if `'a' = col_that_is_missing` becomes `'a' = NULL` that can then be simplified to `FALSE` + // and we can avoid doing any more work on the file (bloom filters, loading the page index, etc.). + PhysicalExprSimplifier::new(&physical_file_schema) + .simplify(p) + .map_err(ArrowError::from) + }) }) .transpose()? .transpose()?; @@ -534,7 +532,7 @@ mod test { use datafusion_expr::{col, lit}; use datafusion_physical_expr::{ expressions::DynamicFilterPhysicalExpr, planner::logical2physical, - DefaultPhysicalExprAdapter, PhysicalExpr, + schema_rewriter::DefaultPhysicalExprAdapterFactory, PhysicalExpr, }; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use futures::{Stream, StreamExt}; @@ -640,7 +638,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, - expr_adapter: Arc::new(DefaultPhysicalExprAdapter), + expr_adapter: Arc::new(DefaultPhysicalExprAdapterFactory), } }; @@ -726,7 +724,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, - expr_adapter: Arc::new(DefaultPhysicalExprAdapter), + expr_adapter: Arc::new(DefaultPhysicalExprAdapterFactory), } }; @@ -828,7 +826,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, - expr_adapter: Arc::new(DefaultPhysicalExprAdapter), + expr_adapter: Arc::new(DefaultPhysicalExprAdapterFactory), } }; let make_meta = || FileMeta { @@ -940,7 +938,7 @@ mod test { enable_row_group_stats_pruning: false, // note that this is false! coerce_int96: None, file_decryption_properties: None, - expr_adapter: Arc::new(DefaultPhysicalExprAdapter), + expr_adapter: Arc::new(DefaultPhysicalExprAdapterFactory), } }; @@ -1053,7 +1051,7 @@ mod test { enable_row_group_stats_pruning: true, coerce_int96: None, file_decryption_properties: None, - expr_adapter: Arc::new(DefaultPhysicalExprAdapter), + expr_adapter: Arc::new(DefaultPhysicalExprAdapterFactory), } }; diff --git a/datafusion/datasource-parquet/src/source.rs b/datafusion/datasource-parquet/src/source.rs index 84b1b1c1ca077..67a27b5b401ab 100644 --- a/datafusion/datasource-parquet/src/source.rs +++ b/datafusion/datasource-parquet/src/source.rs @@ -39,7 +39,7 @@ use datafusion_common::{DataFusionError, Statistics}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_physical_expr::conjunction; -use datafusion_physical_expr::DefaultPhysicalExprAdapter; +use datafusion_physical_expr::schema_rewriter::DefaultPhysicalExprAdapterFactory; use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::filter_pushdown::{ @@ -513,7 +513,7 @@ impl FileSource for ParquetSource { expr_adapter: base_config .expr_adapter .clone() - .unwrap_or_else(|| Arc::new(DefaultPhysicalExprAdapter)), + .unwrap_or_else(|| Arc::new(DefaultPhysicalExprAdapterFactory)), }) } diff --git a/datafusion/datasource/src/file_scan_config.rs b/datafusion/datasource/src/file_scan_config.rs index 6c64def579c99..7e13f84ce0a52 100644 --- a/datafusion/datasource/src/file_scan_config.rs +++ b/datafusion/datasource/src/file_scan_config.rs @@ -53,7 +53,7 @@ use datafusion_execution::{ object_store::ObjectStoreUrl, SendableRecordBatchStream, TaskContext, }; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::schema_rewriter::PhysicalExprAdapter; +use datafusion_physical_expr::schema_rewriter::PhysicalExprAdapterFactory; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -191,7 +191,7 @@ pub struct FileScanConfig { pub batch_size: Option, /// Expression adapter used to adapt filters and projections that are pushed down into the scan /// from the logical schema to the physical schema of the file. - pub expr_adapter: Option>, + pub expr_adapter: Option>, } /// A builder for [`FileScanConfig`]'s. @@ -269,7 +269,7 @@ pub struct FileScanConfigBuilder { file_compression_type: Option, new_lines_in_values: Option, batch_size: Option, - expr_adapter: Option>, + expr_adapter: Option>, } impl FileScanConfigBuilder { @@ -415,7 +415,7 @@ impl FileScanConfigBuilder { /// - Rewriting expression to use pre-computed values or file format specific optimizations pub fn with_expr_adapter( mut self, - expr_adapter: Arc, + expr_adapter: Arc, ) -> Self { self.expr_adapter = Some(expr_adapter); self diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index cc4ad9125b259..ca3e4ffa20c48 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::compute::can_cast_types; -use arrow::datatypes::{FieldRef, Schema}; +use arrow::datatypes::{FieldRef, Schema, SchemaRef}; use datafusion_common::{ exec_err, tree_node::{Transformed, TransformedResult, TreeNode}, @@ -40,28 +40,25 @@ use crate::expressions::{self, CastExpr, Column}; /// For example, to fill in missing columns with default values instead of nulls: /// /// ```rust -/// use datafusion_physical_expr::schema_rewriter::PhysicalExprAdapter; -/// use arrow::datatypes::{Schema, Field, DataType, FieldRef}; +/// use datafusion_physical_expr::schema_rewriter::{PhysicalExprAdapter, PhysicalExprAdapterFactory}; +/// use arrow::datatypes::{Schema, Field, DataType, FieldRef, SchemaRef}; /// use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -/// use datafusion_common::{Result, ScalarValue, tree_node::{Transformed, TreeNode}}; +/// use datafusion_common::{Result, ScalarValue, tree_node::{Transformed, TransformedResult, TreeNode}}; /// use datafusion_physical_expr::expressions::{self, Column}; /// use std::sync::Arc; /// /// #[derive(Debug)] -/// pub struct CustomPhysicalExprAdapter; +/// pub struct CustomPhysicalExprAdapter { +/// logical_file_schema: SchemaRef, +/// physical_file_schema: SchemaRef, +/// } /// /// impl PhysicalExprAdapter for CustomPhysicalExprAdapter { -/// fn rewrite( -/// &self, -/// expr: Arc, -/// logical_file_schema: &Schema, -/// physical_file_schema: &Schema, -/// partition_values: &[(FieldRef, ScalarValue)], -/// ) -> Result> { +/// fn rewrite(&self, expr: Arc) -> Result> { /// expr.transform(|expr| { /// if let Some(column) = expr.as_any().downcast_ref::() { /// // Check if the column exists in the physical schema -/// if physical_file_schema.index_of(column.name()).is_err() { +/// if self.physical_file_schema.index_of(column.name()).is_err() { /// // If the column is missing, fill it with a default value instead of null /// // The default value could be stored in the table schema's column metadata for example. /// let default_value = ScalarValue::Int32(Some(0)); @@ -70,6 +67,33 @@ use crate::expressions::{self, CastExpr, Column}; /// } /// // If the column exists, return it as is /// Ok(Transformed::no(expr)) +/// }).data() +/// } +/// +/// fn with_partition_values( +/// &self, +/// partition_values: Vec<(FieldRef, ScalarValue)>, +/// ) -> Arc { +/// // For simplicity, this example ignores partition values +/// Arc::new(CustomPhysicalExprAdapter { +/// logical_file_schema: self.logical_file_schema.clone(), +/// physical_file_schema: self.physical_file_schema.clone(), +/// }) +/// } +/// } +/// +/// #[derive(Debug)] +/// pub struct CustomPhysicalExprAdapterFactory; +/// +/// impl PhysicalExprAdapterFactory for CustomPhysicalExprAdapterFactory { +/// fn create( +/// &self, +/// logical_file_schema: SchemaRef, +/// physical_file_schema: SchemaRef, +/// ) -> Arc { +/// Arc::new(CustomPhysicalExprAdapter { +/// logical_file_schema, +/// physical_file_schema, /// }) /// } /// } @@ -88,58 +112,111 @@ pub trait PhysicalExprAdapter: Send + Sync + std::fmt::Debug { /// /// Returns: /// - `Arc`: The rewritten physical expression that can be evaluated against the physical schema. - fn rewrite_to_file_schema( + fn rewrite(&self, expr: Arc) -> Result>; + + fn with_partition_values( &self, - expr: Arc, - logical_file_schema: &Schema, - physical_file_schema: &Schema, - partition_values: &[(FieldRef, ScalarValue)], - ) -> Result>; + partition_values: Vec<(FieldRef, ScalarValue)>, + ) -> Arc; +} + +pub trait PhysicalExprAdapterFactory: Send + Sync + std::fmt::Debug { + /// Create a new instance of the physical expression adapter. + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc; } -/// Builder for rewriting physical expressions to match different schemas. +#[derive(Debug, Clone)] +pub struct DefaultPhysicalExprAdapterFactory; + +impl PhysicalExprAdapterFactory for DefaultPhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + Arc::new(DefaultPhysicalExprAdapter { + logical_file_schema, + physical_file_schema, + partition_values: Vec::new(), + }) + } +} + +/// Default implementation for rewriting physical expressions to match different schemas. /// /// # Example /// /// ```rust -/// use datafusion_physical_expr::schema_rewriter::DefaultPhysicalExprAdapter; +/// use datafusion_physical_expr::schema_rewriter::{DefaultPhysicalExprAdapterFactory, PhysicalExprAdapterFactory}; /// use arrow::datatypes::Schema; +/// use std::sync::Arc; /// /// # fn example( /// # predicate: std::sync::Arc, /// # physical_file_schema: &Schema, /// # logical_file_schema: &Schema, /// # ) -> datafusion_common::Result<()> { -/// let rewriter = DefaultPhysicalExprAdapter; -/// let adapted_predicate = rewriter.rewrite(predicate, logical_file_schema, physical_file_schema, &[])?; +/// let factory = DefaultPhysicalExprAdapterFactory; +/// let adapter = factory.create(Arc::new(logical_file_schema.clone()), Arc::new(physical_file_schema.clone())); +/// let adapted_predicate = adapter.rewrite(predicate)?; /// # Ok(()) /// # } /// ``` #[derive(Debug, Clone)] -pub struct DefaultPhysicalExprAdapter; +pub struct DefaultPhysicalExprAdapter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + partition_values: Vec<(FieldRef, ScalarValue)>, +} + +// impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { +// /// Rewrite the given physical expression to match the target schema +// /// +// /// This method applies the following transformations: +// /// 1. Replaces partition column references with literal values +// /// 2. Handles missing columns by inserting null literals +// /// 3. Casts columns when logical and physical schemas have different types +// fn rewrite_to_file_schema( +// &self, +// expr: Arc, +// logical_file_schema: &Schema, +// physical_file_schema: &Schema, +// partition_values: &[(FieldRef, ScalarValue)], +// ) -> Result> { +// let rewriter = DefaultPhysicalExprAdapterRewriter { +// logical_file_schema, +// physical_file_schema, +// partition_fields: partition_values, +// }; +// expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) +// .data() +// } +// } impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { - /// Rewrite the given physical expression to match the target schema - /// - /// This method applies the following transformations: - /// 1. Replaces partition column references with literal values - /// 2. Handles missing columns by inserting null literals - /// 3. Casts columns when logical and physical schemas have different types - fn rewrite_to_file_schema( - &self, - expr: Arc, - logical_file_schema: &Schema, - physical_file_schema: &Schema, - partition_values: &[(FieldRef, ScalarValue)], - ) -> Result> { + fn rewrite(&self, expr: Arc) -> Result> { let rewriter = DefaultPhysicalExprAdapterRewriter { - logical_file_schema, - physical_file_schema, - partition_fields: partition_values, + logical_file_schema: &self.logical_file_schema, + physical_file_schema: &self.physical_file_schema, + partition_fields: &self.partition_values, }; expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) .data() } + + fn with_partition_values( + &self, + partition_values: Vec<(FieldRef, ScalarValue)>, + ) -> Arc { + Arc::new(DefaultPhysicalExprAdapter { + partition_values, + ..self.clone() + }) + } } struct DefaultPhysicalExprAdapterRewriter<'a> { @@ -185,7 +262,7 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { // we'll at least handle the casts for them. physical_field } else { - // A completely unknwon column that doesn't exist in either schema! + // A completely unknown column that doesn't exist in either schema! // This should probably never be hit unless something upstream broke, but nontheless it's better // for us to return a handleable error than to panic / do something unexpected. return Err(e.into()); @@ -296,12 +373,11 @@ mod tests { fn test_rewrite_column_with_type_cast() { let (physical_schema, logical_schema) = create_test_schema(); - let rewriter = DefaultPhysicalExprAdapter; + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); let column_expr = Arc::new(Column::new("a", 0)); - let result = rewriter - .rewrite_to_file_schema(column_expr, &logical_schema, &physical_schema, &[]) - .unwrap(); + let result = adapter.rewrite(column_expr).unwrap(); // Should be wrapped in a cast expression assert!(result.as_any().downcast_ref::().is_some()); @@ -310,7 +386,8 @@ mod tests { #[test] fn test_rewrite_mulit_column_expr_with_type_cast() { let (physical_schema, logical_schema) = create_test_schema(); - let rewriter = DefaultPhysicalExprAdapter; + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter let column_a = Arc::new(Column::new("a", 0)) as Arc; @@ -330,14 +407,7 @@ mod tests { )), ); - let result = rewriter - .rewrite_to_file_schema( - Arc::new(expr), - &logical_schema, - &physical_schema, - &[], - ) - .unwrap(); + let result = adapter.rewrite(Arc::new(expr)).unwrap(); println!("Rewritten expression: {result}"); let expected = expressions::BinaryExpr::new( @@ -370,15 +440,11 @@ mod tests { fn test_rewrite_missing_column() -> Result<()> { let (physical_schema, logical_schema) = create_test_schema(); - let rewriter = DefaultPhysicalExprAdapter; + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); let column_expr = Arc::new(Column::new("c", 2)); - let result = rewriter.rewrite_to_file_schema( - column_expr, - &logical_schema, - &physical_schema, - &[], - )?; + let result = adapter.rewrite(column_expr)?; // Should be replaced with a literal null if let Some(literal) = result.as_any().downcast_ref::() { @@ -399,15 +465,12 @@ mod tests { let partition_value = ScalarValue::Utf8(Some("test_value".to_string())); let partition_values = vec![(partition_field, partition_value)]; - let rewriter = DefaultPhysicalExprAdapter; + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); + let adapter = adapter.with_partition_values(partition_values); let column_expr = Arc::new(Column::new("partition_col", 0)); - let result = rewriter.rewrite_to_file_schema( - column_expr, - &logical_schema, - &physical_schema, - &partition_values, - )?; + let result = adapter.rewrite(column_expr)?; // Should be replaced with the partition value if let Some(literal) = result.as_any().downcast_ref::() { @@ -426,15 +489,11 @@ mod tests { fn test_rewrite_no_change_needed() -> Result<()> { let (physical_schema, logical_schema) = create_test_schema(); - let rewriter = DefaultPhysicalExprAdapter; + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); let column_expr = Arc::new(Column::new("b", 1)) as Arc; - let result = rewriter.rewrite_to_file_schema( - Arc::clone(&column_expr), - &logical_schema, - &physical_schema, - &[], - )?; + let result = adapter.rewrite(Arc::clone(&column_expr))?; // Should be the same expression (no transformation needed) // We compare the underlying pointer through the trait object @@ -454,15 +513,11 @@ mod tests { Field::new("b", DataType::Utf8, false), // Non-nullable missing column ]); - let rewriter = DefaultPhysicalExprAdapter; + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory.create(Arc::new(logical_schema), Arc::new(physical_schema)); let column_expr = Arc::new(Column::new("b", 1)); - let result = rewriter.rewrite_to_file_schema( - column_expr, - &logical_schema, - &physical_schema, - &[], - ); + let result = adapter.rewrite(column_expr); assert!(result.is_err()); assert!(result .unwrap_err() @@ -516,15 +571,13 @@ mod tests { col("a", &logical_schema).unwrap(), ]; - let rewriter = DefaultPhysicalExprAdapter; + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = + factory.create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)); let adapted_projection = projection .into_iter() - .map(|expr| { - rewriter - .rewrite_to_file_schema(expr, &logical_schema, &physical_schema, &[]) - .unwrap() - }) + .map(|expr| adapter.rewrite(expr).unwrap()) .collect_vec(); let adapted_schema = Arc::new(Schema::new( From 5dd47de87455f746e20bf08267c1dd8db0e9faf9 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 15 Jul 2025 09:34:36 -0500 Subject: [PATCH 18/21] Update tests and examples to use new factory-style PhysicalExprAdapter API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Update all tests in schema_rewriter.rs to use DefaultPhysicalExprAdapterFactory - Update documentation examples to demonstrate factory pattern - Update default_column_values.rs example to use factory-style API - Convert from rewrite_to_file_schema method to rewrite method with factory pattern - Add proper partition values handling with with_partition_values method 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../examples/default_column_values.rs | 68 +++++++++++++------ 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs index 836be14c50524..21181a540fce4 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/default_column_values.rs @@ -39,7 +39,7 @@ use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::expressions::{CastExpr, Column, Literal}; use datafusion::physical_expr::schema_rewriter::{ - DefaultPhysicalExprAdapter, PhysicalExprAdapter, + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, }; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; @@ -263,9 +263,7 @@ impl TableProvider for DefaultValueTableProvider { .with_projection(projection.cloned()) .with_limit(limit) .with_file_group(file_group) - .with_expr_adapter(Arc::new(DefaultValuePhysicalExprAdapter { - default_adapter: DefaultPhysicalExprAdapter, - }) as _); + .with_expr_adapter(Arc::new(DefaultValuePhysicalExprAdapterFactory) as _); Ok(Arc::new(DataSourceExec::new(Arc::new( file_scan_config.build(), @@ -273,40 +271,72 @@ impl TableProvider for DefaultValueTableProvider { } } +/// Factory for creating DefaultValuePhysicalExprAdapter instances +#[derive(Debug)] +struct DefaultValuePhysicalExprAdapterFactory; + +impl PhysicalExprAdapterFactory for DefaultValuePhysicalExprAdapterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + let default_factory = DefaultPhysicalExprAdapterFactory; + let default_adapter = default_factory.create(logical_file_schema.clone(), physical_file_schema.clone()); + + Arc::new(DefaultValuePhysicalExprAdapter { + logical_file_schema, + physical_file_schema, + default_adapter, + partition_values: Vec::new(), + }) + } +} + /// Custom PhysicalExprAdapter that handles missing columns with default values from metadata /// and wraps DefaultPhysicalExprAdapter for standard schema adaptation #[derive(Debug)] struct DefaultValuePhysicalExprAdapter { - default_adapter: DefaultPhysicalExprAdapter, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + default_adapter: Arc, + partition_values: Vec<(FieldRef, ScalarValue)>, } impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { - fn rewrite_to_file_schema( - &self, - expr: Arc, - logical_file_schema: &Schema, - physical_file_schema: &Schema, - partition_values: &[(FieldRef, ScalarValue)], - ) -> Result> { + fn rewrite(&self, expr: Arc) -> Result> { // First try our custom default value injection for missing columns let rewritten = expr .transform(|expr| { self.inject_default_values( expr, - logical_file_schema, - physical_file_schema, + &self.logical_file_schema, + &self.physical_file_schema, ) }) .data()?; // Then apply the default adapter as a fallback to handle standard schema differences // like type casting, partition column handling, etc. - self.default_adapter.rewrite_to_file_schema( - rewritten, - logical_file_schema, - physical_file_schema, + let default_adapter = if !self.partition_values.is_empty() { + self.default_adapter.with_partition_values(self.partition_values.clone()) + } else { + self.default_adapter.clone() + }; + + default_adapter.rewrite(rewritten) + } + + fn with_partition_values( + &self, + partition_values: Vec<(FieldRef, ScalarValue)>, + ) -> Arc { + Arc::new(DefaultValuePhysicalExprAdapter { + logical_file_schema: self.logical_file_schema.clone(), + physical_file_schema: self.physical_file_schema.clone(), + default_adapter: self.default_adapter.clone(), partition_values, - ) + }) } } From 72ad6a416ec5d44831864629e447deb024989392 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 15 Jul 2025 09:39:02 -0500 Subject: [PATCH 19/21] fix example --- .../examples/default_column_values.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs index 21181a540fce4..b0270ffd8a3c0 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/default_column_values.rs @@ -282,8 +282,9 @@ impl PhysicalExprAdapterFactory for DefaultValuePhysicalExprAdapterFactory { physical_file_schema: SchemaRef, ) -> Arc { let default_factory = DefaultPhysicalExprAdapterFactory; - let default_adapter = default_factory.create(logical_file_schema.clone(), physical_file_schema.clone()); - + let default_adapter = default_factory + .create(logical_file_schema.clone(), physical_file_schema.clone()); + Arc::new(DefaultValuePhysicalExprAdapter { logical_file_schema, physical_file_schema, @@ -319,14 +320,15 @@ impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { // Then apply the default adapter as a fallback to handle standard schema differences // like type casting, partition column handling, etc. let default_adapter = if !self.partition_values.is_empty() { - self.default_adapter.with_partition_values(self.partition_values.clone()) + self.default_adapter + .with_partition_values(self.partition_values.clone()) } else { self.default_adapter.clone() }; - + default_adapter.rewrite(rewritten) } - + fn with_partition_values( &self, partition_values: Vec<(FieldRef, ScalarValue)>, From 3ce0a0f65aafc86ef8a7d2a71bbaca2e1d09a1e8 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 15 Jul 2025 09:59:34 -0500 Subject: [PATCH 20/21] Fix json_shredding.rs --- .../examples/json_shredding.rs | 66 ++++++++++++++----- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index e55693effb098..c273422d11cf5 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -41,7 +41,7 @@ use datafusion::logical_expr::{ use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::schema_rewriter::{ - DefaultPhysicalExprAdapter, PhysicalExprAdapter, + DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, }; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; @@ -273,9 +273,7 @@ impl TableProvider for ExampleTableProvider { .with_limit(limit) .with_file_group(file_group) // if the rewriter needs a reference to the table schema you can bind self.schema() here - .with_expr_adapter(Arc::new(ShreddedJsonRewriter { - default_adapter: DefaultPhysicalExprAdapter, - }) as _); + .with_expr_adapter(Arc::new(ShreddedJsonRewriterFactory) as _); Ok(Arc::new(DataSourceExec::new(Arc::new( file_scan_config.build(), @@ -363,34 +361,66 @@ impl ScalarUDFImpl for JsonGetStr { } } +/// Factory for creating ShreddedJsonRewriter instances +#[derive(Debug)] +struct ShreddedJsonRewriterFactory; + +impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory { + fn create( + &self, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + ) -> Arc { + let default_factory = DefaultPhysicalExprAdapterFactory; + let default_adapter = default_factory.create(logical_file_schema.clone(), physical_file_schema.clone()); + + Arc::new(ShreddedJsonRewriter { + logical_file_schema, + physical_file_schema, + default_adapter, + partition_values: Vec::new(), + }) + } +} + /// Rewriter that converts json_get_str calls to direct flat column references /// and wraps DefaultPhysicalExprAdapter for standard schema adaptation #[derive(Debug)] struct ShreddedJsonRewriter { - default_adapter: DefaultPhysicalExprAdapter, + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, + default_adapter: Arc, + partition_values: Vec<(FieldRef, ScalarValue)>, } impl PhysicalExprAdapter for ShreddedJsonRewriter { - fn rewrite_to_file_schema( - &self, - expr: Arc, - logical_file_schema: &Schema, - physical_file_schema: &Schema, - partition_values: &[(FieldRef, ScalarValue)], - ) -> Result> { + fn rewrite(&self, expr: Arc) -> Result> { // First try our custom JSON shredding rewrite let rewritten = expr - .transform(|expr| self.rewrite_impl(expr, physical_file_schema)) + .transform(|expr| self.rewrite_impl(expr, &self.physical_file_schema)) .data()?; // Then apply the default adapter as a fallback to handle standard schema differences // like type casting, missing columns, and partition column handling - self.default_adapter.rewrite_to_file_schema( - rewritten, - logical_file_schema, - physical_file_schema, + let default_adapter = if !self.partition_values.is_empty() { + self.default_adapter.with_partition_values(self.partition_values.clone()) + } else { + self.default_adapter.clone() + }; + + default_adapter.rewrite(rewritten) + } + + fn with_partition_values( + &self, + partition_values: Vec<(FieldRef, ScalarValue)>, + ) -> Arc { + Arc::new(ShreddedJsonRewriter { + logical_file_schema: self.logical_file_schema.clone(), + physical_file_schema: self.physical_file_schema.clone(), + default_adapter: self.default_adapter.clone(), partition_values, - ) + }) } } From 3df40419a23e805d11171ef8af4ce050cbccceff Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 15 Jul 2025 09:59:40 -0500 Subject: [PATCH 21/21] fmt --- datafusion-examples/examples/json_shredding.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index c273422d11cf5..ba9158f6913e9 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -372,8 +372,9 @@ impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory { physical_file_schema: SchemaRef, ) -> Arc { let default_factory = DefaultPhysicalExprAdapterFactory; - let default_adapter = default_factory.create(logical_file_schema.clone(), physical_file_schema.clone()); - + let default_adapter = default_factory + .create(logical_file_schema.clone(), physical_file_schema.clone()); + Arc::new(ShreddedJsonRewriter { logical_file_schema, physical_file_schema, @@ -403,14 +404,15 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter { // Then apply the default adapter as a fallback to handle standard schema differences // like type casting, missing columns, and partition column handling let default_adapter = if !self.partition_values.is_empty() { - self.default_adapter.with_partition_values(self.partition_values.clone()) + self.default_adapter + .with_partition_values(self.partition_values.clone()) } else { self.default_adapter.clone() }; - + default_adapter.rewrite(rewritten) } - + fn with_partition_values( &self, partition_values: Vec<(FieldRef, ScalarValue)>,