Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 48 additions & 41 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Logical plan types

use std::cmp::Ordering;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, LazyLock};
Expand Down Expand Up @@ -2679,24 +2679,16 @@ impl Union {
Ok(Union { inputs, schema })
}

/// When constructing a `UNION BY NAME`, we may need to wrap inputs
/// When constructing a `UNION BY NAME`, we need to wrap inputs
/// in an additional `Projection` to account for absence of columns
/// in input schemas.
/// in input schemas or differing projection orders.
fn rewrite_inputs_from_schema(
schema: &DFSchema,
schema: &Arc<DFSchema>,
inputs: Vec<Arc<LogicalPlan>>,
) -> Result<Vec<Arc<LogicalPlan>>> {
let schema_width = schema.iter().count();
let mut wrapped_inputs = Vec::with_capacity(inputs.len());
for input in inputs {
// If the input plan's schema contains the same number of fields
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does indeed seem to be an overzealous optimization

// as the derived schema, then it does not to be wrapped in an
// additional `Projection`.
if input.schema().iter().count() == schema_width {
wrapped_inputs.push(input);
continue;
}

// Any columns that exist within the derived schema but do not exist
// within an input's schema should be replaced with `NULL` aliased
// to the appropriate column in the derived schema.
Expand All @@ -2711,9 +2703,9 @@ impl Union {
expr.push(Expr::Literal(ScalarValue::Null).alias(column.name()));
}
}
wrapped_inputs.push(Arc::new(LogicalPlan::Projection(Projection::try_new(
expr, input,
)?)));
wrapped_inputs.push(Arc::new(LogicalPlan::Projection(
Projection::try_new_with_schema(expr, input, Arc::clone(schema))?,
)));
}

Ok(wrapped_inputs)
Expand Down Expand Up @@ -2747,45 +2739,60 @@ impl Union {
inputs: &[Arc<LogicalPlan>],
loose_types: bool,
) -> Result<DFSchemaRef> {
type FieldData<'a> = (&'a DataType, bool, Vec<&'a HashMap<String, String>>);
// Prefer `BTreeMap` as it produces items in order by key when iterated over
let mut cols: BTreeMap<&str, FieldData> = BTreeMap::new();
type FieldData<'a> =
(&'a DataType, bool, Vec<&'a HashMap<String, String>>, usize);
let mut cols: Vec<(&str, FieldData)> = Vec::new();
for input in inputs.iter() {
for field in input.schema().fields() {
match cols.entry(field.name()) {
std::collections::btree_map::Entry::Occupied(mut occupied) => {
let (data_type, is_nullable, metadata) = occupied.get_mut();
if !loose_types && *data_type != field.data_type() {
return plan_err!(
"Found different types for field {}",
field.name()
);
}

metadata.push(field.metadata());
// If the field is nullable in any one of the inputs,
// then the field in the final schema is also nullable.
*is_nullable |= field.is_nullable();
if let Some((_, (data_type, is_nullable, metadata, occurrences))) =
cols.iter_mut().find(|(name, _)| name == field.name())
{
if !loose_types && *data_type != field.data_type() {
return plan_err!(
"Found different types for field {}",
field.name()
);
}
std::collections::btree_map::Entry::Vacant(vacant) => {
vacant.insert((

metadata.push(field.metadata());
// If the field is nullable in any one of the inputs,
// then the field in the final schema is also nullable.
*is_nullable |= field.is_nullable();
*occurrences += 1;
} else {
cols.push((
field.name(),
(
field.data_type(),
field.is_nullable(),
vec![field.metadata()],
));
}
1,
),
));
}
}
}

let union_fields = cols
.into_iter()
.map(|(name, (data_type, is_nullable, unmerged_metadata))| {
let mut field = Field::new(name, data_type.clone(), is_nullable);
field.set_metadata(intersect_maps(unmerged_metadata));
.map(
|(name, (data_type, is_nullable, unmerged_metadata, occurrences))| {
// If the final number of occurrences of the field is less
// than the number of inputs (i.e. the field is missing from
// one or more inputs), then it must be treated as nullable.
let final_is_nullable = if occurrences == inputs.len() {
is_nullable
} else {
true
};

(None, Arc::new(field))
})
let mut field =
Field::new(name, data_type.clone(), final_is_nullable);
field.set_metadata(intersect_maps(unmerged_metadata));

(None, Arc::new(field))
},
)
.collect::<Vec<(Option<TableReference>, _)>>();

let union_schema_metadata =
Expand Down
27 changes: 16 additions & 11 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1898,11 +1898,12 @@ fn union_by_name_different_columns() {
let expected = "\
Distinct:\
\n Union\
\n Projection: NULL AS Int64(1), order_id\
\n Projection: order_id, NULL AS Int64(1)\
\n Projection: orders.order_id\
\n TableScan: orders\
\n Projection: orders.order_id, Int64(1)\
\n TableScan: orders";
\n Projection: order_id, Int64(1)\
\n Projection: orders.order_id, Int64(1)\
\n TableScan: orders";
quick_test(sql, expected);
}

Expand Down Expand Up @@ -1936,22 +1937,26 @@ fn union_all_by_name_different_columns() {
"SELECT order_id from orders UNION ALL BY NAME SELECT order_id, 1 FROM orders";
let expected = "\
Union\
\n Projection: NULL AS Int64(1), order_id\
\n Projection: order_id, NULL AS Int64(1)\
\n Projection: orders.order_id\
\n TableScan: orders\
\n Projection: orders.order_id, Int64(1)\
\n TableScan: orders";
\n Projection: order_id, Int64(1)\
\n Projection: orders.order_id, Int64(1)\
\n TableScan: orders";
quick_test(sql, expected);
}

#[test]
fn union_all_by_name_same_column_names() {
let sql = "SELECT order_id from orders UNION ALL BY NAME SELECT order_id FROM orders";
let expected = "Union\
\n Projection: orders.order_id\
\n TableScan: orders\
\n Projection: orders.order_id\
\n TableScan: orders";
let expected = "\
Union\
\n Projection: order_id\
\n Projection: orders.order_id\
\n TableScan: orders\
\n Projection: order_id\
\n Projection: orders.order_id\
\n TableScan: orders";
quick_test(sql, expected);
}

Expand Down
Loading