Skip to content

Commit d009ee8

Browse files
Respect IGNORE NULLS flag in ARRAY_AGG (#260/15544)
1 parent 7952308 commit d009ee8

File tree

1 file changed

+69
-16
lines changed

1 file changed

+69
-16
lines changed

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
20-
use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, StructArray};
20+
use arrow::array::{
21+
new_empty_array, Array, ArrayRef, AsArray, BooleanArray, StructArray,
22+
};
23+
use arrow::compute::filter;
2124
use arrow::datatypes::DataType;
2225

2326
use arrow_schema::{Field, Fields};
@@ -116,13 +119,21 @@ impl AggregateUDFImpl for ArrayAgg {
116119

117120
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
118121
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
122+
let ignore_nulls =
123+
acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;
119124

120125
if acc_args.is_distinct {
121-
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(&data_type)?));
126+
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
127+
&data_type,
128+
acc_args.ignore_nulls,
129+
)?));
122130
}
123131

124132
if acc_args.ordering_req.is_empty() {
125-
return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?));
133+
return Ok(Box::new(ArrayAggAccumulator::try_new(
134+
&data_type,
135+
acc_args.ignore_nulls,
136+
)?));
126137
}
127138

128139
let ordering_dtypes = acc_args
@@ -136,6 +147,7 @@ impl AggregateUDFImpl for ArrayAgg {
136147
&ordering_dtypes,
137148
acc_args.ordering_req.to_vec(),
138149
acc_args.is_reversed,
150+
ignore_nulls,
139151
)
140152
.map(|acc| Box::new(acc) as _)
141153
}
@@ -178,14 +190,16 @@ fn get_array_agg_doc() -> &'static Documentation {
178190
pub struct ArrayAggAccumulator {
179191
values: Vec<ArrayRef>,
180192
datatype: DataType,
193+
ignore_nulls: bool,
181194
}
182195

183196
impl ArrayAggAccumulator {
184197
/// new array_agg accumulator based on given item data type
185-
pub fn try_new(datatype: &DataType) -> Result<Self> {
198+
pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
186199
Ok(Self {
187200
values: vec![],
188201
datatype: datatype.clone(),
202+
ignore_nulls,
189203
})
190204
}
191205
}
@@ -201,10 +215,23 @@ impl Accumulator for ArrayAggAccumulator {
201215
return internal_err!("expects single batch");
202216
}
203217

204-
let val = Arc::clone(&values[0]);
205-
if val.len() > 0 {
218+
let val = &values[0];
219+
let nulls = if self.ignore_nulls {
220+
val.logical_nulls()
221+
} else {
222+
None
223+
};
224+
225+
let val = match nulls {
226+
Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
227+
Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
228+
None => Arc::clone(val),
229+
};
230+
231+
if !val.is_empty() {
206232
self.values.push(val);
207233
}
234+
208235
Ok(())
209236
}
210237

@@ -261,13 +288,15 @@ impl Accumulator for ArrayAggAccumulator {
261288
struct DistinctArrayAggAccumulator {
262289
values: HashSet<ScalarValue>,
263290
datatype: DataType,
291+
ignore_nulls: bool,
264292
}
265293

266294
impl DistinctArrayAggAccumulator {
267-
pub fn try_new(datatype: &DataType) -> Result<Self> {
295+
pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
268296
Ok(Self {
269297
values: HashSet::new(),
270298
datatype: datatype.clone(),
299+
ignore_nulls,
271300
})
272301
}
273302
}
@@ -282,11 +311,20 @@ impl Accumulator for DistinctArrayAggAccumulator {
282311
return internal_err!("expects single batch");
283312
}
284313

285-
let array = &values[0];
314+
let val = &values[0];
315+
let nulls = if self.ignore_nulls {
316+
val.logical_nulls()
317+
} else {
318+
None
319+
};
286320

287-
for i in 0..array.len() {
288-
let scalar = ScalarValue::try_from_array(&array, i)?;
289-
self.values.insert(scalar);
321+
let nulls = nulls.as_ref();
322+
if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
323+
for i in 0..val.len() {
324+
if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
325+
self.values.insert(ScalarValue::try_from_array(val, i)?);
326+
}
327+
}
290328
}
291329

292330
Ok(())
@@ -344,6 +382,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
344382
ordering_req: LexOrdering,
345383
/// Whether the aggregation is running in reverse.
346384
reverse: bool,
385+
/// Whether the aggregation should ignore null values.
386+
ignore_nulls: bool,
347387
}
348388

349389
impl OrderSensitiveArrayAggAccumulator {
@@ -354,6 +394,7 @@ impl OrderSensitiveArrayAggAccumulator {
354394
ordering_dtypes: &[DataType],
355395
ordering_req: LexOrdering,
356396
reverse: bool,
397+
ignore_nulls: bool,
357398
) -> Result<Self> {
358399
let mut datatypes = vec![datatype.clone()];
359400
datatypes.extend(ordering_dtypes.iter().cloned());
@@ -363,6 +404,7 @@ impl OrderSensitiveArrayAggAccumulator {
363404
datatypes,
364405
ordering_req,
365406
reverse,
407+
ignore_nulls,
366408
})
367409
}
368410
}
@@ -373,11 +415,22 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
373415
return Ok(());
374416
}
375417

376-
let n_row = values[0].len();
377-
for index in 0..n_row {
378-
let row = get_row_at_idx(values, index)?;
379-
self.values.push(row[0].clone());
380-
self.ordering_values.push(row[1..].to_vec());
418+
let val = &values[0];
419+
let ord = &values[1..];
420+
let nulls = if self.ignore_nulls {
421+
val.logical_nulls()
422+
} else {
423+
None
424+
};
425+
426+
let nulls = nulls.as_ref();
427+
if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
428+
for i in 0..val.len() {
429+
if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
430+
self.values.push(ScalarValue::try_from_array(val, i)?);
431+
self.ordering_values.push(get_row_at_idx(ord, i)?)
432+
}
433+
}
381434
}
382435

383436
Ok(())

0 commit comments

Comments
 (0)