@@ -105,13 +105,15 @@ class DAGScheduler(
105105
106106 private val eventQueue = new LinkedBlockingQueue [DAGSchedulerEvent ]
107107
108- val nextJobId = new AtomicInteger (0 )
108+ private [scheduler] val nextJobId = new AtomicInteger (0 )
109109
110- val nextStageId = new AtomicInteger ( 0 )
110+ def numTotalJobs : Int = nextJobId.get( )
111111
112- val stageIdToStage = new TimeStampedHashMap [ Int , Stage ]
112+ private val nextStageId = new AtomicInteger ( 0 )
113113
114- val shuffleToMapStage = new TimeStampedHashMap [Int , Stage ]
114+ private val stageIdToStage = new TimeStampedHashMap [Int , Stage ]
115+
116+ private val shuffleToMapStage = new TimeStampedHashMap [Int , Stage ]
115117
116118 private [spark] val stageToInfos = new TimeStampedHashMap [Stage , StageInfo ]
117119
@@ -263,54 +265,50 @@ class DAGScheduler(
263265 }
264266
265267 /**
266- * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
267- * JobWaiter whose getResult() method will return the result of the job when it is complete.
268- *
269- * The job is assumed to have at least one partition; zero partition jobs should be handled
270- * without a JobSubmitted event.
268+ * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
269+ * can be used to block until the the job finishes executing or can be used to kill the job.
270+ * If the given RDD does not contain any partitions, the function returns None.
271271 */
272- private [scheduler] def prepareJob [T , U : ClassManifest ](
273- finalRdd : RDD [T ],
272+ def submitJob [T , U ](
273+ rdd : RDD [T ],
274274 func : (TaskContext , Iterator [T ]) => U ,
275275 partitions : Seq [Int ],
276276 callSite : String ,
277277 allowLocal : Boolean ,
278278 resultHandler : (Int , U ) => Unit ,
279- properties : Properties = null )
280- : (JobSubmitted , JobWaiter [U ]) =
279+ properties : Properties = null ): JobWaiter [U ] =
281280 {
281+ val jobId = nextJobId.getAndIncrement()
282+ if (partitions.size == 0 ) {
283+ return new JobWaiter [U ](this , jobId, 0 , resultHandler)
284+ }
285+
286+ // Check to make sure we are not launching a task on a partition that does not exist.
287+ val maxPartitions = rdd.partitions.length
288+ partitions.find(p => p >= maxPartitions).foreach { p =>
289+ throw new IllegalArgumentException (
290+ " Attempting to access a non-existent partition: " + p + " . " +
291+ " Total number of partitions: " + maxPartitions)
292+ }
293+
282294 assert(partitions.size > 0 )
283- val waiter = new JobWaiter (partitions.size, resultHandler)
284295 val func2 = func.asInstanceOf [(TaskContext , Iterator [_]) => _]
285- val toSubmit = JobSubmitted (finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
286- properties)
287- (toSubmit, waiter)
296+ val waiter = new JobWaiter (this , jobId, partitions.size, resultHandler)
297+ eventQueue.put(JobSubmitted (jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
298+ waiter, properties))
299+ waiter
288300 }
289301
290302 def runJob [T , U : ClassManifest ](
291- finalRdd : RDD [T ],
303+ rdd : RDD [T ],
292304 func : (TaskContext , Iterator [T ]) => U ,
293305 partitions : Seq [Int ],
294306 callSite : String ,
295307 allowLocal : Boolean ,
296308 resultHandler : (Int , U ) => Unit ,
297309 properties : Properties = null )
298310 {
299- if (partitions.size == 0 ) {
300- return
301- }
302-
303- // Check to make sure we are not launching a task on a partition that does not exist.
304- val maxPartitions = finalRdd.partitions.length
305- partitions.find(p => p >= maxPartitions).foreach { p =>
306- throw new IllegalArgumentException (
307- " Attempting to access a non-existent partition: " + p + " . " +
308- " Total number of partitions: " + maxPartitions)
309- }
310-
311- val (toSubmit : JobSubmitted , waiter : JobWaiter [_]) = prepareJob(
312- finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
313- eventQueue.put(toSubmit)
311+ val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
314312 waiter.awaitResult() match {
315313 case JobSucceeded => {}
316314 case JobFailed (exception : Exception , _) =>
@@ -331,45 +329,50 @@ class DAGScheduler(
331329 val listener = new ApproximateActionListener (rdd, func, evaluator, timeout)
332330 val func2 = func.asInstanceOf [(TaskContext , Iterator [_]) => _]
333331 val partitions = (0 until rdd.partitions.size).toArray
334- eventQueue.put(JobSubmitted (rdd, func2, partitions, allowLocal = false , callSite, listener, properties))
332+ val jobId = nextJobId.getAndIncrement()
333+ eventQueue.put(JobSubmitted (jobId, rdd, func2, partitions, allowLocal = false , callSite,
334+ listener, properties))
335335 listener.awaitResult() // Will throw an exception if the job fails
336336 }
337337
338+ /**
339+ * Kill a job that is running or waiting in the queue.
340+ */
338341 def killJob (jobId : Int ): Unit = this .synchronized {
339342 activeJobs.find(job => job.jobId == jobId).foreach(job => killJob(job))
340- }
341343
342- private def killJob (job : ActiveJob ): Unit = this .synchronized {
343- logInfo(" Killing Job and cleaning up stages %d" .format(job.jobId))
344- activeJobs.remove(job)
345- idToActiveJob.remove(job.jobId)
346- val stage = job.finalStage
347- resultStageToJob.remove(stage)
348- killStage(job, stage)
349- val e = new SparkException (" Job killed" )
350- job.listener.jobFailed(e)
351- listenerBus.post(SparkListenerJobEnd (job, JobFailed (e, None )))
352- }
353-
354- private def killStage (job : ActiveJob , stage : Stage ): Unit = this .synchronized {
355- // TODO: Can we reuse taskSetFailed?
356- logInfo(" Killing Stage %s" .format(stage.id))
357- stageIdToStage.remove(stage.id)
358- if (stage.isShuffleMap) {
359- shuffleToMapStage.remove(stage.id)
360- }
361- waiting.remove(stage)
362- pendingTasks.remove(stage)
363- taskSched.killTasks(stage.id)
364-
365- if (running.contains(stage)) {
366- running.remove(stage)
344+ def killJob (job : ActiveJob ): Unit = this .synchronized {
345+ logInfo(" Killing Job and cleaning up stages %d" .format(job.jobId))
346+ activeJobs.remove(job)
347+ idToActiveJob.remove(job.jobId)
348+ val stage = job.finalStage
349+ resultStageToJob.remove(stage)
350+ killStage(job, stage)
367351 val e = new SparkException (" Job killed" )
368- listenerBus.post(SparkListenerJobEnd (job, JobFailed (e, Some (stage))))
352+ job.listener.jobFailed(e)
353+ listenerBus.post(SparkListenerJobEnd (job, JobFailed (e, None )))
369354 }
370355
371- stage.parents.foreach(parentStage => killStage(job, parentStage))
372- // stageToInfos -= stage
356+ def killStage (job : ActiveJob , stage : Stage ): Unit = this .synchronized {
357+ // TODO: Can we reuse taskSetFailed?
358+ logInfo(" Killing Stage %s" .format(stage.id))
359+ stageIdToStage.remove(stage.id)
360+ if (stage.isShuffleMap) {
361+ shuffleToMapStage.remove(stage.id)
362+ }
363+ waiting.remove(stage)
364+ pendingTasks.remove(stage)
365+ taskSched.killTasks(stage.id)
366+
367+ if (running.contains(stage)) {
368+ running.remove(stage)
369+ val e = new SparkException (" Job killed" )
370+ listenerBus.post(SparkListenerJobEnd (job, JobFailed (e, Some (stage))))
371+ }
372+
373+ stage.parents.foreach(parentStage => killStage(job, parentStage))
374+ // stageToInfos -= stage
375+ }
373376 }
374377
375378 /**
@@ -378,9 +381,8 @@ class DAGScheduler(
378381 */
379382 private [scheduler] def processEvent (event : DAGSchedulerEvent ): Boolean = {
380383 event match {
381- case JobSubmitted (finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
382- val jobId = nextJobId.getAndIncrement()
383- val finalStage = newStage(finalRDD, None , jobId, Some (callSite))
384+ case JobSubmitted (jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
385+ val finalStage = newStage(rdd, None , jobId, Some (callSite))
384386 val job = new ActiveJob (jobId, finalStage, func, partitions, callSite, listener, properties)
385387 clearCacheLocs()
386388 logInfo(" Got job " + job.jobId + " (" + callSite + " ) with " + partitions.length +
0 commit comments