Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import java.util.Objects

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.supportedExpression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.codegen.ExprValue
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -163,20 +165,6 @@ class EquivalentExpressions(
case _ => Nil
}

private def supportedExpression(e: Expression) = {
!e.exists {
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
case _: LambdaVariable => true

// `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor,
// can cause error like NPE.
case _: PlanExpression[_] => Utils.isInRunningSparkTask

case _ => false
}
}

/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
Expand All @@ -202,6 +190,30 @@ class EquivalentExpressions(
}
}

/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
*/
def addConditionalExprTree(
expr: Expression,
map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = {
if (supportedExpression(expr)) {
updateConditionalExprTree(expr, map)
}
}

private def updateConditionalExprTree(
expr: Expression,
map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap,
useCount: Int = 1): Unit = {
val skip = useCount == 0 || expr.isInstanceOf[LeafExpression]

if (!skip && !updateExprInMap(expr, map, useCount)) {
val uc = useCount.signum
expr.children.foreach(updateConditionalExprTree(_, map, uc))
}
}

/**
* Returns the state of the given expression in the `equivalenceMap`. Returns None if there is no
* equivalent expressions.
Expand Down Expand Up @@ -240,6 +252,23 @@ class EquivalentExpressions(
}
}

object EquivalentExpressions {
def supportedExpression(e: Expression): Boolean = {
!e.exists {
// `LambdaVariable` is usually used as a loop variable and `NamedLambdaVariable` is used in
// higher-order functions, which can't be evaluated ahead of the execution.
case _: LambdaVariable => true
case _: NamedLambdaVariable => true

// `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor,
// can cause error like NPE.
case _: PlanExpression[_] => Utils.isInRunningSparkTask

case _ => false
}
}
}

/**
* Wrapper around an Expression that provides semantic equality.
*/
Expand Down Expand Up @@ -267,4 +296,11 @@ case class ExpressionEquals(e: Expression) {
* Instead of appending to a mutable list/buffer of Expressions, just update the "flattened"
* useCount in this wrapper in-place.
*/
case class ExpressionStats(expr: Expression)(var useCount: Int)
case class ExpressionStats(expr: Expression)(
var useCount: Int,
var initialized: Option[String] = None,
var isNull: Option[ExprValue] = None,
var value: Option[ExprValue] = None,
var funcName: Option[String] = None,
var params: Option[Seq[Class[_]]] = None,
var addedFunction: Boolean = false)
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,36 @@ abstract class Expression extends TreeNode[Expression] {
subExprState.eval.isNull,
subExprState.eval.value)
}.getOrElse {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val eval = doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
val eval =
if (EquivalentExpressions.supportedExpression(this)) {
ctx.commonExpressions.get(ExpressionEquals(this)) match {
case Some(stats) =>
// We should reuse the currentVar references which code is not empty
val nonEmptyRefs = this.exists {
case BoundReference(ordinal, _, _) =>
ctx.currentVars != null && ctx.currentVars(ordinal) != null &&
ctx.currentVars(ordinal).code != EmptyBlock
case _ => false
}
val eval = doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(ctx.freshName("isNull")),
JavaCode.variable(ctx.freshName("value"), dataType)))
if (eval.code != EmptyBlock && !nonEmptyRefs) {
ctx.genReusedCode(stats, eval)
} else {
eval
}

case None =>
doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(ctx.freshName("isNull")),
JavaCode.variable(ctx.freshName("value"), dataType)))
}
} else {
doGenCode(ctx, ExprCode(
JavaCode.isNullVariable(ctx.freshName("isNull")),
JavaCode.variable(ctx.freshName("value"), dataType)))
}
reduceCodeSize(ctx, eval)
if (eval.code.toString.nonEmpty) {
// Add `this` in the comment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1263,19 +1263,115 @@ class CodegenContext extends Logging {
}
}

/**
* If includeDefiniteExpression is true, collect all commons expressions whether or not the
* expressions will definite be executed and return the initialization code block.
* If includeDefiniteExpression is false, we will exclude the common expressions which will
* definite be executed.
* @param expressions
* @return
*/
def conditionalSubexpressionElimination(
expressions: Seq[Expression],
includeDefiniteExpression: Boolean = true): Block = {
var initBlock: Block = EmptyBlock
if (!SQLConf.get.subexpressionEliminationEnabled) return initBlock

val equivalence = new EquivalentExpressions
expressions.map(equivalence.addConditionalExprTree(_))
val commonExpressions = equivalence.getAllExprStates(1)
if (includeDefiniteExpression) {
commonExpressions.map(initBlock += initCommonExpression(_))
} else {
val definiteEquivalence = new EquivalentExpressions
expressions.foreach(definiteEquivalence.addExprTree(_))
(commonExpressions diff definiteEquivalence.getAllExprStates(1))
.map(initBlock += initCommonExpression(_))
}
initBlock
}

def initCommonExpression(stats: ExpressionStats): Block = {
if (stats.initialized.isEmpty) {
val expr = stats.expr
stats.initialized = Some(addMutableState(JAVA_BOOLEAN, "subExprInit"))
stats.isNull = Some(JavaCode.isNullGlobal(addMutableState(JAVA_BOOLEAN, "subExprIsNull")))
stats.value = Some(JavaCode.global(addMutableState(javaType(expr.dataType), "subExprValue"),
expr.dataType))
stats.funcName = Some(freshName("subExpr"))
commonExpressions += ExpressionEquals(expr) -> stats
code"${stats.initialized.get} = false;\n"
} else {
EmptyBlock
}
}

def genReusedCode(stats: ExpressionStats, eval: ExprCode): ExprCode = {
val (inputVars, _) = getLocalInputVariableValues(this, stats.expr, subExprEliminationExprs)
val (initialized, isNull, value) = (stats.initialized.get, stats.isNull.get, stats.value.get)
val validParamLength = isValidParamLength(calculateParamLengthFromExprValues(inputVars))
if (!stats.addedFunction && validParamLength) {
// Wrap the expression code in a function.
val argList =
inputVars.map(v => s"${CodeGenerator.typeName(v.javaType)} ${v.variableName}")
val fn =
s"""
|private void ${stats.funcName.get}(${argList.mkString(", ")}) {
| if (!$initialized) {
| ${eval.code}
| $initialized = true;
| $isNull = ${eval.isNull};
| $value = ${eval.value};
| }
|}
""".stripMargin
stats.funcName = Some(addNewFunction(stats.funcName.get, fn))
stats.params = Some(inputVars.map(_.javaType))
stats.addedFunction = true
}
if (!classFunctions.values.map(_.keys).flatten.toSet.contains(stats.funcName.get)) {
// The CodegenContext has changed, all the corresponding variables will also not be available
eval
} else if (inputVars.map(_.javaType) != stats.params.get) {
// input vars changed, e.g. some input vars now are GlobalValue.
eval
} else {
val code =
if (validParamLength) {
val inputVariables = inputVars.map(_.variableName).mkString(", ")
code"${stats.funcName.get}($inputVariables);"
} else {
code"""
|if (!$initialized) {
| ${eval.code}
| $initialized = true;
| $isNull = ${eval.isNull};
| $value = ${eval.value};
|}
""".stripMargin
}
ExprCode(code, isNull, value)
}
}

/**
* Generates code for expressions. If doSubexpressionElimination is true, subexpression
* elimination will be performed. Subexpression elimination assumes that the code for each
* expression will be combined in the `expressions` order.
*/
def generateExpressions(
expressions: Seq[Expression],
doSubexpressionElimination: Boolean = false): Seq[ExprCode] = {
doSubexpressionElimination: Boolean = false): (Seq[ExprCode], Block) = {
// We need to make sure that we do not reuse stateful expressions. This is needed for codegen
// as well because some expressions may implement `CodegenFallback`.
val cleanedExpressions = expressions.map(_.freshCopyIfContainsStatefulExpression())
if (doSubexpressionElimination) subexpressionElimination(cleanedExpressions)
cleanedExpressions.map(e => e.genCode(this))
val initBlock = if (doSubexpressionElimination) {
subexpressionElimination(cleanedExpressions)
conditionalSubexpressionElimination(cleanedExpressions, false)
} else {
EmptyBlock
}
(cleanedExpressions.map(e => e.genCode(this)), initBlock)
}

/**
Expand Down Expand Up @@ -1314,6 +1410,8 @@ class CodegenContext extends Logging {
EmptyBlock
}
}

var commonExpressions = mutable.Map[ExpressionEquals, ExpressionStats]()
}

/**
Expand Down Expand Up @@ -1843,16 +1941,16 @@ object CodeGenerator extends Logging {
* elimination states for a given `expr`. This result will be used to split the
* generated code of expressions into multiple functions.
*
* Second value: Returns the set of `ExprCodes`s which are necessary codes before
* Second value: Returns the seq of `ExprCodes`s which are necessary codes before
* evaluating subexpressions.
*/
def getLocalInputVariableValues(
ctx: CodegenContext,
expr: Expression,
subExprs: Map[ExpressionEquals, SubExprEliminationState] = Map.empty)
: (Set[VariableValue], Set[ExprCode]) = {
val argSet = mutable.Set[VariableValue]()
val exprCodesNeedEvaluate = mutable.Set[ExprCode]()
: (Seq[VariableValue], Seq[ExprCode]) = {
val argSet = mutable.LinkedHashSet[VariableValue]()
val exprCodesNeedEvaluate = mutable.LinkedHashSet[ExprCode]()

if (ctx.INPUT_ROW != null) {
argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow])
Expand Down Expand Up @@ -1889,7 +1987,7 @@ object CodeGenerator extends Logging {
}
}

(argSet.toSet, exprCodesNeedEvaluate.toSet)
(argSet.toSeq, exprCodesNeedEvaluate.toSeq)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
case (NoOp, _) => false
case _ => true
}
val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination)
val (exprVals, initBlock) = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination)

// 4-tuples: (code for projection, isNull variable name, value variable name, column index)
val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map {
Expand Down Expand Up @@ -130,6 +130,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP

public java.lang.Object apply(java.lang.Object _i) {
InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
$initBlock
$evalSubexpr
$allProjections
// copy all the results into MutableRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] {
val ctx = newCodeGenContext()

// Do sub-expression elimination for predicates.
val eval = ctx.generateExpressions(Seq(predicate), useSubexprElimination).head
val (evals, initBlock) = ctx.generateExpressions(Seq(predicate), useSubexprElimination)
val eval = evals.head
val evalSubexpr = ctx.subexprFunctionsCode

val codeBody = s"""
Expand All @@ -60,6 +61,7 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] {
}

public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
$initBlock
$evalSubexpr
${eval.code}
return !${eval.isNull} && ${eval.value};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodegenContext,
expressions: Seq[Expression],
useSubexprElimination: Boolean = false): ExprCode = {
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
val (exprEvals, initBlock) = ctx.generateExpressions(expressions, useSubexprElimination)
val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable))

val numVarLenFields = exprSchemas.count {
Expand All @@ -307,6 +307,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro

val code =
code"""
|$initBlock
|$rowWriter.reset();
|$evalSubexpr
|$writeExpressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ case class ExpandExec(
}

// Part 2: switch/case statements
initBlock += ctx.conditionalSubexpressionElimination(
projections.flatten.map(BindReferences.bindReference(_, attributeSeq)))
val switchCaseExprs = projections.zipWithIndex.map { case (exprs, row) =>
val (exprCodesWithIndices, inputVarSets) = exprs.indices.flatMap { col =>
if (!sameOutput(col)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ import org.apache.spark.util.Utils
*/
trait CodegenSupport extends SparkPlan {

var initBlock: Block = EmptyBlock
var commonExpressions = mutable.Map.empty[ExpressionEquals, ExpressionStats]

/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
case _: HashAggregateExec => "hashAgg"
Expand Down Expand Up @@ -176,6 +179,7 @@ trait CodegenSupport extends SparkPlan {
ctx.currentVars = inputVars
ctx.INPUT_ROW = null
ctx.freshNamePrefix = parent.variablePrefix
ctx.commonExpressions = parent.commonExpressions
val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)

// Under certain conditions, we can put the logic to consume the rows of this operator into
Expand All @@ -198,6 +202,7 @@ trait CodegenSupport extends SparkPlan {
s"""
|${ctx.registerComment(s"CONSUME: ${parent.simpleString(conf.maxToStringFields)}")}
|$evaluated
|${parent.initBlock}
|$consumeFunc
""".stripMargin
}
Expand Down
Loading