Skip to content

Commit 3519e75

Browse files
committed
fix failed test from apache#12050
Signed-off-by: jayzhan211 <[email protected]>
1 parent 83ce363 commit 3519e75

5 files changed

Lines changed: 33 additions & 14 deletions

File tree

datafusion/expr/src/expr_schema.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,14 @@ impl ExprSchemable for Expr {
338338
Expr::ScalarFunction(ScalarFunction { func, args }) => {
339339
Ok(func.is_nullable(args, input_schema))
340340
}
341-
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
342-
Ok(func.is_nullable())
341+
Expr::AggregateFunction(AggregateFunction { func, args, .. }) => {
342+
let nullables = args
343+
.iter()
344+
.map(|e| e.nullable(input_schema))
345+
.collect::<Result<Vec<_>>>()?;
346+
Ok(func.is_nullable(&nullables))
343347
}
344-
Expr::WindowFunction(WindowFunction { fun, .. }) => match fun {
348+
Expr::WindowFunction(WindowFunction { fun, args, .. }) => match fun {
345349
WindowFunctionDefinition::BuiltInWindowFunction(func) => {
346350
if func.name() == "RANK"
347351
|| func.name() == "NTILE"
@@ -352,7 +356,13 @@ impl ExprSchemable for Expr {
352356
Ok(true)
353357
}
354358
}
355-
WindowFunctionDefinition::AggregateUDF(func) => Ok(func.is_nullable()),
359+
WindowFunctionDefinition::AggregateUDF(func) => {
360+
let nullables = args
361+
.iter()
362+
.map(|e| e.nullable(input_schema))
363+
.collect::<Result<Vec<_>>>()?;
364+
Ok(func.is_nullable(&nullables))
365+
}
356366
WindowFunctionDefinition::WindowUDF(udwf) => Ok(udwf.nullable()),
357367
},
358368
Expr::ScalarVariable(_, _)

datafusion/expr/src/udaf.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ impl AggregateUDF {
163163
self.inner.name()
164164
}
165165

166-
pub fn is_nullable(&self) -> bool {
167-
self.inner.is_nullable()
166+
pub fn is_nullable(&self, nullables: &[bool]) -> bool {
167+
self.inner.is_nullable(nullables)
168168
}
169169

170170
/// Returns the aliases for this function.
@@ -355,8 +355,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
355355
///
356356
/// Nullable means that that the function could return `null` for any inputs.
357357
/// For example, aggregate functions like `COUNT` always return a non null value
358-
/// but others like `MIN` will return `NULL` if there is no non null input.
359-
fn is_nullable(&self) -> bool {
358+
/// but others like `MIN` will return `NULL` if there is nullable input.
359+
fn is_nullable(&self, _nullables: &[bool]) -> bool {
360360
true
361361
}
362362

datafusion/functions-aggregate/src/count.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ impl AggregateUDFImpl for Count {
121121
Ok(DataType::Int64)
122122
}
123123

124-
fn is_nullable(&self) -> bool {
124+
// Count is always nullable regardless of the input nullability
125+
fn is_nullable(&self, _nullables: &[bool]) -> bool {
125126
false
126127
}
127128

datafusion/functions/src/core/arrow_cast.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ use datafusion_common::{
2626
};
2727

2828
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
29-
use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility};
29+
use datafusion_expr::{
30+
ColumnarValue, Expr, ExprSchemable, ScalarUDFImpl, Signature, Volatility,
31+
};
3032

3133
/// Implements casting to arbitrary arrow types (rather than SQL types)
3234
///
@@ -87,6 +89,10 @@ impl ScalarUDFImpl for ArrowCastFunc {
8789
internal_err!("arrow_cast should return type from exprs")
8890
}
8991

92+
fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
93+
args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true))
94+
}
95+
9096
fn return_type_from_exprs(
9197
&self,
9298
args: &[Expr],

datafusion/physical-expr-functions-aggregate/src/aggregate.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,10 @@ pub struct AggregateExprBuilder {
5656
is_distinct: bool,
5757
/// Whether the expression is reversed
5858
is_reversed: bool,
59-
is_nullable: bool,
6059
}
6160

6261
impl AggregateExprBuilder {
6362
pub fn new(fun: Arc<AggregateUDF>, args: Vec<Arc<dyn PhysicalExpr>>) -> Self {
64-
let is_nullable = fun.is_nullable();
6563
Self {
6664
fun,
6765
args,
@@ -71,7 +69,6 @@ impl AggregateExprBuilder {
7169
ignore_nulls: false,
7270
is_distinct: false,
7371
is_reversed: false,
74-
is_nullable,
7572
}
7673
}
7774

@@ -85,7 +82,6 @@ impl AggregateExprBuilder {
8582
ignore_nulls,
8683
is_distinct,
8784
is_reversed,
88-
is_nullable,
8985
} = self;
9086
if args.is_empty() {
9187
return internal_err!("args should not be empty");
@@ -107,13 +103,19 @@ impl AggregateExprBuilder {
107103
.map(|arg| arg.data_type(&schema))
108104
.collect::<Result<Vec<_>>>()?;
109105

106+
let input_nullables = args
107+
.iter()
108+
.map(|arg| arg.nullable(&schema))
109+
.collect::<Result<Vec<_>>>()?;
110+
110111
check_arg_count(
111112
fun.name(),
112113
&input_exprs_types,
113114
&fun.signature().type_signature,
114115
)?;
115116

116117
let data_type = fun.return_type(&input_exprs_types)?;
118+
let is_nullable = fun.is_nullable(&input_nullables);
117119
let name = match alias {
118120
// TODO: Ideally, we should build the name from physical expressions
119121
None => create_function_physical_name(fun.name(), is_distinct, &[], None)?,

0 commit comments

Comments
 (0)