Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,12 @@ case class HashAggregateExec(
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
val evaluateResultVars = evaluateVariables(resultVars)
s"""
$evaluateKeyVars
$evaluateBufferVars
$evaluateAggResults
$evaluateResultVars
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. If you replace .distinct() to .groupBy("idx").max() in the example then this code path runs and the change fixes the same issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, could you please add test cases to cover all the code paths you added in this pr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I've added that path to the test.

${consume(ctx, resultVars)}
"""
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
Expand Down Expand Up @@ -497,19 +499,25 @@ case class HashAggregateExec(
val resultVars = bindReferences[Expression](
resultExpressions,
inputAttrs).map(_.genCode(ctx))
val evaluateResultVars = evaluateVariables(resultVars)
s"""
$evaluateKeyVars
$evaluateResultBufferVars
$evaluateResultVars
${consume(ctx, resultVars)}
"""
} else {
// generate result based on grouping key
ctx.INPUT_ROW = keyTerm
ctx.currentVars = null
val eval = bindReferences[Expression](
val resultVars = bindReferences[Expression](
resultExpressions,
groupingAttributes).map(_.genCode(ctx))
consume(ctx, eval)
val evaluateResultVars = evaluateVariables(resultVars)
s"""
$evaluateResultVars
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non broadcast join cases, the change will force evaluation unnecessarily too. We should move evaluation out of the loop in broadcast join, if possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I a bit concern about is; is it semantically ok to defer the evaluation of nondeterministic exprs if HashAggregateExec has these exprs?

I think, to fix this issue, its ok to modify code in the join side if we could find a simpler solution there with no performance regression. But, I have just a question about the design regardless of this issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh.. Kris answered my question.. in #23731 (review)

${consume(ctx, resultVars)}
"""
}
ctx.addNewFunction(funcName,
s"""
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2110,4 +2110,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(res, Row("1-1", 6, 6))
}
}

test("SPARK-26572: fix aggregate codegen result evaluation") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a problem with whole stage codegen, waht about moving this test to WholeStageCodegenSuite? And adding an assert that whole stage codegen is actually used, ie. the HashAggregate is a child of WholeStageCodegenExec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with moving it to WholeStageCodegenSuite but the plan looks like:

*(3) Project [idx#4, id#6L]
+- *(3) BroadcastHashJoin [idx#4], [idx#9], Inner, BuildRight
   :- *(3) HashAggregate(keys=[idx#4], functions=[], output=[idx#4, id#6L])
   :  +- Exchange hashpartitioning(idx#4, 1)
   :     +- *(1) HashAggregate(keys=[idx#4], functions=[], output=[idx#4])
   :        +- *(1) Project [value#1 AS idx#4]
   :           +- LocalTableScan [value#1]
   +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint)))
      +- *(2) Project [value#1 AS idx#9]
         +- LocalTableScan [value#1]

so I guess you mean checking WholeStageCodegenExec has a ProjectExec child that has a BroadcastHashJoinExec child?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved and added WholeStageCodegenExec check.

val baseTable = Seq((1), (1)).toDF("idx")
val distinctWithId =
baseTable.distinct.withColumn("id", functions.monotonically_increasing_id())
val res = baseTable.join(distinctWithId, "idx")
.groupBy("id").count().as("count")
.select("count")
checkAnswer(res, Row(2))
}
}