diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index d2ad14f2a1a9..6ffd6605f75b 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -18,12 +18,15 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import java.lang.invoke.SerializedLambda +import java.lang.invoke.{MethodHandleInfo, SerializedLambda} +import scala.collection.JavaConverters._ import scala.collection.mutable.{Map, Set, Stack} -import org.apache.xbean.asm7.{ClassReader, ClassVisitor, MethodVisitor, Type} +import org.apache.commons.lang3.ClassUtils +import org.apache.xbean.asm7.{ClassReader, ClassVisitor, Handle, MethodVisitor, Type} import org.apache.xbean.asm7.Opcodes._ +import org.apache.xbean.asm7.tree.{ClassNode, MethodNode} import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging @@ -159,39 +162,6 @@ private[spark] object ClosureCleaner extends Logging { clean(closure, checkSerializable, cleanTransitively, Map.empty) } - /** - * Try to get a serialized Lambda from the closure. - * - * @param closure the closure to check. - */ - private def getSerializedLambda(closure: AnyRef): Option[SerializedLambda] = { - val isClosureCandidate = - closure.getClass.isSynthetic && - closure - .getClass - .getInterfaces.exists(_.getName == "scala.Serializable") - - if (isClosureCandidate) { - try { - Option(inspect(closure)) - } catch { - case e: Exception => - // no need to check if debug is enabled here the Spark - // logging api covers this. - logDebug("Closure is not a serialized lambda.", e) - None - } - } else { - None - } - } - - private def inspect(closure: AnyRef): SerializedLambda = { - val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") - writeReplace.setAccessible(true) - writeReplace.invoke(closure).asInstanceOf[java.lang.invoke.SerializedLambda] - } - /** * Helper method to clean the given closure in place. * @@ -239,12 +209,12 @@ private[spark] object ClosureCleaner extends Logging { cleanTransitively: Boolean, accessedFields: Map[Class[_], Set[String]]): Unit = { - // most likely to be the case with 2.12, 2.13 + // indylambda check. Most likely to be the case with 2.12, 2.13 // so we check first // non LMF-closures should be less frequent from now on - val lambdaFunc = getSerializedLambda(func) + val maybeIndylambdaProxy = IndylambdaScalaClosures.getSerializationProxy(func) - if (!isClosure(func.getClass) && lambdaFunc.isEmpty) { + if (!isClosure(func.getClass) && maybeIndylambdaProxy.isEmpty) { logDebug(s"Expected a closure; got ${func.getClass.getName}") return } @@ -256,7 +226,7 @@ private[spark] object ClosureCleaner extends Logging { return } - if (lambdaFunc.isEmpty) { + if (maybeIndylambdaProxy.isEmpty) { logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++") // A list of classes that represents closures enclosed in the given one @@ -300,7 +270,7 @@ private[spark] object ClosureCleaner extends Logging { } } - logDebug(s" + fields accessed by starting closure: " + accessedFields.size) + logDebug(s" + fields accessed by starting closure: ${accessedFields.size} classes") accessedFields.foreach { f => logDebug(" " + f) } // List of outer (class, object) pairs, ordered from outermost to innermost @@ -372,14 +342,64 @@ private[spark] object ClosureCleaner extends Logging { logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++") } else { - logDebug(s"Cleaning lambda: ${lambdaFunc.get.getImplMethodName}") + val lambdaProxy = maybeIndylambdaProxy.get + val implMethodName = lambdaProxy.getImplMethodName + + logDebug(s"Cleaning indylambda closure: $implMethodName") + + // capturing class is the class that declared this lambda + val capturingClassName = lambdaProxy.getCapturingClass.replace('/', '.') + val classLoader = func.getClass.getClassLoader // this is the safest option + // scalastyle:off classforname + val capturingClass = Class.forName(capturingClassName, false, classLoader) + // scalastyle:on classforname - val captClass = Utils.classForName(lambdaFunc.get.getCapturingClass.replace('/', '.'), - initialize = false, noSparkClassLoader = true) // Fail fast if we detect return statements in closures - getClassReader(captClass) - .accept(new ReturnStatementFinder(Some(lambdaFunc.get.getImplMethodName)), 0) - logDebug(s" +++ Lambda closure (${lambdaFunc.get.getImplMethodName}) is now cleaned +++") + val capturingClassReader = getClassReader(capturingClass) + capturingClassReader.accept(new ReturnStatementFinder(Option(implMethodName)), 0) + + val isClosureDeclaredInScalaRepl = capturingClassName.startsWith("$line") && + capturingClassName.endsWith("$iw") + val outerThisOpt = if (lambdaProxy.getCapturedArgCount > 0) { + Option(lambdaProxy.getCapturedArg(0)) + } else { + None + } + + // only need to clean when there is an enclosing "this" captured by the closure, and it + // should be something cleanable, i.e. a Scala REPL line object + val needsCleaning = isClosureDeclaredInScalaRepl && + outerThisOpt.isDefined && outerThisOpt.get.getClass.getName == capturingClassName + + if (needsCleaning) { + // indylambda closures do not reference enclosing closures via an `$outer` chain, so no + // transitive cleaning on the `$outer` chain is needed. + // Thus clean() shouldn't be recursively called with a non-empty accessedFields. + assert(accessedFields.isEmpty) + + initAccessedFields(accessedFields, Seq(capturingClass)) + IndylambdaScalaClosures.findAccessedFields( + lambdaProxy, classLoader, accessedFields, cleanTransitively) + + logDebug(s" + fields accessed by starting closure: ${accessedFields.size} classes") + accessedFields.foreach { f => logDebug(" " + f) } + + if (accessedFields(capturingClass).size < capturingClass.getDeclaredFields.length) { + // clone and clean the enclosing `this` only when there are fields to null out + + val outerThis = outerThisOpt.get + + logDebug(s" + cloning instance of REPL class $capturingClassName") + val clonedOuterThis = cloneAndSetFields( + parent = null, outerThis, capturingClass, accessedFields) + + val outerField = func.getClass.getDeclaredField("arg$1") + outerField.setAccessible(true) + outerField.set(func, clonedOuterThis) + } + } + + logDebug(s" +++ indylambda closure ($implMethodName) is now cleaned +++") } if (checkSerializable) { @@ -414,6 +434,312 @@ private[spark] object ClosureCleaner extends Logging { } } +private[spark] object IndylambdaScalaClosures extends Logging { + // internal name of java.lang.invoke.LambdaMetafactory + val LambdaMetafactoryClassName = "java/lang/invoke/LambdaMetafactory" + // the method that Scala indylambda use for bootstrap method + val LambdaMetafactoryMethodName = "altMetafactory" + val LambdaMetafactoryMethodDesc = "(Ljava/lang/invoke/MethodHandles$Lookup;" + + "Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)" + + "Ljava/lang/invoke/CallSite;" + + /** + * Check if the given reference is a indylambda style Scala closure. + * If so (e.g. for Scala 2.12+ closures), return a non-empty serialization proxy + * (SerializedLambda) of the closure; + * otherwise (e.g. for Scala 2.11 closures) return None. + * + * @param maybeClosure the closure to check. + */ + def getSerializationProxy(maybeClosure: AnyRef): Option[SerializedLambda] = { + def isClosureCandidate(cls: Class[_]): Boolean = { + // TODO: maybe lift this restriction to support other functional interfaces in the future + val implementedInterfaces = ClassUtils.getAllInterfaces(cls).asScala + implementedInterfaces.exists(_.getName.startsWith("scala.Function")) + } + + maybeClosure.getClass match { + // shortcut the fast check: + // 1. indylambda closure classes are generated by Java's LambdaMetafactory, and they're + // always synthetic. + // 2. We only care about Serializable closures, so let's check that as well + case c if !c.isSynthetic || !maybeClosure.isInstanceOf[Serializable] => None + + case c if isClosureCandidate(c) => + try { + Option(inspect(maybeClosure)).filter(isIndylambdaScalaClosure) + } catch { + case e: Exception => + logDebug("The given reference is not an indylambda Scala closure.", e) + None + } + + case _ => None + } + } + + def isIndylambdaScalaClosure(lambdaProxy: SerializedLambda): Boolean = { + lambdaProxy.getImplMethodKind == MethodHandleInfo.REF_invokeStatic && + lambdaProxy.getImplMethodName.contains("$anonfun$") + } + + def inspect(closure: AnyRef): SerializedLambda = { + val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + writeReplace.invoke(closure).asInstanceOf[SerializedLambda] + } + + /** + * Check if the handle represents the LambdaMetafactory that indylambda Scala closures + * use for creating the lambda class and getting a closure instance. + */ + def isLambdaMetafactory(bsmHandle: Handle): Boolean = { + bsmHandle.getOwner == LambdaMetafactoryClassName && + bsmHandle.getName == LambdaMetafactoryMethodName && + bsmHandle.getDesc == LambdaMetafactoryMethodDesc + } + + /** + * Check if the handle represents a target method that is: + * - a STATIC method that implements a Scala lambda body in the indylambda style + * - captures the enclosing `this`, i.e. the first argument is a reference to the same type as + * the owning class. + * Returns true if both criteria above are met. + */ + def isLambdaBodyCapturingOuter(handle: Handle, ownerInternalName: String): Boolean = { + handle.getTag == H_INVOKESTATIC && + handle.getName.contains("$anonfun$") && + handle.getOwner == ownerInternalName && + handle.getDesc.startsWith(s"(L$ownerInternalName;") + } + + /** + * Check if the callee of a call site is a inner class constructor. + * - A constructor has to be invoked via INVOKESPECIAL + * - A constructor's internal name is "<init>" and the return type is "V" (void) + * - An inner class' first argument in the signature has to be a reference to the + * enclosing "this", aka `$outer` in Scala. + */ + def isInnerClassCtorCapturingOuter( + op: Int, owner: String, name: String, desc: String, callerInternalName: String): Boolean = { + op == INVOKESPECIAL && name == "" && desc.startsWith(s"(L$callerInternalName;") + } + + /** + * Scans an indylambda Scala closure, along with its lexically nested closures, and populate + * the accessed fields info on which fields on the outer object are accessed. + * + * This is equivalent to getInnerClosureClasses() + InnerClosureFinder + FieldAccessFinder fused + * into one for processing indylambda closures. The traversal order along the call graph is the + * same for all three combined, so they can be fused together easily while maintaining the same + * ordering as the existing implementation. + * + * Precondition: this function expects the `accessedFields` to be populated with all known + * outer classes and their super classes to be in the map as keys, e.g. + * initializing via ClosureCleaner.initAccessedFields. + */ + // scalastyle:off line.size.limit + // Example: run the following code snippet in a Spark Shell w/ Scala 2.12+: + // val topLevelValue = "someValue"; val closure = (j: Int) => { + // class InnerFoo { + // val innerClosure = (x: Int) => (1 to x).map { y => y + topLevelValue } + // } + // val innerFoo = new InnerFoo + // (1 to j).flatMap(innerFoo.innerClosure) + // } + // sc.parallelize(0 to 2).map(closure).collect + // + // produces the following trace-level logs: + // (slightly simplified: + // - omitting the "ignoring ..." lines; + // - "$iw" is actually "$line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw"; + // - "invokedynamic" lines are simplified to just show the name+desc, omitting the bsm info) + // Cleaning indylambda closure: $anonfun$closure$1$adapted + // scanning $iw.$anonfun$closure$1$adapted(L$iw;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found intra class call to $iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq; + // scanning $iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq; + // found inner class $iw$InnerFoo$1 + // found method innerClosure()Lscala/Function1; + // found method $anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found method $anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // found method (L$iw;)V + // found method $anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found method $anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found method $deserializeLambda$(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object; + // found call to outer $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // scanning $iw$InnerFoo$1.$deserializeLambda$(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object; + // invokedynamic: lambdaDeserialize(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;, bsm...) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found call to outer $iw.topLevelValue()Ljava/lang/String; + // scanning $iw.topLevelValue()Ljava/lang/String; + // found field access topLevelValue on $iw + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // scanning $iw$InnerFoo$1.(L$iw;)V + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found call to outer $iw.topLevelValue()Ljava/lang/String; + // scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // + fields accessed by starting closure: 2 classes + // (class java.lang.Object,Set()) + // (class $iw,Set(topLevelValue)) + // + cloning instance of REPL class $iw + // +++ indylambda closure ($anonfun$closure$1$adapted) is now cleaned +++ + // + // scalastyle:on line.size.limit + def findAccessedFields( + lambdaProxy: SerializedLambda, + lambdaClassLoader: ClassLoader, + accessedFields: Map[Class[_], Set[String]], + findTransitively: Boolean): Unit = { + + // We may need to visit the same class multiple times for different methods on it, and we'll + // need to lookup by name. So we use ASM's Tree API and cache the ClassNode/MethodNode. + val classInfoByInternalName = Map.empty[String, (Class[_], ClassNode)] + val methodNodeById = Map.empty[MethodIdentifier[_], MethodNode] + def getOrUpdateClassInfo(classInternalName: String): (Class[_], ClassNode) = { + val classInfo = classInfoByInternalName.getOrElseUpdate(classInternalName, { + val classExternalName = classInternalName.replace('/', '.') + // scalastyle:off classforname + val clazz = Class.forName(classExternalName, false, lambdaClassLoader) + // scalastyle:on classforname + val classNode = new ClassNode() + val classReader = ClosureCleaner.getClassReader(clazz) + classReader.accept(classNode, 0) + + for (m <- classNode.methods.asScala) { + methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m + } + + (clazz, classNode) + }) + classInfo + } + + val implClassInternalName = lambdaProxy.getImplClass + val (implClass, _) = getOrUpdateClassInfo(implClassInternalName) + + val implMethodId = MethodIdentifier( + implClass, lambdaProxy.getImplMethodName, lambdaProxy.getImplMethodSignature) + + // The set internal names of classes that we would consider following the calls into. + // Candidates are: known outer class which happens to be the starting closure's impl class, + // and all inner classes discovered below. + // Note that code in an inner class can make calls to methods in any of its enclosing classes, + // e.g. + // starting closure (in class T) + // inner class A + // inner class B + // inner closure + // we need to track calls from "inner closure" to outer classes relative to it (class T, A, B) + // to better find and track field accesses. + val trackedClassInternalNames = Set[String](implClassInternalName) + + // Depth-first search for inner closures and track the fields that were accessed in them. + // Start from the lambda body's implementation method, follow method invocations + val visited = Set.empty[MethodIdentifier[_]] + val stack = Stack[MethodIdentifier[_]](implMethodId) + def pushIfNotVisited(methodId: MethodIdentifier[_]): Unit = { + if (!visited.contains(methodId)) { + stack.push(methodId) + } + } + + while (!stack.isEmpty) { + val currentId = stack.pop + visited += currentId + + val currentClass = currentId.cls + val currentMethodNode = methodNodeById(currentId) + logTrace(s" scanning ${currentId.cls.getName}.${currentId.name}${currentId.desc}") + currentMethodNode.accept(new MethodVisitor(ASM7) { + val currentClassName = currentClass.getName + val currentClassInternalName = currentClassName.replace('.', '/') + + // Find and update the accessedFields info. Only fields on known outer classes are tracked. + // This is the FieldAccessFinder equivalent. + override def visitFieldInsn(op: Int, owner: String, name: String, desc: String): Unit = { + if (op == GETFIELD || op == PUTFIELD) { + val ownerExternalName = owner.replace('/', '.') + for (cl <- accessedFields.keys if cl.getName == ownerExternalName) { + logTrace(s" found field access $name on $ownerExternalName") + accessedFields(cl) += name + } + } + } + + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = { + val ownerExternalName = owner.replace('/', '.') + if (owner == currentClassInternalName) { + logTrace(s" found intra class call to $ownerExternalName.$name$desc") + // could be invoking a helper method or a field accessor method, just follow it. + pushIfNotVisited(MethodIdentifier(currentClass, name, desc)) + } else if (isInnerClassCtorCapturingOuter( + op, owner, name, desc, currentClassInternalName)) { + // Discover inner classes. + // This this the InnerClassFinder equivalent for inner classes, which still use the + // `$outer` chain. So this is NOT controlled by the `findTransitively` flag. + logDebug(s" found inner class $ownerExternalName") + val innerClassInfo = getOrUpdateClassInfo(owner) + val innerClass = innerClassInfo._1 + val innerClassNode = innerClassInfo._2 + trackedClassInternalNames += owner + // We need to visit all methods on the inner class so that we don't missing anything. + for (m <- innerClassNode.methods.asScala) { + logTrace(s" found method ${m.name}${m.desc}") + pushIfNotVisited(MethodIdentifier(innerClass, m.name, m.desc)) + } + } else if (findTransitively && trackedClassInternalNames.contains(owner)) { + logTrace(s" found call to outer $ownerExternalName.$name$desc") + val (calleeClass, _) = getOrUpdateClassInfo(owner) // make sure MethodNodes are cached + pushIfNotVisited(MethodIdentifier(calleeClass, name, desc)) + } else { + // keep the same behavior as the original ClosureCleaner + logTrace(s" ignoring call to $ownerExternalName.$name$desc") + } + } + + // Find the lexically nested closures + // This is the InnerClosureFinder equivalent for indylambda nested closures + override def visitInvokeDynamicInsn( + name: String, desc: String, bsmHandle: Handle, bsmArgs: Object*): Unit = { + logTrace(s" invokedynamic: $name$desc, bsmHandle=$bsmHandle, bsmArgs=$bsmArgs") + + // fast check: we only care about Scala lambda creation + // TODO: maybe lift this restriction and support other functional interfaces + if (!name.startsWith("apply")) return + if (!Type.getReturnType(desc).getDescriptor.startsWith("Lscala/Function")) return + + if (isLambdaMetafactory(bsmHandle)) { + // OK we're in the right bootstrap method for serializable Java 8 style lambda creation + val targetHandle = bsmArgs(1).asInstanceOf[Handle] + if (isLambdaBodyCapturingOuter(targetHandle, currentClassInternalName)) { + // this is a lexically nested closure that also captures the enclosing `this` + logDebug(s" found inner closure $targetHandle") + val calleeMethodId = + MethodIdentifier(currentClass, targetHandle.getName, targetHandle.getDesc) + pushIfNotVisited(calleeMethodId) + } + } + } + }) + } + } +} + private[spark] class ReturnStatementInClosureException extends SparkException("Return statements aren't allowed in Spark closures") @@ -422,7 +748,7 @@ private class ReturnStatementFinder(targetMethodName: Option[String] = None) override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { - // $anonfun$ covers Java 8 lambdas + // $anonfun$ covers indylambda closures if (name.contains("apply") || name.contains("$anonfun$")) { // A method with suffix "$adapted" will be generated in cases like // { _:Int => return; Seq()} but not { _:Int => return; true} diff --git a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala index 4795306692f7..e11a54bc8807 100644 --- a/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/SingletonReplSuite.scala @@ -380,6 +380,67 @@ class SingletonReplSuite extends SparkFunSuite { assertDoesNotContain("Exception", output) } + test("SPARK-31399: should clone+clean line object w/ non-serializable state in ClosureCleaner") { + // Test ClosureCleaner when a closure captures the enclosing `this` REPL line object, and that + // object contains an unused non-serializable field. + // Specifically, the closure in this test case contains a directly nested closure, and the + // capture is triggered by the inner closure. + // `ns` should be nulled out, but `topLevelValue` should stay intact. + + // Can't use :paste mode because PipedOutputStream/PipedInputStream doesn't work well with the + // EOT control character (i.e. Ctrl+D). + // Just write things on a single line to emulate :paste mode. + + // NOTE: in order for this test case to trigger the intended scenario, the following three + // variables need to be in the same "input", which will make the REPL pack them into the + // same REPL line object: + // - ns: a non-serializable state, not accessed by the closure; + // - topLevelValue: a serializable state, accessed by the closure; + // - closure: the starting closure, captures the enclosing REPL line object. + val output = runInterpreter( + """ + |class NotSerializableClass(val x: Int) + |val ns = new NotSerializableClass(42); val topLevelValue = "someValue"; val closure = + |(j: Int) => { + | (1 to j).flatMap { x => + | (1 to x).map { y => y + topLevelValue } + | } + |} + |val r = sc.parallelize(0 to 2).map(closure).collect + """.stripMargin) + assertContains("r: Array[scala.collection.immutable.IndexedSeq[String]] = " + + "Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output) + assertDoesNotContain("Exception", output) + } + + test("SPARK-31399: ClosureCleaner should discover indirectly nested closure in inner class") { + // Similar to the previous test case, but with indirect closure nesting instead. + // There's still nested closures involved, but the inner closure is indirectly nested in the + // outer closure, with a level of inner class in between them. + // This changes how the inner closure references/captures the outer closure/enclosing `this` + // REPL line object, and covers a different code path in inner closure discovery. + + // `ns` should be nulled out, but `topLevelValue` should stay intact. + + val output = runInterpreter( + """ + |class NotSerializableClass(val x: Int) + |val ns = new NotSerializableClass(42); val topLevelValue = "someValue"; val closure = + |(j: Int) => { + | class InnerFoo { + | val innerClosure = (x: Int) => (1 to x).map { y => y + topLevelValue } + | } + | val innerFoo = new InnerFoo + | (1 to j).flatMap(innerFoo.innerClosure) + |} + |val r = sc.parallelize(0 to 2).map(closure).collect + """.stripMargin) + assertContains("r: Array[scala.collection.immutable.IndexedSeq[String]] = " + + "Array(Vector(), Vector(1someValue), Vector(1someValue, 1someValue, 2someValue))", output) + assertDoesNotContain("Array(Vector(), Vector(1null), Vector(1null, 1null, 2null)", output) + assertDoesNotContain("Exception", output) + } + test("newProductSeqEncoder with REPL defined class") { val output = runInterpreter( """