Skip to content

Commit fd0ab64

Browse files
authored
feat: Support ANSI mode SUM (Decimal types) (#2826)
1 parent 0bda9d2 commit fd0ab64

File tree

6 files changed

+415
-169
lines changed

6 files changed

+415
-169
lines changed

native/core/src/execution/planner.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2022,7 +2022,9 @@ impl PhysicalPlanner {
20222022

20232023
let builder = match datatype {
20242024
DataType::Decimal128(_, _) => {
2025-
let func = AggregateUDF::new_from_impl(SumDecimal::try_new(datatype)?);
2025+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
2026+
let func =
2027+
AggregateUDF::new_from_impl(SumDecimal::try_new(datatype, eval_mode)?);
20262028
AggregateExprBuilder::new(Arc::new(func), vec![child])
20272029
}
20282030
_ => {

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ message Count {
120120
message Sum {
121121
Expr child = 1;
122122
DataType datatype = 2;
123-
bool fail_on_error = 3;
123+
EvalMode eval_mode = 3;
124124
}
125125

126126
message Min {

native/spark-expr/benches/aggregate.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ use datafusion::physical_expr::expressions::Column;
3131
use datafusion::physical_expr::PhysicalExpr;
3232
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
3333
use datafusion::physical_plan::ExecutionPlan;
34-
use datafusion_comet_spark_expr::AvgDecimal;
3534
use datafusion_comet_spark_expr::SumDecimal;
35+
use datafusion_comet_spark_expr::{AvgDecimal, EvalMode};
3636
use futures::StreamExt;
3737
use std::hint::black_box;
3838
use std::sync::Arc;
@@ -97,7 +97,7 @@ fn criterion_benchmark(c: &mut Criterion) {
9797

9898
group.bench_function("sum_decimal_comet", |b| {
9999
let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl(
100-
SumDecimal::try_new(DataType::Decimal128(38, 10)).unwrap(),
100+
SumDecimal::try_new(DataType::Decimal128(38, 10), EvalMode::Legacy).unwrap(),
101101
));
102102
b.to_async(&rt).iter(|| {
103103
black_box(agg_test(

0 commit comments

Comments
 (0)