Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
}
"""
}
val fieldsEvalCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
val fieldsEvalCodes = if (ctx.currentVars == null) {
ctx.splitExpressions(
expressions = fieldsEvalCode,
funcName = "castStruct",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,11 +788,31 @@ class CodegenContext {
* @param expressions the codes to evaluate expressions.
*/
def splitExpressions(expressions: Seq[String]): String = {
splitExpressions(expressions, funcName = "apply", extraArguments = Nil)
}

/**
* Similar to [[splitExpressions(expressions: Seq[String])]], but has customized function name
* and extra arguments.
*
* @param expressions the codes to evaluate expressions.
* @param funcName the split function name base.
* @param extraArguments the list of (type, name) of the arguments of the split function
* except for ctx.INPUT_ROW
*/
def splitExpressions(
expressions: Seq[String],
funcName: String,
extraArguments: Seq[(String, String)]): String = {
// TODO: support whole stage codegen
if (INPUT_ROW == null || currentVars != null) {
return expressions.mkString("\n")
expressions.mkString("\n")
} else {
splitExpressions(
expressions,
funcName,
arguments = ("InternalRow", INPUT_ROW) +: extraArguments)
}
splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", INPUT_ROW) :: Nil)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}
}
""")
val listCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
val args = ("InternalRow", ctx.INPUT_ROW) :: (ctx.javaType(value.dataType), valueArg) :: Nil
ctx.splitExpressions(expressions = listCode, funcName = "valueIn", arguments = args)
} else {
listCode.mkString("\n")
}
val listCodes = ctx.splitExpressions(
expressions = listCode,
funcName = "valueIn",
extraArguments = (ctx.javaType(value.dataType), valueArg) :: Nil)
ev.copy(code = s"""
${valueGen.code}
${ev.value} = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
}
"""
}
val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
ctx.splitExpressions(
expressions = inputs,
funcName = "valueConcat",
arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
} else {
inputs.mkString("\n")
}
val codes = ctx.splitExpressions(
expressions = inputs,
funcName = "valueConcat",
extraArguments = ("UTF8String[]", args) :: Nil)
ev.copy(s"""
UTF8String[] $args = new UTF8String[${evals.length}];
$codes
Expand Down Expand Up @@ -156,14 +152,10 @@ case class ConcatWs(children: Seq[Expression])
""
}
}
val codes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
ctx.splitExpressions(
val codes = ctx.splitExpressions(
expressions = inputs,
funcName = "valueConcatWs",
arguments = ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String[]", args) :: Nil)
} else {
inputs.mkString("\n")
}
extraArguments = ("UTF8String[]", args) :: Nil)
ev.copy(s"""
UTF8String[] $args = new UTF8String[$numArgs];
${separator.code}
Expand Down Expand Up @@ -1388,14 +1380,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
$argList[$index] = $value;
"""
}
val argListCodes = if (ctx.INPUT_ROW != null && ctx.currentVars == null) {
ctx.splitExpressions(
expressions = argListCode,
funcName = "valueFormatString",
arguments = ("InternalRow", ctx.INPUT_ROW) :: ("Object[]", argList) :: Nil)
} else {
argListCode.mkString("\n")
}
val argListCodes = ctx.splitExpressions(
expressions = argListCode,
funcName = "valueFormatString",
extraArguments = ("Object[]", argList) :: Nil)

val form = ctx.freshName("formatter")
val formatter = classOf[java.util.Formatter].getName
Expand Down