Skip to content
Closed
51 changes: 43 additions & 8 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark
import java.lang.ref.{ReferenceQueue, WeakReference}

import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.reflect.ClassTag
import scala.util.DynamicVariable

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -63,6 +65,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
with SynchronizedBuffer[CleanerListener]

private val cleaningThread = new Thread() { override def run() { keepCleaning() }}

private var broadcastRefCounts = Map(0L -> 0L)

/**
* Whether the cleaning thread will block on cleanup tasks.
Expand Down Expand Up @@ -102,9 +106,25 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {

/** Register a Broadcast for cleanup when it is garbage collected. */
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
incBroadcastRefCount(broadcast.id)
registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
}

private def incBroadcastRefCount[T](bid: Long) {
val newRefCount: Long = this.broadcastRefCounts.getOrElse(bid, 0L) + 1
this.broadcastRefCounts = this.broadcastRefCounts + Pair(bid, newRefCount)
}

private def decBroadcastRefCount[T](bid: Long) = {
this.broadcastRefCounts.get(bid) match {
case Some(rc:Long) if rc > 0 => {
this.broadcastRefCounts = this.broadcastRefCounts + Pair(bid, rc - 1)
rc - 1
}
case _ => 0
}
}

/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
Expand Down Expand Up @@ -161,14 +181,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {

/** Perform broadcast cleanup. */
def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) {
try {
logDebug("Cleaning broadcast " + broadcastId)
broadcastManager.unbroadcast(broadcastId, true, blocking)
listeners.foreach(_.broadcastCleaned(broadcastId))
logInfo("Cleaned broadcast " + broadcastId)
} catch {
case e: Exception => logError("Error cleaning broadcast " + broadcastId, e)
decBroadcastRefCount(broadcastId) match {
case x if x > 0 => {}
case _ => try {
logDebug("Cleaning broadcast " + broadcastId)
broadcastManager.unbroadcast(broadcastId, true, blocking)
listeners.foreach(_.broadcastCleaned(broadcastId))
logInfo("Cleaned broadcast " + broadcastId)
} catch {
case e: Exception => logError("Error cleaning broadcast " + broadcastId, e)
}
}

}

private def blockManagerMaster = sc.env.blockManager.master
Expand All @@ -179,8 +203,19 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
// to ensure that more reliable testing.
}

private object ContextCleaner {
private[spark] object ContextCleaner {
private val REF_QUEUE_POLL_TIMEOUT = 100
val currentCleaner = new DynamicVariable[Option[ContextCleaner]](None)

/**
* Runs the given thunk with a dynamically-scoped binding for the current ContextCleaner.
* This is necessary for blocks of code that serialize and deserialize broadcast variable
* objects, since all clones of a Broadcast object <tt>b</tt> need to be re-registered with the
* context cleaner that is tracking <tt>b</tt>.
*/
def withCurrentCleaner[T <: Any : ClassTag](cc: Option[ContextCleaner])(thunk: => T) = {
currentCleaner.withValue(cc)(thunk)
}
}

/**
Expand Down
9 changes: 5 additions & 4 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,9 @@ class SparkContext(config: SparkConf) extends Logging {
throw new SparkException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
// There's no need to check this function for serializability,
// since it will be run right away.
val cleanedFunc = clean(func, false)
logInfo("Starting job: " + callSite.short)
val start = System.nanoTime
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
Expand Down Expand Up @@ -1212,9 +1214,8 @@ class SparkContext(config: SparkConf) extends Logging {
* @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not
* serializable
*/
private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = {
ClosureCleaner.clean(f, checkSerializable)
f
private[spark] def clean[F <: AnyRef : ClassTag](f: F, checkSerializable: Boolean = true): F = {
ClosureCleaner.clean(f, checkSerializable, this)
}

/**
Expand Down
12 changes: 10 additions & 2 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.spark.broadcast

import java.io.Serializable
import java.io.{ObjectInputStream, Serializable}

import org.apache.spark.SparkException
import org.apache.spark.{ContextCleaner, SparkException}

import scala.reflect.ClassTag

Expand Down Expand Up @@ -129,4 +129,12 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
}

override def toString = "Broadcast(" + id + ")"

private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
ContextCleaner.currentCleaner.value match {
case None => {}
case Some(cc: ContextCleaner) => cc.registerBroadcastForCleanup(this)
}
}
}
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -733,14 +733,16 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}

/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
sc.runJob(this, (iter: Iterator[T]) => f(iter))
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
}

/**
Expand Down
21 changes: 14 additions & 7 deletions core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.Map
import scala.collection.mutable.Set

import scala.reflect.ClassTag

import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._

import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.{Logging, SparkEnv, SparkException, SparkContext, ContextCleaner}

private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
Expand Down Expand Up @@ -100,8 +102,8 @@ private[spark] object ClosureCleaner extends Logging {
null
}
}

def clean(func: AnyRef, checkSerializable: Boolean = true) {
def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true, sc: SparkContext): F = {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
Expand Down Expand Up @@ -154,14 +156,19 @@ private[spark] object ClosureCleaner extends Logging {
field.set(func, outer)
}

if (checkSerializable) {
ensureSerializable(func)
if (captureNow) {
ContextCleaner.withCurrentCleaner(sc.cleaner){
cloneViaSerializing(func)
}
} else {
func
}
}

private def ensureSerializable(func: AnyRef) {
private def cloneViaSerializing[T: ClassTag](func: T): T = {
try {
SparkEnv.get.closureSerializer.newInstance().serialize(func)
val serializer = SparkEnv.get.closureSerializer.newInstance()
serializer.deserialize[T](serializer.serialize[T](func))
} catch {
case ex: Exception => throw new SparkException("Task not serializable", ex)
}
Expand Down
37 changes: 36 additions & 1 deletion core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.lang.ref.WeakReference
import scala.collection.mutable.{HashSet, SynchronizedSet}
import scala.language.existentials
import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.Random

import org.scalatest.{BeforeAndAfter, FunSuite}
Expand Down Expand Up @@ -141,6 +142,33 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
postGCTester.assertCleanup()
}

test("automatically cleanup broadcast only after all extant clones become unreachable") {
var broadcast = newBroadcast

// clone this broadcast variable
var broadcastClone = cloneBySerializing(broadcast)

val id = broadcast.id

// eliminate all strong references to the original broadcast; keep the clone
broadcast = null

// Test that GC does not cause broadcast cleanup since a strong reference to a
// clone of the broadcast with the given id still exist
val preGCTester = new CleanerTester(sc, broadcastIds = Seq(id))
runGC()
intercept[Exception] {
preGCTester.assertCleanup()(timeout(1000 millis))
}

// Test that GC causes broadcast cleanup after dereferencing the clone
val postGCTester = new CleanerTester(sc, broadcastIds = Seq(id))
broadcastClone = null
runGC()
postGCTester.assertCleanup()
}


test("automatically cleanup RDD + shuffle + broadcast") {
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
Expand Down Expand Up @@ -242,7 +270,14 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
Thread.sleep(200)
}
}


def cloneBySerializing[T <: Any : ClassTag](ref: T): T = {
val serializer = SparkEnv.get.closureSerializer.newInstance()
ContextCleaner.withCurrentCleaner[T](sc.cleaner){
serializer.deserialize(serializer.serialize(ref))
}
}

def cleaner = sc.cleaner.get
}

Expand Down
17 changes: 16 additions & 1 deletion core/src/test/scala/org/apache/spark/FailureSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}

test("failure because task closure is not serializable") {
test("failure because closure in final-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

Expand All @@ -122,6 +122,13 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown.getMessage.contains("NotSerializableException") ||
thrown.getCause.getClass === classOf[NotSerializableException])

FailureSuiteState.clear()
}

test("failure because closure in early-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
Expand All @@ -130,6 +137,13 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown1.getMessage.contains("NotSerializableException") ||
thrown1.getCause.getClass === classOf[NotSerializableException])

FailureSuiteState.clear()
}

test("failure because closure in foreach task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
Expand All @@ -141,5 +155,6 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}


// TODO: Need to add tests with shuffle fetch failures.
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex
// transformation on a given RDD, creating one test case for each

for (transformation <-
Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _,
"mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _,
"mapPartitionsWithIndex" -> xmapPartitionsWithIndex _,
"mapPartitionsWithContext" -> xmapPartitionsWithContext _,
"filterWith" -> xfilterWith _)) {
Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _,
"mapWith" -> mapWith _, "mapPartitions" -> mapPartitions _,
"mapPartitionsWithIndex" -> mapPartitionsWithIndex _,
"mapPartitionsWithContext" -> mapPartitionsWithContext _,
"filterWith" -> filterWith _)) {
val (name, xf) = transformation

test(s"$name transformations throw proactive serialization exceptions") {
Expand All @@ -70,21 +70,28 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex
}
}

private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.map(y=>uc.op(y))
private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapWith(x => x.toString)((x,y)=>x + uc.op(y))
private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
def map(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.map(y => uc.op(y))

def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapWith(x => x.toString)((x,y) => x + uc.op(y))

def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.flatMap(y=>Seq(uc.op(y)))
private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] =

def filter(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filter(y=>uc.pred(y))
private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filterWith(x => x.toString)((x,y)=>uc.pred(y))
private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitions(_.map(y=>uc.op(y)))
private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y)))
private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y)))

def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filterWith(x => x.toString)((x,y) => uc.pred(y))

def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitions(_.map(y => uc.op(y)))

def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y)))

def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y)))

}
Loading