Skip to content

Commit 3476390

Browse files
committed
[SPARK-20715] Store MapStatuses only in MapOutputTracker, not ShuffleMapStage
## What changes were proposed in this pull request? This PR refactors `ShuffleMapStage` and `MapOutputTracker` in order to simplify the management of `MapStatuses`, reduce driver memory consumption, and remove a potential source of scheduler correctness bugs. ### Background In Spark there are currently two places where MapStatuses are tracked: - The `MapOutputTracker` maintains an `Array[MapStatus]` storing a single location for each map output. This mapping is used by the `DAGScheduler` for determining reduce-task locality preferences (when locality-aware reduce task scheduling is enabled) and is also used to serve map output locations to executors / tasks. - Each `ShuffleMapStage` also contains a mapping of `Array[List[MapStatus]]` which holds the complete set of locations where each map output could be available. This mapping is used to determine which map tasks need to be run when constructing `TaskSets` for the stage. This duplication adds complexity and creates the potential for certain types of correctness bugs. Bad things can happen if these two copies of the map output locations get out of sync. For instance, if the `MapOutputTracker` is missing locations for a map output but `ShuffleMapStage` believes that locations are available then tasks will fail with `MetadataFetchFailedException` but `ShuffleMapStage` will not be updated to reflect the missing map outputs, leading to situations where the stage will be reattempted (because downstream stages experienced fetch failures) but no task sets will be launched (because `ShuffleMapStage` thinks all maps are available). I observed this behavior in a real-world deployment. I'm still not quite sure how the state got out of sync in the first place, but we can completely avoid this class of bug if we eliminate the duplicate state. ### Why we only need to track a single location for each map output I think that storing an `Array[List[MapStatus]]` in `ShuffleMapStage` is unnecessary. First, note that this adds memory/object bloat to the driver we need one extra `List` per task. If you have millions of tasks across all stages then this can add up to be a significant amount of resources. Secondly, I believe that it's extremely uncommon that these lists will ever contain more than one entry. It's not impossible, but is very unlikely given the conditions which must occur for that to happen: - In normal operation (no task failures) we'll only run each task once and thus will have at most one output. - If speculation is enabled then it's possible that we'll have multiple attempts of a task. The TaskSetManager will [kill duplicate attempts of a task](https://github.com/apache/spark/blob/04901dd03a3f8062fd39ea38d585935ff71a9248/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L717) after a task finishes successfully, reducing the likelihood that both the original and speculated task will successfully register map outputs. - There is a [comment in `TaskSetManager`](https://github.com/apache/spark/blob/04901dd03a3f8062fd39ea38d585935ff71a9248/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L113) which suggests that running tasks are not killed if a task set becomes a zombie. However: - If the task set becomes a zombie due to the job being cancelled then it doesn't matter whether we record map outputs. - If the task set became a zombie because of a stage failure (e.g. the map stage itself had a fetch failure from an upstream match stage) then I believe that the "failedEpoch" will be updated which may cause map outputs from still-running tasks to [be ignored](https://github.com/apache/spark/blob/04901dd03a3f8062fd39ea38d585935ff71a9248/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L1213). (I'm not 100% sure on this point, though). - Even if you _do_ manage to record multiple map outputs for a stage, only a single map output is reported to / tracked by the MapOutputTracker. The only situation where the additional output locations could actually be read or used would be if a task experienced a `FetchFailure` exception. The most likely cause of a `FetchFailure` exception is an executor lost, which will have most likely caused the loss of several map tasks' output, so saving on potential re-execution of a single map task isn't a huge win if we're going to have to recompute several other lost map outputs from other tasks which ran on that lost executor. Also note that the re-population of MapOutputTracker state from state in the ShuffleMapTask only happens after the reduce stage has failed; the additional location doesn't help to prevent FetchFailures but, instead, can only reduce the amount of work when recomputing missing parent stages. Given this, this patch chooses to do away with tracking multiple locations for map outputs and instead stores only a single location. This change removes the main distinction between the `ShuffleMapTask` and `MapOutputTracker`'s copies of this state, paving the way for storing it only in the `MapOutputTracker`. ### Overview of other changes - Significantly simplified the cache / lock management inside of the `MapOutputTrackerMaster`: - The old code had several parallel `HashMap`s which had to be guarded by maps of `Object`s which were used as locks. This code was somewhat complicated to follow. - The new code uses a new `ShuffleStatus` class to group together all of the state associated with a particular shuffle, including cached serialized map statuses, significantly simplifying the logic. - Moved more code out of the shared `MapOutputTracker` abstract base class and into the `MapOutputTrackerMaster` and `MapOutputTrackerWorker` subclasses. This makes it easier to reason about which functionality needs to be supported only on the driver or executor. - Removed a bunch of code from the `DAGScheduler` which was used to synchronize information from the `MapOutputTracker` to `ShuffleMapStage`. - Added comments to clarify the role of `MapOutputTrackerMaster`'s `epoch` in invalidating executor-side shuffle map output caches. I will comment on these changes via inline GitHub review comments. /cc hvanhovell and rxin (whom I discussed this with offline), tgravescs (who recently worked on caching of serialized MapOutputStatuses), and kayousterhout and markhamstra (for scheduler changes). ## How was this patch tested? Existing tests. I purposely avoided making interface / API which would require significant updates or modifications to test code. Author: Josh Rosen <joshrosen@databricks.com> Closes #17955 from JoshRosen/map-output-tracker-rewrite.
1 parent f48273c commit 3476390

8 files changed

Lines changed: 398 additions & 389 deletions

File tree

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 360 additions & 276 deletions
Large diffs are not rendered by default.

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,14 @@ private[spark] class Executor(
322322
throw new TaskKilledException(killReason.get)
323323
}
324324

325-
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
326-
env.mapOutputTracker.updateEpoch(task.epoch)
325+
// The purpose of updating the epoch here is to invalidate executor map output status cache
326+
// in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
327+
// MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
328+
// we don't need to make any special calls here.
329+
if (!isLocal) {
330+
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
331+
env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
332+
}
327333

328334
// Run the actual task and measure its runtime.
329335
taskStart = System.currentTimeMillis()

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 14 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -328,25 +328,14 @@ class DAGScheduler(
328328
val numTasks = rdd.partitions.length
329329
val parents = getOrCreateParentStages(rdd, jobId)
330330
val id = nextStageId.getAndIncrement()
331-
val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep)
331+
val stage = new ShuffleMapStage(
332+
id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker)
332333

333334
stageIdToStage(id) = stage
334335
shuffleIdToMapStage(shuffleDep.shuffleId) = stage
335336
updateJobIdStageIdMaps(jobId, stage)
336337

337-
if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
338-
// A previously run stage generated partitions for this shuffle, so for each output
339-
// that's still available, copy information about that output location to the new stage
340-
// (so we don't unnecessarily re-compute that data).
341-
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
342-
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
343-
(0 until locs.length).foreach { i =>
344-
if (locs(i) ne null) {
345-
// locs(i) will be null if missing
346-
stage.addOutputLoc(i, locs(i))
347-
}
348-
}
349-
} else {
338+
if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
350339
// Kind of ugly: need to register RDDs with the cache and map output tracker here
351340
// since we can't do it in the RDD constructor because # of partitions is unknown
352341
logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
@@ -1217,7 +1206,8 @@ class DAGScheduler(
12171206
// The epoch of the task is acceptable (i.e., the task was launched after the most
12181207
// recent failure we're aware of for the executor), so mark the task's output as
12191208
// available.
1220-
shuffleStage.addOutputLoc(smt.partitionId, status)
1209+
mapOutputTracker.registerMapOutput(
1210+
shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
12211211
// Remove the task's partition from pending partitions. This may have already been
12221212
// done above, but will not have been done yet in cases where the task attempt was
12231213
// from an earlier attempt of the stage (i.e., not the attempt that's currently
@@ -1234,16 +1224,14 @@ class DAGScheduler(
12341224
logInfo("waiting: " + waitingStages)
12351225
logInfo("failed: " + failedStages)
12361226

1237-
// We supply true to increment the epoch number here in case this is a
1238-
// recomputation of the map outputs. In that case, some nodes may have cached
1239-
// locations with holes (from when we detected the error) and will need the
1240-
// epoch incremented to refetch them.
1241-
// TODO: Only increment the epoch number if this is not the first time
1242-
// we registered these map outputs.
1243-
mapOutputTracker.registerMapOutputs(
1244-
shuffleStage.shuffleDep.shuffleId,
1245-
shuffleStage.outputLocInMapOutputTrackerFormat(),
1246-
changeEpoch = true)
1227+
// This call to increment the epoch may not be strictly necessary, but it is retained
1228+
// for now in order to minimize the changes in behavior from an earlier version of the
1229+
// code. This existing behavior of always incrementing the epoch following any
1230+
// successful shuffle map stage completion may have benefits by causing unneeded
1231+
// cached map outputs to be cleaned up earlier on executors. In the future we can
1232+
// consider removing this call, but this will require some extra investigation.
1233+
// See https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
1234+
mapOutputTracker.incrementEpoch()
12471235

12481236
clearCacheLocs()
12491237

@@ -1343,7 +1331,6 @@ class DAGScheduler(
13431331
}
13441332
// Mark the map whose fetch failed as broken in the map stage
13451333
if (mapId != -1) {
1346-
mapStage.removeOutputLoc(mapId, bmAddress)
13471334
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
13481335
}
13491336

@@ -1393,17 +1380,7 @@ class DAGScheduler(
13931380

13941381
if (filesLost || !env.blockManager.externalShuffleServiceEnabled) {
13951382
logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
1396-
// TODO: This will be really slow if we keep accumulating shuffle map stages
1397-
for ((shuffleId, stage) <- shuffleIdToMapStage) {
1398-
stage.removeOutputsOnExecutor(execId)
1399-
mapOutputTracker.registerMapOutputs(
1400-
shuffleId,
1401-
stage.outputLocInMapOutputTrackerFormat(),
1402-
changeEpoch = true)
1403-
}
1404-
if (shuffleIdToMapStage.isEmpty) {
1405-
mapOutputTracker.incrementEpoch()
1406-
}
1383+
mapOutputTracker.removeOutputsOnExecutor(execId)
14071384
clearCacheLocs()
14081385
}
14091386
} else {

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala

Lines changed: 8 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@ package org.apache.spark.scheduler
1919

2020
import scala.collection.mutable.HashSet
2121

22-
import org.apache.spark.ShuffleDependency
22+
import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv}
2323
import org.apache.spark.rdd.RDD
24-
import org.apache.spark.storage.BlockManagerId
2524
import org.apache.spark.util.CallSite
2625

2726
/**
@@ -42,13 +41,12 @@ private[spark] class ShuffleMapStage(
4241
parents: List[Stage],
4342
firstJobId: Int,
4443
callSite: CallSite,
45-
val shuffleDep: ShuffleDependency[_, _, _])
44+
val shuffleDep: ShuffleDependency[_, _, _],
45+
mapOutputTrackerMaster: MapOutputTrackerMaster)
4646
extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) {
4747

4848
private[this] var _mapStageJobs: List[ActiveJob] = Nil
4949

50-
private[this] var _numAvailableOutputs: Int = 0
51-
5250
/**
5351
* Partitions that either haven't yet been computed, or that were computed on an executor
5452
* that has since been lost, so should be re-computed. This variable is used by the
@@ -60,13 +58,6 @@ private[spark] class ShuffleMapStage(
6058
*/
6159
val pendingPartitions = new HashSet[Int]
6260

63-
/**
64-
* List of [[MapStatus]] for each partition. The index of the array is the map partition id,
65-
* and each value in the array is the list of possible [[MapStatus]] for a partition
66-
* (a single task might run multiple times).
67-
*/
68-
private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
69-
7061
override def toString: String = "ShuffleMapStage " + id
7162

7263
/**
@@ -88,69 +79,18 @@ private[spark] class ShuffleMapStage(
8879
/**
8980
* Number of partitions that have shuffle outputs.
9081
* When this reaches [[numPartitions]], this map stage is ready.
91-
* This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
9282
*/
93-
def numAvailableOutputs: Int = _numAvailableOutputs
83+
def numAvailableOutputs: Int = mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId)
9484

9585
/**
9686
* Returns true if the map stage is ready, i.e. all partitions have shuffle outputs.
97-
* This should be the same as `outputLocs.contains(Nil)`.
9887
*/
99-
def isAvailable: Boolean = _numAvailableOutputs == numPartitions
88+
def isAvailable: Boolean = numAvailableOutputs == numPartitions
10089

10190
/** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
10291
override def findMissingPartitions(): Seq[Int] = {
103-
val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
104-
assert(missing.size == numPartitions - _numAvailableOutputs,
105-
s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
106-
missing
107-
}
108-
109-
def addOutputLoc(partition: Int, status: MapStatus): Unit = {
110-
val prevList = outputLocs(partition)
111-
outputLocs(partition) = status :: prevList
112-
if (prevList == Nil) {
113-
_numAvailableOutputs += 1
114-
}
115-
}
116-
117-
def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = {
118-
val prevList = outputLocs(partition)
119-
val newList = prevList.filterNot(_.location == bmAddress)
120-
outputLocs(partition) = newList
121-
if (prevList != Nil && newList == Nil) {
122-
_numAvailableOutputs -= 1
123-
}
124-
}
125-
126-
/**
127-
* Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned
128-
* value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition,
129-
* that position is filled with null.
130-
*/
131-
def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = {
132-
outputLocs.map(_.headOption.orNull)
133-
}
134-
135-
/**
136-
* Removes all shuffle outputs associated with this executor. Note that this will also remove
137-
* outputs which are served by an external shuffle server (if one exists), as they are still
138-
* registered with this execId.
139-
*/
140-
def removeOutputsOnExecutor(execId: String): Unit = {
141-
var becameUnavailable = false
142-
for (partition <- 0 until numPartitions) {
143-
val prevList = outputLocs(partition)
144-
val newList = prevList.filterNot(_.location.executorId == execId)
145-
outputLocs(partition) = newList
146-
if (prevList != Nil && newList == Nil) {
147-
becameUnavailable = true
148-
_numAvailableOutputs -= 1
149-
}
150-
}
151-
if (becameUnavailable) {
152-
logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
153-
this, execId, _numAvailableOutputs, numPartitions, isAvailable))
154-
}
92+
mapOutputTrackerMaster
93+
.findMissingPartitions(shuffleDep.shuffleId)
94+
.getOrElse(0 until numPartitions)
15595
}
15696
}

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
129129

130130
var backend: SchedulerBackend = null
131131

132-
val mapOutputTracker = SparkEnv.get.mapOutputTracker
132+
val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
133133

134134
private var schedulableBuilder: SchedulableBuilder = null
135135
// default scheduler is FIFO

core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,21 +139,21 @@ class MapOutputTrackerSuite extends SparkFunSuite {
139139
slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
140140

141141
masterTracker.registerShuffle(10, 1)
142-
masterTracker.incrementEpoch()
143142
slaveTracker.updateEpoch(masterTracker.getEpoch)
143+
// This is expected to fail because no outputs have been registered for the shuffle.
144144
intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
145145

146146
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
147147
masterTracker.registerMapOutput(10, 0, MapStatus(
148148
BlockManagerId("a", "hostA", 1000), Array(1000L)))
149-
masterTracker.incrementEpoch()
150149
slaveTracker.updateEpoch(masterTracker.getEpoch)
151150
assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
152151
Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
153152
assert(0 == masterTracker.getNumCachedSerializedBroadcast)
154153

154+
val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch
155155
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
156-
masterTracker.incrementEpoch()
156+
assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput)
157157
slaveTracker.updateEpoch(masterTracker.getEpoch)
158158
intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
159159

core/src/test/scala/org/apache/spark/ShuffleSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
359359
val shuffleMapRdd = new MyRDD(sc, 1, Nil)
360360
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
361361
val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
362+
mapTrackerMaster.registerShuffle(0, 1)
362363

363364
// first attempt -- its successful
364365
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
@@ -393,7 +394,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
393394

394395
// register one of the map outputs -- doesn't matter which one
395396
mapOutput1.foreach { case mapStatus =>
396-
mapTrackerMaster.registerMapOutputs(0, Array(mapStatus))
397+
mapTrackerMaster.registerMapOutput(0, 0, mapStatus)
397398
}
398399

399400
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,

core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
8686
sc = new SparkContext(conf)
8787
val scheduler = mock[TaskSchedulerImpl]
8888
when(scheduler.sc).thenReturn(sc)
89-
when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker)
89+
when(scheduler.mapOutputTracker).thenReturn(
90+
SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])
9091
scheduler
9192
}
9293

0 commit comments

Comments
 (0)