Skip to content

Commit 0e7c26b

Browse files
authored
perf: Use DataFusion's count_udaf instead of SUM(IF(expr IS NOT NULL, 1, 0)) (#2407)
1 parent 3bb1b40 commit 0e7c26b

File tree

3 files changed

+7
-53
lines changed

3 files changed

+7
-53
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ dev/dist
1717
apache-rat-*.jar
1818
venv
1919
dev/release/comet-rm/workdir
20+
spark/benchmarks

native/core/src/execution/planner.rs

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use crate::{
3030
use arrow::compute::CastOptions;
3131
use arrow::datatypes::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION};
3232
use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf};
33+
use datafusion::functions_aggregate::count::count_udaf;
3334
use datafusion::functions_aggregate::min_max::max_udaf;
3435
use datafusion::functions_aggregate::min_max::min_udaf;
3536
use datafusion::functions_aggregate::sum::sum_udaf;
@@ -1904,35 +1905,13 @@ impl PhysicalPlanner {
19041905
match spark_expr.expr_struct.as_ref().unwrap() {
19051906
AggExprStruct::Count(expr) => {
19061907
assert!(!expr.children.is_empty());
1907-
// Using `count_udaf` from Comet is exceptionally slow for some reason, so
1908-
// as a workaround we translate it to `SUM(IF(expr IS NOT NULL, 1, 0))`
1909-
// https://github.com/apache/datafusion-comet/issues/744
1910-
19111908
let children = expr
19121909
.children
19131910
.iter()
19141911
.map(|child| self.create_expr(child, Arc::clone(&schema)))
19151912
.collect::<Result<Vec<_>, _>>()?;
19161913

1917-
// create `IS NOT NULL expr` and join them with `AND` if there are multiple
1918-
let not_null_expr: Arc<dyn PhysicalExpr> = children.iter().skip(1).fold(
1919-
Arc::new(IsNotNullExpr::new(Arc::clone(&children[0]))) as Arc<dyn PhysicalExpr>,
1920-
|acc, child| {
1921-
Arc::new(BinaryExpr::new(
1922-
acc,
1923-
DataFusionOperator::And,
1924-
Arc::new(IsNotNullExpr::new(Arc::clone(child))),
1925-
))
1926-
},
1927-
);
1928-
1929-
let child = Arc::new(IfExpr::new(
1930-
not_null_expr,
1931-
Arc::new(Literal::new(ScalarValue::Int64(Some(1)))),
1932-
Arc::new(Literal::new(ScalarValue::Int64(Some(0)))),
1933-
));
1934-
1935-
AggregateExprBuilder::new(sum_udaf(), vec![child])
1914+
AggregateExprBuilder::new(count_udaf(), children)
19361915
.schema(schema)
19371916
.alias("count")
19381917
.with_ignore_nulls(false)

spark/src/test/scala/org/apache/spark/sql/benchmark/CometAggregateBenchmark.scala

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
6464
spark.sql(query).noop()
6565
}
6666

67-
benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") { _ =>
68-
withSQLConf(CometConf.COMET_ENABLED.key -> "true") {
69-
spark.sql(query).noop()
70-
}
71-
}
72-
73-
benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) ($aggregateFunction)") { _ =>
67+
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
7468
withSQLConf(
7569
CometConf.COMET_ENABLED.key -> "true",
7670
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -111,13 +105,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
111105
spark.sql(query).noop()
112106
}
113107

114-
benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") { _ =>
115-
withSQLConf(CometConf.COMET_ENABLED.key -> "true") {
116-
spark.sql(query).noop()
117-
}
118-
}
119-
120-
benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) ($aggregateFunction)") { _ =>
108+
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
121109
withSQLConf(
122110
CometConf.COMET_ENABLED.key -> "true",
123111
CometConf.COMET_EXEC_ENABLED.key -> "true") {
@@ -153,15 +141,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
153141
spark.sql(query).noop()
154142
}
155143

156-
benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") { _ =>
157-
withSQLConf(
158-
CometConf.COMET_ENABLED.key -> "true",
159-
CometConf.COMET_MEMORY_OVERHEAD.key -> "1G") {
160-
spark.sql(query).noop()
161-
}
162-
}
163-
164-
benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) ($aggregateFunction)") { _ =>
144+
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
165145
withSQLConf(
166146
CometConf.COMET_ENABLED.key -> "true",
167147
CometConf.COMET_EXEC_ENABLED.key -> "true",
@@ -198,13 +178,7 @@ object CometAggregateBenchmark extends CometBenchmarkBase {
198178
spark.sql(query).noop()
199179
}
200180

201-
benchmark.addCase(s"SQL Parquet - Comet (Scan) ($aggregateFunction)") { _ =>
202-
withSQLConf(CometConf.COMET_ENABLED.key -> "true") {
203-
spark.sql(query).noop()
204-
}
205-
}
206-
207-
benchmark.addCase(s"SQL Parquet - Comet (Scan, Exec) ($aggregateFunction)") { _ =>
181+
benchmark.addCase(s"SQL Parquet - Comet ($aggregateFunction)") { _ =>
208182
withSQLConf(
209183
CometConf.COMET_ENABLED.key -> "true",
210184
CometConf.COMET_EXEC_ENABLED.key -> "true") {

0 commit comments

Comments
 (0)