From 6d489248f4b2c2e5b40f51524e21992a59226ae8 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Sat, 9 Dec 2017 15:27:24 +0000 Subject: [PATCH 1/2] initial commit --- .../execution/joins/SortMergeJoinExec.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 9c08ec71c1fd..1221b49a8b58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -513,26 +513,28 @@ case class SortMergeJoinExec( * the variables should be declared separately from accessing the columns, we can't use the * codegen of BoundReference here. */ - private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { + private def createLeftVars(ctx: CodegenContext, leftRow: String): (Seq[ExprCode], Seq[String]) = { ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) - // declare it as class member, so we can access the column before or in the loop. - ctx.addMutableState(ctx.javaType(a.dataType), value) if (a.nullable) { val isNull = ctx.freshName("isNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) val code = s""" |$isNull = $leftRow.isNullAt($i); |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); """.stripMargin - ExprCode(code, isNull, value) + (ExprCode(code, isNull, value), + s""" + |boolean $isNull = false; + |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; + """.stripMargin) } else { - ExprCode(s"$value = $valueCode;", "false", value) + (ExprCode(s"$value = $valueCode;", "false", value), + s"""${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};""") } - } + }.unzip } /** @@ -580,7 +582,7 @@ case class SortMergeJoinExec( val (leftRow, matches) = genScanner(ctx) // Create variables for row from both sides. - val leftVars = createLeftVars(ctx, leftRow) + val (leftVars, leftVarDecl) = createLeftVars(ctx, leftRow) val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) @@ -617,6 +619,7 @@ case class SortMergeJoinExec( s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { + | ${leftVarDecl.mkString("\n")} | ${beforeLoop.trim} | scala.collection.Iterator $iterator = $matches.generateIterator(); | while ($iterator.hasNext()) { From 9faf0a2644739e9e19968c5077d6b14011aab9dd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 11 Dec 2017 08:36:52 +0000 Subject: [PATCH 2/2] address review comments --- .../execution/joins/SortMergeJoinExec.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 1221b49a8b58..554b73181116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -507,7 +507,7 @@ case class SortMergeJoinExec( } /** - * Creates variables for left part of result row. + * Creates variables and declarations for left part of result row. * * In order to defer the access after condition and also only access once in the loop, * the variables should be declared separately from accessing the columns, we can't use the @@ -518,21 +518,25 @@ case class SortMergeJoinExec( left.output.zipWithIndex.map { case (a, i) => val value = ctx.freshName("value") val valueCode = ctx.getValue(leftRow, a.dataType, i.toString) + val javaType = ctx.javaType(a.dataType) + val defaultValue = ctx.defaultValue(a.dataType) if (a.nullable) { val isNull = ctx.freshName("isNull") val code = s""" |$isNull = $leftRow.isNullAt($i); - |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode); + |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin - (ExprCode(code, isNull, value), + val leftVarsDecl = s""" |boolean $isNull = false; - |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)}; - """.stripMargin) + |$javaType $value = $defaultValue; + """.stripMargin + (ExprCode(code, isNull, value), leftVarsDecl) } else { - (ExprCode(s"$value = $valueCode;", "false", value), - s"""${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};""") + val code = s"$value = $valueCode;" + val leftVarsDecl = s"""$javaType $value = $defaultValue;""" + (ExprCode(code, "false", value), leftVarsDecl) } }.unzip }