diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 7d96962c4acd7..ce468d197e112 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -189,6 +189,76 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } + def getUpdatedStatus(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + var fetchedStatuses: Array[MapStatus] = null + fetching.synchronized { + if (fetching.contains(shuffleId)) { + // Someone else is fetching it; wait for them to be done + while (fetching.contains(shuffleId)) { + try { + fetching.wait() + } catch { + case e: InterruptedException => + } + } + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. + fetching += shuffleId + } + } + + if (fetchedStatuses == null) { + // We won the race to fetch the output locs; do so + logInfo("Doing the fetch; tracker actor = " + trackerActor) + // This try-finally prevents hangs due to timeouts: + try { + val fetchedBytes = + askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] + fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + logInfo("Got the output locations") + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + if (fetchedStatuses != null) { + var isMapFinished: Boolean = true + fetchedStatuses.synchronized { + val statuses: Array[(BlockManagerId, Long)] = fetchedStatuses.map { + status => + if (status == null) { + isMapFinished = false + (null) + } else { + (status.location, status.getSizeForBlock(reduceId)) + } + } + if (isMapFinished) { + mapStatuses.put(shuffleId, fetchedStatuses) + } + statuses + } + } else { + throw new MetadataFetchFailedException( + shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) + } + } else { + statuses.synchronized { + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) + } + } + } + /** Called to get current epoch number. */ def getEpoch: Long = { epochLock.synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 22449517d100f..d061439481ff2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -112,6 +112,10 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] + val removeStageBarrier = sc.getConf.getBoolean("spark.scheduler.removeStageBarrier", false) + // Track the pre-started stages depending on a stage (the key) + private val dependantStagePreStarted = new HashMap[Stage, ArrayBuffer[Stage]]() + private val dagSchedulerActorSupervisor = env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) @@ -901,6 +905,52 @@ class DAGScheduler( } } + // Select a waiting stage to pre-start + private def getPreStartableStage(stage: Stage): Option[Stage] = { + for (waitingStage <- waitingStages) { + val missingParents = getMissingParentStages(waitingStage) + if(missingParents.contains(stage)){ + for (parent <- missingParents) { + if(!(waitingStages.contains(parent) || failedStages.contains(parent) + || parent.pendingTasks.size > 0 || parent.rdd.getStorageLevel != StorageLevel.NONE)){ + return Some(waitingStage) + } + } + } + } + None + } + + private def maybePreStartWaitingStage(stage: Stage) { + if (removeStageBarrier && taskScheduler.isInstanceOf[TaskSchedulerImpl]) { + val backend = taskScheduler.asInstanceOf[TaskSchedulerImpl].backend + var numPendingTask:Int = 0 + runningStages.foreach { stage => + numPendingTask += stage.pendingTasks.size + } + val numWaitingStage = waitingStages.size + if (backend.freeSlotAvail(numPendingTask) && numWaitingStage > 0 && + stage.shuffleDep.isDefined) { + for (preStartStage <- getPreStartableStage(stage)) { + logInfo("Pre-start stage " + preStartStage.id) + // Register map output finished so far + mapOutputTracker.registerMapOutputs(stage.shuffleDep.get.shuffleId, + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + changeEpoch = false) + waitingStages -= preStartStage + runningStages += preStartStage + // Inform parent stages that the dependant stage has been pre-started + for (parentStage <- getMissingParentStages(preStartStage) + if runningStages.contains(parentStage)) { + dependantStagePreStarted.getOrElseUpdate( + parentStage, new ArrayBuffer[Stage]()) += preStartStage + } + submitMissingTasks(preStartStage, activeJobForStage(preStartStage).get) + } + } + } + } + /** * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. @@ -1002,6 +1052,13 @@ class DAGScheduler( } else { stage.addOutputLoc(smt.partitionId, status) } + // Need to register map outputs progressively if remove stage barrier is enabled + if (removeStageBarrier && dependantStagePreStarted.contains(stage) && + stage.shuffleDep.isDefined) { + mapOutputTracker.registerMapOutputs(stage.shuffleDep.get.shuffleId, + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + changeEpoch = false) + } if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) { markStageAsFinished(stage) logInfo("looking for newly runnable stages") @@ -1046,6 +1103,9 @@ class DAGScheduler( submitMissingTasks(stage, jobId) } } + dependantStagePreStarted -= stage + } else if(removeStageBarrier){ + maybePreStartWaitingStage(stage) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 992c477493d8e..69021957a8508 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -34,6 +34,8 @@ private[spark] trait SchedulerBackend { throw new UnsupportedOperationException def isReady(): Boolean = true + def freeSlotAvail(numPendingTask: Int): Boolean = false + /** * Get an application ID associated with the job. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 047fae104b485..18affb0451c53 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -348,6 +348,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste */ protected def doKillExecutors(executorIds: Seq[String]): Boolean = false + override def freeSlotAvail(numPendingTask: Int): Boolean = { + numPendingTask * scheduler.CPUS_PER_TASK < totalCoreCount.get() + } + } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index e3e7434df45b0..af9247b39c838 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -24,7 +24,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark._ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, PartialBlockFetcherIterator, ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator private[hash] object BlockStoreShuffleFetcher extends Logging { @@ -39,18 +39,39 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) - logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) - - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) + var statuses: Array[(BlockManagerId, Long)] = null + val blockFetcherItr = if (blockManager.conf.getBoolean("spark.scheduler.removeStageBarrier", + false)) { + statuses = SparkEnv.get.mapOutputTracker.getUpdatedStatus(shuffleId, reduceId) + logDebug("Fetching partial output for shuffle %d, reduce %d took %d ms".format( + shuffleId, reduceId, System.currentTimeMillis - startTime)) + new PartialBlockFetcherIterator( + context, + SparkEnv.get.blockManager.shuffleClient, + blockManager, + statuses, + serializer, + shuffleId, + reduceId) + }else{ + statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( + shuffleId, reduceId, System.currentTimeMillis - startTime)) + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] + for (((address, size), index) <- statuses.zipWithIndex) { + splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) + } + val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { + case (address, splits) => + (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) + } + new ShuffleBlockFetcherIterator( + context, + SparkEnv.get.blockManager.shuffleClient, + blockManager, + blocksByAddress, + serializer, + SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) } def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { @@ -73,13 +94,6 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.shuffleClient, - blockManager, - blocksByAddress, - serializer, - SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/storage/PartialBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/PartialBlockFetcherIterator.scala new file mode 100644 index 0000000000000..9ed3ea65dd412 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/PartialBlockFetcherIterator.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap} +import scala.util.Try + +import org.apache.spark.{SparkEnv, MapOutputTracker, Logging, TaskContext} +import org.apache.spark.network.shuffle. ShuffleClient +import org.apache.spark.serializer.Serializer + +private[spark] +class PartialBlockFetcherIterator( + context: TaskContext, + shuffleClient: ShuffleClient, + blockManager: BlockManager, + var statuses: Array[(BlockManagerId, Long)], + serializer: Serializer, + shuffleId: Int, + reduceId: Int) + extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + + private val mapOutputFetchInterval = + SparkEnv.get.conf.getInt("spark.reducer.mapOutput.fetchInterval", 1000) + + private var iterator:Iterator[(BlockId, Try[Iterator[Any]])] = null + + // Track the map outputs we've delegated + private val delegatedStatuses = new HashSet[Int]() + + private var fetchTime:Int = 1 + + initialize() + + // Get the updated map output + private def updateStatuses() { + fetchTime += 1 + logDebug("Still missing " + statuses.filter(_ == null).size + " map outputs for reduce " + + reduceId + " of shuffle " + shuffleId + " next fetchTime=" + fetchTime) + val update = SparkEnv.get.mapOutputTracker.getUpdatedStatus(shuffleId, reduceId) + statuses = update + } + + private def readyStatuses = (0 until statuses.size).filter(statuses(_) != null) + + // Check if there's new map outputs available + private def newStatusesReady = readyStatuses.exists(!delegatedStatuses.contains(_)) + + private def getIterator() = { + while (!newStatusesReady) { + Thread.sleep(mapOutputFetchInterval) + updateStatuses() + } + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] + for (index <- readyStatuses if !delegatedStatuses.contains(index)) { + splitsByAddress.getOrElseUpdate(statuses(index)._1, ArrayBuffer()) += + ((index, statuses(index)._2)) + delegatedStatuses += index + } + val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { + case (address, splits) => + (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) + } + logDebug("Delegating " + blocksByAddress.map(_._2.size).sum + + " blocks to a new iterator for reduce " + reduceId + " of shuffle " + shuffleId) + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + SparkEnv.get.blockManager.shuffleClient, + blockManager, + blocksByAddress, + serializer, + SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) + blockFetcherItr + } + + private[this] def initialize(){ + iterator = getIterator() + } + + override def hasNext: Boolean = { + // Firstly see if the delegated iterators have more blocks for us + if (iterator.hasNext) { + return true + } + // If we have blocks not delegated yet, try to delegate them to a new iterator + // and depend on the iterator to tell us if there are valid blocks. + while (delegatedStatuses.size < statuses.size) { + iterator = getIterator() + if (iterator.hasNext) { + return true + } + } + false + } + + override def next(): (BlockId, Try[Iterator[Any]]) = { + return iterator.next() + } +}