Skip to content

Commit 22ceda9

Browse files
author
Davies Liu
committed
generate Expand
1 parent 0d50a22 commit 22ceda9

2 files changed

Lines changed: 94 additions & 1 deletion

File tree

sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import scala.collection.immutable.IndexedSeq
21+
2022
import org.apache.spark.rdd.RDD
2123
import org.apache.spark.sql.catalyst.InternalRow
2224
import org.apache.spark.sql.catalyst.errors._
2325
import org.apache.spark.sql.catalyst.expressions._
26+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2427
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
28+
import org.apache.spark.sql.execution.metric.SQLMetrics
2529

2630
/**
2731
* Apply the all of the GroupExpressions to every input row, hence we will get
@@ -35,7 +39,10 @@ case class Expand(
3539
projections: Seq[Seq[Expression]],
3640
output: Seq[Attribute],
3741
child: SparkPlan)
38-
extends UnaryNode {
42+
extends UnaryNode with CodegenSupport {
43+
44+
private[sql] override lazy val metrics = Map(
45+
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
3946

4047
// The GroupExpressions can output data with arbitrary partitioning, so set it
4148
// as UNKNOWN partitioning
@@ -48,6 +55,8 @@ case class Expand(
4855
(exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output)
4956

5057
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
58+
val numOutputRows = longMetric("numOutputRows")
59+
5160
child.execute().mapPartitions { iter =>
5261
val groups = projections.map(projection).toArray
5362
new Iterator[InternalRow] {
@@ -71,9 +80,76 @@ case class Expand(
7180
idx = 0
7281
}
7382

83+
numOutputRows += 1
7484
result
7585
}
7686
}
7787
}
7888
}
89+
90+
override def upstream(): RDD[InternalRow] = {
91+
child.asInstanceOf[CodegenSupport].upstream()
92+
}
93+
94+
protected override def doProduce(ctx: CodegenContext): String = {
95+
child.asInstanceOf[CodegenSupport].produce(ctx, this)
96+
}
97+
98+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
99+
val uniqExprs: IndexedSeq[Set[Expression]] = output.indices.map { i =>
100+
projections.map(p => p(i)).toSet
101+
}
102+
103+
ctx.currentVars = input
104+
val resultVars = uniqExprs.zipWithIndex.map { case (exprs, i) =>
105+
val expr = exprs.head
106+
if (exprs.size == 1) {
107+
// it's common to have same expression for some columns in all the projections, for example,
108+
// GroupingSet will copy all the output from child as the first part of output.
109+
// We should only generate the columns once.
110+
BindReferences.bindReference(expr, child.output).gen(ctx)
111+
} else {
112+
val isNull = ctx.freshName("isNull")
113+
val value = ctx.freshName("value")
114+
val code =
115+
s"""
116+
|boolean $isNull = true;
117+
|${ctx.javaType(expr.dataType)} $value = ${ctx.defaultValue(expr.dataType)};
118+
""".stripMargin
119+
ExprCode(code, isNull, value)
120+
}
121+
}
122+
123+
// In order to prevent code exploration, we can't call `consume()` many times, so we call
124+
// that in a loop, and use swith/case to select the projections.
125+
val projectCodes = projections.zipWithIndex.map { case (exprs, i) =>
126+
val need = exprs.zipWithIndex.filter { case (e, j) =>
127+
uniqExprs(j).size > 1
128+
}
129+
val updates = need.map { case (e, j) =>
130+
val ev = BindReferences.bindReference(e, child.output).gen(ctx)
131+
s"""
132+
|${ev.code}
133+
|${resultVars(j).isNull} = ${ev.isNull};
134+
|${resultVars(j).value} = ${ev.value};
135+
""".stripMargin
136+
}
137+
s"""
138+
|case $i:
139+
| ${updates.mkString("\n").trim}
140+
| break;
141+
""".stripMargin
142+
}
143+
144+
val i = ctx.freshName("i")
145+
s"""
146+
|${resultVars.map(_.code).mkString("\n").trim}
147+
|for (int $i = 0; $i < ${projections.length}; $i ++) {
148+
| switch ($i) {
149+
| ${projectCodes.mkString("\n").trim}
150+
| }
151+
| ${consume(ctx, resultVars)}
152+
|}
153+
""".stripMargin
154+
}
79155
}

sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,23 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
157157

158158
}
159159

160+
ignore("rube") {
161+
val N = 5 << 20
162+
163+
runBenchmark("cube", N) {
164+
sqlContext.range(N).selectExpr("id", "id % 1000 as k1", "id & 256 as k2")
165+
.cube("k1", "k2").sum("id").collect()
166+
}
167+
168+
/**
169+
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
170+
cube: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
171+
-------------------------------------------------------------------------------------------
172+
cube codegen=false 3188 / 3392 1.6 608.2 1.0X
173+
cube codegen=true 1239 / 1394 4.2 236.3 2.6X
174+
*/
175+
}
176+
160177
ignore("hash and BytesToBytesMap") {
161178
val N = 50 << 20
162179

0 commit comments

Comments
 (0)