Skip to content
Closed
28 changes: 18 additions & 10 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle
* Base class for dependencies.
*/
@DeveloperApi
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
abstract class Dependency[T] extends Serializable {
def rdd: RDD[T]
}


/**
Expand All @@ -36,41 +38,47 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
*/
@DeveloperApi
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just FYI - this is an API breaking change... probably not a huge deal, but FYI

/**
* Get the parent partitions for a child partition.
* @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
def getParents(partitionId: Int): Seq[Int]

override def rdd: RDD[T] = _rdd
}


/**
* :: DeveloperApi ::
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
* the RDD is transient since we don't need it on the executor side.
*
* @param _rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
class ShuffleDependency[K, V, C](
@transient rdd: RDD[_ <: Product2[K, V]],
@transient _rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
extends Dependency[Product2[K, V]] {

override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]]

val shuffleId: Int = rdd.context.newShuffleId()
val shuffleId: Int = _rdd.context.newShuffleId()

val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
shuffleId, rdd.partitions.size, this)
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
shuffleId, _rdd.partitions.size, this)

rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}


Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging {
// TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
ResultTask.clearCache()
listenerBus.stop()
eventLogger.foreach(_.stop())
logInfo("Successfully stopped SparkContext")
Expand Down
11 changes: 4 additions & 7 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}

Expand Down Expand Up @@ -1206,16 +1207,12 @@ abstract class RDD[T: ClassTag](
/**
* Return whether this RDD has been checkpointed or not
*/
def isCheckpointed: Boolean = {
checkpointData.map(_.isCheckpointed).getOrElse(false)
}
def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)

/**
* Gets the name of the file to which this RDD was checkpointed
*/
def getCheckpointFile: Option[String] = {
checkpointData.flatMap(_.getCheckpointFile)
}
def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)

// =======================================================================
// Other internal methods and fields
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
}
logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
}
Expand All @@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}

private[spark] object RDDCheckpointData {
def clearTaskCaches() {
ShuffleMapTask.clearCache()
ResultTask.clearCache()
}
}
// Used for synchronization
private[spark] object RDDCheckpointData
87 changes: 63 additions & 24 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.scheduler

import java.io.{NotSerializableException, PrintWriter, StringWriter}
import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

Expand All @@ -35,6 +35,7 @@ import akka.pattern.ask
import akka.util.Timeout

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -114,6 +115,10 @@ class DAGScheduler(
private val dagSchedulerActorSupervisor =
env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))

// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()

private[scheduler] var eventProcessActor: ActorRef = _

private def initializeEventProcessActor() {
Expand Down Expand Up @@ -361,9 +366,6 @@ class DAGScheduler(
// data structures based on StageId
stageIdToStage -= stageId

ShuffleMapTask.removeStage(stageId)
ResultTask.removeStage(stageId)

logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
Expand Down Expand Up @@ -691,49 +693,83 @@ class DAGScheduler(
}
}


/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
var tasks = ArrayBuffer[Task[_]]()

val properties = if (jobIdToActiveJob.contains(jobId)) {
jobIdToActiveJob(stage.jobId).properties
} else {
// this stage will be assigned to "default" pool
null
}

runningStages += stage
// SparkListenerStageSubmitted should be posted before testing whether tasks are
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))

// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
// Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
// the serialized copy of the RDD and for each task we will deserialize it, which means each
// task gets a different copy of the RDD. This provides stronger isolation between tasks that
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] =
if (stage.isShuffleMap) {
closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array()
} else {
closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array()
}
taskBinary = sc.broadcast(taskBinaryBytes)
} catch {
// In the case of a failure during serialization, abort the stage.
case e: NotSerializableException =>
abortStage(stage, "Task not serializable: " + e.toString)
runningStages -= stage
return
case NonFatal(e) =>
abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
runningStages -= stage
return
}

if (stage.isShuffleMap) {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
val part = stage.rdd.partitions(p)
tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs)
}
} else {
// This is a final stage; figure out its job's missing partitions
val job = stage.resultOfJob.get
for (id <- 0 until job.numPartitions if !job.finished(id)) {
val partition = job.partitions(id)
val locs = getPreferredLocs(stage.rdd, partition)
tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ResultTask(stage.id, taskBinary, part, locs, id)
}
}

val properties = if (jobIdToActiveJob.contains(jobId)) {
jobIdToActiveJob(stage.jobId).properties
} else {
// this stage will be assigned to "default" pool
null
}

if (tasks.size > 0) {
runningStages += stage
// SparkListenerStageSubmitted should be posted before testing whether tasks are
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))

// Preemptively serialize a task to make sure it can be serialized. We are catching this
// exception here because it would be fairly hard to catch the non-serializable exception
// down the road, where we have several different implementations for local scheduler and
// cluster schedulers.
//
// We've already serialized RDDs and closures in taskBinary, but here we check for all other
// objects such as Partition.
try {
SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
closureSerializer.serialize(tasks.head)
} catch {
case e: NotSerializableException =>
abortStage(stage, "Task not serializable: " + e.toString)
Expand All @@ -752,6 +788,9 @@ class DAGScheduler(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
stage.info.submissionTime = Some(clock.getTime())
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should post
// SparkListenerStageCompleted here in case there are no tasks to run.
listenerBus.post(SparkListenerStageCompleted(stage.info))
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
runningStages -= stage
Expand Down
Loading