From 9e12f59d4e184094d655981342b3610cf0cf616e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 May 2017 19:29:59 -0700 Subject: [PATCH 01/14] WIP towards changing map output tracker internals --- .../org/apache/spark/MapOutputTracker.scala | 270 +++++++----------- .../apache/spark/scheduler/DAGScheduler.scala | 14 +- 2 files changed, 105 insertions(+), 179 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4ef665622245..d92b140e5e48 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -34,6 +34,53 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ +private class ShuffleStatus(mapStatuses: Array[MapStatus]) { + + private[this] var cachedSerializedMapStatus: Array[Byte] = _ + private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + + def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { + mapStatuses(mapId) = status + removeBroadcast() + } + + def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { + mapStatuses(mapId) = null + } + removeBroadcast() + } + + def serializedMapStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int): Array[Byte] = synchronized { + if (cachedSerializedMapStatus eq null) { + val serResult = MapOutputTracker.serializeMapStatuses( + mapStatuses, broadcastManager, isLocal, minBroadcastSize) + cachedSerializedMapStatus = serResult._1 + cachedSerializedBroadcast = serResult._2 + } + cachedSerializedMapStatus + } + + def withStatuses[T](f: Array[MapStatus] => T): T = synchronized { + f(mapStatuses) + } + + def removeBroadcast(): Unit = synchronized { + if (cachedSerializedBroadcast != null) { + cachedSerializedBroadcast.destroy() + cachedSerializedBroadcast = null + } + cachedSerializedMapStatus = null + } +} + +private object ShuffleStatus { + def empty(numMaps: Int): ShuffleStatus = new ShuffleStatus(new Array[MapStatus](numMaps)) +} + private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage @@ -78,10 +125,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the * driver's corresponding HashMap. * - * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a + * Note: because shuffleStatuses is accessed concurrently, subclasses should make sure it's a * thread-safe map. */ - protected val mapStatuses: Map[Int, Array[MapStatus]] + protected val shuffleStatuses: Map[Int, ShuffleStatus] /** * Incremented every time a fetch fails so that client nodes know to clear @@ -141,42 +188,20 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - val statuses = getStatuses(shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { + getShuffleStatus(shuffleId).withStatuses { statuses => return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } } /** - * Return statistics about all of the outputs for a given shuffle. - */ - def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { - val statuses = getStatuses(dep.shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { - for (i <- 0 until totalSizes.length) { - totalSizes(i) += s.getSizeForBlock(i) - } - } - new MapOutputStatistics(dep.shuffleId, totalSizes) - } - } - - /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize - * on this array when reading it, because on the driver, we may be changing it in place. - * - * (It would be nice to remove this restriction in the future.) + * Get or fetch ShuffleStatus for a given shuffle ID. */ - private def getStatuses(shuffleId: Int): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { + private def getShuffleStatus(shuffleId: Int): ShuffleStatus = { + val shuffleStatus = shuffleStatuses.get(shuffleId).orNull + if (shuffleStatus == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") val startTime = System.currentTimeMillis - var fetchedStatuses: Array[MapStatus] = null + var fetchedStatus: ShuffleStatus = null fetching.synchronized { // Someone else is fetching it; wait for them to be done while (fetching.contains(shuffleId)) { @@ -189,22 +214,22 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // Either while we waited the fetch happened successfully, or // someone fetched it in between the get and the fetching.synchronized. - fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { + fetchedStatus = shuffleStatuses.get(shuffleId).orNull + if (fetchedStatus == null) { // We have to do the fetch, get others to wait for us. fetching += shuffleId } } - if (fetchedStatuses == null) { + if (fetchedStatus == null) { // We won the race to fetch the statuses; do so logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts: try { val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + fetchedStatus = new ShuffleStatus(MapOutputTracker.deserializeMapStatuses(fetchedBytes)) logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) + shuffleStatuses.put(shuffleId, fetchedStatus) } finally { fetching.synchronized { fetching -= shuffleId @@ -215,15 +240,15 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + s"${System.currentTimeMillis - startTime} ms") - if (fetchedStatuses != null) { - return fetchedStatuses + if (fetchedStatus != null) { + fetchedStatus } else { logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) } } else { - return statuses + shuffleStatus } } @@ -244,14 +269,14 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging if (newEpoch > epoch) { logInfo("Updating epoch to " + newEpoch + " and clearing cache") epoch = newEpoch - mapStatuses.clear() + shuffleStatuses.clear() } } } /** Unregister shuffle data. */ def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) + shuffleStatuses.remove(shuffleId) } /** Stop the tracker. */ @@ -265,9 +290,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, broadcastManager: BroadcastManager, isLocal: Boolean) extends MapOutputTracker(conf) { - /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ - private var cacheEpoch = epoch - // The size at which we use Broadcast to send the map output statuses to the executors private val minSizeForBroadcast = conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt @@ -287,22 +309,12 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, // can be read locally, but may lead to more delay in scheduling if those locations are busy. private val REDUCER_PREF_LOCS_FRACTION = 0.2 - // HashMaps for storing mapStatuses and cached serialized statuses in the driver. + // HashMap for storing shuffleStatuses in the driver. // Statuses are dropped only by explicit de-registering. - protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala - private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala + protected val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - // Kept in sync with cachedSerializedStatuses explicitly - // This is required so that the Broadcast variable remains in scope until we remove - // the shuffleId explicitly or implicitly. - private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]() - - // This is to prevent multiple serializations of the same shuffle - which happens when - // there is a request storm when shuffle start. - private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]() - // requests for map output statuses private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] @@ -348,8 +360,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, val hostPort = context.senderAddress.hostPort logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) - context.reply(mapOutputStatuses) + val shuffleStatus = shuffleStatuses.get(shuffleId).head + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -364,26 +377,21 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, private val PoisonPill = new GetMapOutputMessage(-99, null) // Exposed for testing - private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size + private[spark] def getNumCachedSerializedBroadcast = 0 // TODO(josh) def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (shuffleStatuses.put(shuffleId, ShuffleStatus.empty(numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - // add in advance - shuffleIdLocks.putIfAbsent(shuffleId, new Object()) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - val array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } + shuffleStatuses(shuffleId).addMapOutput(mapId, status) } /** Register multiple map output information for the given shuffle */ def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { - mapStatuses.put(shuffleId, statuses.clone()) + shuffleStatuses.put(shuffleId, new ShuffleStatus(statuses.clone())) if (changeEpoch) { incrementEpoch() } @@ -391,31 +399,38 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - val arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - val array = arrayOpt.get - array.synchronized { - if (array(mapId) != null && array(mapId).location == bmAddress) { - array(mapId) = null - } - } - incrementEpoch() - } else { - throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeMapOutput(mapId, bmAddress) + incrementEpoch() + case None => + throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") } } /** Unregister shuffle data */ override def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - cachedSerializedStatuses.remove(shuffleId) - cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v)) - shuffleIdLocks.remove(shuffleId) + shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => + shuffleStatus.removeBroadcast() + } } /** Check if the given shuffle is being tracked */ - def containsShuffle(shuffleId: Int): Boolean = { - cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) + def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) + + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + shuffleStatuses(dep.shuffleId).withStatuses { statuses => + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } } /** @@ -459,9 +474,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, fractionThreshold: Double) : Option[Array[BlockManagerId]] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses != null) { - statuses.synchronized { + val shuffleStatus = shuffleStatuses.get(shuffleId).orNull + if (shuffleStatus != null) { + shuffleStatus.withStatuses { statuses => if (statuses.nonEmpty) { // HashMap to add up sizes of all blocks at the same location val locs = new HashMap[BlockManagerId, Long] @@ -502,89 +517,12 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, } } - private def removeBroadcast(bcast: Broadcast[_]): Unit = { - if (null != bcast) { - broadcastManager.unbroadcast(bcast.id, - removeFromDriver = true, blocking = false) - } - } - - private def clearCachedBroadcast(): Unit = { - for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2) - cachedSerializedBroadcast.clear() - } - - def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { - var statuses: Array[MapStatus] = null - var retBytes: Array[Byte] = null - var epochGotten: Long = -1 - - // Check to see if we have a cached version, returns true if it does - // and has side effect of setting retBytes. If not returns false - // with side effect of setting statuses - def checkCachedStatuses(): Boolean = { - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - clearCachedBroadcast() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - retBytes = bytes - true - case None => - logDebug("cached status not found for : " + shuffleId) - statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus]) - epochGotten = epoch - false - } - } - } - - if (checkCachedStatuses()) return retBytes - var shuffleIdLock = shuffleIdLocks.get(shuffleId) - if (null == shuffleIdLock) { - val newLock = new Object() - // in general, this condition should be false - but good to be paranoid - val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) - shuffleIdLock = if (null != prevLock) prevLock else newLock - } - // synchronize so we only serialize/broadcast it once since multiple threads call - // in parallel - shuffleIdLock.synchronized { - // double check to make sure someone else didn't serialize and cache the same - // mapstatus while we were waiting on the synchronize - if (checkCachedStatuses()) return retBytes - - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "statuses"; let's serialize and return that - val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager, - isLocal, minSizeForBroadcast) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes - if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast - } else { - logInfo("Epoch changed, not caching!") - removeBroadcast(bcast) - } - } - bytes - } - } - override def stop() { mapOutputRequests.offer(PoisonPill) threadpool.shutdown() sendTracker(StopMapOutputTracker) - mapStatuses.clear() + shuffleStatuses.clear() trackerEndpoint = null - cachedSerializedStatuses.clear() - clearCachedBroadcast() - shuffleIdLocks.clear() } } @@ -593,8 +531,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, * MapOutputTrackerMaster. */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { - protected val mapStatuses: Map[Int, Array[MapStatus]] = - new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + protected val shuffleStatuses: Map[Int, ShuffleStatus] = + new ConcurrentHashMap[Int, ShuffleStatus]().asScala } private[spark] object MapOutputTracker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 68178c7fb3bb..ac2e378bc5c3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -334,19 +334,7 @@ class DAGScheduler( shuffleIdToMapStage(shuffleDep.shuffleId) = stage updateJobIdStageIdMaps(jobId, stage) - if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { - // A previously run stage generated partitions for this shuffle, so for each output - // that's still available, copy information about that output location to the new stage - // (so we don't unnecessarily re-compute that data). - val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) - val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - (0 until locs.length).foreach { i => - if (locs(i) ne null) { - // locs(i) will be null if missing - stage.addOutputLoc(i, locs(i)) - } - } - } else { + if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") From 52be832d0847443ba7ebed1bd2f47ed4bf545678 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 May 2017 23:05:20 -0700 Subject: [PATCH 02/14] Separate driver and executor impls. --- .../org/apache/spark/MapOutputTracker.scala | 370 +++++++++++------- .../org/apache/spark/executor/Executor.scala | 6 +- .../apache/spark/scheduler/DAGScheduler.scala | 23 +- .../spark/scheduler/ShuffleMapStage.scala | 72 +--- .../spark/scheduler/TaskSchedulerImpl.scala | 2 +- .../scala/org/apache/spark/ShuffleSuite.scala | 2 +- .../scheduler/BlacklistTrackerSuite.scala | 3 +- 7 files changed, 238 insertions(+), 240 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index d92b140e5e48..e1296cfe00c7 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -34,23 +34,84 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ -private class ShuffleStatus(mapStatuses: Array[MapStatus]) { +private class ShuffleStatus(numPartitions: Int) { + /** + * List of [[MapStatus]] for each partition. The index of the array is the map partition id, + * and each value in the array is the list of possible [[MapStatus]] for a partition + * (a single task might run multiple times). + */ + private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) private[this] var cachedSerializedMapStatus: Array[Byte] = _ private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + private[this] var _numAvailableOutputs: Int = 0 def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { - mapStatuses(mapId) = status + val prevList = outputLocs(mapId) + outputLocs(mapId) = status :: prevList + if (prevList == Nil) { + _numAvailableOutputs += 1 + } removeBroadcast() } def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { - if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { - mapStatuses(mapId) = null + val prevList = outputLocs(mapId) + val newList = prevList.filterNot(_.location == bmAddress) + outputLocs(mapId) = newList + if (prevList != Nil && newList == Nil) { + _numAvailableOutputs -= 1 + } + removeBroadcast() + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = synchronized { + var becameUnavailable = false + for (partition <- 0 until outputLocs.length) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.location.executorId == execId) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + becameUnavailable = true + _numAvailableOutputs -= 1 + } } removeBroadcast() +// if (becameUnavailable) { +// logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( +// this, execId, _numAvailableOutputs, numPartitions, isAvailable)) +// } + } + + /** + * Number of partitions that have shuffle outputs. + * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. + */ + def numAvailableOutputs: Int = synchronized { + _numAvailableOutputs } + /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ + def findMissingPartitions(): Seq[Int] = synchronized { + val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) + val _numAvailableOutputs = numAvailableOutputs + assert(missing.size == numPartitions - _numAvailableOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") + missing + } + + /** + * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned + * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition, + * that position is filled with null. + */ + private def mapStatuses: Array[MapStatus] = outputLocs.map(_.headOption.orNull) + def serializedMapStatus( broadcastManager: BroadcastManager, isLocal: Boolean, @@ -64,6 +125,7 @@ private class ShuffleStatus(mapStatuses: Array[MapStatus]) { cachedSerializedMapStatus } + // TODO(josh): we can reduce the number of places this is called in MapOutputTrackerMaster def withStatuses[T](f: Array[MapStatus] => T): T = synchronized { f(mapStatuses) } @@ -77,10 +139,6 @@ private class ShuffleStatus(mapStatuses: Array[MapStatus]) { } } -private object ShuffleStatus { - def empty(numMaps: Int): ShuffleStatus = new ShuffleStatus(new Array[MapStatus](numMaps)) -} - private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage @@ -109,27 +167,16 @@ private[spark] class MapOutputTrackerMasterEndpoint( } /** - * Class that keeps track of the location of the map output of - * a stage. This is abstract because different versions of MapOutputTracker - * (driver and executor) use different HashMap to store its metadata. - */ + * Class that keeps track of the location of the map output of a stage. This is abstract because the + * driver and executor have different versions of the MapOutputTracker. In principle the driver- + * and executor-side classes don't need to share a common base class; the current shared base class + * is maintained primarily for backwards-compatibility in order to avoid having to update existing + * test code. +*/ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ var trackerEndpoint: RpcEndpointRef = _ - /** - * This HashMap has different behavior for the driver and the executors. - * - * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks. - * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the - * driver's corresponding HashMap. - * - * Note: because shuffleStatuses is accessed concurrently, subclasses should make sure it's a - * thread-safe map. - */ - protected val shuffleStatuses: Map[Int, ShuffleStatus] - /** * Incremented every time a fetch fails so that client nodes know to clear * their cache of map output locations if this happens. @@ -137,9 +184,6 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging protected var epoch: Long = 0 protected val epochLock = new AnyRef - /** Remembers which map output locations are currently being fetched on an executor. */ - private val fetching = new HashSet[Int] - /** * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. @@ -163,16 +207,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** - * Called from executors to get the server URIs and output sizes for each shuffle block that - * needs to be read from a given reduce task. - * - * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block id, shuffle block size) tuples - * describing the shuffle blocks that are stored at that block manager. - */ + // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } @@ -186,108 +223,20 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - getShuffleStatus(shuffleId).withStatuses { statuses => - return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) - } - } - - /** - * Get or fetch ShuffleStatus for a given shuffle ID. - */ - private def getShuffleStatus(shuffleId: Int): ShuffleStatus = { - val shuffleStatus = shuffleStatuses.get(shuffleId).orNull - if (shuffleStatus == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - val startTime = System.currentTimeMillis - var fetchedStatus: ShuffleStatus = null - fetching.synchronized { - // Someone else is fetching it; wait for them to be done - while (fetching.contains(shuffleId)) { - try { - fetching.wait() - } catch { - case e: InterruptedException => - } - } - - // Either while we waited the fetch happened successfully, or - // someone fetched it in between the get and the fetching.synchronized. - fetchedStatus = shuffleStatuses.get(shuffleId).orNull - if (fetchedStatus == null) { - // We have to do the fetch, get others to wait for us. - fetching += shuffleId - } - } - - if (fetchedStatus == null) { - // We won the race to fetch the statuses; do so - logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) - // This try-finally prevents hangs due to timeouts: - try { - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - fetchedStatus = new ShuffleStatus(MapOutputTracker.deserializeMapStatuses(fetchedBytes)) - logInfo("Got the output locations") - shuffleStatuses.put(shuffleId, fetchedStatus) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() - } - } - } - logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + - s"${System.currentTimeMillis - startTime} ms") - - if (fetchedStatus != null) { - fetchedStatus - } else { - logError("Missing all output locations for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) - } - } else { - shuffleStatus - } - } - - /** Called to get current epoch number. */ - def getEpoch: Long = { - epochLock.synchronized { - return epoch - } - } - - /** - * Called from executors to update the epoch number, potentially clearing old outputs - * because of a fetch failure. Each executor task calls this with the latest epoch - * number on the driver at the time it was created. - */ - def updateEpoch(newEpoch: Long) { - epochLock.synchronized { - if (newEpoch > epoch) { - logInfo("Updating epoch to " + newEpoch + " and clearing cache") - epoch = newEpoch - shuffleStatuses.clear() - } - } - } + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] - /** Unregister shuffle data. */ - def unregisterShuffle(shuffleId: Int) { - shuffleStatuses.remove(shuffleId) - } + def unregisterShuffle(shuffleId: Int): Unit - /** Stop the tracker. */ - def stop() { } + def stop() {} } /** - * MapOutputTracker for the driver. + * Driver-side class that keeps track of the location of the map output of a stage. */ -private[spark] class MapOutputTrackerMaster(conf: SparkConf, - broadcastManager: BroadcastManager, isLocal: Boolean) +private[spark] class MapOutputTrackerMaster( + conf: SparkConf, + broadcastManager: BroadcastManager, + isLocal: Boolean) extends MapOutputTracker(conf) { // The size at which we use Broadcast to send the map output statuses to the executors @@ -311,7 +260,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, // HashMap for storing shuffleStatuses in the driver. // Statuses are dropped only by explicit de-registering. - protected val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala + private val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) @@ -380,7 +329,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, private[spark] def getNumCachedSerializedBroadcast = 0 // TODO(josh) def registerShuffle(shuffleId: Int, numMaps: Int) { - if (shuffleStatuses.put(shuffleId, ShuffleStatus.empty(numMaps)).isDefined) { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } } @@ -389,14 +338,6 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, shuffleStatuses(shuffleId).addMapOutput(mapId, status) } - /** Register multiple map output information for the given shuffle */ - def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { - shuffleStatuses.put(shuffleId, new ShuffleStatus(statuses.clone())) - if (changeEpoch) { - incrementEpoch() - } - } - /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { shuffleStatuses.get(shuffleId) match { @@ -409,15 +350,34 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, } /** Unregister shuffle data */ - override def unregisterShuffle(shuffleId: Int) { + def unregisterShuffle(shuffleId: Int) { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => shuffleStatus.removeBroadcast() } } + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) } + incrementEpoch() + } + /** Check if the given shuffle is being tracked */ def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) + def getNumAvailableOutputs(shuffleId: Int): Int = { + shuffleStatuses(shuffleId).numAvailableOutputs + } + + /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ + def findMissingPartitions(shuffleId: Int): Seq[Int] = { + shuffleStatuses(shuffleId).findMissingPartitions() + } + /** * Return statistics about all of the outputs for a given shuffle. */ @@ -517,22 +477,130 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, } } - override def stop() { + /** Called to get current epoch number. */ + def getEpoch: Long = { + epochLock.synchronized { + return epoch + } + } + + def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + shuffleStatuses(shuffleId).withStatuses { statuses => + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } + } + + override def stop() { mapOutputRequests.offer(PoisonPill) threadpool.shutdown() sendTracker(StopMapOutputTracker) - shuffleStatuses.clear() trackerEndpoint = null + shuffleStatuses.clear() } } /** - * MapOutputTracker for the executors, which fetches map output information from the driver's - * MapOutputTrackerMaster. + * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster. */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { - protected val shuffleStatuses: Map[Int, ShuffleStatus] = - new ConcurrentHashMap[Int, ShuffleStatus]().asScala + + val mapStatuses: Map[Int, Array[MapStatus]] = + new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + + /** Remembers which map output locations are currently being fetched on an executor. */ + private val fetching = new HashSet[Int] + + override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + val statuses = getStatuses(shuffleId) + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis + var fetchedStatuses: Array[MapStatus] = null + fetching.synchronized { + // Someone else is fetching it; wait for them to be done + while (fetching.contains(shuffleId)) { + try { + fetching.wait() + } catch { + case e: InterruptedException => + } + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. + fetching += shuffleId + } + } + + if (fetchedStatuses == null) { + // We won the race to fetch the statuses; do so + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + // This try-finally prevents hangs due to timeouts: + try { + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${System.currentTimeMillis - startTime} ms") + + if (fetchedStatuses != null) { + fetchedStatuses + } else { + logError("Missing all output locations for shuffle " + shuffleId) + throw new MetadataFetchFailedException( + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) + } + } else { + statuses + } + } + + + /** Unregister shuffle data. */ + def unregisterShuffle(shuffleId: Int): Unit = { + mapStatuses.remove(shuffleId) + } + + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each executor task calls this with the latest epoch + * number on the driver at the time it was created. + */ + def updateEpoch(newEpoch: Long): Unit = { + epochLock.synchronized { + if (newEpoch > epoch) { + logInfo("Updating epoch to " + newEpoch + " and clearing cache") + epoch = newEpoch + mapStatuses.clear() + } + } + } } private[spark] object MapOutputTracker extends Logging { @@ -621,7 +689,7 @@ private[spark] object MapOutputTracker extends Logging { * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - private def convertMapStatuses( + def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 3bc47b670305..08f6f10b6363 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -322,8 +322,10 @@ private[spark] class Executor( throw new TaskKilledException(killReason.get) } - logDebug("Task " + taskId + "'s epoch is " + task.epoch) - env.mapOutputTracker.updateEpoch(task.epoch) + if (!isLocal) { + logDebug("Task " + taskId + "'s epoch is " + task.epoch) + env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch) + } // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ac2e378bc5c3..4eb28e2be9e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1204,7 +1204,8 @@ class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. - shuffleStage.addOutputLoc(smt.partitionId, status) + mapOutputTracker.registerMapOutput( + shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) // Remove the task's partition from pending partitions. This may have already been // done above, but will not have been done yet in cases where the task attempt was // from an earlier attempt of the stage (i.e., not the attempt that's currently @@ -1221,16 +1222,13 @@ class DAGScheduler( logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - // We supply true to increment the epoch number here in case this is a + // Increment the epoch number here in case this is a // recomputation of the map outputs. In that case, some nodes may have cached // locations with holes (from when we detected the error) and will need the // epoch incremented to refetch them. // TODO: Only increment the epoch number if this is not the first time // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) + mapOutputTracker.incrementEpoch() clearCacheLocs() @@ -1330,7 +1328,6 @@ class DAGScheduler( } // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) } @@ -1380,17 +1377,7 @@ class DAGScheduler( if (filesLost || !env.blockManager.externalShuffleServiceEnabled) { logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleIdToMapStage) { - stage.removeOutputsOnExecutor(execId) - mapOutputTracker.registerMapOutputs( - shuffleId, - stage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) - } - if (shuffleIdToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() - } + mapOutputTracker.removeOutputsOnExecutor(execId) clearCacheLocs() } } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index db4d9efa2270..e1000352f122 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -19,9 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashSet -import org.apache.spark.ShuffleDependency +import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -47,8 +46,6 @@ private[spark] class ShuffleMapStage( private[this] var _mapStageJobs: List[ActiveJob] = Nil - private[this] var _numAvailableOutputs: Int = 0 - /** * Partitions that either haven't yet been computed, or that were computed on an executor * that has since been lost, so should be re-computed. This variable is used by the @@ -60,13 +57,6 @@ private[spark] class ShuffleMapStage( */ val pendingPartitions = new HashSet[Int] - /** - * List of [[MapStatus]] for each partition. The index of the array is the map partition id, - * and each value in the array is the list of possible [[MapStatus]] for a partition - * (a single task might run multiple times). - */ - private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - override def toString: String = "ShuffleMapStage " + id /** @@ -85,72 +75,22 @@ private[spark] class ShuffleMapStage( _mapStageJobs = _mapStageJobs.filter(_ != job) } + private def mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + /** * Number of partitions that have shuffle outputs. * When this reaches [[numPartitions]], this map stage is ready. - * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. */ - def numAvailableOutputs: Int = _numAvailableOutputs + def numAvailableOutputs: Int = mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) /** * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. * This should be the same as `outputLocs.contains(Nil)`. */ - def isAvailable: Boolean = _numAvailableOutputs == numPartitions + def isAvailable: Boolean = numAvailableOutputs == numPartitions /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ override def findMissingPartitions(): Seq[Int] = { - val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) - assert(missing.size == numPartitions - _numAvailableOutputs, - s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") - missing - } - - def addOutputLoc(partition: Int, status: MapStatus): Unit = { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - _numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - _numAvailableOutputs -= 1 - } - } - - /** - * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned - * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition, - * that position is filled with null. - */ - def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = { - outputLocs.map(_.headOption.orNull) - } - - /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. - */ - def removeOutputsOnExecutor(execId: String): Unit = { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - _numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, _numAvailableOutputs, numPartitions, isAvailable)) - } + mapOutputTracker.findMissingPartitions(shuffleDep.shuffleId) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1b6bc9139f9c..ae0e6a099186 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -129,7 +129,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( var backend: SchedulerBackend = null - val mapOutputTracker = SparkEnv.get.mapOutputTracker + val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 58b865969f51..192d04884507 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -367,7 +367,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => - mapTrackerMaster.registerMapOutputs(0, Array(mapStatus)) + mapTrackerMaster.registerMapOutput(0, 0, mapStatus) } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 2b18ebee79a2..571c6bbb4585 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M sc = new SparkContext(conf) val scheduler = mock[TaskSchedulerImpl] when(scheduler.sc).thenReturn(sc) - when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker) + when(scheduler.mapOutputTracker).thenReturn( + SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]) scheduler } From 1aa14f895969d252705af15a47cf8bd243da131d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 May 2017 23:11:58 -0700 Subject: [PATCH 03/14] Avoid unnecessary materialization of mapoutputtracker format locs. --- .../org/apache/spark/MapOutputTracker.scala | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index e1296cfe00c7..44d925c844f0 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -125,9 +125,8 @@ private class ShuffleStatus(numPartitions: Int) { cachedSerializedMapStatus } - // TODO(josh): we can reduce the number of places this is called in MapOutputTrackerMaster - def withStatuses[T](f: Array[MapStatus] => T): T = synchronized { - f(mapStatuses) + def withOutputLocs[T](f: Array[List[MapStatus]] => T): T = synchronized { + f(outputLocs) } def removeBroadcast(): Unit = synchronized { @@ -382,12 +381,14 @@ private[spark] class MapOutputTrackerMaster( * Return statistics about all of the outputs for a given shuffle. */ def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { - shuffleStatuses(dep.shuffleId).withStatuses { statuses => + shuffleStatuses(dep.shuffleId).withOutputLocs { outputLocs => val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { - for (i <- 0 until totalSizes.length) { - totalSizes(i) += s.getSizeForBlock(i) - } + for ( + mapOutputs <- outputLocs; + s <- mapOutputs.headOption; + i <- 0 until totalSizes.length + ) { + totalSizes(i) += s.getSizeForBlock(i) } new MapOutputStatistics(dep.shuffleId, totalSizes) } @@ -436,14 +437,14 @@ private[spark] class MapOutputTrackerMaster( val shuffleStatus = shuffleStatuses.get(shuffleId).orNull if (shuffleStatus != null) { - shuffleStatus.withStatuses { statuses => + shuffleStatus.withOutputLocs { statuses => if (statuses.nonEmpty) { // HashMap to add up sizes of all blocks at the same location val locs = new HashMap[BlockManagerId, Long] var totalOutputSize = 0L var mapIdx = 0 while (mapIdx < statuses.length) { - val status = statuses(mapIdx) + val status = statuses(mapIdx).headOption.orNull // status may be null here if we are called between registerShuffle, which creates an // array with null entries for each output, and registerMapOutputs, which populates it // with valid status entries. This is possible if one thread schedules a job which @@ -487,7 +488,8 @@ private[spark] class MapOutputTrackerMaster( def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - shuffleStatuses(shuffleId).withStatuses { statuses => + shuffleStatuses(shuffleId).withOutputLocs { outputLocs => + val statuses = outputLocs.map(_.headOption.orNull) MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } } From a4298da66da765a9987bc1d0d794e911c4a45e1f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 May 2017 23:13:20 -0700 Subject: [PATCH 04/14] Fix formatting problem. --- core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 44d925c844f0..5f3d2b8ab97b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -208,7 +208,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) } From 8ed62bdbca293384b55c8eee67ada5fbb9d83481 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 May 2017 23:22:45 -0700 Subject: [PATCH 05/14] Implement getNumCachedSerializedBroadcast --- .../scala/org/apache/spark/MapOutputTracker.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5f3d2b8ab97b..788707f225c7 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -46,6 +46,10 @@ private class ShuffleStatus(numPartitions: Int) { private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ private[this] var _numAvailableOutputs: Int = 0 + def hasCachedSerializedBroadcast: Boolean = synchronized { + cachedSerializedBroadcast != null + } + def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { val prevList = outputLocs(mapId) outputLocs(mapId) = status :: prevList @@ -82,10 +86,6 @@ private class ShuffleStatus(numPartitions: Int) { } } removeBroadcast() -// if (becameUnavailable) { -// logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( -// this, execId, _numAvailableOutputs, numPartitions, isAvailable)) -// } } /** @@ -325,7 +325,9 @@ private[spark] class MapOutputTrackerMaster( private val PoisonPill = new GetMapOutputMessage(-99, null) // Exposed for testing - private[spark] def getNumCachedSerializedBroadcast = 0 // TODO(josh) + private[spark] def getNumCachedSerializedBroadcast: Int = { + shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) + } def registerShuffle(shuffleId: Int, numMaps: Int) { if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { From 5683ec17d6c6df32820e84906107cc838c1166fb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 May 2017 23:26:36 -0700 Subject: [PATCH 06/14] Fix getMapSizesByExecutorId in case of no outputs. --- .../scala/org/apache/spark/MapOutputTracker.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 788707f225c7..8f50636f63b9 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -490,9 +490,14 @@ private[spark] class MapOutputTrackerMaster( def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - shuffleStatuses(shuffleId).withOutputLocs { outputLocs => - val statuses = outputLocs.map(_.headOption.orNull) - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + shuffleStatuses.get(shuffleId) match { + case Some (shuffleStatus) => + shuffleStatus.withOutputLocs { outputLocs => + val statuses = outputLocs.map(_.headOption.orNull) + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } + case None => + Seq.empty } } From 06ef8d35589901088daa1fd90bd2e8892b8ce7f5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 11 May 2017 00:13:13 -0700 Subject: [PATCH 07/14] Get DAGScheduler suite passing. --- .../scala/org/apache/spark/MapOutputTracker.scala | 12 +++++++----- .../org/apache/spark/scheduler/DAGScheduler.scala | 3 ++- .../org/apache/spark/scheduler/ShuffleMapStage.scala | 12 ++++++------ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 8f50636f63b9..b69ba4fb8057 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -99,7 +99,6 @@ private class ShuffleStatus(numPartitions: Int) { /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ def findMissingPartitions(): Seq[Int] = synchronized { val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) - val _numAvailableOutputs = numAvailableOutputs assert(missing.size == numPartitions - _numAvailableOutputs, s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") missing @@ -371,12 +370,15 @@ private[spark] class MapOutputTrackerMaster( def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) def getNumAvailableOutputs(shuffleId: Int): Int = { - shuffleStatuses(shuffleId).numAvailableOutputs + shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0) } - /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ - def findMissingPartitions(shuffleId: Int): Seq[Int] = { - shuffleStatuses(shuffleId).findMissingPartitions() + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None + * if the MapOutputTrackerMaster doesn't know about this shuffle. + */ + def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = { + shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4eb28e2be9e0..f28a5f291d75 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -328,7 +328,8 @@ class DAGScheduler( val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() - val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep) + val stage = new ShuffleMapStage( + id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker) stageIdToStage(id) = stage shuffleIdToMapStage(shuffleDep.shuffleId) = stage diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index e1000352f122..05f650fbf5df 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -41,7 +41,8 @@ private[spark] class ShuffleMapStage( parents: List[Stage], firstJobId: Int, callSite: CallSite, - val shuffleDep: ShuffleDependency[_, _, _]) + val shuffleDep: ShuffleDependency[_, _, _], + mapOutputTrackerMaster: MapOutputTrackerMaster) extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { private[this] var _mapStageJobs: List[ActiveJob] = Nil @@ -75,22 +76,21 @@ private[spark] class ShuffleMapStage( _mapStageJobs = _mapStageJobs.filter(_ != job) } - private def mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - /** * Number of partitions that have shuffle outputs. * When this reaches [[numPartitions]], this map stage is ready. */ - def numAvailableOutputs: Int = mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) + def numAvailableOutputs: Int = mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId) /** * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. - * This should be the same as `outputLocs.contains(Nil)`. */ def isAvailable: Boolean = numAvailableOutputs == numPartitions /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ override def findMissingPartitions(): Seq[Int] = { - mapOutputTracker.findMissingPartitions(shuffleDep.shuffleId) + mapOutputTrackerMaster + .findMissingPartitions(shuffleDep.shuffleId) + .getOrElse(0 until numPartitions) } } From e9caad5246f9d6e81663743ef3de7600a6a65960 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 11 May 2017 00:14:40 -0700 Subject: [PATCH 08/14] formatting --- core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index b69ba4fb8057..de6e0d9e9b82 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -503,7 +503,7 @@ private[spark] class MapOutputTrackerMaster( } } - override def stop() { + override def stop() { mapOutputRequests.offer(PoisonPill) threadpool.shutdown() sendTracker(StopMapOutputTracker) From 54a033cfbeaf99ec380b60a2acf7d93ed24c8545 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 11 May 2017 14:46:21 -0700 Subject: [PATCH 09/14] Remove need to store multiple statuses for a single map (explanation to come in PR desc.) --- .../org/apache/spark/MapOutputTracker.scala | 80 ++++++++----------- 1 file changed, 35 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index de6e0d9e9b82..f432d0b58c9f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -37,11 +37,14 @@ import org.apache.spark.util._ private class ShuffleStatus(numPartitions: Int) { /** - * List of [[MapStatus]] for each partition. The index of the array is the map partition id, - * and each value in the array is the list of possible [[MapStatus]] for a partition - * (a single task might run multiple times). + * [[MapStatus]] for each partition. The index of the array is the map partition id. + * Each value in the array is the [[MapStatus]] for a partition, or null if the partition + * is not available. Even though in theory a task may run multiple times (due to speculation, + * stage retries, etc., in practice the likelihood of a map output being available at multiple + * locations is so small that we choose to ignore that case and store only a single location + * for each output. */ - private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + private[this] val mapStatuses = new Array[MapStatus](numPartitions) private[this] var cachedSerializedMapStatus: Array[Byte] = _ private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ private[this] var _numAvailableOutputs: Int = 0 @@ -51,22 +54,21 @@ private class ShuffleStatus(numPartitions: Int) { } def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { - val prevList = outputLocs(mapId) - outputLocs(mapId) = status :: prevList - if (prevList == Nil) { + if (mapStatuses(mapId) == null) { _numAvailableOutputs += 1 + invalidateSerializedMapOutputStatusCache() + } else { + // TODO(josh) log? } - removeBroadcast() + mapStatuses(mapId) = status } def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { - val prevList = outputLocs(mapId) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(mapId) = newList - if (prevList != Nil && newList == Nil) { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() } - removeBroadcast() } /** @@ -75,17 +77,13 @@ private class ShuffleStatus(numPartitions: Int) { * registered with this execId. */ def removeOutputsOnExecutor(execId: String): Unit = synchronized { - var becameUnavailable = false - for (partition <- 0 until outputLocs.length) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true + for (mapId <- 0 until mapStatuses.length) { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location.executorId == execId) { _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() } } - removeBroadcast() } /** @@ -98,19 +96,12 @@ private class ShuffleStatus(numPartitions: Int) { /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ def findMissingPartitions(): Seq[Int] = synchronized { - val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) + val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) assert(missing.size == numPartitions - _numAvailableOutputs, s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") missing } - /** - * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned - * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition, - * that position is filled with null. - */ - private def mapStatuses: Array[MapStatus] = outputLocs.map(_.headOption.orNull) - def serializedMapStatus( broadcastManager: BroadcastManager, isLocal: Boolean, @@ -124,11 +115,11 @@ private class ShuffleStatus(numPartitions: Int) { cachedSerializedMapStatus } - def withOutputLocs[T](f: Array[List[MapStatus]] => T): T = synchronized { - f(outputLocs) + def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized { + f(mapStatuses) } - def removeBroadcast(): Unit = synchronized { + def invalidateSerializedMapOutputStatusCache(): Unit = synchronized { if (cachedSerializedBroadcast != null) { cachedSerializedBroadcast.destroy() cachedSerializedBroadcast = null @@ -352,7 +343,7 @@ private[spark] class MapOutputTrackerMaster( /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int) { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => - shuffleStatus.removeBroadcast() + shuffleStatus.invalidateSerializedMapOutputStatusCache() } } @@ -381,23 +372,23 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) } - /** +/** * Return statistics about all of the outputs for a given shuffle. */ def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { - shuffleStatuses(dep.shuffleId).withOutputLocs { outputLocs => + shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for ( - mapOutputs <- outputLocs; - s <- mapOutputs.headOption; - i <- 0 until totalSizes.length - ) { - totalSizes(i) += s.getSizeForBlock(i) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } } new MapOutputStatistics(dep.shuffleId, totalSizes) } } + + /** * Return the preferred hosts on which to run the given map output partition in a given shuffle, * i.e. the nodes that the most outputs for that partition are on. @@ -441,14 +432,14 @@ private[spark] class MapOutputTrackerMaster( val shuffleStatus = shuffleStatuses.get(shuffleId).orNull if (shuffleStatus != null) { - shuffleStatus.withOutputLocs { statuses => + shuffleStatus.withMapStatuses { statuses => if (statuses.nonEmpty) { // HashMap to add up sizes of all blocks at the same location val locs = new HashMap[BlockManagerId, Long] var totalOutputSize = 0L var mapIdx = 0 while (mapIdx < statuses.length) { - val status = statuses(mapIdx).headOption.orNull + val status = statuses(mapIdx) // status may be null here if we are called between registerShuffle, which creates an // array with null entries for each output, and registerMapOutputs, which populates it // with valid status entries. This is possible if one thread schedules a job which @@ -494,8 +485,7 @@ private[spark] class MapOutputTrackerMaster( logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => - shuffleStatus.withOutputLocs { outputLocs => - val statuses = outputLocs.map(_.headOption.orNull) + shuffleStatus.withMapStatuses { statuses => MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } case None => From f4c096f69cadafd7695b38fad14d58c54943a085 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 11 May 2017 15:23:13 -0700 Subject: [PATCH 10/14] Add Scaladoc. --- .../org/apache/spark/MapOutputTracker.scala | 95 +++++++++++++++---- 1 file changed, 79 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index f432d0b58c9f..708b19718860 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -34,35 +34,68 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ +/** + * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single + * ShuffleMapStage. + * + * This class maintains a mapping from mapIds to [[MapStatus]]es. It also maintains a cache of + * serialized map statuses in order to speed up tasks' requests for map output statuses. + * + * All public methods of this class are thread-safe. + */ private class ShuffleStatus(numPartitions: Int) { + // All accesses to the following state must be guarded with `this.synchronized`. + /** * [[MapStatus]] for each partition. The index of the array is the map partition id. * Each value in the array is the [[MapStatus]] for a partition, or null if the partition * is not available. Even though in theory a task may run multiple times (due to speculation, - * stage retries, etc., in practice the likelihood of a map output being available at multiple + * stage retries, etc.), in practice the likelihood of a map output being available at multiple * locations is so small that we choose to ignore that case and store only a single location * for each output. */ private[this] val mapStatuses = new Array[MapStatus](numPartitions) + + /** + * The cached result of serializing the map statuses array. This cache is lazily populated when + * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. + */ private[this] var cachedSerializedMapStatus: Array[Byte] = _ + + /** + * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]] + * serializes the map statuses array it may detect that the result is too large to send in a + * single RPC, in which case it places the serialized array into a broadcast variable and then + * sends a serialized broadcast variable instead. This variable holds a reference to that + * broadcast variable in order to keep it from being garbage collected and to allow for it to be + * explicitly destroyed later on when the ShuffleMapStage is garbage-collected. + */ private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ - private[this] var _numAvailableOutputs: Int = 0 - def hasCachedSerializedBroadcast: Boolean = synchronized { - cachedSerializedBroadcast != null - } + /** + * Counter tracking the number of partitions that have output. This is a performance optimization + * to avoid having to count the number of non-null entries in the `mapStatuses` array and should + * be equivalent to`mapStatuses.count(_ ne null)`. + */ + private[this] var _numAvailableOutputs: Int = 0 + /** + * Register a map output. If there is already a registered location for the map output then it + * will be replaced by the new location. + */ def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { if (mapStatuses(mapId) == null) { _numAvailableOutputs += 1 - invalidateSerializedMapOutputStatusCache() - } else { - // TODO(josh) log? } mapStatuses(mapId) = status } + /** + * Remove the map output which was served by the specified block manager. + * This is a no-op if there is no registered map output or if the registered output is from a + * different block manager. + */ def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { _numAvailableOutputs -= 1 @@ -72,9 +105,9 @@ private class ShuffleStatus(numPartitions: Int) { } /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. + * Removes all map outputs associated with the specified executor. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists), as they are + * still registered with that execId. */ def removeOutputsOnExecutor(execId: String): Unit = synchronized { for (mapId <- 0 until mapStatuses.length) { @@ -88,13 +121,14 @@ private class ShuffleStatus(numPartitions: Int) { /** * Number of partitions that have shuffle outputs. - * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. */ def numAvailableOutputs: Int = synchronized { _numAvailableOutputs } - /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + */ def findMissingPartitions(): Seq[Int] = synchronized { val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) assert(missing.size == numPartitions - _numAvailableOutputs, @@ -102,6 +136,15 @@ private class ShuffleStatus(numPartitions: Int) { missing } + /** + * Serializes the mapStatuses array into an efficient compressed format. See the comments on + * [[MapOutputTracker.serializeMapStatuses()]] for more details on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the map statuses then serialization will only be performed in a single thread and all + * other threads will block until the cache is populated. + */ def serializedMapStatus( broadcastManager: BroadcastManager, isLocal: Boolean, @@ -115,10 +158,22 @@ private class ShuffleStatus(numPartitions: Int) { cachedSerializedMapStatus } + // Used in testing. + def hasCachedSerializedBroadcast: Boolean = synchronized { + cachedSerializedBroadcast != null + } + + /** + * Helper function which provides thread-safe access to the mapStatuses array. + * The function should NOT mutate the array. + */ def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized { f(mapStatuses) } + /** + * Clears the cached serialized map output statuses. + */ def invalidateSerializedMapOutputStatusCache(): Unit = synchronized { if (cachedSerializedBroadcast != null) { cachedSerializedBroadcast.destroy() @@ -214,6 +269,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] + /** + * Deletes map output status information for the specified shuffle stage. + */ def unregisterShuffle(shuffleId: Int): Unit def stop() {} @@ -221,6 +279,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging /** * Driver-side class that keeps track of the location of the map output of a stage. + * + * The DAGScheduler uses this class to (de)register map output statuses and to look up statistics + * for performing locality-aware reduce task scheduling. + * + * ShuffleMapStage uses this class for tracking available / missing outputs in order to determine + * which tasks need to be run. */ private[spark] class MapOutputTrackerMaster( conf: SparkConf, @@ -372,7 +436,7 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) } -/** + /** * Return statistics about all of the outputs for a given shuffle. */ def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { @@ -387,8 +451,6 @@ private[spark] class MapOutputTrackerMaster( } } - - /** * Return the preferred hosts on which to run the given map output partition in a given shuffle, * i.e. the nodes that the most outputs for that partition are on. @@ -480,6 +542,7 @@ private[spark] class MapOutputTrackerMaster( } } + // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") From 7d59bbe117a0a06b418c10fd509b9ff7e344bfd7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 11 May 2017 15:28:43 -0700 Subject: [PATCH 11/14] Fix test multiple simultaneous attempts for one task (SPARK-8029) --- core/src/main/scala/org/apache/spark/MapOutputTracker.scala | 2 +- core/src/test/scala/org/apache/spark/ShuffleSuite.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 708b19718860..eafbc976428d 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -378,7 +378,7 @@ private[spark] class MapOutputTrackerMaster( /** A poison endpoint that indicates MessageLoop should exit its message loop. */ private val PoisonPill = new GetMapOutputMessage(-99, null) - // Exposed for testing + // Used only in unit tests. private[spark] def getNumCachedSerializedBroadcast: Int = { shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 192d04884507..a2042cbdbc54 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -333,6 +333,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) + mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, From e3da298d59c764388ec6ca93ec23ba3eb8de96d3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 11 May 2017 17:47:03 -0700 Subject: [PATCH 12/14] Fix javaunidoc. --- .../main/scala/org/apache/spark/MapOutputTracker.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index eafbc976428d..2650ca7069b8 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -38,7 +38,7 @@ import org.apache.spark.util._ * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single * ShuffleMapStage. * - * This class maintains a mapping from mapIds to [[MapStatus]]es. It also maintains a cache of + * This class maintains a mapping from mapIds to `MapStatus`. It also maintains a cache of * serialized map statuses in order to speed up tasks' requests for map output statuses. * * All public methods of this class are thread-safe. @@ -48,8 +48,8 @@ private class ShuffleStatus(numPartitions: Int) { // All accesses to the following state must be guarded with `this.synchronized`. /** - * [[MapStatus]] for each partition. The index of the array is the map partition id. - * Each value in the array is the [[MapStatus]] for a partition, or null if the partition + * MapStatus for each partition. The index of the array is the map partition id. + * Each value in the array is the MapStatus for a partition, or null if the partition * is not available. Even though in theory a task may run multiple times (due to speculation, * stage retries, etc.), in practice the likelihood of a map output being available at multiple * locations is so small that we choose to ignore that case and store only a single location @@ -138,7 +138,7 @@ private class ShuffleStatus(numPartitions: Int) { /** * Serializes the mapStatuses array into an efficient compressed format. See the comments on - * [[MapOutputTracker.serializeMapStatuses()]] for more details on the serialization format. + * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format. * * This method is designed to be called multiple times and implements caching in order to speed * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to From a8069a3fb3edbff39786301d7572be6de1cd931c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 12 May 2017 16:52:21 -0700 Subject: [PATCH 13/14] Clarify handling of epochs; fix 'remote fetch' failing test. --- .../org/apache/spark/MapOutputTracker.scala | 22 +++++++++++++++---- .../org/apache/spark/executor/Executor.scala | 4 ++++ .../apache/spark/scheduler/DAGScheduler.scala | 8 ------- .../apache/spark/MapOutputTrackerSuite.scala | 6 ++--- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 2650ca7069b8..3a44ed7241e6 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -87,6 +87,7 @@ private class ShuffleStatus(numPartitions: Int) { def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { if (mapStatuses(mapId) == null) { _numAvailableOutputs += 1 + invalidateSerializedMapOutputStatusCache() } mapStatuses(mapId) = status } @@ -222,8 +223,11 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging var trackerEndpoint: RpcEndpointRef = _ /** - * Incremented every time a fetch fails so that client nodes know to clear - * their cache of map output locations if this happens. + * The driver-side counter is incremented every time that a map output is lost. This value is sent + * to executors as part of tasks, where executors compare the new epoch number to the highest + * epoch number that they received in the past. If the new epoch number is higher then executors + * will clear their local caches of map output statuses and will re-fetch (possibly updated) + * statuses from the driver. */ protected var epoch: Long = 0 protected val epochLock = new AnyRef @@ -528,7 +532,7 @@ private[spark] class MapOutputTrackerMaster( None } - def incrementEpoch() { + private def incrementEpoch() { epochLock.synchronized { epoch += 1 logDebug("Increasing epoch to " + epoch) @@ -567,6 +571,9 @@ private[spark] class MapOutputTrackerMaster( /** * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster. + * Note that this is not used in local-mode; instead, local-mode Executors access the + * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon + * superclass). */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { @@ -580,7 +587,14 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + try { + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } catch { + case e: MetadataFetchFailedException => + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + mapStatuses.clear() + throw e + } } /** diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 08f6f10b6363..425568483440 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -322,6 +322,10 @@ private[spark] class Executor( throw new TaskKilledException(killReason.get) } + // The purpose of updating the epoch here is to invalidate executor map output status cache + // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be + // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so + // we don't need to make any special calls here. if (!isLocal) { logDebug("Task " + taskId + "'s epoch is " + task.epoch) env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f28a5f291d75..c0e02a45fb7f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1223,14 +1223,6 @@ class DAGScheduler( logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - // Increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.incrementEpoch() - clearCacheLocs() if (!shuffleStage.isAvailable) { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index bb24c6ce4d33..252b784f879b 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -138,21 +138,21 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) + // This is expected to fail because no outputs have been registered for the shuffle. intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) + val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - masterTracker.incrementEpoch() + assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) slaveTracker.updateEpoch(masterTracker.getEpoch) intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } From 4550f616a4f9c144a2da49a31ef3eaa19a0eeea8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 5 Jun 2017 13:41:42 -0700 Subject: [PATCH 14/14] Restore incrementEpoch() call. --- .../main/scala/org/apache/spark/MapOutputTracker.scala | 2 +- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 3a44ed7241e6..3e10b9eee4e2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -532,7 +532,7 @@ private[spark] class MapOutputTrackerMaster( None } - private def incrementEpoch() { + def incrementEpoch() { epochLock.synchronized { epoch += 1 logDebug("Increasing epoch to " + epoch) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 09132707ef4e..932e6c138e1c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1224,6 +1224,15 @@ class DAGScheduler( logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) + // This call to increment the epoch may not be strictly necessary, but it is retained + // for now in order to minimize the changes in behavior from an earlier version of the + // code. This existing behavior of always incrementing the epoch following any + // successful shuffle map stage completion may have benefits by causing unneeded + // cached map outputs to be cleaned up earlier on executors. In the future we can + // consider removing this call, but this will require some extra investigation. + // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details. + mapOutputTracker.incrementEpoch() + clearCacheLocs() if (!shuffleStage.isAvailable) {