Skip to content

Commit 1a65122

Browse files
ivosoncloud-fan
authored andcommitted
[SPARK-54556][CORE] Rollback succeeding shuffle map stages when shuffle checksum mismatch detected
### What changes were proposed in this pull request? Rollback shuffle map stages when shuffle checksum mismatch detected: - cancel and resubmit the stage if it's running; - clean up the shuffle status to ensure it'll be resubmitted; - mark rollback attemptId and ignore the results from these elder attempts which may consume inconsistent data; ### Why are the changes needed? To ensure all the succeeding stages will be re-submitted and fully-retry when there is shuffle checksum mismatch detected. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT added. ### Was this patch authored or co-authored using generative AI tooling? No Closes #53274 from ivoson/SPARK-54556. Authored-by: Tengfei Huang <tengfei.h@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 0da9e05) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 41127db commit 1a65122

File tree

3 files changed

+313
-54
lines changed

3 files changed

+313
-54
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 170 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,42 +1560,27 @@ private[spark] class DAGScheduler(
15601560
// `findMissingPartitions()` returns all partitions every time.
15611561
stage match {
15621562
case sms: ShuffleMapStage if !sms.isAvailable =>
1563-
val needFullStageRetry = if (sms.shuffleDep.checksumMismatchFullRetryEnabled) {
1564-
// When the parents of this stage are indeterminate (e.g., some parents are not
1565-
// checkpointed and checksum mismatches are detected), the output data of the parents
1566-
// may have changed due to task retries. For correctness reason, we need to
1567-
// retry all tasks of the current stage. The legacy way of using current stage's
1568-
// deterministic level to trigger full stage retry is not accurate.
1569-
stage.isParentIndeterminate
1570-
} else {
1571-
if (stage.isIndeterminate) {
1572-
// already executed at least once
1573-
if (sms.getNextAttemptId > 0) {
1574-
// While we previously validated possible rollbacks during the handling of a FetchFailure,
1575-
// where we were fetching from an indeterminate source map stages, this later check
1576-
// covers additional cases like recalculating an indeterminate stage after an executor
1577-
// loss. Moreover, because this check occurs later in the process, if a result stage task
1578-
// has successfully completed, we can detect this and abort the job, as rolling back a
1579-
// result stage is not possible.
1580-
val stagesToRollback = collectSucceedingStages(sms)
1581-
abortStageWithInvalidRollBack(stagesToRollback)
1582-
// stages which cannot be rolled back were aborted which leads to removing the
1583-
// the dependant job(s) from the active jobs set
1584-
val numActiveJobsWithStageAfterRollback =
1585-
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
1586-
if (numActiveJobsWithStageAfterRollback == 0) {
1587-
logInfo(log"All jobs depending on the indeterminate stage " +
1588-
log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.")
1589-
return
1590-
}
1563+
if (!sms.shuffleDep.checksumMismatchFullRetryEnabled && stage.isIndeterminate) {
1564+
// already executed at least once
1565+
if (sms.getNextAttemptId > 0) {
1566+
// While we previously validated possible rollbacks during the handling of a FetchFailure,
1567+
// where we were fetching from an indeterminate source map stages, this later check
1568+
// covers additional cases like recalculating an indeterminate stage after an executor
1569+
// loss. Moreover, because this check occurs later in the process, if a result stage task
1570+
// has successfully completed, we can detect this and abort the job, as rolling back a
1571+
// result stage is not possible.
1572+
val stagesToRollback = collectSucceedingStages(sms)
1573+
filterAndAbortUnrollbackableStages(stagesToRollback)
1574+
// stages which cannot be rolled back were aborted which leads to removing the
1575+
// the dependant job(s) from the active jobs set
1576+
val numActiveJobsWithStageAfterRollback =
1577+
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
1578+
if (numActiveJobsWithStageAfterRollback == 0) {
1579+
logInfo(log"All jobs depending on the indeterminate stage " +
1580+
log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.")
1581+
return
15911582
}
1592-
true
1593-
} else {
1594-
false
15951583
}
1596-
}
1597-
1598-
if (needFullStageRetry) {
15991584
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
16001585
sms.shuffleDep.newShuffleMergeState()
16011586
}
@@ -1913,16 +1898,127 @@ private[spark] class DAGScheduler(
19131898

19141899
/**
19151900
* If a map stage is non-deterministic, the map tasks of the stage may return different result
1916-
* when re-try. To make sure data correctness, we need to re-try all the tasks of its succeeding
1917-
* stages, as the input data may be changed after the map tasks are re-tried. For stages where
1918-
* rollback and retry all tasks are not possible, we will need to abort the stages.
1901+
* when re-try. To make sure data correctness, we need to clean up shuffles to make sure succeeding
1902+
* stages will be resubmitted and re-try all the tasks, as the input data may be changed after
1903+
* the map tasks are re-tried. For stages where rollback and retry all tasks are not possible,
1904+
* we will need to abort the stages.
1905+
*/
1906+
private[scheduler] def rollbackSucceedingStages(mapStage: ShuffleMapStage): Unit = {
1907+
val stagesToRollback = collectSucceedingStages(mapStage).filterNot(_ == mapStage)
1908+
val stagesCanRollback = filterAndAbortUnrollbackableStages(stagesToRollback)
1909+
// stages which cannot be rolled back were aborted which leads to removing the
1910+
// the dependant job(s) from the active jobs set, there could be no active jobs
1911+
// left depending on the indeterminate stage and hence no need to roll back any stages.
1912+
val numActiveJobsWithStageAfterRollback =
1913+
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
1914+
if (numActiveJobsWithStageAfterRollback == 0) {
1915+
logInfo(log"All jobs depending on the indeterminate stage " +
1916+
log"(${MDC(STAGE_ID, mapStage.id)}) were aborted.")
1917+
} else {
1918+
// Mark rollback attempt to identify elder attempts which could consume inconsistent data,
1919+
// the results from these attempts should be ignored.
1920+
// Rollback the running stages first to avoid triggering more fetch failures.
1921+
stagesToRollback.toSeq.sortBy(!runningStages.contains(_)).foreach {
1922+
case sms: ShuffleMapStage =>
1923+
rollbackShuffleMapStage(sms, "rolling back due to indeterminate " +
1924+
s"output of shuffle map stage $mapStage")
1925+
sms.markAsRollingBack()
1926+
1927+
case rs: ResultStage =>
1928+
rs.markAsRollingBack()
1929+
}
1930+
1931+
logInfo(log"The shuffle map stage ${MDC(STAGE, mapStage)} with indeterminate output " +
1932+
log"was retried, we will roll back and rerun its succeeding " +
1933+
log"stages: ${MDC(STAGES, stagesCanRollback)}")
1934+
}
1935+
}
1936+
1937+
/**
1938+
* Roll back the given shuffle map stage:
1939+
* 1. If the stage is running, cancel the stage and kill all running tasks. Clean up the shuffle
1940+
* output resubmit it if it's not exceeded max retries.
1941+
* 2. If the stage is not running but having output generated, clean up the shuffle output to
1942+
* ensure the stage will be re-executed with fully retry.
1943+
*
1944+
* @param sms the shuffle map stage to roll back
1945+
* @param reason the reason for rolling back
1946+
*/
1947+
private def rollbackShuffleMapStage(sms: ShuffleMapStage, reason: String): Unit = {
1948+
logInfo(log"Rolling back ${MDC(STAGE, sms)} due to indeterminate rollback")
1949+
val clearShuffle = if (runningStages.contains(sms)) {
1950+
logInfo(log"Stage ${MDC(STAGE, sms)} is running, marking it as failed and " +
1951+
log"resubmit if allowed")
1952+
cancelStageAndTryResubmit(sms, reason)
1953+
} else {
1954+
true
1955+
}
1956+
1957+
// Clean up shuffle outputs in case the stage is not aborted to ensure the stage
1958+
// will be re-executed.
1959+
if (clearShuffle) {
1960+
logInfo(log"Cleaning up shuffle for stage ${MDC(STAGE, sms)} to ensure re-execution")
1961+
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
1962+
sms.shuffleDep.newShuffleMergeState()
1963+
}
1964+
}
1965+
1966+
/**
1967+
* Cancel the give running shuffle map stage, killing all running tasks, resubmit if it doesn't
1968+
* exceed max retries.
1969+
*
1970+
* @param stage the stage to cancel and resubmit
1971+
* @param reason the reason for the operation
1972+
* @return true if the stage is successfully cancelled and resubmitted, otherwise false
19191973
*/
1920-
private[scheduler] def abortUnrollbackableStages(mapStage: ShuffleMapStage): Unit = {
1921-
val stagesToRollback = collectSucceedingStages(mapStage)
1922-
val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback)
1923-
logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output " +
1924-
log"was failed, we will roll back and rerun below stages which include itself and all its " +
1925-
log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
1974+
private def cancelStageAndTryResubmit(stage: ShuffleMapStage, reason: String): Boolean = {
1975+
assert(runningStages.contains(stage), "stage must be running to be cancelled and resubmitted")
1976+
try {
1977+
// killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask.
1978+
val job = jobIdToActiveJob.get(stage.firstJobId)
1979+
val shouldInterrupt = job.exists(j => shouldInterruptTaskThread(j))
1980+
taskScheduler.killAllTaskAttempts(stage.id, shouldInterrupt, reason)
1981+
} catch {
1982+
case e: UnsupportedOperationException =>
1983+
logWarning(log"Could not kill all tasks for stage ${MDC(STAGE_ID, stage.id)}", e)
1984+
abortStage(stage, "Rollback failed due to: Not able to kill running tasks for stage " +
1985+
s"$stage (${stage.name})", Some(e))
1986+
return false
1987+
}
1988+
1989+
stage.failedAttemptIds.add(stage.latestInfo.attemptNumber())
1990+
val shouldAbortStage = stage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
1991+
disallowStageRetryForTest
1992+
markStageAsFinished(stage, Some(reason), willRetry = !shouldAbortStage)
1993+
1994+
if (shouldAbortStage) {
1995+
val abortMessage = if (disallowStageRetryForTest) {
1996+
"Stage will not retry stage due to testing config. Most recent failure " +
1997+
s"reason: $reason"
1998+
} else {
1999+
s"$stage (${stage.name}) has failed the maximum allowable number of " +
2000+
s"times: $maxConsecutiveStageAttempts. Most recent failure reason: $reason"
2001+
}
2002+
abortStage(stage, s"rollback failed due to: $abortMessage", None)
2003+
} else {
2004+
// In case multiple task failures triggered for a single stage attempt, ensure we only
2005+
// resubmit the failed stage once.
2006+
val noResubmitEnqueued = !failedStages.contains(stage)
2007+
failedStages += stage
2008+
if (noResubmitEnqueued) {
2009+
logInfo(log"Resubmitting ${MDC(FAILED_STAGE, stage)} " +
2010+
log"(${MDC(FAILED_STAGE_NAME, stage.name)}) due to rollback.")
2011+
messageScheduler.schedule(
2012+
new Runnable {
2013+
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
2014+
},
2015+
DAGScheduler.RESUBMIT_TIMEOUT,
2016+
TimeUnit.MILLISECONDS
2017+
)
2018+
}
2019+
}
2020+
2021+
!shouldAbortStage
19262022
}
19272023

19282024
/**
@@ -1990,7 +2086,21 @@ private[spark] class DAGScheduler(
19902086
// tasks complete, they still count and we can mark the corresponding partitions as
19912087
// finished if the stage is determinate. Here we notify the task scheduler to skip running
19922088
// tasks for the same partition to save resource.
1993-
if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()) {
2089+
def stageWithChecksumMismatchFullRetryEnabled(stage: Stage): Boolean = {
2090+
stage match {
2091+
case s: ShuffleMapStage => s.shuffleDep.checksumMismatchFullRetryEnabled
2092+
case _ => stage.parents.exists(stageWithChecksumMismatchFullRetryEnabled)
2093+
}
2094+
}
2095+
2096+
// Ignore task completion for old attempt of indeterminate stage
2097+
val ignoreOldTaskAttempts = if (stageWithChecksumMismatchFullRetryEnabled(stage)) {
2098+
stage.maxAttemptIdToIgnore.exists(_ >= task.stageAttemptId)
2099+
} else {
2100+
stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()
2101+
}
2102+
2103+
if (!ignoreOldTaskAttempts && task.stageAttemptId < stage.latestInfo.attemptNumber()) {
19942104
taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
19952105
}
19962106

@@ -2002,6 +2112,13 @@ private[spark] class DAGScheduler(
20022112
resultStage.activeJob match {
20032113
case Some(job) =>
20042114
if (!job.finished(rt.outputId)) {
2115+
if (ignoreOldTaskAttempts) {
2116+
val reason = "Task with indeterminate results from old attempt succeeded, " +
2117+
s"aborting the stage $resultStage to ensure data correctness."
2118+
abortStage(resultStage, reason, None)
2119+
return
2120+
}
2121+
20052122
job.finished(rt.outputId) = true
20062123
job.numFinished += 1
20072124
// If the whole job has finished, remove it
@@ -2045,10 +2162,7 @@ private[spark] class DAGScheduler(
20452162

20462163
case smt: ShuffleMapTask =>
20472164
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
2048-
// Ignore task completion for old attempt of indeterminate stage
2049-
val ignoreIndeterminate = stage.isIndeterminate &&
2050-
task.stageAttemptId < stage.latestInfo.attemptNumber()
2051-
if (!ignoreIndeterminate) {
2165+
if (!ignoreOldTaskAttempts) {
20522166
shuffleStage.pendingPartitions -= task.partitionId
20532167
val status = event.result.asInstanceOf[MapStatus]
20542168
val execId = status.location.executorId
@@ -2077,7 +2191,7 @@ private[spark] class DAGScheduler(
20772191
shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId
20782192
if (shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled
20792193
&& shuffleStage.isStageIndeterminate) {
2080-
abortUnrollbackableStages(shuffleStage)
2194+
rollbackSucceedingStages(shuffleStage)
20812195
}
20822196
}
20832197
}
@@ -2206,7 +2320,11 @@ private[spark] class DAGScheduler(
22062320
// guaranteed to be determinate, so the input data of the reducers will not change
22072321
// even if the map tasks are re-tried.
22082322
if (mapStage.isIndeterminate && !mapStage.shuffleDep.checksumMismatchFullRetryEnabled) {
2209-
abortUnrollbackableStages(mapStage)
2323+
val stagesToRollback = collectSucceedingStages(mapStage)
2324+
val stagesCanRollback = filterAndAbortUnrollbackableStages(stagesToRollback)
2325+
logInfo(log"The shuffle map stage ${MDC(STAGE, mapStage)} with indeterminate output " +
2326+
log"was failed, we will roll back and rerun below stages which include itself and all " +
2327+
log"its indeterminate child stages: ${MDC(STAGES, stagesCanRollback)}")
22102328
}
22112329

22122330
// We expect one executor failure to trigger many FetchFailures in rapid succession,
@@ -2396,7 +2514,8 @@ private[spark] class DAGScheduler(
23962514
* @param stagesToRollback stages to roll back
23972515
* @return Shuffle map stages which need and can be rolled back
23982516
*/
2399-
private def abortStageWithInvalidRollBack(stagesToRollback: HashSet[Stage]): HashSet[Stage] = {
2517+
private def filterAndAbortUnrollbackableStages(
2518+
stagesToRollback: HashSet[Stage]): HashSet[Stage] = {
24002519

24012520
def generateErrorMessage(stage: Stage): String = {
24022521
"A shuffle map stage with indeterminate output was failed and retried. " +

core/src/main/scala/org/apache/spark/scheduler/Stage.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ private[scheduler] abstract class Stage(
8484
*/
8585
private[scheduler] var maxChecksumMismatchedId: Int = nextAttemptId
8686

87+
/**
88+
* The max attempt id we should ignore results for this stage, indicating there are ancestor
89+
* stages having been detected with checksum mismatches. This stage is probably also
90+
* indeterminate, so we need to avoid completing the stage and the job with incorrect result
91+
* by ignoring the task output from previous attempts which might consume inconsistent data
92+
*/
93+
private[scheduler] var maxAttemptIdToIgnore: Option[Int] = None
94+
8795
val name: String = callSite.shortForm
8896
val details: String = callSite.longForm
8997

@@ -108,6 +116,14 @@ private[scheduler] abstract class Stage(
108116
failedAttemptIds.clear()
109117
}
110118

119+
/** Mark the latest attempt as rollback */
120+
private[scheduler] def markAsRollingBack(): Unit = {
121+
// Only if the stage has been submitted
122+
if (getNextAttemptId > 0) {
123+
maxAttemptIdToIgnore = Some(latestInfo.attemptNumber())
124+
}
125+
}
126+
111127
/** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */
112128
def makeNewStageAttempt(
113129
numPartitionsToCompute: Int,

0 commit comments

Comments
 (0)