diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4ef665622245..3e10b9eee4e2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -34,6 +34,156 @@ import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ +/** + * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single + * ShuffleMapStage. + * + * This class maintains a mapping from mapIds to `MapStatus`. It also maintains a cache of + * serialized map statuses in order to speed up tasks' requests for map output statuses. + * + * All public methods of this class are thread-safe. + */ +private class ShuffleStatus(numPartitions: Int) { + + // All accesses to the following state must be guarded with `this.synchronized`. + + /** + * MapStatus for each partition. The index of the array is the map partition id. + * Each value in the array is the MapStatus for a partition, or null if the partition + * is not available. Even though in theory a task may run multiple times (due to speculation, + * stage retries, etc.), in practice the likelihood of a map output being available at multiple + * locations is so small that we choose to ignore that case and store only a single location + * for each output. + */ + private[this] val mapStatuses = new Array[MapStatus](numPartitions) + + /** + * The cached result of serializing the map statuses array. This cache is lazily populated when + * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed. + */ + private[this] var cachedSerializedMapStatus: Array[Byte] = _ + + /** + * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]] + * serializes the map statuses array it may detect that the result is too large to send in a + * single RPC, in which case it places the serialized array into a broadcast variable and then + * sends a serialized broadcast variable instead. This variable holds a reference to that + * broadcast variable in order to keep it from being garbage collected and to allow for it to be + * explicitly destroyed later on when the ShuffleMapStage is garbage-collected. + */ + private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _ + + /** + * Counter tracking the number of partitions that have output. This is a performance optimization + * to avoid having to count the number of non-null entries in the `mapStatuses` array and should + * be equivalent to`mapStatuses.count(_ ne null)`. + */ + private[this] var _numAvailableOutputs: Int = 0 + + /** + * Register a map output. If there is already a registered location for the map output then it + * will be replaced by the new location. + */ + def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { + if (mapStatuses(mapId) == null) { + _numAvailableOutputs += 1 + invalidateSerializedMapOutputStatusCache() + } + mapStatuses(mapId) = status + } + + /** + * Remove the map output which was served by the specified block manager. + * This is a no-op if there is no registered map output or if the registered output is from a + * different block manager. + */ + def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Removes all map outputs associated with the specified executor. Note that this will also + * remove outputs which are served by an external shuffle server (if one exists), as they are + * still registered with that execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = synchronized { + for (mapId <- 0 until mapStatuses.length) { + if (mapStatuses(mapId) != null && mapStatuses(mapId).location.executorId == execId) { + _numAvailableOutputs -= 1 + mapStatuses(mapId) = null + invalidateSerializedMapOutputStatusCache() + } + } + } + + /** + * Number of partitions that have shuffle outputs. + */ + def numAvailableOutputs: Int = synchronized { + _numAvailableOutputs + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + */ + def findMissingPartitions(): Seq[Int] = synchronized { + val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null) + assert(missing.size == numPartitions - _numAvailableOutputs, + s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") + missing + } + + /** + * Serializes the mapStatuses array into an efficient compressed format. See the comments on + * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format. + * + * This method is designed to be called multiple times and implements caching in order to speed + * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to + * serialize the map statuses then serialization will only be performed in a single thread and all + * other threads will block until the cache is populated. + */ + def serializedMapStatus( + broadcastManager: BroadcastManager, + isLocal: Boolean, + minBroadcastSize: Int): Array[Byte] = synchronized { + if (cachedSerializedMapStatus eq null) { + val serResult = MapOutputTracker.serializeMapStatuses( + mapStatuses, broadcastManager, isLocal, minBroadcastSize) + cachedSerializedMapStatus = serResult._1 + cachedSerializedBroadcast = serResult._2 + } + cachedSerializedMapStatus + } + + // Used in testing. + def hasCachedSerializedBroadcast: Boolean = synchronized { + cachedSerializedBroadcast != null + } + + /** + * Helper function which provides thread-safe access to the mapStatuses array. + * The function should NOT mutate the array. + */ + def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized { + f(mapStatuses) + } + + /** + * Clears the cached serialized map output statuses. + */ + def invalidateSerializedMapOutputStatusCache(): Unit = synchronized { + if (cachedSerializedBroadcast != null) { + cachedSerializedBroadcast.destroy() + cachedSerializedBroadcast = null + } + cachedSerializedMapStatus = null + } +} + private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage @@ -62,37 +212,26 @@ private[spark] class MapOutputTrackerMasterEndpoint( } /** - * Class that keeps track of the location of the map output of - * a stage. This is abstract because different versions of MapOutputTracker - * (driver and executor) use different HashMap to store its metadata. - */ + * Class that keeps track of the location of the map output of a stage. This is abstract because the + * driver and executor have different versions of the MapOutputTracker. In principle the driver- + * and executor-side classes don't need to share a common base class; the current shared base class + * is maintained primarily for backwards-compatibility in order to avoid having to update existing + * test code. +*/ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */ var trackerEndpoint: RpcEndpointRef = _ /** - * This HashMap has different behavior for the driver and the executors. - * - * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks. - * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the - * driver's corresponding HashMap. - * - * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a - * thread-safe map. - */ - protected val mapStatuses: Map[Int, Array[MapStatus]] - - /** - * Incremented every time a fetch fails so that client nodes know to clear - * their cache of map output locations if this happens. + * The driver-side counter is incremented every time that a map output is lost. This value is sent + * to executors as part of tasks, where executors compare the new epoch number to the highest + * epoch number that they received in the past. If the new epoch number is higher then executors + * will clear their local caches of map output statuses and will re-fetch (possibly updated) + * statuses from the driver. */ protected var epoch: Long = 0 protected val epochLock = new AnyRef - /** Remembers which map output locations are currently being fetched on an executor. */ - private val fetching = new HashSet[Int] - /** * Send a message to the trackerEndpoint and get its result within a default timeout, or * throw a SparkException if this fails. @@ -116,14 +255,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } - /** - * Called from executors to get the server URIs and output sizes for each shuffle block that - * needs to be read from a given reduce task. - * - * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block id, shuffle block size) tuples - * describing the shuffle blocks that are stored at that block manager. - */ + // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) @@ -139,135 +271,31 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * describing the shuffle blocks that are stored at that block manager. */ def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") - val statuses = getStatuses(shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) - } - } + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] /** - * Return statistics about all of the outputs for a given shuffle. + * Deletes map output status information for the specified shuffle stage. */ - def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { - val statuses = getStatuses(dep.shuffleId) - // Synchronize on the returned array because, on the driver, it gets mutated in place - statuses.synchronized { - val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { - for (i <- 0 until totalSizes.length) { - totalSizes(i) += s.getSizeForBlock(i) - } - } - new MapOutputStatistics(dep.shuffleId, totalSizes) - } - } - - /** - * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize - * on this array when reading it, because on the driver, we may be changing it in place. - * - * (It would be nice to remove this restriction in the future.) - */ - private def getStatuses(shuffleId: Int): Array[MapStatus] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses == null) { - logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") - val startTime = System.currentTimeMillis - var fetchedStatuses: Array[MapStatus] = null - fetching.synchronized { - // Someone else is fetching it; wait for them to be done - while (fetching.contains(shuffleId)) { - try { - fetching.wait() - } catch { - case e: InterruptedException => - } - } - - // Either while we waited the fetch happened successfully, or - // someone fetched it in between the get and the fetching.synchronized. - fetchedStatuses = mapStatuses.get(shuffleId).orNull - if (fetchedStatuses == null) { - // We have to do the fetch, get others to wait for us. - fetching += shuffleId - } - } + def unregisterShuffle(shuffleId: Int): Unit - if (fetchedStatuses == null) { - // We won the race to fetch the statuses; do so - logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) - // This try-finally prevents hangs due to timeouts: - try { - val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) - logInfo("Got the output locations") - mapStatuses.put(shuffleId, fetchedStatuses) - } finally { - fetching.synchronized { - fetching -= shuffleId - fetching.notifyAll() - } - } - } - logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + - s"${System.currentTimeMillis - startTime} ms") - - if (fetchedStatuses != null) { - return fetchedStatuses - } else { - logError("Missing all output locations for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) - } - } else { - return statuses - } - } - - /** Called to get current epoch number. */ - def getEpoch: Long = { - epochLock.synchronized { - return epoch - } - } - - /** - * Called from executors to update the epoch number, potentially clearing old outputs - * because of a fetch failure. Each executor task calls this with the latest epoch - * number on the driver at the time it was created. - */ - def updateEpoch(newEpoch: Long) { - epochLock.synchronized { - if (newEpoch > epoch) { - logInfo("Updating epoch to " + newEpoch + " and clearing cache") - epoch = newEpoch - mapStatuses.clear() - } - } - } - - /** Unregister shuffle data. */ - def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - } - - /** Stop the tracker. */ - def stop() { } + def stop() {} } /** - * MapOutputTracker for the driver. + * Driver-side class that keeps track of the location of the map output of a stage. + * + * The DAGScheduler uses this class to (de)register map output statuses and to look up statistics + * for performing locality-aware reduce task scheduling. + * + * ShuffleMapStage uses this class for tracking available / missing outputs in order to determine + * which tasks need to be run. */ -private[spark] class MapOutputTrackerMaster(conf: SparkConf, - broadcastManager: BroadcastManager, isLocal: Boolean) +private[spark] class MapOutputTrackerMaster( + conf: SparkConf, + broadcastManager: BroadcastManager, + isLocal: Boolean) extends MapOutputTracker(conf) { - /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ - private var cacheEpoch = epoch - // The size at which we use Broadcast to send the map output statuses to the executors private val minSizeForBroadcast = conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt @@ -287,22 +315,12 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, // can be read locally, but may lead to more delay in scheduling if those locations are busy. private val REDUCER_PREF_LOCS_FRACTION = 0.2 - // HashMaps for storing mapStatuses and cached serialized statuses in the driver. + // HashMap for storing shuffleStatuses in the driver. // Statuses are dropped only by explicit de-registering. - protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala - private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala + private val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) - // Kept in sync with cachedSerializedStatuses explicitly - // This is required so that the Broadcast variable remains in scope until we remove - // the shuffleId explicitly or implicitly. - private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]() - - // This is to prevent multiple serializations of the same shuffle - which happens when - // there is a request storm when shuffle start. - private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]() - // requests for map output statuses private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage] @@ -348,8 +366,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, val hostPort = context.senderAddress.hostPort logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort) - val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) - context.reply(mapOutputStatuses) + val shuffleStatus = shuffleStatuses.get(shuffleId).head + context.reply( + shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast)) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -363,59 +382,77 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, /** A poison endpoint that indicates MessageLoop should exit its message loop. */ private val PoisonPill = new GetMapOutputMessage(-99, null) - // Exposed for testing - private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size + // Used only in unit tests. + private[spark] def getNumCachedSerializedBroadcast: Int = { + shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast) + } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - // add in advance - shuffleIdLocks.putIfAbsent(shuffleId, new Object()) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - val array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } - } - - /** Register multiple map output information for the given shuffle */ - def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { - mapStatuses.put(shuffleId, statuses.clone()) - if (changeEpoch) { - incrementEpoch() - } + shuffleStatuses(shuffleId).addMapOutput(mapId, status) } /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - val arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - val array = arrayOpt.get - array.synchronized { - if (array(mapId) != null && array(mapId).location == bmAddress) { - array(mapId) = null - } - } - incrementEpoch() - } else { - throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeMapOutput(mapId, bmAddress) + incrementEpoch() + case None => + throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") } } /** Unregister shuffle data */ - override def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) - cachedSerializedStatuses.remove(shuffleId) - cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v)) - shuffleIdLocks.remove(shuffleId) + def unregisterShuffle(shuffleId: Int) { + shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => + shuffleStatus.invalidateSerializedMapOutputStatusCache() + } + } + + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ + def removeOutputsOnExecutor(execId: String): Unit = { + shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) } + incrementEpoch() } /** Check if the given shuffle is being tracked */ - def containsShuffle(shuffleId: Int): Boolean = { - cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) + def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId) + + def getNumAvailableOutputs(shuffleId: Int): Int = { + shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0) + } + + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None + * if the MapOutputTrackerMaster doesn't know about this shuffle. + */ + def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = { + shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) + } + + /** + * Return statistics about all of the outputs for a given shuffle. + */ + def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { + shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => + val totalSizes = new Array[Long](dep.partitioner.numPartitions) + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + new MapOutputStatistics(dep.shuffleId, totalSizes) + } } /** @@ -459,9 +496,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, fractionThreshold: Double) : Option[Array[BlockManagerId]] = { - val statuses = mapStatuses.get(shuffleId).orNull - if (statuses != null) { - statuses.synchronized { + val shuffleStatus = shuffleStatuses.get(shuffleId).orNull + if (shuffleStatus != null) { + shuffleStatus.withMapStatuses { statuses => if (statuses.nonEmpty) { // HashMap to add up sizes of all blocks at the same location val locs = new HashMap[BlockManagerId, Long] @@ -502,77 +539,24 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, } } - private def removeBroadcast(bcast: Broadcast[_]): Unit = { - if (null != bcast) { - broadcastManager.unbroadcast(bcast.id, - removeFromDriver = true, blocking = false) + /** Called to get current epoch number. */ + def getEpoch: Long = { + epochLock.synchronized { + return epoch } } - private def clearCachedBroadcast(): Unit = { - for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2) - cachedSerializedBroadcast.clear() - } - - def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { - var statuses: Array[MapStatus] = null - var retBytes: Array[Byte] = null - var epochGotten: Long = -1 - - // Check to see if we have a cached version, returns true if it does - // and has side effect of setting retBytes. If not returns false - // with side effect of setting statuses - def checkCachedStatuses(): Boolean = { - epochLock.synchronized { - if (epoch > cacheEpoch) { - cachedSerializedStatuses.clear() - clearCachedBroadcast() - cacheEpoch = epoch - } - cachedSerializedStatuses.get(shuffleId) match { - case Some(bytes) => - retBytes = bytes - true - case None => - logDebug("cached status not found for : " + shuffleId) - statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus]) - epochGotten = epoch - false - } - } - } - - if (checkCachedStatuses()) return retBytes - var shuffleIdLock = shuffleIdLocks.get(shuffleId) - if (null == shuffleIdLock) { - val newLock = new Object() - // in general, this condition should be false - but good to be paranoid - val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock) - shuffleIdLock = if (null != prevLock) prevLock else newLock - } - // synchronize so we only serialize/broadcast it once since multiple threads call - // in parallel - shuffleIdLock.synchronized { - // double check to make sure someone else didn't serialize and cache the same - // mapstatus while we were waiting on the synchronize - if (checkCachedStatuses()) return retBytes - - // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "statuses"; let's serialize and return that - val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager, - isLocal, minSizeForBroadcast) - logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) - // Add them into the table only if the epoch hasn't changed while we were working - epochLock.synchronized { - if (epoch == epochGotten) { - cachedSerializedStatuses(shuffleId) = bytes - if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast - } else { - logInfo("Epoch changed, not caching!") - removeBroadcast(bcast) + // This method is only called in local-mode. + def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + shuffleStatuses.get(shuffleId) match { + case Some (shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) } - } - bytes + case None => + Seq.empty } } @@ -580,21 +564,121 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, mapOutputRequests.offer(PoisonPill) threadpool.shutdown() sendTracker(StopMapOutputTracker) - mapStatuses.clear() trackerEndpoint = null - cachedSerializedStatuses.clear() - clearCachedBroadcast() - shuffleIdLocks.clear() + shuffleStatuses.clear() } } /** - * MapOutputTracker for the executors, which fetches map output information from the driver's - * MapOutputTrackerMaster. + * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster. + * Note that this is not used in local-mode; instead, local-mode Executors access the + * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon + * superclass). */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { - protected val mapStatuses: Map[Int, Array[MapStatus]] = + + val mapStatuses: Map[Int, Array[MapStatus]] = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + + /** Remembers which map output locations are currently being fetched on an executor. */ + private val fetching = new HashSet[Int] + + override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") + val statuses = getStatuses(shuffleId) + try { + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + } catch { + case e: MetadataFetchFailedException => + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + mapStatuses.clear() + throw e + } + } + + /** + * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize + * on this array when reading it, because on the driver, we may be changing it in place. + * + * (It would be nice to remove this restriction in the future.) + */ + private def getStatuses(shuffleId: Int): Array[MapStatus] = { + val statuses = mapStatuses.get(shuffleId).orNull + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") + val startTime = System.currentTimeMillis + var fetchedStatuses: Array[MapStatus] = null + fetching.synchronized { + // Someone else is fetching it; wait for them to be done + while (fetching.contains(shuffleId)) { + try { + fetching.wait() + } catch { + case e: InterruptedException => + } + } + + // Either while we waited the fetch happened successfully, or + // someone fetched it in between the get and the fetching.synchronized. + fetchedStatuses = mapStatuses.get(shuffleId).orNull + if (fetchedStatuses == null) { + // We have to do the fetch, get others to wait for us. + fetching += shuffleId + } + } + + if (fetchedStatuses == null) { + // We won the race to fetch the statuses; do so + logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) + // This try-finally prevents hangs due to timeouts: + try { + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) + fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + logInfo("Got the output locations") + mapStatuses.put(shuffleId, fetchedStatuses) + } finally { + fetching.synchronized { + fetching -= shuffleId + fetching.notifyAll() + } + } + } + logDebug(s"Fetching map output statuses for shuffle $shuffleId took " + + s"${System.currentTimeMillis - startTime} ms") + + if (fetchedStatuses != null) { + fetchedStatuses + } else { + logError("Missing all output locations for shuffle " + shuffleId) + throw new MetadataFetchFailedException( + shuffleId, -1, "Missing all output locations for shuffle " + shuffleId) + } + } else { + statuses + } + } + + + /** Unregister shuffle data. */ + def unregisterShuffle(shuffleId: Int): Unit = { + mapStatuses.remove(shuffleId) + } + + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each executor task calls this with the latest epoch + * number on the driver at the time it was created. + */ + def updateEpoch(newEpoch: Long): Unit = { + epochLock.synchronized { + if (newEpoch > epoch) { + logInfo("Updating epoch to " + newEpoch + " and clearing cache") + epoch = newEpoch + mapStatuses.clear() + } + } + } } private[spark] object MapOutputTracker extends Logging { @@ -683,7 +767,7 @@ private[spark] object MapOutputTracker extends Logging { * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - private def convertMapStatuses( + def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index f83f5278e8b8..93a933720f22 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -32,6 +32,9 @@ import org.apache.spark.util.random.SamplingUtils /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. * Maps each key to a partition ID, from 0 to `numPartitions - 1`. + * + * Note that, partitioner must be deterministic, i.e. it must return the same partition id given + * the same partition key. */ abstract class Partitioner extends Serializable { def numPartitions: Int diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 36b17439aa81..47c51c0474f2 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -325,8 +325,14 @@ private[spark] class Executor( throw new TaskKilledException(killReason.get) } - logDebug("Task " + taskId + "'s epoch is " + task.epoch) - env.mapOutputTracker.updateEpoch(task.epoch) + // The purpose of updating the epoch here is to invalidate executor map output status cache + // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be + // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so + // we don't need to make any special calls here. + if (!isLocal) { + logDebug("Task " + taskId + "'s epoch is " + task.epoch) + env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch) + } // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index e4587c96eae1..15128f0913af 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -23,11 +23,22 @@ import org.apache.spark.{Partition, TaskContext} /** * An RDD that applies the provided function to every partition of the parent RDD. + * + * @param prev the parent RDD. + * @param f The function used to map a tuple of (TaskContext, partition index, input iterator) to + * an output iterator. + * @param preservesPartitioning Whether the input function preserves the partitioner, which should + * be `false` unless `prev` is a pair RDD and the input function + * doesn't modify the keys. + * @param isOrderSensitive whether or not the function is order-sensitive. If it's order + * sensitive, it may return totally different result when the input order + * is changed. Mostly stateful functions are order-sensitive. */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) - preservesPartitioning: Boolean = false) + preservesPartitioning: Boolean = false, + isOrderSensitive: Boolean = false) extends RDD[U](prev) { override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None @@ -41,4 +52,12 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( super.clearDependencies() prev = null } + + override protected def getOutputDeterministicLevel = { + if (isOrderSensitive && prev.outputDeterministicLevel == DeterministicLevel.UNORDERED) { + DeterministicLevel.INDETERMINATE + } else { + super.getOutputDeterministicLevel + } + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 102836d5a558..4ff0f83263db 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -461,8 +461,9 @@ abstract class RDD[T: ClassTag]( // include a shuffle step so that our upstream tasks are still distributed new CoalescedRDD( - new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition), - new HashPartitioner(numPartitions)), + new ShuffledRDD[Int, T, T]( + mapPartitionsWithIndexInternal(distributePartition, isOrderSensitive = true), + new HashPartitioner(numPartitions)), numPartitions, partitionCoalescer).values } else { @@ -806,16 +807,21 @@ abstract class RDD[T: ClassTag]( * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, - * which should be `false` unless this is a pair RDD and the input function doesn't modify - * the keys. + * which should be `false` unless this is a pair RDD and the input + * function doesn't modify the keys. + * @param isOrderSensitive whether or not the function is order-sensitive. If it's order + * sensitive, it may return totally different result when the input order + * is changed. Mostly stateful functions are order-sensitive. */ private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = withScope { + preservesPartitioning: Boolean = false, + isOrderSensitive: Boolean = false): RDD[U] = withScope { new MapPartitionsRDD( this, (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), - preservesPartitioning) + preservesPartitioning = preservesPartitioning, + isOrderSensitive = isOrderSensitive) } /** @@ -1634,6 +1640,16 @@ abstract class RDD[T: ClassTag]( } } + /** + * Return whether this RDD is reliably checkpointed and materialized. + */ + private[rdd] def isReliablyCheckpointed: Boolean = { + checkpointData match { + case Some(reliable: ReliableRDDCheckpointData[_]) if reliable.isCheckpointed => true + case _ => false + } + } + /** * Gets the name of the directory to which this RDD was checkpointed. * This is not defined if the RDD is checkpointed locally. @@ -1838,6 +1854,63 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } + + /** + * Returns the deterministic level of this RDD's output. Please refer to [[DeterministicLevel]] + * for the definition. + * + * By default, an reliably checkpointed RDD, or RDD without parents(root RDD) is DETERMINATE. For + * RDDs with parents, we will generate a deterministic level candidate per parent according to + * the dependency. The deterministic level of the current RDD is the deterministic level + * candidate that is deterministic least. Please override [[getOutputDeterministicLevel]] to + * provide custom logic of calculating output deterministic level. + */ + // TODO: make it public so users can set deterministic level to their custom RDDs. + // TODO: this can be per-partition. e.g. UnionRDD can have different deterministic level for + // different partitions. + private[spark] final lazy val outputDeterministicLevel: DeterministicLevel.Value = { + if (isReliablyCheckpointed) { + DeterministicLevel.DETERMINATE + } else { + getOutputDeterministicLevel + } + } + + @DeveloperApi + protected def getOutputDeterministicLevel: DeterministicLevel.Value = { + val deterministicLevelCandidates = dependencies.map { + // The shuffle is not really happening, treat it like narrow dependency and assume the output + // deterministic level of current RDD is same as parent. + case dep: ShuffleDependency[_, _, _] if dep.rdd.partitioner.exists(_ == dep.partitioner) => + dep.rdd.outputDeterministicLevel + + case dep: ShuffleDependency[_, _, _] => + if (dep.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) { + // If map output was indeterminate, shuffle output will be indeterminate as well + DeterministicLevel.INDETERMINATE + } else if (dep.keyOrdering.isDefined && dep.aggregator.isDefined) { + // if aggregator specified (and so unique keys) and key ordering specified - then + // consistent ordering. + DeterministicLevel.DETERMINATE + } else { + // In Spark, the reducer fetches multiple remote shuffle blocks at the same time, and + // the arrival order of these shuffle blocks are totally random. Even if the parent map + // RDD is DETERMINATE, the reduce RDD is always UNORDERED. + DeterministicLevel.UNORDERED + } + + // For narrow dependency, assume the output deterministic level of current RDD is same as + // parent. + case dep => dep.rdd.outputDeterministicLevel + } + + if (deterministicLevelCandidates.isEmpty) { + // By default we assume the root RDD is determinate. + DeterministicLevel.DETERMINATE + } else { + deterministicLevelCandidates.maxBy(_.id) + } + } } @@ -1891,3 +1964,18 @@ object RDD { new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) } } + +/** + * The deterministic level of RDD's output (i.e. what `RDD#compute` returns). This explains how + * the output will diff when Spark reruns the tasks for the RDD. There are 3 deterministic levels: + * 1. DETERMINATE: The RDD output is always the same data set in the same order after a rerun. + * 2. UNORDERED: The RDD output is always the same data set but the order can be different + * after a rerun. + * 3. INDETERMINATE. The RDD output can be different after a rerun. + * + * Note that, the output of an RDD usually relies on the parent RDDs. When the parent RDD's output + * is INDETERMINATE, it's very likely the RDD's output is also INDETERMINATE. + */ +private[spark] object DeterministicLevel extends Enumeration { + val DETERMINATE, UNORDERED, INDETERMINATE = Value +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 099bc2e7c2cb..cb6cdcd192a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -328,25 +328,14 @@ class DAGScheduler( val numTasks = rdd.partitions.length val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() - val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep) + val stage = new ShuffleMapStage( + id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker) stageIdToStage(id) = stage shuffleIdToMapStage(shuffleDep.shuffleId) = stage updateJobIdStageIdMaps(jobId, stage) - if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { - // A previously run stage generated partitions for this shuffle, so for each output - // that's still available, copy information about that output location to the new stage - // (so we don't unnecessarily re-compute that data). - val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) - val locs = MapOutputTracker.deserializeMapStatuses(serLocs) - (0 until locs.length).foreach { i => - if (locs(i) ne null) { - // locs(i) will be null if missing - stage.addOutputLoc(i, locs(i)) - } - } - } else { + if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") @@ -1240,7 +1229,8 @@ class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. - shuffleStage.addOutputLoc(smt.partitionId, status) + mapOutputTracker.registerMapOutput( + shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) // Remove the task's partition from pending partitions. This may have already been // done above, but will not have been done yet in cases where the task attempt was // from an earlier attempt of the stage (i.e., not the attempt that's currently @@ -1257,16 +1247,14 @@ class DAGScheduler( logInfo("waiting: " + waitingStages) logInfo("failed: " + failedStages) - // We supply true to increment the epoch number here in case this is a - // recomputation of the map outputs. In that case, some nodes may have cached - // locations with holes (from when we detected the error) and will need the - // epoch incremented to refetch them. - // TODO: Only increment the epoch number if this is not the first time - // we registered these map outputs. - mapOutputTracker.registerMapOutputs( - shuffleStage.shuffleDep.shuffleId, - shuffleStage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) + // This call to increment the epoch may not be strictly necessary, but it is retained + // for now in order to minimize the changes in behavior from an earlier version of the + // code. This existing behavior of always incrementing the epoch following any + // successful shuffle map stage completion may have benefits by causing unneeded + // cached map outputs to be cleaned up earlier on executors. In the future we can + // consider removing this call, but this will require some extra investigation. + // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details. + mapOutputTracker.incrementEpoch() clearCacheLocs() @@ -1344,6 +1332,63 @@ class DAGScheduler( failedStages += failedStage failedStages += mapStage if (noResubmitEnqueued) { + // If the map stage is INDETERMINATE, which means the map tasks may return + // different result when re-try, we need to re-try all the tasks of the failed + // stage and its succeeding stages, because the input data will be changed after the + // map tasks are re-tried. + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. + if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) { + // It's a little tricky to find all the succeeding stages of `failedStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages + // in the stage chains that connect to the `failedStage`. To speed up the stage + // traversing, we collect the stages to rollback first. If a stage needs to + // rollback, all its succeeding stages need to rollback to. + val stagesToRollback = scala.collection.mutable.HashSet(failedStage) + + def collectStagesToRollback(stageChain: List[Stage]): Unit = { + if (stagesToRollback.contains(stageChain.head)) { + stageChain.drop(1).foreach(s => stagesToRollback += s) + } else { + stageChain.head.parents.foreach { s => + collectStagesToRollback(s :: stageChain) + } + } + } + + def generateErrorMessage(stage: Stage): String = { + "A shuffle map stage with indeterminate output was failed and retried. " + + s"However, Spark cannot rollback the $stage to re-process the input data, " + + "and has to fail this job. Please eliminate the indeterminacy by " + + "checkpointing the RDD before repartition and try again." + } + + activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil)) + + stagesToRollback.foreach { + case mapStage: ShuffleMapStage => + val numMissingPartitions = mapStage.findMissingPartitions().length + if (numMissingPartitions < mapStage.numTasks) { + // TODO: support to rollback shuffle files. + // Currently the shuffle writing is "first write wins", so we can't re-run a + // shuffle map stage and overwrite existing shuffle files. We have to finish + // SPARK-8029 first. + abortStage(mapStage, generateErrorMessage(mapStage), None) + } + + case resultStage: ResultStage if resultStage.activeJob.isDefined => + val numMissingPartitions = resultStage.findMissingPartitions().length + if (numMissingPartitions < resultStage.numTasks) { + // TODO: support to rollback result tasks. + abortStage(resultStage, generateErrorMessage(resultStage), None) + } + + case _ => + } + } + // We expect one executor failure to trigger many FetchFailures in rapid succession, // but all of those task failures can typically be handled by a single resubmission of // the failed stage. We avoid flooding the scheduler's event queue with resubmit @@ -1367,7 +1412,6 @@ class DAGScheduler( } // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) } @@ -1416,17 +1460,7 @@ class DAGScheduler( if (filesLost || !env.blockManager.externalShuffleServiceEnabled) { logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch)) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleIdToMapStage) { - stage.removeOutputsOnExecutor(execId) - mapOutputTracker.registerMapOutputs( - shuffleId, - stage.outputLocInMapOutputTrackerFormat(), - changeEpoch = true) - } - if (shuffleIdToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() - } + mapOutputTracker.removeOutputsOnExecutor(execId) clearCacheLocs() } } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index db4d9efa2270..05f650fbf5df 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -19,9 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashSet -import org.apache.spark.ShuffleDependency +import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv} import org.apache.spark.rdd.RDD -import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.CallSite /** @@ -42,13 +41,12 @@ private[spark] class ShuffleMapStage( parents: List[Stage], firstJobId: Int, callSite: CallSite, - val shuffleDep: ShuffleDependency[_, _, _]) + val shuffleDep: ShuffleDependency[_, _, _], + mapOutputTrackerMaster: MapOutputTrackerMaster) extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { private[this] var _mapStageJobs: List[ActiveJob] = Nil - private[this] var _numAvailableOutputs: Int = 0 - /** * Partitions that either haven't yet been computed, or that were computed on an executor * that has since been lost, so should be re-computed. This variable is used by the @@ -60,13 +58,6 @@ private[spark] class ShuffleMapStage( */ val pendingPartitions = new HashSet[Int] - /** - * List of [[MapStatus]] for each partition. The index of the array is the map partition id, - * and each value in the array is the list of possible [[MapStatus]] for a partition - * (a single task might run multiple times). - */ - private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) - override def toString: String = "ShuffleMapStage " + id /** @@ -88,69 +79,18 @@ private[spark] class ShuffleMapStage( /** * Number of partitions that have shuffle outputs. * When this reaches [[numPartitions]], this map stage is ready. - * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`. */ - def numAvailableOutputs: Int = _numAvailableOutputs + def numAvailableOutputs: Int = mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId) /** * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs. - * This should be the same as `outputLocs.contains(Nil)`. */ - def isAvailable: Boolean = _numAvailableOutputs == numPartitions + def isAvailable: Boolean = numAvailableOutputs == numPartitions /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ override def findMissingPartitions(): Seq[Int] = { - val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty) - assert(missing.size == numPartitions - _numAvailableOutputs, - s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}") - missing - } - - def addOutputLoc(partition: Int, status: MapStatus): Unit = { - val prevList = outputLocs(partition) - outputLocs(partition) = status :: prevList - if (prevList == Nil) { - _numAvailableOutputs += 1 - } - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - _numAvailableOutputs -= 1 - } - } - - /** - * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned - * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition, - * that position is filled with null. - */ - def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = { - outputLocs.map(_.headOption.orNull) - } - - /** - * Removes all shuffle outputs associated with this executor. Note that this will also remove - * outputs which are served by an external shuffle server (if one exists), as they are still - * registered with this execId. - */ - def removeOutputsOnExecutor(execId: String): Unit = { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.location.executorId == execId) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - _numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format( - this, execId, _numAvailableOutputs, numPartitions, isAvailable)) - } + mapOutputTrackerMaster + .findMissingPartitions(shuffleDep.shuffleId) + .getOrElse(0 until numPartitions) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index f8c62b4bea83..bc0d4700bb76 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -129,7 +129,7 @@ private[spark] class TaskSchedulerImpl private[scheduler]( var backend: SchedulerBackend = null - val mapOutputTracker = SparkEnv.get.mapOutputTracker + val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ca94fd1560b3..82b6fd1a7469 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -138,21 +138,21 @@ class MapOutputTrackerSuite extends SparkFunSuite { slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) + // This is expected to fail because no outputs have been registered for the shuffle. intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0) === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) + val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) - masterTracker.incrementEpoch() + assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) slaveTracker.updateEpoch(masterTracker.getEpoch) intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 3b564df8a44b..62c40d13ad10 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -333,6 +333,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) + mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, @@ -367,7 +368,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // register one of the map outputs -- doesn't matter which one mapOutput1.foreach { case mapStatus => - mapTrackerMaster.registerMapOutputs(0, Array(mapStatus)) + mapTrackerMaster.registerMapOutput(0, 0, mapStatus) } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index 2b18ebee79a2..571c6bbb4585 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M sc = new SparkContext(conf) val scheduler = mock[TaskSchedulerImpl] when(scheduler.sc).thenReturn(sc) - when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker) + when(scheduler.mapOutputTracker).thenReturn( + SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]) scheduler } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 9112065c9f1f..1fff0d0a2f94 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} @@ -56,6 +56,20 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) } +class MyCheckpointRDD( + sc: SparkContext, + numPartitions: Int, + dependencies: List[Dependency[_]], + locations: Seq[Seq[String]] = Nil, + @(transient @param) tracker: MapOutputTrackerMaster = null, + indeterminate: Boolean = false) + extends MyRDD(sc, numPartitions, dependencies, locations, tracker, indeterminate) { + + // Allow doCheckpoint() on this RDD. + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + Iterator.empty +} + /** * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and * preferredLocations (if any) that are passed to them. They are deliberately not executable @@ -70,7 +84,8 @@ class MyRDD( numPartitions: Int, dependencies: List[Dependency[_]], locations: Seq[Seq[String]] = Nil, - @(transient @param) tracker: MapOutputTrackerMaster = null) + @(transient @param) tracker: MapOutputTrackerMaster = null, + indeterminate: Boolean = false) extends RDD[(Int, Int)](sc, dependencies) with Serializable { override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = @@ -80,6 +95,10 @@ class MyRDD( override def index: Int = i }).toArray + override protected def getOutputDeterministicLevel = { + if (indeterminate) DeterministicLevel.INDETERMINATE else super.getOutputDeterministicLevel + } + override def getPreferredLocations(partition: Partition): Seq[String] = { if (locations.isDefinedAt(partition.index)) { locations(partition.index) @@ -2307,6 +2326,152 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } } + test("SPARK-23207: retry all the succeeding stages when the map stage is indeterminate") { + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) + + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId1 = shuffleDep1.shuffleId + val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleId2 = shuffleDep2.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + // Finish the first shuffle map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + // Finish the second shuffle map stage. + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostC", 2)), + (Success, makeMapStatus("hostD", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // The first task of the final stage failed with fetch failure + runEvent(makeCompletionEvent( + taskSets(2).tasks(0), + FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"), + null)) + + val failedStages = scheduler.failedStages.toSeq + assert(failedStages.length == 2) + // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry. + assert(failedStages.collect { + case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage + }.head.findMissingPartitions() == Seq(0)) + // The result stage is still waiting for its 2 tasks to complete + assert(failedStages.collect { + case stage: ResultStage => stage + }.head.findMissingPartitions() == Seq(0, 1)) + + scheduler.resubmitFailedStages() + + // The first task of the `shuffleMapRdd2` failed with fetch failure + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"), + null)) + + // The job should fail because Spark can't rollback the shuffle map stage. + assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + } + + private def assertResultStageFailToRollback(mapRdd: MyRDD): Unit = { + val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // Finish the first task of the result stage + runEvent(makeCompletionEvent( + taskSets.last.tasks(0), Success, 42, + Seq.empty, createFakeTaskInfoWithId(0))) + + // Fail the second task with FetchFailed. + runEvent(makeCompletionEvent( + taskSets.last.tasks(1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + + // The job should fail because Spark can't rollback the result stage. + assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + } + + test("SPARK-23207: cannot rollback a result stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true) + assertResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-23207: local checkpoint fail to rollback (checkpointed before)") { + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.localCheckpoint() + shuffleMapRdd.doCheckpoint() + assertResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-23207: local checkpoint fail to rollback (checkpointing now)") { + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.localCheckpoint() + assertResultStageFailToRollback(shuffleMapRdd) + } + + private def assertResultStageNotRollbacked(mapRdd: MyRDD): Unit = { + val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // Finish the first task of the result stage + runEvent(makeCompletionEvent( + taskSets.last.tasks(0), Success, 42, + Seq.empty, createFakeTaskInfoWithId(0))) + + // Fail the second task with FetchFailed. + runEvent(makeCompletionEvent( + taskSets.last.tasks(1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + + assert(failure == null, "job should not fail") + val failedStages = scheduler.failedStages.toSeq + assert(failedStages.length == 2) + // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd2` needs to retry. + assert(failedStages.collect { + case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId => stage + }.head.findMissingPartitions() == Seq(0)) + // The first task of result stage remains completed. + assert(failedStages.collect { + case stage: ResultStage => stage + }.head.findMissingPartitions() == Seq(1)) + } + + test("SPARK-23207: reliable checkpoint can avoid rollback (checkpointed before)") { + sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.checkpoint() + shuffleMapRdd.doCheckpoint() + assertResultStageNotRollbacked(shuffleMapRdd) + } + + test("SPARK-23207: reliable checkpoint fail to rollback (checkpointing now)") { + sc.setCheckpointDir(Utils.createTempDir().getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil, indeterminate = true) + shuffleMapRdd.checkpoint() + assertResultStageFailToRollback(shuffleMapRdd) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index c0ba5135e5ed..4496afb4ed71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -247,6 +247,9 @@ object ShuffleExchange { case _ => sys.error(s"Exchange not implemented for $newPartitioning") } + val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] && + newPartitioning.numPartitions > 1 + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, // otherwise a retry task may output different rows and thus lead to data loss. @@ -256,9 +259,7 @@ object ShuffleExchange { // // Note that we don't perform local sort if the new partitioning has only 1 partition, under // that case all output rows go to the same partition. - val newRdd = if (SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION) && - newPartitioning.numPartitions > 1 && - newPartitioning.isInstanceOf[RoundRobinPartitioning]) { + val newRdd = if (isRoundRobin && SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION)) { rdd.mapPartitionsInternal { iter => val recordComparatorSupplier = new Supplier[RecordComparator] { override def get: RecordComparator = new RecordBinaryComparator() @@ -294,17 +295,19 @@ object ShuffleExchange { rdd } + // round-robin function is order sensitive if we don't sort the input. + val isOrderSensitive = isRoundRobin && !SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION) if (needToCopyObjectsBeforeShuffle(part, serializer)) { - newRdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsWithIndexInternal((_, iter) => { val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } - } + }, isOrderSensitive = isOrderSensitive) } else { - newRdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsWithIndexInternal((_, iter) => { val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } - } + }, isOrderSensitive = isOrderSensitive) } }