1717
1818package org .apache .spark .sql .execution
1919
20+ import scala .collection .immutable .IndexedSeq
21+
2022import org .apache .spark .rdd .RDD
2123import org .apache .spark .sql .catalyst .InternalRow
2224import org .apache .spark .sql .catalyst .errors ._
2325import org .apache .spark .sql .catalyst .expressions ._
26+ import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , ExprCode }
2427import org .apache .spark .sql .catalyst .plans .physical .{Partitioning , UnknownPartitioning }
28+ import org .apache .spark .sql .execution .metric .SQLMetrics
2529
2630/**
2731 * Apply the all of the GroupExpressions to every input row, hence we will get
@@ -35,7 +39,10 @@ case class Expand(
3539 projections : Seq [Seq [Expression ]],
3640 output : Seq [Attribute ],
3741 child : SparkPlan )
38- extends UnaryNode {
42+ extends UnaryNode with CodegenSupport {
43+
44+ private [sql] override lazy val metrics = Map (
45+ " numOutputRows" -> SQLMetrics .createLongMetric(sparkContext, " number of output rows" ))
3946
4047 // The GroupExpressions can output data with arbitrary partitioning, so set it
4148 // as UNKNOWN partitioning
@@ -48,6 +55,8 @@ case class Expand(
4855 (exprs : Seq [Expression ]) => UnsafeProjection .create(exprs, child.output)
4956
5057 protected override def doExecute (): RDD [InternalRow ] = attachTree(this , " execute" ) {
58+ val numOutputRows = longMetric(" numOutputRows" )
59+
5160 child.execute().mapPartitions { iter =>
5261 val groups = projections.map(projection).toArray
5362 new Iterator [InternalRow ] {
@@ -71,9 +80,76 @@ case class Expand(
7180 idx = 0
7281 }
7382
83+ numOutputRows += 1
7484 result
7585 }
7686 }
7787 }
7888 }
89+
90+ override def upstream (): RDD [InternalRow ] = {
91+ child.asInstanceOf [CodegenSupport ].upstream()
92+ }
93+
94+ protected override def doProduce (ctx : CodegenContext ): String = {
95+ child.asInstanceOf [CodegenSupport ].produce(ctx, this )
96+ }
97+
98+ override def doConsume (ctx : CodegenContext , input : Seq [ExprCode ]): String = {
99+ val uniqExprs : IndexedSeq [Set [Expression ]] = output.indices.map { i =>
100+ projections.map(p => p(i)).toSet
101+ }
102+
103+ ctx.currentVars = input
104+ val resultVars = uniqExprs.zipWithIndex.map { case (exprs, i) =>
105+ val expr = exprs.head
106+ if (exprs.size == 1 ) {
107+ // it's common to have same expression for some columns in all the projections, for example,
108+ // GroupingSet will copy all the output from child as the first part of output.
109+ // We should only generate the columns once.
110+ BindReferences .bindReference(expr, child.output).gen(ctx)
111+ } else {
112+ val isNull = ctx.freshName(" isNull" )
113+ val value = ctx.freshName(" value" )
114+ val code =
115+ s """
116+ |boolean $isNull = true;
117+ | ${ctx.javaType(expr.dataType)} $value = ${ctx.defaultValue(expr.dataType)};
118+ """ .stripMargin
119+ ExprCode (code, isNull, value)
120+ }
121+ }
122+
123+ // In order to prevent code exploration, we can't call `consume()` many times, so we call
124+ // that in a loop, and use swith/case to select the projections.
125+ val projectCodes = projections.zipWithIndex.map { case (exprs, i) =>
126+ val need = exprs.zipWithIndex.filter { case (e, j) =>
127+ uniqExprs(j).size > 1
128+ }
129+ val updates = need.map { case (e, j) =>
130+ val ev = BindReferences .bindReference(e, child.output).gen(ctx)
131+ s """
132+ | ${ev.code}
133+ | ${resultVars(j).isNull} = ${ev.isNull};
134+ | ${resultVars(j).value} = ${ev.value};
135+ """ .stripMargin
136+ }
137+ s """
138+ |case $i:
139+ | ${updates.mkString(" \n " ).trim}
140+ | break;
141+ """ .stripMargin
142+ }
143+
144+ val i = ctx.freshName(" i" )
145+ s """
146+ | ${resultVars.map(_.code).mkString(" \n " ).trim}
147+ |for (int $i = 0; $i < ${projections.length}; $i ++) {
148+ | switch ( $i) {
149+ | ${projectCodes.mkString(" \n " ).trim}
150+ | }
151+ | ${consume(ctx, resultVars)}
152+ |}
153+ """ .stripMargin
154+ }
79155}
0 commit comments