diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12baddf1bf7ac..8cafaef61c7d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1040,7 +1040,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } """ } - val fieldsEvalCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { + val fieldsEvalCodes = if (ctx.currentVars == null) { ctx.splitExpressions( expressions = fieldsEvalCode, funcName = "castStruct", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 668c816b3fd8d..1645db12c53f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -788,11 +788,31 @@ class CodegenContext { * @param expressions the codes to evaluate expressions. */ def splitExpressions(expressions: Seq[String]): String = { + splitExpressions(expressions, funcName = "apply", extraArguments = Nil) + } + + /** + * Similar to [[splitExpressions(expressions: Seq[String])]], but has customized function name + * and extra arguments. + * + * @param expressions the codes to evaluate expressions. + * @param funcName the split function name base. + * @param extraArguments the list of (type, name) of the arguments of the split function + * except for ctx.INPUT_ROW + */ + def splitExpressions( + expressions: Seq[String], + funcName: String, + extraArguments: Seq[(String, String)]): String = { // TODO: support whole stage codegen if (INPUT_ROW == null || currentVars != null) { - return expressions.mkString("\n") + expressions.mkString("\n") + } else { + splitExpressions( + expressions, + funcName, + arguments = ("InternalRow", INPUT_ROW) +: extraArguments) } - splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", INPUT_ROW) :: Nil) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index eb7475354b104..1aaaaf1db48d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -251,12 +251,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { } } """) - val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil - ctx.splitExpressions(expressions = listCode, funcName = "valueIn", arguments = args) - } else { - listCode.mkString("\n") - } + val listCodes = ctx.splitExpressions( + expressions = listCode, + funcName = "valueIn", + extraArguments = (ctx.javaType(value.dataType), valueArg) :: Nil) ev.copy(code = s""" ${valueGen.code} ${ev.value} = false; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ee5cf925d3cef..34917ace001fa 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -73,14 +73,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } - val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions( - expressions = inputs, - funcName = "valueConcat", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) - } else { - inputs.mkString("\n") - } + val codes = ctx.splitExpressions( + expressions = inputs, + funcName = "valueConcat", + extraArguments = ("UTF8String[]", args) :: Nil) ev.copy(s""" UTF8String[] $args = new UTF8String[${evals.length}]; $codes @@ -156,14 +152,10 @@ case class ConcatWs(children: Seq[Expression]) "" } } - val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions( + val codes = ctx.splitExpressions( expressions = inputs, funcName = "valueConcatWs", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil) - } else { - inputs.mkString("\n") - } + extraArguments = ("UTF8String[]", args) :: Nil) ev.copy(s""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} @@ -1388,14 +1380,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC $argList[$index] = $value; """ } - val argListCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) { - ctx.splitExpressions( - expressions = argListCode, - funcName = "valueFormatString", - arguments = ("InternalRow", ctx.INPUT_ROW) :: ("Object[]", argList) :: Nil) - } else { - argListCode.mkString("\n") - } + val argListCodes = ctx.splitExpressions( + expressions = argListCode, + funcName = "valueFormatString", + extraArguments = ("Object[]", argList) :: Nil) val form = ctx.freshName("formatter") val formatter = classOf[java.util.Formatter].getName