diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 2e82656c91a2..fbb4250fc4df 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -62,7 +62,9 @@ use arrow::array::{builder::StringBuilder, RecordBatch}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, ScalarValue, @@ -2075,29 +2077,36 @@ fn maybe_fix_physical_column_name( expr: Result>, input_physical_schema: &SchemaRef, ) -> Result> { - if let Ok(e) = &expr { - if let Some(column) = e.as_any().downcast_ref::() { - let physical_field = input_physical_schema.field(column.index()); + let Ok(expr) = expr else { return expr }; + expr.transform_down(|node| { + if let Some(column) = node.as_any().downcast_ref::() { + let idx = column.index(); + let physical_field = input_physical_schema.field(idx); let expr_col_name = column.name(); let physical_name = physical_field.name(); - if physical_name != expr_col_name { + if expr_col_name != physical_name { // handle edge cases where the physical_name contains ':'. let colon_count = physical_name.matches(':').count(); let mut splits = expr_col_name.match_indices(':'); let split_pos = splits.nth(colon_count); - if let Some((idx, _)) = split_pos { - let base_name = &expr_col_name[..idx]; + if let Some((i, _)) = split_pos { + let base_name = &expr_col_name[..i]; if base_name == physical_name { - let updated_column = Column::new(physical_name, column.index()); - return Ok(Arc::new(updated_column)); + let updated_column = Column::new(physical_name, idx); + return Ok(Transformed::yes(Arc::new(updated_column))); } } } + + // If names already match or fix is not possible, just leave it as it is + Ok(Transformed::no(node)) + } else { + Ok(Transformed::no(node)) } - } - expr + }) + .data() } struct OptimizationInvariantChecker<'a> { @@ -2203,8 +2212,11 @@ mod tests { }; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; - use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; + use datafusion_expr::{ + col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore, + }; use datafusion_functions_aggregate::expr_fn::sum; + use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr}; use datafusion_physical_expr::EquivalenceProperties; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -2769,6 +2781,47 @@ mod tests { assert_eq!(col.name(), "metric:avg"); } + + #[tokio::test] + async fn test_maybe_fix_nested_column_name_with_colon() { + let schema = Schema::new(vec![Field::new("column", DataType::Int32, false)]); + let schema_ref: SchemaRef = Arc::new(schema); + + // Construct the nested expr + let col_expr = Arc::new(Column::new("column:1", 0)) as Arc; + let is_not_null_expr = Arc::new(IsNotNullExpr::new(col_expr.clone())); + + // Create a binary expression and put the column inside + let binary_expr = Arc::new(BinaryExpr::new( + is_not_null_expr.clone(), + Operator::Or, + is_not_null_expr.clone(), + )) as Arc; + + let fixed_expr = + maybe_fix_physical_column_name(Ok(binary_expr), &schema_ref).unwrap(); + + let bin = fixed_expr + .as_any() + .downcast_ref::() + .expect("Expected BinaryExpr"); + + // Check that both sides where renamed + for expr in &[bin.left(), bin.right()] { + let is_not_null = expr + .as_any() + .downcast_ref::() + .expect("Expected IsNotNull"); + + let col = is_not_null + .arg() + .as_any() + .downcast_ref::() + .expect("Expected Column"); + + assert_eq!(col.name(), "column"); + } + } struct ErrorExtensionPlanner {} #[async_trait] diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 9db12a760eaa..930fe793d1d4 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -540,7 +540,12 @@ fn union_schema(inputs: &[Arc]) -> SchemaRef { let fields = (0..first_schema.fields().len()) .map(|i| { - inputs + // We take the name from the left side of the union to match how names are coerced during logical planning, + // which also uses the left side names. + let base_field = first_schema.field(i).clone(); + + // Coerce metadata and nullability across all inputs + let merged_field = inputs .iter() .enumerate() .map(|(input_idx, input)| { @@ -562,6 +567,9 @@ fn union_schema(inputs: &[Arc]) -> SchemaRef { // We can unwrap this because if inputs was empty, this would've already panic'ed when we // indexed into inputs[0]. .unwrap() + .with_name(base_field.name()); + + merged_field }) .collect::>(); diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index f32920d53d59..4a121e41d27e 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -560,4 +560,28 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn test_multiple_unions() -> Result<()> { + let plan_str = test_plan_to_string("multiple_unions.json").await?; + assert_snapshot!( + plan_str, + @r#" + Projection: Utf8("people") AS product_category, Utf8("people")__temp__0 AS product_type, product_key + Union + Projection: Utf8("people"), Utf8("people") AS Utf8("people")__temp__0, sales.product_key + Left Join: sales.product_key = food.@food_id + TableScan: sales + TableScan: food + Union + Projection: people.$f3, people.$f5, people.product_key0 + Left Join: people.product_key0 = food.@food_id + TableScan: people + TableScan: food + TableScan: more_products + "# + ); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json b/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json new file mode 100644 index 000000000000..8b82d6eec755 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multiple_unions.json @@ -0,0 +1,328 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [2, 3, 4] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["product_key"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "sales" + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["@food_id"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "food" + ] + } + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "literal": { + "string": "people" + } + }, { + "literal": { + "string": "people" + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }] + } + }, { + "set": { + "common": { + "direct": { + } + }, + "inputs": [{ + "project": { + "common": { + "emit": { + "outputMapping": [4, 5, 6] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["$f3", "$f5", "product_key0"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "people" + ] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["@food_id"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "food" + ] + } + + } + }, + "expression": { + "scalarFunction": { + "functionReference": 0, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + } + }, { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["$f1000", "$f2000", "more_products_key0000"], + "struct": { + "types": [{ + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "more_products" + ] + } + + } + }], + "op": "SET_OP_UNION_ALL" + } + }], + "op": "SET_OP_UNION_ALL" + } + }, + "names": ["product_category", "product_type", "product_key"] + } + }] +} \ No newline at end of file