-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-2521] Broadcast RDD object (instead of sending it along with every task) #1498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
cae0af3
d256b45
cc152fc
de779f8
991c002
cf38450
bab1d8b
797c247
111007d
252238d
f8535dc
f7364db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.scheduler | ||
|
|
||
| import java.io.{NotSerializableException, PrintWriter, StringWriter} | ||
| import java.io.NotSerializableException | ||
| import java.util.Properties | ||
| import java.util.concurrent.atomic.AtomicInteger | ||
|
|
||
|
|
@@ -35,6 +35,7 @@ import akka.pattern.ask | |
| import akka.util.Timeout | ||
|
|
||
| import org.apache.spark._ | ||
| import org.apache.spark.broadcast.Broadcast | ||
| import org.apache.spark.executor.TaskMetrics | ||
| import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} | ||
| import org.apache.spark.rdd.RDD | ||
|
|
@@ -361,9 +362,6 @@ class DAGScheduler( | |
| // data structures based on StageId | ||
| stageIdToStage -= stageId | ||
|
|
||
| ShuffleMapTask.removeStage(stageId) | ||
| ResultTask.removeStage(stageId) | ||
|
|
||
| logDebug("After removal of stage %d, remaining stages = %d" | ||
| .format(stageId, stageIdToStage.size)) | ||
| } | ||
|
|
@@ -691,47 +689,81 @@ class DAGScheduler( | |
| } | ||
| } | ||
|
|
||
|
|
||
| /** Called when stage's parents are available and we can now do its task. */ | ||
| private def submitMissingTasks(stage: Stage, jobId: Int) { | ||
| logDebug("submitMissingTasks(" + stage + ")") | ||
| // Get our pending tasks and remember them in our pendingTasks entry | ||
| stage.pendingTasks.clear() | ||
| var tasks = ArrayBuffer[Task[_]]() | ||
|
|
||
| val properties = if (jobIdToActiveJob.contains(jobId)) { | ||
| jobIdToActiveJob(stage.jobId).properties | ||
| } else { | ||
| // this stage will be assigned to "default" pool | ||
| null | ||
| } | ||
|
|
||
| runningStages += stage | ||
| // SparkListenerStageSubmitted should be posted before testing whether tasks are | ||
| // serializable. If tasks are not serializable, a SparkListenerStageCompleted event | ||
| // will be posted, which should always come after a corresponding SparkListenerStageSubmitted | ||
| // event. | ||
| listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) | ||
|
|
||
| // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. | ||
| // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast | ||
| // the serialized copy of the RDD and for each task we will deserialize it, which means each | ||
| // task gets a different copy of the RDD. This provides stronger isolation between tasks that | ||
| // might modify state of objects referenced in their closures. This is necessary in Hadoop | ||
| // where the JobConf/Configuration object is not thread-safe. | ||
| var taskBinary: Broadcast[Array[Byte]] = null | ||
| try { | ||
| // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). | ||
| // For ResultTask, serialize and broadcast (rdd, func). | ||
| val taskBinaryBytes: Array[Byte] = | ||
| if (stage.isShuffleMap) { | ||
| Utils.serializeTaskClosure((stage.rdd, stage.shuffleDep.get) : AnyRef) | ||
| } else { | ||
| Utils.serializeTaskClosure((stage.rdd, stage.resultOfJob.get.func) : AnyRef) | ||
| } | ||
| taskBinary = sc.broadcast(taskBinaryBytes) | ||
| } catch { | ||
| // In the case of a failure during serialization, abort the stage. | ||
| case e: NotSerializableException => | ||
| abortStage(stage, "Task not serializable: " + e.toString) | ||
| runningStages -= stage | ||
| return | ||
| case NonFatal(e) => | ||
| abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") | ||
| runningStages -= stage | ||
| return | ||
| } | ||
|
|
||
| if (stage.isShuffleMap) { | ||
| for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { | ||
| val locs = getPreferredLocs(stage.rdd, p) | ||
| tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) | ||
| val part = stage.rdd.partitions(p) | ||
| tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs) | ||
| } | ||
| } else { | ||
| // This is a final stage; figure out its job's missing partitions | ||
| val job = stage.resultOfJob.get | ||
| for (id <- 0 until job.numPartitions if !job.finished(id)) { | ||
| val partition = job.partitions(id) | ||
| val locs = getPreferredLocs(stage.rdd, partition) | ||
| tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) | ||
| val p: Int = job.partitions(id) | ||
| val part = stage.rdd.partitions(p) | ||
| val locs = getPreferredLocs(stage.rdd, p) | ||
| tasks += new ResultTask(stage.id, taskBinary, part, locs, id) | ||
| } | ||
| } | ||
|
|
||
| val properties = if (jobIdToActiveJob.contains(jobId)) { | ||
| jobIdToActiveJob(stage.jobId).properties | ||
| } else { | ||
| // this stage will be assigned to "default" pool | ||
| null | ||
| } | ||
|
|
||
| if (tasks.size > 0) { | ||
| runningStages += stage | ||
| // SparkListenerStageSubmitted should be posted before testing whether tasks are | ||
| // serializable. If tasks are not serializable, a SparkListenerStageCompleted event | ||
| // will be posted, which should always come after a corresponding SparkListenerStageSubmitted | ||
| // event. | ||
| listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) | ||
|
|
||
| // Preemptively serialize a task to make sure it can be serialized. We are catching this | ||
| // exception here because it would be fairly hard to catch the non-serializable exception | ||
| // down the road, where we have several different implementations for local scheduler and | ||
| // cluster schedulers. | ||
| // | ||
| // We've already serialized RDDs and closures in taskBinary, but here we check for all other | ||
| // objects such as Partition. | ||
| try { | ||
| SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use your Utils function here too
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That one also does compression. But honestly given they are both one-liner, I'm not sure if it is worth it ...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually remove the compression part |
||
| } catch { | ||
|
|
@@ -752,6 +784,9 @@ class DAGScheduler( | |
| new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) | ||
| stage.info.submissionTime = Some(clock.getTime()) | ||
| } else { | ||
| // Because we posted SparkListenerStageSubmitted earlier, we should post | ||
| // SparkListenerStageCompleted here in case there are no tasks to run. | ||
| listenerBus.post(SparkListenerStageCompleted(stage.info)) | ||
| logDebug("Stage " + stage + " is actually done; %b %d %d".format( | ||
| stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) | ||
| runningStages -= stage | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just FYI - this is an API breaking change... probably not a huge deal, but FYI