diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5c9e604a8d29..0f0ec4737508 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -217,6 +217,18 @@ class CodegenContext { splitExpressions(expressions = initCodes, funcName = "init", arguments = Nil) } + /** + * Return true if a given variable has been described as a global variable + */ + def isDeclaredMutableState(varName: String): Boolean = { + val j = varName.indexOf("[") + val qualifiedName = if (j < 0) varName else varName.substring(0, j) + mutableStates.find { s => + val i = s._2.indexOf("[") + qualifiedName == (if (i < 0) s._2 else s._2.substring(0, i)) + }.isDefined + } + /** * Code statements to initialize states that depend on the partition index. * An integer `partitionIndex` will be made available within the scope. @@ -789,7 +801,8 @@ class CodegenContext { * @param expressions the codes to evaluate expressions. * @param funcName the split function name base. * @param extraArguments the list of (type, name) of the arguments of the split function, - * except for the current inputs like `ctx.INPUT_ROW`. + * except for the current inputs like `ctx.INPUT_ROW`. Name must not be + * mutable state. * @param returnType the return type of the split function. * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup. * @param foldFunctions folds the split function calls. @@ -823,7 +836,8 @@ class CodegenContext { * * @param expressions the codes to evaluate expressions. * @param funcName the split function name base. - * @param arguments the list of (type, name) of the arguments of the split function. + * @param arguments the list of (type, name) of the arguments of the split function. Name must + * not be mutable state * @param returnType the return type of the split function. * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup. * @param foldFunctions folds the split function calls. @@ -842,7 +856,10 @@ class CodegenContext { blocks.head } else { val func = freshName(funcName) - val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ") + val argString = arguments.map { case (t, name) => + assert(!isDeclaredMutableState(name), + s"$name in arguments should not be declared as a global variable") + s"$t $name" }.mkString(", ") val functions = blocks.zipWithIndex.map { case (body, i) => val name = s"${func}_$i" val code = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 40bf29bb3b57..c349e44b5945 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -394,4 +394,23 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { Map("add" -> Literal(1))).genCode(ctx) assert(ctx.mutableStates.isEmpty) } + + test("SPARK-22668: ensure no global variables in split method arguments") { + val ctx = new CodegenContext + ctx.addMutableState(ctx.JAVA_INT, "ij") + ctx.addMutableState("int[]", "array") + ctx.addMutableState("int[][]", "b") + + assert(ctx.isDeclaredMutableState("ij")) + assert(ctx.isDeclaredMutableState("array")) + assert(ctx.isDeclaredMutableState("array[1]")) + assert(ctx.isDeclaredMutableState("b[]")) + assert(ctx.isDeclaredMutableState("b[1][]")) + + assert(!ctx.isDeclaredMutableState("i")) + assert(!ctx.isDeclaredMutableState("j")) + assert(!ctx.isDeclaredMutableState("ij1")) + assert(!ctx.isDeclaredMutableState("arr")) + assert(!ctx.isDeclaredMutableState("bb[]")) + } }