-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fix the schema mismatch between logical and physical for aggregate function, add AggregateUDFImpl::is_null
#11989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
aed01f0
cbfefc6
b3fc2c8
20d0a5f
1132686
611092e
e732adc
ab38a5a
1d299eb
19a1ac7
984ced7
9b75540
6361bc4
794ce12
cb63514
9c12566
a42654c
e45d1bb
83ce363
3519e75
da30827
356faa8
043c332
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -670,6 +670,10 @@ impl DefaultPhysicalPlanner { | |
| let input_exec = children.one()?; | ||
| let physical_input_schema = input_exec.schema(); | ||
| let logical_input_schema = input.as_ref().schema(); | ||
| let physical_input_schema_from_logical: Arc<Schema> = | ||
| logical_input_schema.as_ref().clone().into(); | ||
|
|
||
| debug_assert_eq!(physical_input_schema_from_logical, physical_input_schema, "Physical input schema should be the same as the one converted from logical input schema. Please file an issue or send the PR"); | ||
|
|
||
| let groups = self.create_grouping_physical_expr( | ||
| group_expr, | ||
|
|
@@ -1548,7 +1552,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( | |
| e: &Expr, | ||
| name: Option<String>, | ||
| logical_input_schema: &DFSchema, | ||
| _physical_input_schema: &Schema, | ||
| physical_input_schema: &Schema, | ||
| execution_props: &ExecutionProps, | ||
| ) -> Result<AggregateExprWithOptionalArgs> { | ||
| match e { | ||
|
|
@@ -1599,11 +1603,10 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( | |
| let ordering_reqs: Vec<PhysicalSortExpr> = | ||
| physical_sort_exprs.clone().unwrap_or(vec![]); | ||
|
|
||
| let schema: Schema = logical_input_schema.clone().into(); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. workaround cleanup |
||
| let agg_expr = | ||
| AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec()) | ||
| .order_by(ordering_reqs.to_vec()) | ||
| .schema(Arc::new(schema)) | ||
| .schema(Arc::new(physical_input_schema.to_owned())) | ||
| .alias(name) | ||
| .with_ignore_nulls(ignore_nulls) | ||
| .with_distinct(*distinct) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -328,10 +328,45 @@ impl ExprSchemable for Expr { | |
| Ok(true) | ||
| } | ||
| } | ||
| Expr::WindowFunction(WindowFunction { fun, .. }) => { | ||
|
||
| match fun { | ||
| WindowFunctionDefinition::BuiltInWindowFunction(func) => { | ||
| if func.name() == "ROW_NUMBER" | ||
| || func.name() == "RANK" | ||
| || func.name() == "NTILE" | ||
| || func.name() == "CUME_DIST" | ||
| { | ||
| Ok(false) | ||
| } else { | ||
| Ok(true) | ||
| } | ||
| } | ||
| WindowFunctionDefinition::AggregateUDF(func) => { | ||
| // TODO: UDF should be able to customize nullability | ||
| if func.name() == "count" { | ||
| // TODO: there is issue unsolved for count with window, should return false | ||
|
||
| Ok(true) | ||
| } else { | ||
| Ok(true) | ||
| } | ||
| } | ||
| _ => Ok(true), | ||
| } | ||
| } | ||
| Expr::ScalarFunction(ScalarFunction { func, args }) => { | ||
| // If all the element in coalesce is non-null, the result is non-null | ||
|
||
| if func.name() == "coalesce" | ||
| && args | ||
| .iter() | ||
| .all(|e| !e.nullable(input_schema).ok().unwrap_or(true)) | ||
| { | ||
| return Ok(false); | ||
| } | ||
|
|
||
| Ok(true) | ||
| } | ||
| Expr::ScalarVariable(_, _) | ||
| | Expr::TryCast { .. } | ||
| | Expr::ScalarFunction(..) | ||
| | Expr::WindowFunction { .. } | ||
| | Expr::Unnest(_) | ||
| | Expr::Placeholder(_) => Ok(true), | ||
| Expr::IsNull(_) | ||
|
|
@@ -443,6 +478,7 @@ impl ExprSchemable for Expr { | |
| match self { | ||
| Expr::Column(c) => { | ||
| let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; | ||
| // println!("data type: {:?}, nullable: {:?}", data_type, nullable); | ||
| Ok(( | ||
| c.relation.clone(), | ||
| Field::new(&c.name, data_type, nullable) | ||
|
|
@@ -452,6 +488,7 @@ impl ExprSchemable for Expr { | |
| } | ||
| Expr::Alias(Alias { relation, name, .. }) => { | ||
| let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; | ||
| // println!("alias {} data type: {:?}, nullable: {:?}", self, data_type, nullable); | ||
| Ok(( | ||
| relation.clone(), | ||
| Field::new(name, data_type, nullable) | ||
|
|
@@ -461,6 +498,7 @@ impl ExprSchemable for Expr { | |
| } | ||
| _ => { | ||
| let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; | ||
| // println!("otehr data type: {:?}, nullable: {:?}", data_type, nullable); | ||
| Ok(( | ||
| None, | ||
| Field::new(self.schema_name().to_string(), data_type, nullable) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -196,6 +196,10 @@ impl AggregateUDF { | |
| self.inner.state_fields(args) | ||
| } | ||
|
|
||
| pub fn fields(&self, args: StateFieldsArgs) -> Result<Field> { | ||
|
||
| self.inner.field(args) | ||
| } | ||
|
|
||
| /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. | ||
| pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { | ||
| self.inner.groups_accumulator_supported(args) | ||
|
|
@@ -383,6 +387,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { | |
| .collect()) | ||
| } | ||
|
|
||
| fn field(&self, args: StateFieldsArgs) -> Result<Field> { | ||
| Ok(Field::new(args.name, args.return_type.clone(), true)) | ||
| } | ||
|
|
||
| /// If the aggregate expression has a specialized | ||
| /// [`GroupsAccumulator`] implementation. If this returns true, | ||
| /// `[Self::create_groups_accumulator]` will be called. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -171,6 +171,9 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq<dyn Any> { | |
| fn get_minmax_desc(&self) -> Option<(Field, bool)> { | ||
| None | ||
| } | ||
|
|
||
| /// Get function's name, for example `count(x)` returns `count` | ||
| fn func_name(&self) -> &str; | ||
|
||
| } | ||
|
|
||
| /// Stores the physical expressions used inside the `AggregateExpr`. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,6 @@ | |
|
|
||
| //! Optimizer rule for type validation and coercion | ||
|
|
||
| use std::collections::HashMap; | ||
| use std::sync::Arc; | ||
|
|
||
| use itertools::izip; | ||
|
|
@@ -821,9 +820,18 @@ fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> { | |
| .iter() | ||
| .map(|f| f.is_nullable()) | ||
| .collect::<Vec<_>>(); | ||
| let mut union_field_meta = base_schema | ||
| .fields() | ||
| .iter() | ||
| .map(|f| f.metadata().clone()) | ||
| .collect::<Vec<_>>(); | ||
|
|
||
| let mut metadata = base_schema.metadata().clone(); | ||
|
|
||
| for (i, plan) in inputs.iter().enumerate().skip(1) { | ||
| let plan_schema = plan.schema(); | ||
| metadata.extend(plan_schema.metadata().clone()); | ||
|
|
||
| if plan_schema.fields().len() != base_schema.fields().len() { | ||
| return plan_err!( | ||
| "Union schemas have different number of fields: \ | ||
|
|
@@ -833,39 +841,47 @@ fn coerce_union_schema(inputs: &[Arc<LogicalPlan>]) -> Result<DFSchema> { | |
| plan_schema.fields().len() | ||
| ); | ||
| } | ||
| // coerce data type and nullablity for each field | ||
| for (union_datatype, union_nullable, plan_field) in izip!( | ||
| union_datatypes.iter_mut(), | ||
| union_nullabilities.iter_mut(), | ||
| plan_schema.fields() | ||
| ) { | ||
| let coerced_type = | ||
| comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else( | ||
| || { | ||
| plan_datafusion_err!( | ||
| "Incompatible inputs for Union: Previous inputs were \ | ||
| of type {}, but got incompatible type {} on column '{}'", | ||
| union_datatype, | ||
| plan_field.data_type(), | ||
| plan_field.name() | ||
| ) | ||
| }, | ||
| )?; | ||
| *union_datatype = coerced_type; | ||
| *union_nullable = *union_nullable || plan_field.is_nullable(); | ||
|
|
||
| // Safety: Length is checked | ||
| unsafe { | ||
|
||
| // coerce data type and nullablity for each field | ||
| for (i, plan_field) in plan_schema.fields().iter().enumerate() { | ||
| let union_datatype = union_datatypes.get_unchecked_mut(i); | ||
| let union_nullable = union_nullabilities.get_unchecked_mut(i); | ||
| let union_field_map = union_field_meta.get_unchecked_mut(i); | ||
|
|
||
| let coerced_type = | ||
| comparison_coercion(union_datatype, plan_field.data_type()) | ||
| .ok_or_else(|| { | ||
| plan_datafusion_err!( | ||
| "Incompatible inputs for Union: Previous inputs were \ | ||
| of type {}, but got incompatible type {} on column '{}'", | ||
| union_datatype, | ||
| plan_field.data_type(), | ||
| plan_field.name() | ||
| ) | ||
| })?; | ||
|
|
||
| *union_datatype = coerced_type; | ||
| *union_nullable = *union_nullable || plan_field.is_nullable(); | ||
| union_field_map.extend(plan_field.metadata().clone()); | ||
| } | ||
| } | ||
| } | ||
| let union_qualified_fields = izip!( | ||
| base_schema.iter(), | ||
| union_datatypes.into_iter(), | ||
| union_nullabilities | ||
| union_nullabilities, | ||
| union_field_meta.into_iter() | ||
| ) | ||
| .map(|((qualifier, field), datatype, nullable)| { | ||
| let field = Arc::new(Field::new(field.name().clone(), datatype, nullable)); | ||
| (qualifier.cloned(), field) | ||
| .map(|((qualifier, field), datatype, nullable, metadata)| { | ||
| let mut field = Field::new(field.name().clone(), datatype, nullable); | ||
| field.set_metadata(metadata); | ||
| (qualifier.cloned(), field.into()) | ||
| }) | ||
| .collect::<Vec<_>>(); | ||
| DFSchema::new_with_metadata(union_qualified_fields, HashMap::new()) | ||
|
|
||
| DFSchema::new_with_metadata(union_qualified_fields, metadata) | ||
| } | ||
|
|
||
| /// See `<https://github.com/apache/datafusion/pull/2108>` | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,6 +80,14 @@ impl WindowExpr for PlainAggregateWindowExpr { | |
| } | ||
|
|
||
| fn field(&self) -> Result<Field> { | ||
| // TODO: Fix window function to always return non-null for count | ||
|
||
| if let Ok(name) = self.func_name() { | ||
| if name == "count" { | ||
| let field = self.aggregate.field()?; | ||
| return Ok(field.with_nullable(true)); | ||
| } | ||
| } | ||
|
|
||
| self.aggregate.field() | ||
| } | ||
|
|
||
|
|
@@ -157,6 +165,10 @@ impl WindowExpr for PlainAggregateWindowExpr { | |
| fn uses_bounded_memory(&self) -> bool { | ||
| !self.window_frame.end_bound.is_unbounded() | ||
| } | ||
|
|
||
| fn func_name(&self) -> Result<&str> { | ||
| Ok(self.aggregate.func_name()) | ||
| } | ||
| } | ||
|
|
||
| impl AggregateWindowExpr for PlainAggregateWindowExpr { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,7 @@ use arrow::compute::SortOptions; | |
| use arrow::datatypes::Field; | ||
| use arrow::record_batch::RecordBatch; | ||
| use datafusion_common::utils::evaluate_partition_ranges; | ||
| use datafusion_common::{Result, ScalarValue}; | ||
| use datafusion_common::{not_impl_err, Result, ScalarValue}; | ||
| use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; | ||
| use datafusion_expr::WindowFrame; | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move them back to the top? |
||
|
|
@@ -97,6 +97,10 @@ impl BuiltInWindowExpr { | |
| } | ||
|
|
||
| impl WindowExpr for BuiltInWindowExpr { | ||
| fn func_name(&self) -> Result<&str> { | ||
| not_impl_err!("function name not determined") | ||
|
||
| } | ||
|
|
||
| /// Return a reference to Any that can be used for downcasting | ||
| fn as_any(&self) -> &dyn Any { | ||
| self | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main goal of the change is to ensure they are the same. And, we pass
physical_input_schemathrough the function that require input's schema.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Did you consider making this function return an
internal_errorrather than debug_assert ?If we are concerned about breaking existing tests, we could add a config setting like
datafusion.optimizer.skip_failed_rulesto let users bypass the checkUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The objective here is to ensure that the logical schema from
ExprSchemableand the physical schema fromExecutionPlan.schema()are equivalent. if they are not, it indicates a potential schema mismatch issue. This is also why you can see the code change in this PR are mostly fixing schema related things and they are all required thus I don't think we should let user bypass the check 🤔If we encounter inconsistent schemas, it raises an important question: Which schema should we use?
It looks good to me