Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 63 additions & 7 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use std::cmp::Ordering;
use std::collections::{HashSet, VecDeque};
use std::mem::{size_of, size_of_val};
use std::mem::{size_of, size_of_val, take};
use std::sync::Arc;

use arrow::array::{
Expand All @@ -31,14 +31,17 @@ use arrow::datatypes::{DataType, Field, FieldRef, Fields};

use datafusion_common::cast::as_list_array;
use datafusion_common::scalar::copy_array_data;
use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
use datafusion_common::utils::{
compare_rows, get_row_at_idx, take_function_args, SingleRowListArrayBuilder,
};
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility,
};
use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays;
use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
use datafusion_functions_aggregate_common::utils::ordering_fields;
use datafusion_macros::user_doc;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
Expand Down Expand Up @@ -78,12 +81,14 @@ This aggregation function can only mix DISTINCT and ORDER BY if the ordering exp
/// ARRAY_AGG aggregate expression
pub struct ArrayAgg {
signature: Signature,
is_input_pre_ordered: bool,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this new field should trigger adding equals/hash_value implementations.
Being fixed in #17065

}

impl Default for ArrayAgg {
fn default() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
is_input_pre_ordered: false,
}
}
}
Expand Down Expand Up @@ -144,6 +149,20 @@ impl AggregateUDFImpl for ArrayAgg {
Ok(fields)
}

fn order_sensitivity(&self) -> AggregateOrderSensitivity {
AggregateOrderSensitivity::Beneficial
}

fn with_beneficial_ordering(
self: Arc<Self>,
beneficial_ordering: bool,
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
Ok(Some(Arc::new(Self {
signature: self.signature.clone(),
is_input_pre_ordered: beneficial_ordering,
})))
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let ignore_nulls =
Expand Down Expand Up @@ -196,6 +215,7 @@ impl AggregateUDFImpl for ArrayAgg {
&data_type,
&ordering_dtypes,
ordering,
self.is_input_pre_ordered,
acc_args.is_reversed,
ignore_nulls,
)
Expand Down Expand Up @@ -512,6 +532,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
datatypes: Vec<DataType>,
/// Stores the ordering requirement of the `Accumulator`.
ordering_req: LexOrdering,
/// Whether the input is known to be pre-ordered
is_input_pre_ordered: bool,
/// Whether the aggregation is running in reverse.
reverse: bool,
/// Whether the aggregation should ignore null values.
Expand All @@ -525,6 +547,7 @@ impl OrderSensitiveArrayAggAccumulator {
datatype: &DataType,
ordering_dtypes: &[DataType],
ordering_req: LexOrdering,
is_input_pre_ordered: bool,
reverse: bool,
ignore_nulls: bool,
) -> Result<Self> {
Expand All @@ -535,11 +558,34 @@ impl OrderSensitiveArrayAggAccumulator {
ordering_values: vec![],
datatypes,
ordering_req,
is_input_pre_ordered,
reverse,
ignore_nulls,
})
}

fn sort(&mut self) {
let sort_options = self
.ordering_req
.iter()
.map(|sort_expr| sort_expr.options)
.collect::<Vec<_>>();
let mut values = take(&mut self.values)
.into_iter()
.zip(take(&mut self.ordering_values))
.collect::<Vec<_>>();
let mut delayed_cmp_err = Ok(());
values.sort_by(|(_, left_ordering), (_, right_ordering)| {
compare_rows(left_ordering, right_ordering, &sort_options).unwrap_or_else(
|err| {
delayed_cmp_err = Err(err);
Ordering::Equal
},
)
});
(self.values, self.ordering_values) = values.into_iter().unzip();
}

fn evaluate_orderings(&self) -> Result<ScalarValue> {
let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);

Expand Down Expand Up @@ -610,9 +656,8 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
// inside `ARRAY_AGG` list, we will receive an `Array` that stores values
// received from its ordering requirement expression. (This information
// is necessary for during merging).
let [array_agg_values, agg_orderings, ..] = &states else {
return exec_err!("State should have two elements");
};
let [array_agg_values, agg_orderings] =
take_function_args("OrderSensitiveArrayAggAccumulator::merge_batch", states)?;
let Some(agg_orderings) = agg_orderings.as_list_opt::<i32>() else {
return exec_err!("Expects to receive a list array");
};
Expand All @@ -623,8 +668,11 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
let mut partition_ordering_values = vec![];

// Existing values should be merged also.
partition_values.push(self.values.clone().into());
partition_ordering_values.push(self.ordering_values.clone().into());
if !self.is_input_pre_ordered {
self.sort();
}
partition_values.push(take(&mut self.values).into());
partition_ordering_values.push(take(&mut self.ordering_values).into());

// Convert array to Scalars to sort them easily. Convert back to array at evaluation.
let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
Expand Down Expand Up @@ -673,13 +721,21 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
if !self.is_input_pre_ordered {
self.sort();
}

let mut result = vec![self.evaluate()?];
result.push(self.evaluate_orderings()?);

Ok(result)
}

fn evaluate(&mut self) -> Result<ScalarValue> {
if !self.is_input_pre_ordered {
self.sort();
}

if self.values.is_empty() {
return Ok(ScalarValue::new_null_list(
self.datatypes[0].clone(),
Expand Down
149 changes: 89 additions & 60 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ use datafusion_physical_expr_common::sort_expr::{
LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortRequirement,
};

use datafusion_expr::utils::AggregateOrderSensitivity;
use itertools::Itertools;

pub(crate) mod group_values;
Expand Down Expand Up @@ -1071,13 +1072,25 @@ fn get_aggregate_expr_req(
aggr_expr: &AggregateFunctionExpr,
group_by: &PhysicalGroupBy,
agg_mode: &AggregateMode,
include_soft_requirement: bool,
) -> Option<LexOrdering> {
// If the aggregation function is ordering requirement is not absolutely
// necessary, or the aggregation is performing a "second stage" calculation,
// then ignore the ordering requirement.
if !aggr_expr.order_sensitivity().hard_requires() || !agg_mode.is_first_stage() {
// If the aggregation is performing a "second stage" calculation,
// then ignore the ordering requirement. Ordering requirement applies
// only to the aggregation input data.
if !agg_mode.is_first_stage() {
return None;
}

match aggr_expr.order_sensitivity() {
AggregateOrderSensitivity::Insensitive => return None,
AggregateOrderSensitivity::HardRequirement => {}
AggregateOrderSensitivity::Beneficial => {
if !include_soft_requirement {
return None;
}
}
}

let mut sort_exprs = aggr_expr.order_bys().to_vec();
// In non-first stage modes, we accumulate data (using `merge_batch`) from
// different partitions (i.e. merge partial results). During this merge, we
Expand Down Expand Up @@ -1142,60 +1155,73 @@ pub fn get_finer_aggregate_exprs_requirement(
agg_mode: &AggregateMode,
) -> Result<Vec<PhysicalSortRequirement>> {
let mut requirement = None;
for aggr_expr in aggr_exprs.iter_mut() {
let Some(aggr_req) = get_aggregate_expr_req(aggr_expr, group_by, agg_mode)
.and_then(|o| eq_properties.normalize_sort_exprs(o))
else {
// There is no aggregate ordering requirement, or it is trivially
// satisfied -- we can skip this expression.
continue;
};
// If the common requirement is finer than the current expression's,
// we can skip this expression. If the latter is finer than the former,
// adopt it if it is satisfied by the equivalence properties. Otherwise,
// defer the analysis to the reverse expression.
let forward_finer = determine_finer(&requirement, &aggr_req);
if let Some(finer) = forward_finer {
if !finer {
continue;
} else if eq_properties.ordering_satisfy(aggr_req.clone())? {
requirement = Some(aggr_req);
continue;
}
}
if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
let Some(rev_aggr_req) =
get_aggregate_expr_req(&reverse_aggr_expr, group_by, agg_mode)
.and_then(|o| eq_properties.normalize_sort_exprs(o))
else {
// The reverse requirement is trivially satisfied -- just reverse
// the expression and continue with the next one:
*aggr_expr = Arc::new(reverse_aggr_expr);

for include_soft_requirement in [false, true] {
for aggr_expr in aggr_exprs.iter_mut() {
let Some(aggr_req) = get_aggregate_expr_req(
aggr_expr,
group_by,
agg_mode,
include_soft_requirement,
)
.and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
// There is no aggregate ordering requirement, or it is trivially
// satisfied -- we can skip this expression.
continue;
};
// If the common requirement is finer than the reverse expression's,
// just reverse it and continue the loop with the next aggregate
// expression. If the latter is finer than the former, adopt it if
// it is satisfied by the equivalence properties. Otherwise, adopt
// the forward expression.
if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
// If the common requirement is finer than the current expression's,
// we can skip this expression. If the latter is finer than the former,
// adopt it if it is satisfied by the equivalence properties. Otherwise,
// defer the analysis to the reverse expression.
let forward_finer = determine_finer(&requirement, &aggr_req);
if let Some(finer) = forward_finer {
if !finer {
continue;
} else if eq_properties.ordering_satisfy(aggr_req.clone())? {
requirement = Some(aggr_req);
continue;
}
}
if let Some(reverse_aggr_expr) = aggr_expr.reverse_expr() {
let Some(rev_aggr_req) = get_aggregate_expr_req(
&reverse_aggr_expr,
group_by,
agg_mode,
include_soft_requirement,
)
.and_then(|o| eq_properties.normalize_sort_exprs(o)) else {
// The reverse requirement is trivially satisfied -- just reverse
// the expression and continue with the next one:
*aggr_expr = Arc::new(reverse_aggr_expr);
} else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
*aggr_expr = Arc::new(reverse_aggr_expr);
requirement = Some(rev_aggr_req);
} else {
continue;
};
// If the common requirement is finer than the reverse expression's,
// just reverse it and continue the loop with the next aggregate
// expression. If the latter is finer than the former, adopt it if
// it is satisfied by the equivalence properties. Otherwise, adopt
// the forward expression.
if let Some(finer) = determine_finer(&requirement, &rev_aggr_req) {
if !finer {
*aggr_expr = Arc::new(reverse_aggr_expr);
} else if eq_properties.ordering_satisfy(rev_aggr_req.clone())? {
*aggr_expr = Arc::new(reverse_aggr_expr);
requirement = Some(rev_aggr_req);
} else {
requirement = Some(aggr_req);
}
} else if forward_finer.is_some() {
requirement = Some(aggr_req);
} else {
// Neither the existing requirement nor the current aggregate
// requirement satisfy the other (forward or reverse), this
// means they are conflicting. This is a problem only for hard
// requirements. Unsatisfied soft requirements can be ignored.
if !include_soft_requirement {
return not_impl_err!(
"Conflicting ordering requirements in aggregate functions is not supported"
);
}
}
} else if forward_finer.is_some() {
requirement = Some(aggr_req);
} else {
// Neither the existing requirement nor the current aggregate
// requirement satisfy the other (forward or reverse), this
// means they are conflicting.
return not_impl_err!(
"Conflicting ordering requirements in aggregate functions is not supported"
);
}
}
}
Expand Down Expand Up @@ -1442,7 +1468,7 @@ mod tests {
use datafusion_execution::config::SessionConfig;
use datafusion_execution::memory_pool::FairSpillPool;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_functions_aggregate::array_agg::array_agg_udaf;
use datafusion_expr::test::function_stub::max_udaf;
use datafusion_functions_aggregate::average::avg_udaf;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
Expand Down Expand Up @@ -2428,13 +2454,16 @@ mod tests {
let mut aggr_exprs = order_by_exprs
.into_iter()
.map(|order_by_expr| {
AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)])
.alias("a")
.order_by(order_by_expr)
.schema(Arc::clone(&test_schema))
.build()
.map(Arc::new)
.unwrap()
AggregateExprBuilder::new(
max_udaf(), // any UDAF not using Beneficial order sensitivity
vec![Arc::clone(col_a)],
)
.alias("a")
.order_by(order_by_expr)
.schema(Arc::clone(&test_schema))
.build()
.map(Arc::new)
.unwrap()
})
.collect::<Vec<_>>();
let group_by = PhysicalGroupBy::new_single(vec![]);
Expand Down
Loading
Loading