Skip to content

Commit 821e08a

Browse files
committed
Use actual java class instead of string representation.
1 parent 6498884 commit 821e08a

24 files changed

+285
-196
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ abstract class Expression extends TreeNode[Expression] {
104104
}.getOrElse {
105105
val isNull = ctx.freshName("isNull")
106106
val value = ctx.freshName("value")
107-
val eval = doGenCode(ctx, ExprCode("",
108-
VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN),
109-
VariableValue(value, CodeGenerator.javaType(dataType))))
107+
val eval = doGenCode(ctx, ExprCode(
108+
JavaCode.isNullVariable(isNull),
109+
JavaCode.variable(value, dataType)))
110110
reduceCodeSize(ctx, eval)
111111
if (eval.code.nonEmpty) {
112112
// Add `this` in the comment.
@@ -123,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] {
123123
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
124124
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
125125
val localIsNull = eval.isNull
126-
eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN)
126+
eval.isNull = JavaCode.isNullGlobal(globalIsNull)
127127
s"$globalIsNull = $localIsNull;"
128128
} else {
129129
""
@@ -142,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] {
142142
|}
143143
""".stripMargin)
144144

145-
eval.value = VariableValue(newValue, javaType)
145+
eval.value = JavaCode.variable(newValue, dataType)
146146
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
147147
}
148148
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,7 @@ case class Least(children: Seq[Expression]) extends Expression {
591591

592592
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
593593
val evalChildren = children.map(_.genCode(ctx))
594-
ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull),
595-
CodeGenerator.JAVA_BOOLEAN)
594+
ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
596595
val evals = evalChildren.map(eval =>
597596
s"""
598597
|${eval.code}
@@ -671,8 +670,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
671670

672671
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
673672
val evalChildren = children.map(_.genCode(ctx))
674-
ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull),
675-
CodeGenerator.JAVA_BOOLEAN)
673+
ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
676674
val evals = evalChildren.map(eval =>
677675
s"""
678676
|${eval.code}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
5959
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)
6060

6161
object ExprCode {
62+
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
63+
ExprCode(code = "", isNull, value)
64+
}
65+
6266
def forNullValue(dataType: DataType): ExprCode = {
63-
val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true)
64-
ExprCode(code = "", isNull = TrueLiteral,
65-
value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType)))
67+
ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
6668
}
6769

6870
def forNonNullValue(value: ExprValue): ExprCode = {
@@ -331,7 +333,7 @@ class CodegenContext {
331333
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
332334
case _ => s"$value = $initCode;"
333335
}
334-
ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType)))
336+
ExprCode.forNonNullValue(JavaCode.global(value, dataType))
335337
}
336338

337339
def declareMutableStates(): String = {
@@ -1004,8 +1006,9 @@ class CodegenContext {
10041006
// at least two nodes) as the cost of doing it is expected to be low.
10051007

10061008
subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
1007-
val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN),
1008-
GlobalValue(value, javaType(expr.dataType)))
1009+
val state = SubExprEliminationState(
1010+
JavaCode.isNullGlobal(isNull),
1011+
JavaCode.global(value, expr.dataType))
10091012
subExprEliminationExprs ++= e.map(_ -> state).toMap
10101013
}
10111014
}
@@ -1479,6 +1482,26 @@ object CodeGenerator extends Logging {
14791482
case _ => "Object"
14801483
}
14811484

1485+
def javaClass(dt: DataType): Class[_] = dt match {
1486+
case BooleanType => java.lang.Boolean.TYPE
1487+
case ByteType => java.lang.Byte.TYPE
1488+
case ShortType => java.lang.Short.TYPE
1489+
case IntegerType | DateType => java.lang.Integer.TYPE
1490+
case LongType | TimestampType => java.lang.Long.TYPE
1491+
case FloatType => java.lang.Float.TYPE
1492+
case DoubleType => java.lang.Double.TYPE
1493+
case _: DecimalType => classOf[Decimal]
1494+
case BinaryType => classOf[Array[Byte]]
1495+
case StringType => classOf[UTF8String]
1496+
case CalendarIntervalType => classOf[CalendarInterval]
1497+
case _: StructType => classOf[InternalRow]
1498+
case _: ArrayType => classOf[ArrayData]
1499+
case _: MapType => classOf[MapData]
1500+
case udt: UserDefinedType[_] => javaClass(udt.sqlType)
1501+
case ObjectType(cls) => cls
1502+
case _ => classOf[Object]
1503+
}
1504+
14821505
/**
14831506
* Returns the boxed type in Java.
14841507
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala

Lines changed: 0 additions & 76 deletions
This file was deleted.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,43 +52,45 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
5252
expressions: Seq[Expression],
5353
useSubexprElimination: Boolean): MutableProjection = {
5454
val ctx = newCodeGenContext()
55-
val (validExpr, index) = expressions.zipWithIndex.filter {
55+
val validExpr = expressions.zipWithIndex.filter {
5656
case (NoOp, _) => false
5757
case _ => true
58-
}.unzip
59-
val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination)
58+
}
59+
val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination)
6060

6161
// 4-tuples: (code for projection, isNull variable name, value variable name, column index)
62-
val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map {
63-
case (ev, i) =>
64-
val e = expressions(i)
65-
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value")
66-
if (e.nullable) {
62+
val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map {
63+
case ((e, i), ev) =>
64+
val value = JavaCode.global(
65+
ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value"),
66+
e.dataType)
67+
val (code, isNull) = if (e.nullable) {
6768
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull")
6869
(s"""
6970
|${ev.code}
7071
|$isNull = ${ev.isNull};
7172
|$value = ${ev.value};
72-
""".stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i)
73+
""".stripMargin, JavaCode.isNullGlobal(isNull))
7374
} else {
7475
(s"""
7576
|${ev.code}
7677
|$value = ${ev.value};
77-
""".stripMargin, ev.isNull, value, i)
78+
""".stripMargin, FalseLiteral)
7879
}
80+
val update = CodeGenerator.updateColumn(
81+
"mutableRow",
82+
e.dataType,
83+
i,
84+
ExprCode(isNull, value),
85+
e.nullable)
86+
(code, update)
7987
}
8088

8189
// Evaluate all the subexpressions.
8290
val evalSubexpr = ctx.subexprFunctions.mkString("\n")
8391

84-
val updates = validExpr.zip(projectionCodes).map {
85-
case (e, (_, isNull, value, i)) =>
86-
val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType)))
87-
CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
88-
}
89-
9092
val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1))
91-
val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates)
93+
val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2))
9294

9395
val codeBody = s"""
9496
public java.lang.Object generate(Object[] references) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen
1919

2020
import scala.annotation.tailrec
2121

22+
import org.apache.spark.sql.catalyst.InternalRow
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
24-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
25+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
2526
import org.apache.spark.sql.types._
2627

2728
/**
@@ -53,9 +54,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
5354
val rowClass = classOf[GenericInternalRow].getName
5455

5556
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
56-
val converter = convertToSafe(ctx,
57-
StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString),
58-
CodeGenerator.javaType(dt)), dt)
57+
val converter = convertToSafe(
58+
ctx,
59+
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt),
60+
dt)
5961
s"""
6062
if (!$tmpInput.isNullAt($i)) {
6163
${converter.code}
@@ -76,7 +78,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
7678
|final InternalRow $output = new $rowClass($values);
7779
""".stripMargin
7880

79-
ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow"))
81+
ExprCode(code, FalseLiteral, VariableValue(output, classOf[InternalRow]))
8082
}
8183

8284
private def createCodeForArray(
@@ -91,9 +93,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
9193
val index = ctx.freshName("index")
9294
val arrayClass = classOf[GenericArrayData].getName
9395

94-
val elementConverter = convertToSafe(ctx,
95-
StatementValue(CodeGenerator.getValue(tmpInput, elementType, index),
96-
CodeGenerator.javaType(elementType)), elementType)
96+
val elementConverter = convertToSafe(
97+
ctx,
98+
JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType),
99+
elementType)
97100
val code = s"""
98101
final ArrayData $tmpInput = $input;
99102
final int $numElements = $tmpInput.numElements();
@@ -107,7 +110,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
107110
final ArrayData $output = new $arrayClass($values);
108111
"""
109112

110-
ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData"))
113+
ExprCode(code, FalseLiteral, VariableValue(output, classOf[ArrayData]))
111114
}
112115

113116
private def createCodeForMap(
@@ -128,7 +131,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
128131
final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value});
129132
"""
130133

131-
ExprCode(code, FalseLiteral, VariableValue(output, "MapData"))
134+
ExprCode(code, FalseLiteral, VariableValue(output, classOf[MapData]))
132135
}
133136

134137
@tailrec

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
5252
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
5353
val tmpInput = ctx.freshName("tmpInput")
5454
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
55-
ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN),
56-
StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString),
57-
CodeGenerator.javaType(dt)))
55+
ExprCode(
56+
JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"),
57+
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
5858
}
5959

6060
val rowWriterClass = classOf[UnsafeRowWriter].getName
@@ -337,8 +337,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
337337
$writeExpressions
338338
"""
339339
// `rowWriter` is declared as a class field, so we can access it directly in methods.
340-
ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow",
341-
canDirectAccess = true))
340+
ExprCode(code, FalseLiteral, SimpleExprValue(s"$rowWriter.getRow()", classOf[UnsafeRow]))
342341
}
343342

344343
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =

0 commit comments

Comments
 (0)