Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ class EquivalentExpressions {
def childrenToRecurse: Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case i: If => i.predicate :: Nil
// `CaseWhen` implements `CodegenFallback`, we only need to handle `CaseWhenCodegen` here.
case c: CaseWhenCodegen => c.children.head :: Nil
case c: CaseWhen => c.children.head :: Nil
case c: Coalesce => c.children.head :: Nil
case other => other.children
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,34 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}

/**
* Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* When a = true, returns b; when c = true, returns d; else returns e.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
abstract class CaseWhenBase(
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.",
arguments = """
Arguments:
* expr1, expr3 - the branch condition expressions should all be boolean type.
* expr2, expr4, expr5 - the branch value expressions and else value expression should all be
same type or coercible to a common type.
""",
examples = """
Examples:
> SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
1
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
2
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END;
->
SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 END;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you double check Hive returns NULL in the following case?

SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 END;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can follow your first suggestion and I can test this on hive, but actually I haven't changed this part of code. I will post ASAP the result in Hive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirm that Hive returns NULL. Then I am updating the description as requested.

NULL
""")
// scalastyle:on line.size.limit
case class CaseWhen(
branches: Seq[(Expression, Expression)],
elseValue: Option[Expression])
elseValue: Option[Expression] = None)
extends Expression with Serializable {

override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
Expand Down Expand Up @@ -211,111 +231,62 @@ abstract class CaseWhenBase(
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
"CASE" + cases + elseCase + " END"
}
}


/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* When a = true, returns b; when c = true, returns d; else returns e.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; else when `expr3` = true, returns `expr4`; else returns `expr5`.",
arguments = """
Arguments:
* expr1, expr3 - the branch condition expressions should all be boolean type.
* expr2, expr4, expr5 - the branch value expressions and else value expression should all be
same type or coercible to a common type.
""",
examples = """
Examples:
> SELECT CASE WHEN 1 > 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
1
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 > 0 THEN 2.0 ELSE 1.2 END;
2
> SELECT CASE WHEN 1 < 0 THEN 1 WHEN 2 < 0 THEN 2.0 ELSE null END;
NULL
""")
// scalastyle:on line.size.limit
case class CaseWhen(
val branches: Seq[(Expression, Expression)],
val elseValue: Option[Expression] = None)
extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
super[CodegenFallback].doGenCode(ctx, ev)
}

def toCodegen(): CaseWhenCodegen = {
CaseWhenCodegen(branches, elseValue)
}
}

/**
* CaseWhen expression used when code generation condition is satisfied.
* OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
case class CaseWhenCodegen(
val branches: Seq[(Expression, Expression)],
val elseValue: Option[Expression] = None)
extends CaseWhenBase(branches, elseValue) with Serializable {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
// Generate code that looks like:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we keep this comment and update it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is necessary since now the generated code is way easier and more standard and nowhere else a comment like this is provided. Anyway, if you feel it is needed, I can add it.

//
// condA = ...
// if (condA) {
// valueA
// } else {
// condB = ...
// if (condB) {
// valueB
// } else {
// condC = ...
// if (condC) {
// valueC
// } else {
// elseValue
// }
// }
// }
val conditionMet = ctx.freshName("caseWhenConditionMet")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment to explain what it is.

ctx.addMutableState("boolean", ev.isNull, "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.JAVA_BOOLEAN

ctx.addMutableState(ctx.javaType(dataType), ev.value, "")
val cases = branches.map { case (condExpr, valueExpr) =>
val cond = condExpr.genCode(ctx)
val res = valueExpr.genCode(ctx)
s"""
${cond.code}
if (!${cond.isNull} && ${cond.value}) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
if(!$conditionMet) {
${cond.code}
if (!${cond.isNull} && ${cond.value}) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
$conditionMet = true;
}
}
"""
}

var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")

elseValue.foreach { elseExpr =>
val elseCode = elseValue.map { elseExpr =>
val res = elseExpr.genCode(ctx)
generatedCode +=
s"""
s"""
if(!$conditionMet) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
"""
}
"""
}

generatedCode += "}\n" * cases.size
val allConditions = cases ++ elseCode

val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
allConditions.mkString("\n")
} else {
ctx.splitExpressions(allConditions, "caseWhen",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style issue. Indent

("InternalRow", ctx.INPUT_ROW) :: ("boolean", conditionMet) :: Nil, returnType = "boolean",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.JAVA_BOOLEAN

makeSplitFunction = {
func =>
s"""
$func
return $conditionMet;
"""
},
foldFunctions = { funcCalls =>
funcCalls.map(funcCall => s"$conditionMet = $funcCall;").mkString("\n")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When caseWhenConditionMet is false, we do not need to call funcCall .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, but we already have checks about it inside the functions. Do we need also to check it outside?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to avoid the extra function calls here. It is not cheap when the number of rows is large. Now, we split the functions pretty aggressively. I saw many new functions are generated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll do. Then I'd suggest to do the same also in other places. I can check where an analogous pattern is used and create a PR if it is ok.

})
}

ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$generatedCode""")
${ev.isNull} = true;
${ev.value} = ${ctx.defaultValue(dataType)};
boolean $conditionMet = false;
$code""")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
// The following batch should be executed after batch "Join Reorder" and "LocalRelation".
Batch("Check Cartesian Products", Once,
CheckCartesianProducts) ::
Batch("OptimizeCodegen", Once,
OptimizeCodegen) ::
Batch("RewriteSubquery", Once,
RewritePredicateSubquery,
CollapseProject) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -552,21 +552,6 @@ object FoldablePropagation extends Rule[LogicalPlan] {
}


/**
* Optimizes expressions by replacing according to CodeGen configuration.
*/
object OptimizeCodegen extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e: CaseWhen if canCodegen(e) => e.toCodegen()
}

private def canCodegen(e: CaseWhen): Boolean = {
val numBranches = e.branches.size + e.elseValue.size
numBranches <= SQLConf.get.maxCaseBranchesForCodegen
}
}


/**
* Removes [[Cast Casts]] that are unnecessary because the input is already the correct type.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,6 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val MAX_CASES_BRANCHES = buildConf("spark.sql.codegen.maxCaseBranches")
.internal()
.doc("The maximum number of switches supported with codegen.")
.intConf
.createWithDefault(20)

val CODEGEN_LOGGING_MAX_LINES = buildConf("spark.sql.codegen.logging.maxLines")
.internal()
.doc("The maximum number of codegen lines to log when errors occur. Use -1 for unlimited.")
Expand Down Expand Up @@ -1084,8 +1078,6 @@ class SQLConf extends Serializable with Logging {

def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)

def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)

def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES)

def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("SPARK-13242: case-when expression with large number of branches (or cases)") {
val cases = 50
val cases = 500
val clauses = 20

// Generate an individual case
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class FlatMapGroupsWithState_StateManager(
val deser = stateEncoder.resolveAndBind().deserializer.transformUp {
case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal)
}
CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen()
CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser)
}

// Converters for translating state between rows and Java objects
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -2126,4 +2126,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val mean = result.select("DecimalCol").where($"summary" === "mean")
assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
}

test("SPARK-22520: support code generation for large CaseWhen") {
val N = 30
var expr1 = when($"id" === lit(0), 0)
var expr2 = when($"id" === lit(0), 10)
(1 to N).foreach { i =>
expr1 = expr1.when($"id" === lit(i), -i)
expr2 = expr2.when($"id" === lit(i + 10), i)
}
val df = spark.range(1).select(expr1, expr2.otherwise(0))
df.show
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare the results?

assert(df.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
}
}
Loading