Skip to content
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a9bf31f
wip
squito Jun 9, 2015
28d70aa
wip on getting a better test case ...
squito Jun 9, 2015
c443def
better fix and simpler test case
squito Jun 10, 2015
06a0af6
ignore for jenkins
squito Jun 10, 2015
6e14683
unit test just to make sure we fail fast on concurrent attempts
squito Jun 10, 2015
883fe49
Unit tests for concurrent stages issue
kayousterhout May 26, 2015
7021d28
update test since listenerBus.waitUntilEmpty now throws an exception …
squito Jun 10, 2015
55f4a94
get rid of more random test case since kays tests are clearer
squito Jun 10, 2015
b6bc248
style
squito Jun 10, 2015
ecb4e7d
debugging printlns
squito Jun 11, 2015
9601b47
more debug printlns
squito Jun 11, 2015
89a59b6
more printlns ...
squito Jun 11, 2015
8c29707
Merge branch 'master' into SPARK-8103
squito Jun 11, 2015
cb245da
finally found the issue ... clean up debug stuff
squito Jun 11, 2015
46bc26a
more cleanup of debug garbage
squito Jun 11, 2015
d8eb202
Merge branch 'master' into SPARK-8103
squito Jun 29, 2015
ada7726
reviewer feedback
squito Jul 1, 2015
6542b42
remove extra stageAttemptId
squito Jul 2, 2015
b2faef5
faster check for conflicting task sets
squito Jul 7, 2015
517b6e5
get rid of SparkIllegalStateException
squito Jul 7, 2015
227b40d
style
squito Jul 7, 2015
a5f7c8c
remove comment for reviewers
squito Jul 7, 2015
19685bb
switch to using latestInfo.attemptId, and add comments
squito Jul 7, 2015
baf46e1
Index active task sets by stage Id rather than by task set id
kayousterhout Jul 8, 2015
f025154
Merge pull request #2 from kayousterhout/imran_SPARK-8103
squito Jul 14, 2015
c0d4d90
Revert "Index active task sets by stage Id rather than by task set id"
squito Jul 14, 2015
109900e
Merge branch 'master' into SPARK-8103
squito Jul 14, 2015
906d626
fix merge
squito Jul 14, 2015
a21c8b5
Merge branch 'master' into SPARK-8103
squito Jul 14, 2015
d7f1ef2
get rid of activeTaskSets
squito Jul 15, 2015
88b61cc
add tests to make sure that TaskSchedulerImpl schedules correctly wit…
squito Jul 15, 2015
c04707e
style
squito Jul 15, 2015
4470fa1
rename
squito Jul 15, 2015
6bc23af
update log msg
squito Jul 17, 2015
e43ac25
Merge branch 'master' into SPARK-8103
squito Jul 17, 2015
584acd4
simplify going from taskId to taskSetMgr
squito Jul 17, 2015
e01b7aa
fix some comments, style
squito Jul 17, 2015
fb3acfc
fix log msg
squito Jul 17, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ class SparkException(message: String, cause: Throwable)
*/
private[spark] class SparkDriverExecutionException(cause: Throwable)
extends SparkException("Execution error", cause)

/**
* Exception indicating an error internal to Spark -- it is in an inconsistent state, not due
* to any error by the user
*/
class SparkIllegalStateException(message: String, cause: Throwable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the existing IllegalStateException (the non-spark specific one)? That's what done elsewhere in Spark.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping to create an exception which would make it clear that the user is totally free from blame. I find IllegalStateExceptions can be confusing -- did this condition result because the user violated some precondition, misused the api etc? Or is it a spark bug?

I felt this would be more clear, but its not critical.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to stick with the existing strategy, which is to use IllegalStateExceptions and introduce this new exception in a separate PR (my understanding is that we do currently use IllegalStateExceptions only for things that aren't the user's fault; if not, agree with your sentiment that it would be nice to clean that up, but not as part of this PR).

extends SparkException(message, cause) {

def this(message: String) = this(message, null)
}
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ private[spark] class TaskContextImpl(
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
val runningLocally: Boolean = false,
val stageAttemptId: Int = 0, // for testing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, good catch. this was used in earlier versions of tests before i switched to your test. thanks, I've updated it now.

val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
with Logging {
Expand Down
77 changes: 43 additions & 34 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,6 @@ class DAGScheduler(
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()


// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = {
stage match {
Expand Down Expand Up @@ -913,7 +912,7 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, taskBinary, part, locs)
new ShuffleMapTask(stage.id, stage.attemptId, taskBinary, part, locs)
}

case stage: ResultStage =>
Expand All @@ -922,7 +921,7 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
new ResultTask(stage.id, taskBinary, part, locs, id)
new ResultTask(stage.id, stage.attemptId, taskBinary, part, locs, id)
}
}
} catch {
Expand Down Expand Up @@ -1002,6 +1001,7 @@ class DAGScheduler(
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)

// REVIEWERS: does this need special handling for multiple completions of the same task?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look like it to me. The OutputCommitCoordinator should deny any completion from other than the currently authorized Stage, and should only authorize one Task per partition to commit per Stage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great, thanks mark. I'll leave the comment in for now just to get any other reviewers to take a look too. Think this deserve some special handling in the test cases? (and if so, any suggestions / pointers on what to do?) I'm not very familiar w/ the OutputCommitCoordinator so guidance is appreciated.

outputCommitCoordinator.taskCompleted(stageId, task.partitionId,
event.taskInfo.attempt, event.reason)

Expand Down Expand Up @@ -1064,10 +1064,11 @@ class DAGScheduler(
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
shuffleStage.addOutputLoc(smt.partitionId, status)
}

if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
Expand Down Expand Up @@ -1126,39 +1127,47 @@ class DAGScheduler(
case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
if (failedStage.attemptId - 1 > task.stageAttemptId) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the intuition here (I was expecting "if(failedStage.attemptId != task.stageAttemptId)")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stage attempt ids get incremented when a task set is submitted (see here). So its normal for the attempt id of tasks to be one behind stage.attemptId. This magically still lets all spark listener events have the right id, b/c we store stage.latestInfo before the attemptId gets incremented.

Honestly that logic is pretty weird to me -- i could change that instead, but I thought maybe better to not mess with that now. I suppose the test could be failedStage.attemptId - 1 != task.stageAttemptId. Maybe it would actually make the most sense to change to failedStage.latestInfo.attemptId != task.stageAttemptId, with a comment explaining why failedStage.attemptId can't be used (and maybe add a comment to Stage.attemptId itself)?

logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId}, which has already failed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this to s" ${task.stageAttemptId}, and there is a more recent attempt for that stage (attempt ID ${failedStage.latestInfo.attemptId}) running"? It's a little misleading as-is, because we actually do still do something with the fetch failure if the stage has already been marked as failed, as long as there's not already a newer version of the stage running.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good idea. done

} else {

// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is possible
// the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
}
// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is
// possible the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
} else {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage, " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like, even when the failed stage is no longer running, we don't fully ignore the failure, because we remove the map id from the set of output locations (on line 1164)? As a result, I'm not sure this log message makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good point ... perhaps we should just get rid of this msg, I'm not sure if there is anything we could say here that is succint, accurate, and still useful. I think I added this msg during some early debugging. Do you think its worth logging "Received fetch failure from $task, but its from $failedStage which is no longer running"? Maybe a logDebug?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing a logDebug seems fine if you think it will be useful! (I never turn on debug level logging just because Spark's debug logging is sooo verbose, but I can imagine that others probably do turn it on!)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this log statement is still lingering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

woops, sorry I forgot about this one! thanks, changed to the logDebug

s"which is no longer running")
}

if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}
if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}

// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
}

case commitDenied: TaskCommitDenied =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
*/
private[spark] class ResultTask[T, U](
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
extends Task[U](stageId, partition.index) with Serializable {
extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
*/
private[spark] class ShuffleMapTask(
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, partition.index) with Logging {
extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {

/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
this(0, null, new Partition { override def index: Int = 0 }, null)
this(0, 0, null, new Partition { override def index: Int = 0 }, null)
}

@transient private val preferredLocs: Seq[TaskLocation] = {
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
* @param stageId id of the stage this task belongs to
* @param partitionId index of the number in the RDD
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
var partitionId: Int) extends Serializable {

/**
* Called by [[Executor]] to run this task.
Expand All @@ -55,6 +58,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
context = new TaskContextImpl(
stageId = stageId,
stageAttemptId = stageAttemptId,
partitionId = partitionId,
taskAttemptId = taskAttemptId,
attemptNumber = attemptNumber,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ private[spark] class TaskSchedulerImpl(
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
val taskSetsPerStage = activeTaskSets.values.filterNot(_.isZombie).groupBy(_.stageId)
taskSetsPerStage.foreach { case (stage, taskSets) =>
if (taskSets.size > 1) {
throw new SparkIllegalStateException("more than one active taskSet for stage " + stage)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kayousterhout How much of a concern should the extra overhead be here? Just wondering whether this (let's hope rare) condition might better be handled only in a non-production environment and behind an if(debug) kind of flag.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the code could just look for an existing task set that matches the stage ID of the task set being added? That should be a little better than the filter / groupBy. Something like:

activeTaskSets.exists { case (_, ts) => ts.stageId == taskSet.stageId }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restoring the comments from the old diff b/c they are still relevant:

from mark:

@kayousterhout How much of a concern should the extra overhead be here? Just wondering whether this (let's hope rare) condition might better be handled only in a non-production environment and behind an if(debug) kind of flag.

from marcelo:

Perhaps the code could just look for an existing task set that matches the stage ID of the task set being added? That should be a little better than the filter / groupBy.

good point, there isn't any need to do the groupBy, so I've made it simpler.

I'd really rather leave the check in place. In fact I think this fail-fast behavior is especially important in a production environment -- that's much better than an infinite loop of failures.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I totally agree that it is valuable to catch the failure in production. I'm only suggesting that at some point the check becomes a big enough performance hit that it makes sense to compromise on the fail-fast desiderata in order to maintain production performance while trying to ensure in development that the failure can never occur. I doubt that this check is that costly, but my expectation is that Kay has a better sense of how much more we can afford to do within this synchronized block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you change activeTaskSets to be keyed on the stageId? That seems pretty easy to do and would make this check O(1) rather than O(N)

schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

if (!isLocal && !hasReceivedTask) {
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, 0, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}

Expand Down
137 changes: 137 additions & 0 deletions core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,15 @@ class DAGSchedulerSuite
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
val sparkListener = new SparkListener() {
val submittedStageInfos = new HashSet[StageInfo]
val successfulStages = new HashSet[Int]
val failedStages = new ArrayBuffer[Int]
val stageByOrderOfExecution = new ArrayBuffer[Int]

override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
submittedStageInfos += stageSubmitted.stageInfo
}

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
val stageInfo = stageCompleted.stageInfo
stageByOrderOfExecution += stageInfo.stageId
Expand Down Expand Up @@ -150,6 +156,7 @@ class DAGSchedulerSuite
// Enable local execution for this test
val conf = new SparkConf().set("spark.localExecution.enabled", "true")
sc = new SparkContext("local", "DAGSchedulerSuite", conf)
sparkListener.submittedStageInfos.clear()
sparkListener.successfulStages.clear()
sparkListener.failedStages.clear()
failure = null
Expand Down Expand Up @@ -547,6 +554,136 @@ class DAGSchedulerSuite
assert(sparkListener.failedStages.size == 1)
}

/** This tests the case where another FetchFailed comes in while the map stage is getting
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we generally use javadoc-style (even though this is not a method declaration):

/**
 * Comment.
 */

* re-run. */
test("late fetch failures don't cause multiple concurrent attempts for the same map stage") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
submit(reduceRdd, Array(0, 1))

val mapStageId = 0
def countSubmittedMapStageAttempts(): Int = {
sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
}

// The map stage should have been submitted.
assert(countSubmittedMapStageAttempts() === 1)

complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))))
// The MapOutputTracker should know about both map output locations.
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
Array("hostA", "hostB"))
assert(mapOutputTracker.getServerStatuses(shuffleId, 1).map(_._1.host) ===
Array("hostA", "hostB"))

// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(CompletionEvent(
taskSets(1).tasks(0),
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
null,
Map[Long, Any](),
createFakeTaskInfo(),
null))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(sparkListener.failedStages.contains(1))

// Trigger resubmission of the failed map stage.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)

// Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
assert(countSubmittedMapStageAttempts() === 2)

// The second ResultTask fails, with a fetch failure for the output from the second mapper.
runEvent(CompletionEvent(
taskSets(1).tasks(1),
FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
null,
Map[Long, Any](),
createFakeTaskInfo(),
null))

// Another ResubmitFailedStages event should not result result in another attempt for the map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: result result

// stage being run concurrently.
runEvent(ResubmitFailedStages)
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)

// NOTE: the actual ResubmitFailedStages may get called at any time during this, shouldn't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this comment be right before L612?

// effect anything -- our calling it just makes *SURE* it gets called between the desired event
// and our check.
}

/** This tests the case where a late FetchFailed comes in after the map stage has finished getting
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to javadoc style

* retried and a new reduce stage starts running.
*/
test("extremely late fetch failures don't cause multiple concurrent attempts for " +
"the same stage") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
submit(reduceRdd, Array(0, 1))

def countSubmittedReduceStageAttempts(): Int = {
sparkListener.submittedStageInfos.count(_.stageId == 1)
}
def countSubmittedMapStageAttempts(): Int = {
sparkListener.submittedStageInfos.count(_.stageId == 0)
}

// The map stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 1)

// Complete the map stage.
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))))

// The reduce stage should have been submitted.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedReduceStageAttempts() === 1)

// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(CompletionEvent(
taskSets(1).tasks(0),
FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
null,
Map[Long, Any](),
createFakeTaskInfo(),
null))

// Trigger resubmission of the failed map stage and finish the re-started map task.
runEvent(ResubmitFailedStages)
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))

// Because the map stage finished, another attempt for the reduce stage should have been
// submitted, resulting in 2 total attempts for each the map and the reduce stage.
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(countSubmittedMapStageAttempts() === 2)
assert(countSubmittedReduceStageAttempts() === 2)

// A late FetchFailed arrives from the second task in the original reduce stage.
runEvent(CompletionEvent(
taskSets(1).tasks(1),
FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
null,
Map[Long, Any](),
createFakeTaskInfo(),
null))

// Trigger resubmission of the failed map stage and finish the re-started map task.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment is wrong, and should say something like "Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because the FetchFailed should have been ignored"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

github is being weird ... I think this comment was originally on line 683 below (which I've updated), but the comment was originally duplicated. Lemme know if I've messed up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh weird... yeah your understanding is correct!

runEvent(ResubmitFailedStages)

// The FetchFailed from the original reduce stage should be ignored.
assert(countSubmittedMapStageAttempts() === 2)
}

test("ignore late map task completions") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.scheduler

import org.apache.spark.TaskContext

class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
override def runTask(context: TaskContext): Int = 0

override def preferredLocations: Seq[TaskLocation] = prefLocs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
extends Task[Array[Byte]](stageId, 0) {
extends Task[Array[Byte]](stageId, 0, 0) {

override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
Expand Down
Loading