diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 1a84859cc3a1..38a61196d0dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -21,7 +21,9 @@ import java.util.Objects import scala.collection.mutable +import org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.supportedExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.ExprValue import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -163,20 +165,6 @@ class EquivalentExpressions( case _ => Nil } - private def supportedExpression(e: Expression) = { - !e.exists { - // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the - // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. - case _: LambdaVariable => true - - // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, - // can cause error like NPE. - case _: PlanExpression[_] => Utils.isInRunningSparkTask - - case _ => false - } - } - /** * Adds the expression to this data structure recursively. Stops if a matching expression * is found. That is, if `expr` has already been added, its children are not added. @@ -202,6 +190,30 @@ class EquivalentExpressions( } } + /** + * Adds the expression to this data structure recursively. Stops if a matching expression + * is found. That is, if `expr` has already been added, its children are not added. + */ + def addConditionalExprTree( + expr: Expression, + map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = { + if (supportedExpression(expr)) { + updateConditionalExprTree(expr, map) + } + } + + private def updateConditionalExprTree( + expr: Expression, + map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap, + useCount: Int = 1): Unit = { + val skip = useCount == 0 || expr.isInstanceOf[LeafExpression] + + if (!skip && !updateExprInMap(expr, map, useCount)) { + val uc = useCount.signum + expr.children.foreach(updateConditionalExprTree(_, map, uc)) + } + } + /** * Returns the state of the given expression in the `equivalenceMap`. Returns None if there is no * equivalent expressions. @@ -240,6 +252,23 @@ class EquivalentExpressions( } } +object EquivalentExpressions { + def supportedExpression(e: Expression): Boolean = { + !e.exists { + // `LambdaVariable` is usually used as a loop variable and `NamedLambdaVariable` is used in + // higher-order functions, which can't be evaluated ahead of the execution. + case _: LambdaVariable => true + case _: NamedLambdaVariable => true + + // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, + // can cause error like NPE. + case _: PlanExpression[_] => Utils.isInRunningSparkTask + + case _ => false + } + } +} + /** * Wrapper around an Expression that provides semantic equality. */ @@ -267,4 +296,11 @@ case class ExpressionEquals(e: Expression) { * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened" * useCount in this wrapper in-place. */ -case class ExpressionStats(expr: Expression)(var useCount: Int) +case class ExpressionStats(expr: Expression)( + var useCount: Int, + var initialized: Option[String] = None, + var isNull: Option[ExprValue] = None, + var value: Option[ExprValue] = None, + var funcName: Option[String] = None, + var params: Option[Seq[Class[_]]] = None, + var addedFunction: Boolean = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c2330cdb59db..2e410cef6fe8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -194,11 +194,36 @@ abstract class Expression extends TreeNode[Expression] { subExprState.eval.isNull, subExprState.eval.value) }.getOrElse { - val isNull = ctx.freshName("isNull") - val value = ctx.freshName("value") - val eval = doGenCode(ctx, ExprCode( - JavaCode.isNullVariable(isNull), - JavaCode.variable(value, dataType))) + val eval = + if (EquivalentExpressions.supportedExpression(this)) { + ctx.commonExpressions.get(ExpressionEquals(this)) match { + case Some(stats) => + // We should reuse the currentVar references which code is not empty + val nonEmptyRefs = this.exists { + case BoundReference(ordinal, _, _) => + ctx.currentVars != null && ctx.currentVars(ordinal) != null && + ctx.currentVars(ordinal).code != EmptyBlock + case _ => false + } + val eval = doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(ctx.freshName("isNull")), + JavaCode.variable(ctx.freshName("value"), dataType))) + if (eval.code != EmptyBlock && !nonEmptyRefs) { + ctx.genReusedCode(stats, eval) + } else { + eval + } + + case None => + doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(ctx.freshName("isNull")), + JavaCode.variable(ctx.freshName("value"), dataType))) + } + } else { + doGenCode(ctx, ExprCode( + JavaCode.isNullVariable(ctx.freshName("isNull")), + JavaCode.variable(ctx.freshName("value"), dataType))) + } reduceCodeSize(ctx, eval) if (eval.code.toString.nonEmpty) { // Add `this` in the comment. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5651a30515f2..a534c6f4f071 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1263,6 +1263,97 @@ class CodegenContext extends Logging { } } + /** + * If includeDefiniteExpression is true, collect all commons expressions whether or not the + * expressions will definite be executed and return the initialization code block. + * If includeDefiniteExpression is false, we will exclude the common expressions which will + * definite be executed. + * @param expressions + * @return + */ + def conditionalSubexpressionElimination( + expressions: Seq[Expression], + includeDefiniteExpression: Boolean = true): Block = { + var initBlock: Block = EmptyBlock + if (!SQLConf.get.subexpressionEliminationEnabled) return initBlock + + val equivalence = new EquivalentExpressions + expressions.map(equivalence.addConditionalExprTree(_)) + val commonExpressions = equivalence.getAllExprStates(1) + if (includeDefiniteExpression) { + commonExpressions.map(initBlock += initCommonExpression(_)) + } else { + val definiteEquivalence = new EquivalentExpressions + expressions.foreach(definiteEquivalence.addExprTree(_)) + (commonExpressions diff definiteEquivalence.getAllExprStates(1)) + .map(initBlock += initCommonExpression(_)) + } + initBlock + } + + def initCommonExpression(stats: ExpressionStats): Block = { + if (stats.initialized.isEmpty) { + val expr = stats.expr + stats.initialized = Some(addMutableState(JAVA_BOOLEAN, "subExprInit")) + stats.isNull = Some(JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "subExprIsNull"))) + stats.value = Some(JavaCode.global(addMutableState(javaType(expr.dataType), "subExprValue"), + expr.dataType)) + stats.funcName = Some(freshName("subExpr")) + commonExpressions += ExpressionEquals(expr) -> stats + code"${stats.initialized.get} = false;\n" + } else { + EmptyBlock + } + } + + def genReusedCode(stats: ExpressionStats, eval: ExprCode): ExprCode = { + val (inputVars, _) = getLocalInputVariableValues(this, stats.expr, subExprEliminationExprs) + val (initialized, isNull, value) = (stats.initialized.get, stats.isNull.get, stats.value.get) + val validParamLength = isValidParamLength(calculateParamLengthFromExprValues(inputVars)) + if (!stats.addedFunction && validParamLength) { + // Wrap the expression code in a function. + val argList = + inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}") + val fn = + s""" + |private void ${stats.funcName.get}(${argList.mkString(", ")}) { + | if (!$initialized) { + | ${eval.code} + | $initialized = true; + | $isNull = ${eval.isNull}; + | $value = ${eval.value}; + | } + |} + """.stripMargin + stats.funcName = Some(addNewFunction(stats.funcName.get, fn)) + stats.params = Some(inputVars.map(_.javaType)) + stats.addedFunction = true + } + if (!classFunctions.values.map(_.keys).flatten.toSet.contains(stats.funcName.get)) { + // The CodegenContext has changed, all the corresponding variables will also not be available + eval + } else if (inputVars.map(_.javaType) != stats.params.get) { + // input vars changed, e.g. some input vars now are GlobalValue. + eval + } else { + val code = + if (validParamLength) { + val inputVariables = inputVars.map(_.variableName).mkString(", ") + code"${stats.funcName.get}($inputVariables);" + } else { + code""" + |if (!$initialized) { + | ${eval.code} + | $initialized = true; + | $isNull = ${eval.isNull}; + | $value = ${eval.value}; + |} + """.stripMargin + } + ExprCode(code, isNull, value) + } + } + /** * Generates code for expressions. If doSubexpressionElimination is true, subexpression * elimination will be performed. Subexpression elimination assumes that the code for each @@ -1270,12 +1361,17 @@ class CodegenContext extends Logging { */ def generateExpressions( expressions: Seq[Expression], - doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { + doSubexpressionElimination: Boolean = false): (Seq[ExprCode], Block) = { // We need to make sure that we do not reuse stateful expressions. This is needed for codegen // as well because some expressions may implement `CodegenFallback`. val cleanedExpressions = expressions.map(_.freshCopyIfContainsStatefulExpression()) - if (doSubexpressionElimination) subexpressionElimination(cleanedExpressions) - cleanedExpressions.map(e => e.genCode(this)) + val initBlock = if (doSubexpressionElimination) { + subexpressionElimination(cleanedExpressions) + conditionalSubexpressionElimination(cleanedExpressions, false) + } else { + EmptyBlock + } + (cleanedExpressions.map(e => e.genCode(this)), initBlock) } /** @@ -1314,6 +1410,8 @@ class CodegenContext extends Logging { EmptyBlock } } + + var commonExpressions = mutable.Map[ExpressionEquals, ExpressionStats]() } /** @@ -1843,16 +1941,16 @@ object CodeGenerator extends Logging { * elimination states for a given `expr`. This result will be used to split the * generated code of expressions into multiple functions. * - * Second value: Returns the set of `ExprCodes`s which are necessary codes before + * Second value: Returns the seq of `ExprCodes`s which are necessary codes before * evaluating subexpressions. */ def getLocalInputVariableValues( ctx: CodegenContext, expr: Expression, subExprs: Map[ExpressionEquals, SubExprEliminationState] = Map.empty) - : (Set[VariableValue], Set[ExprCode]) = { - val argSet = mutable.Set[VariableValue]() - val exprCodesNeedEvaluate = mutable.Set[ExprCode]() + : (Seq[VariableValue], Seq[ExprCode]) = { + val argSet = mutable.LinkedHashSet[VariableValue]() + val exprCodesNeedEvaluate = mutable.LinkedHashSet[ExprCode]() if (ctx.INPUT_ROW != null) { argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) @@ -1889,7 +1987,7 @@ object CodeGenerator extends Logging { } } - (argSet.toSet, exprCodesNeedEvaluate.toSet) + (argSet.toSeq, exprCodesNeedEvaluate.toSeq) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 2e018de07101..4829348aacf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -61,7 +61,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP case (NoOp, _) => false case _ => true } - val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) + val (exprVals, initBlock) = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { @@ -130,6 +130,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $initBlock $evalSubexpr $allProjections // copy all the results into MutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index c246d07f189b..ab2f6f47e847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,7 +38,8 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { val ctx = newCodeGenContext() // Do sub-expression elimination for predicates. - val eval = ctx.generateExpressions(Seq(predicate), useSubexprElimination).head + val (evals, initBlock) = ctx.generateExpressions(Seq(predicate), useSubexprElimination) + val eval = evals.head val evalSubexpr = ctx.subexprFunctionsCode val codeBody = s""" @@ -60,6 +61,7 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { } public boolean eval(InternalRow ${ctx.INPUT_ROW}) { + $initBlock $evalSubexpr ${eval.code} return !${eval.isNull} && ${eval.value}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 459c1d9a8ba1..68bed063ab61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -287,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, expressions: Seq[Expression], useSubexprElimination: Boolean = false): ExprCode = { - val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) + val (exprEvals, initBlock) = ctx.generateExpressions(expressions, useSubexprElimination) val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) val numVarLenFields = exprSchemas.count { @@ -307,6 +307,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = code""" + |$initBlock |$rowWriter.reset(); |$evalSubexpr |$writeExpressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index c087fdf5f962..847ca85d9bee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -168,6 +168,8 @@ case class ExpandExec( } // Part 2: switch/case statements + initBlock += ctx.conditionalSubexpressionElimination( + projections.flatten.map(BindReferences.bindReference(_, attributeSeq))) val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) => val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col => if (!sameOutput(col)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index ddc2cfb56d4f..556299d46cf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -45,6 +45,9 @@ import org.apache.spark.util.Utils */ trait CodegenSupport extends SparkPlan { + var initBlock: Block = EmptyBlock + var commonExpressions = mutable.Map.empty[ExpressionEquals, ExpressionStats] + /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { case _: HashAggregateExec => "hashAgg" @@ -176,6 +179,7 @@ trait CodegenSupport extends SparkPlan { ctx.currentVars = inputVars ctx.INPUT_ROW = null ctx.freshNamePrefix = parent.variablePrefix + ctx.commonExpressions = parent.commonExpressions val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs) // Under certain conditions, we can put the logic to consume the rows of this operator into @@ -198,6 +202,7 @@ trait CodegenSupport extends SparkPlan { s""" |${ctx.registerComment(s"CONSUME: ${parent.simpleString(conf.maxToStringFields)}")} |$evaluated + |${parent.initBlock} |$consumeFunc """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala index 1377a9842231..bd59c0d21fd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenSupport.scala @@ -208,6 +208,7 @@ trait AggregateCodegenSupport bindReferences(updateExprsForOneFunc, inputAttrs) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val initBlock = ctx.conditionalSubexpressionElimination(boundUpdateExprs.flatten, false) val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val bufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { @@ -240,6 +241,7 @@ trait AggregateCodegenSupport s""" |// do aggregate |// common sub-expressions + |$initBlock |$effectiveCodes |// evaluate aggregate functions and update aggregation buffers |$codeToEvalAggFuncs @@ -307,7 +309,7 @@ trait AggregateCodegenSupport } else { val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => val inputVarsForOneFunc = aggExprsForOneFunc.map( - CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)._1).reduce(_ ++ _).toSeq + CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)._1.toSet).reduce(_ ++ _).toSeq val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6c83ba5546d2..b62b80a8aafd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -551,6 +551,7 @@ case class HashAggregateExec( def outputFromRowBasedMap: String = { s""" |while ($limitNotReachedCondition $iterTermForFastHashMap.next()) { + | ${initBlock} | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -577,6 +578,7 @@ case class HashAggregateExec( s""" |while ($limitNotReachedCondition $iterTermForFastHashMap.hasNext()) { | InternalRow $row = (InternalRow) $iterTermForFastHashMap.next(); + | ${initBlock} | ${generateKeyRow.code} | ${generateBufferRow.code} | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); @@ -591,6 +593,7 @@ case class HashAggregateExec( def outputFromRegularHashMap: String = { s""" |while ($limitNotReachedCondition $iterTerm.next()) { + | ${initBlock} | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | $outputFunc($keyTerm, $bufferTerm); @@ -625,7 +628,7 @@ case class HashAggregateExec( // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, bindReferences[Expression](groupingExpressions, child.output)) - val fastRowKeys = ctx.generateExpressions( + val (fastRowKeys, initBlock) = ctx.generateExpressions( bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash") @@ -688,6 +691,7 @@ case class HashAggregateExec( // If fast hash map is on, we first generate code to probe and update the fast hash map. // If the probe is successful the corresponding fast row buffer will hold the mutable row. s""" + |$initBlock |${fastRowKeys.map(_.code).mkString("\n")} |if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { | $fastRowBuffer = $fastHashMapTerm.findOrInsert( @@ -728,6 +732,7 @@ case class HashAggregateExec( bindReferences(updateExprsForOneFunc, inputAttrs) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val initBlock = ctx.conditionalSubexpressionElimination(boundUpdateExprs.flatten, false) val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val unsafeRowBufferEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { @@ -760,6 +765,7 @@ case class HashAggregateExec( ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs) s""" |// common sub-expressions + |$initBlock |$effectiveCodes |// evaluate aggregate functions and update aggregation buffers |$codeToEvalAggFuncs @@ -774,6 +780,7 @@ case class HashAggregateExec( bindReferences(updateExprsForOneFunc, inputAttrs) } val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten) + val initBlock = ctx.conditionalSubexpressionElimination(boundUpdateExprs.flatten, false) val effectiveCodes = ctx.evaluateSubExprEliminationState(subExprs.states.values) val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc => ctx.withSubExprEliminationExprs(subExprs.states) { @@ -810,6 +817,7 @@ case class HashAggregateExec( s""" |if ($fastRowBuffer != null) { | // common sub-expressions + | $initBlock | $effectiveCodes | // evaluate aggregate functions and update aggregation buffers | $codeToEvalAggFuncs diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 68f056d894b9..c2bf013d4812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -70,6 +70,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) { // subexpression elimination val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + initBlock += ctx.conditionalSubexpressionElimination(exprs, false) val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { exprs.map(_.genCode(ctx)) } @@ -178,6 +179,8 @@ trait GeneratePredicateHelper extends PredicateHelper { // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val extraIsNotNullAttrs = mutable.Set[Attribute]() + initBlock += ctx.conditionalSubexpressionElimination( + otherPreds.map(BindReferences.bindReference(_, inputAttrs))) val generated = otherPreds.map { c => val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 84c0cd127f45..3114a5125845 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -464,6 +464,7 @@ case class BroadcastNestedLoopJoinExec( s""" |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; + | ${initBlock} | $checkCondition { | $numOutput.add(1); | ${consume(ctx, resultVars)} @@ -497,6 +498,7 @@ case class BroadcastNestedLoopJoinExec( |boolean $foundMatch = false; |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; + | ${initBlock} | boolean $shouldOutputRow = false; | $checkCondition { | $shouldOutputRow = true; @@ -548,6 +550,7 @@ case class BroadcastNestedLoopJoinExec( |boolean $foundMatch = false; |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; + | ${initBlock} | $checkCondition { | $foundMatch = true; | break; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 7c48baf99ef8..f3a1f52c612e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -458,6 +458,7 @@ trait HashJoin extends JoinCodegenSupport { val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) s""" + |$initBlock |boolean $conditionPassed = true; |${eval.trim} |if ($matched != null) { @@ -657,6 +658,7 @@ trait HashJoin extends JoinCodegenSupport { val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) s""" + |$initBlock |$eval |${ev.code} |$existsVar = !${ev.isNull} && ${ev.value}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala index a7d1edefcd61..43fa0cdb5316 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/JoinCodegenSupport.scala @@ -55,10 +55,12 @@ trait JoinCodegenSupport extends CodegenSupport with BaseJoinExec { // filter the output via condition ctx.currentVars = streamVars2 ++ buildVars - val ev = - BindReferences.bindReference(expr, streamPlan.output ++ buildPlan.output).genCode(ctx) + val bondExpr = BindReferences.bindReference(expr, streamPlan.output ++ buildPlan.output) + val initBlock = ctx.conditionalSubexpressionElimination(Seq(bondExpr)) + val ev = bondExpr.genCode(ctx) val skipRow = s"${ev.isNull} || !${ev.value}" s""" + |$initBlock |$eval |${ev.code} |if (!($skipRow))