Skip to content

Commit 930620a

Browse files
authored
Introduce expr_fields to AccumulatorArgs to hold input argument fields (#18100)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #16997 - Part of #11725 - Supersedes #17085 ## Rationale for this change <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> When reviewing #17085 I was very confused by the fix suggested, and tried to understand why `AccumulatorArgs` didn't have easy access to `Field`s of its input expressions, as compared to scalar/window functions which do. Introducing this new field should make it easier for users to grab datatype, metadata, nullability of their input expressions for aggregate functions. ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> Add a slice of `FieldRef` to `AccumulatorArgs` so users don't need to compute the input expression fields themselves via using schema. This addresses #16997 as it was confusing to have only the schema available as there are valid (?) cases where the schema is empty (such as literal only input). This fix differs from #17085 in that it doesn't special case for when there is literal only input; it leaves the physical `schema` provided to `AccumulatorArgs` untouched but provides a more ergonomic (and less confusing) API for users to retrieve `Field`s of their input arguments. - I'm still not sure if the schema being empty for literal only inputs is correct or not, so this might be considered a side step. If we could remove `schema` entirely from `AccumulatorArgs` maybe we wouldn't need to worry about this, but see my comment for why that wasn't done in this PR ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> Existing unit tests. ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> Yes, new field to `AccumulatorArgs` which is publicly exposed (with all it's fields). <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 931ffab commit 930620a

File tree

19 files changed

+126
-51
lines changed

19 files changed

+126
-51
lines changed

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -954,13 +954,7 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
954954
}
955955

956956
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
957-
let input_expr = acc_args
958-
.exprs
959-
.first()
960-
.ok_or(exec_datafusion_err!("Expected one argument"))?;
961-
let input_field = input_expr.return_field(acc_args.schema)?;
962-
963-
let double_output = input_field
957+
let double_output = acc_args.expr_fields[0]
964958
.metadata()
965959
.get("modify_values")
966960
.map(|v| v == "double_output")

datafusion/ffi/src/udaf/accumulator_args.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ impl TryFrom<AccumulatorArgs<'_>> for FFI_AccumulatorArgs {
9797
pub struct ForeignAccumulatorArgs {
9898
pub return_field: FieldRef,
9999
pub schema: Schema,
100+
pub expr_fields: Vec<FieldRef>,
100101
pub ignore_nulls: bool,
101102
pub order_bys: Vec<PhysicalSortExpr>,
102103
pub is_reversed: bool,
@@ -132,9 +133,15 @@ impl TryFrom<FFI_AccumulatorArgs> for ForeignAccumulatorArgs {
132133

133134
let exprs = parse_physical_exprs(&proto_def.expr, &task_ctx, &schema, &codex)?;
134135

136+
let expr_fields = exprs
137+
.iter()
138+
.map(|e| e.return_field(&schema))
139+
.collect::<Result<Vec<_>, _>>()?;
140+
135141
Ok(Self {
136142
return_field,
137143
schema,
144+
expr_fields,
138145
ignore_nulls: proto_def.ignore_nulls,
139146
order_bys,
140147
is_reversed: value.is_reversed,
@@ -150,6 +157,7 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> {
150157
Self {
151158
return_field: Arc::clone(&value.return_field),
152159
schema: &value.schema,
160+
expr_fields: &value.expr_fields,
153161
ignore_nulls: value.ignore_nulls,
154162
order_bys: &value.order_bys,
155163
is_reversed: value.is_reversed,
@@ -175,6 +183,7 @@ mod tests {
175183
let orig_args = AccumulatorArgs {
176184
return_field: Field::new("f", DataType::Float64, true).into(),
177185
schema: &schema,
186+
expr_fields: &[Field::new("a", DataType::Int32, true).into()],
178187
ignore_nulls: false,
179188
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
180189
is_reversed: false,

datafusion/ffi/src/udaf/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ mod tests {
705705
let acc_args = AccumulatorArgs {
706706
return_field: Field::new("f", DataType::Float64, true).into(),
707707
schema: &schema,
708+
expr_fields: &[Field::new("a", DataType::Float64, true).into()],
708709
ignore_nulls: true,
709710
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
710711
is_reversed: false,
@@ -782,6 +783,7 @@ mod tests {
782783
let acc_args = AccumulatorArgs {
783784
return_field: Field::new("f", DataType::Float64, true).into(),
784785
schema: &schema,
786+
expr_fields: &[Field::new("a", DataType::Float64, true).into()],
785787
ignore_nulls: true,
786788
order_bys: &[PhysicalSortExpr::new_default(col("a", &schema)?)],
787789
is_reversed: false,

datafusion/functions-aggregate-common/src/accumulator.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ pub struct AccumulatorArgs<'a> {
3030
/// The return field of the aggregate function.
3131
pub return_field: FieldRef,
3232

33-
/// The schema of the input arguments
33+
/// Input schema to the aggregate function. If you need to check data type, nullability
34+
/// or metadata of input arguments then you should use `expr_fields` below instead.
3435
pub schema: &'a Schema,
3536

3637
/// Whether to ignore nulls.
@@ -67,6 +68,9 @@ pub struct AccumulatorArgs<'a> {
6768

6869
/// The physical expression of arguments the aggregate function takes.
6970
pub exprs: &'a [Arc<dyn PhysicalExpr>],
71+
72+
/// Fields corresponding to each expr (same order & length).
73+
pub expr_fields: &'a [FieldRef],
7074
}
7175

7276
impl AccumulatorArgs<'_> {

datafusion/functions-aggregate/benches/count.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,17 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion};
3333

3434
fn prepare_group_accumulator() -> Box<dyn GroupsAccumulator> {
3535
let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)]));
36+
let expr = col("f", &schema).unwrap();
3637
let accumulator_args = AccumulatorArgs {
3738
return_field: Field::new("f", DataType::Int64, true).into(),
3839
schema: &schema,
40+
expr_fields: &[expr.return_field(&schema).unwrap()],
3941
ignore_nulls: false,
4042
order_bys: &[],
4143
is_reversed: false,
4244
name: "COUNT(f)",
4345
is_distinct: false,
44-
exprs: &[col("f", &schema).unwrap()],
46+
exprs: &[expr],
4547
};
4648
let count_fn = Count::new();
4749

@@ -56,15 +58,17 @@ fn prepare_accumulator() -> Box<dyn Accumulator> {
5658
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
5759
true,
5860
)]));
61+
let expr = col("f", &schema).unwrap();
5962
let accumulator_args = AccumulatorArgs {
6063
return_field: Arc::new(Field::new_list_field(DataType::Int64, true)),
6164
schema: &schema,
65+
expr_fields: &[expr.return_field(&schema).unwrap()],
6266
ignore_nulls: false,
6367
order_bys: &[],
6468
is_reversed: false,
6569
name: "COUNT(f)",
6670
is_distinct: true,
67-
exprs: &[col("f", &schema).unwrap()],
71+
exprs: &[expr],
6872
};
6973
let count_fn = Count::new();
7074

datafusion/functions-aggregate/benches/min_max_bytes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ fn create_max_bytes_accumulator() -> Box<dyn GroupsAccumulator> {
4444
max.create_groups_accumulator(AccumulatorArgs {
4545
return_field: Arc::new(Field::new("value", DataType::Utf8, true)),
4646
schema: &input_schema,
47+
expr_fields: &[Field::new("value", DataType::Utf8, true).into()],
4748
ignore_nulls: true,
4849
order_bys: &[],
4950
is_reversed: false,

datafusion/functions-aggregate/benches/sum.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ fn prepare_accumulator(data_type: &DataType) -> Box<dyn GroupsAccumulator> {
3131
let field = Field::new("f", data_type.clone(), true).into();
3232
let schema = Arc::new(Schema::new(vec![Arc::clone(&field)]));
3333
let accumulator_args = AccumulatorArgs {
34-
return_field: field,
34+
return_field: Arc::clone(&field),
3535
schema: &schema,
36+
expr_fields: &[field],
3637
ignore_nulls: false,
3738
order_bys: &[],
3839
is_reversed: false,

datafusion/functions-aggregate/src/approx_distinct.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ impl AggregateUDFImpl for ApproxDistinct {
361361
}
362362

363363
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
364-
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
364+
let data_type = acc_args.expr_fields[0].data_type();
365365

366366
let accumulator: Box<dyn Accumulator> = match data_type {
367367
// TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL

datafusion/functions-aggregate/src/approx_median.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ impl AggregateUDFImpl for ApproxMedian {
134134

135135
Ok(Box::new(ApproxPercentileAccumulator::new(
136136
0.5_f64,
137-
acc_args.exprs[0].data_type(acc_args.schema)?,
137+
acc_args.expr_fields[0].data_type().clone(),
138138
)))
139139
}
140140

datafusion/functions-aggregate/src/approx_percentile_cont.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ impl ApproxPercentileCont {
187187
None
188188
};
189189

190-
let data_type = args.exprs[0].data_type(args.schema)?;
190+
let data_type = args.expr_fields[0].data_type();
191191
let accumulator: ApproxPercentileAccumulator = match data_type {
192-
t @ (DataType::UInt8
192+
DataType::UInt8
193193
| DataType::UInt16
194194
| DataType::UInt32
195195
| DataType::UInt64
@@ -198,12 +198,11 @@ impl ApproxPercentileCont {
198198
| DataType::Int32
199199
| DataType::Int64
200200
| DataType::Float32
201-
| DataType::Float64) => {
201+
| DataType::Float64 => {
202202
if let Some(max_size) = tdigest_max_size {
203-
ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size)
204-
}else{
205-
ApproxPercentileAccumulator::new(percentile, t)
206-
203+
ApproxPercentileAccumulator::new_with_max_size(percentile, data_type.clone(), max_size)
204+
} else {
205+
ApproxPercentileAccumulator::new(percentile, data_type.clone())
207206
}
208207
}
209208
other => {

0 commit comments

Comments
 (0)