Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false,
val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor,
val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS)
val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS,
val checksumMismatchFullRetryEnabled: Boolean = false)
extends Dependency[Product2[K, V]] with Logging {

def this(
Expand Down
10 changes: 7 additions & 3 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,11 @@ private class ShuffleStatus(

/**
* Register a map output. If there is already a registered location for the map output then it
* will be replaced by the new location.
* will be replaced by the new location. Returns true if the checksum in the new MapStatus is
* different from a previous registered MapStatus. Otherwise, returns false.
*/
def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
def addMapOutput(mapIndex: Int, status: MapStatus): Boolean = withWriteLock {
var isChecksumMismatch: Boolean = false
val currentMapStatus = mapStatuses(mapIndex)
if (currentMapStatus == null) {
_numAvailableMapOutputs += 1
Expand All @@ -183,9 +185,11 @@ private class ShuffleStatus(
logInfo(s"Checksum of map output changes from ${preStatus.checksumValue} to " +
s"${status.checksumValue} for task ${status.mapId}.")
checksumMismatchIndices.add(mapIndex)
isChecksumMismatch = true
}
mapStatuses(mapIndex) = status
mapIdToMapIndex(status.mapId) = mapIndex
isChecksumMismatch
}

/**
Expand Down Expand Up @@ -853,7 +857,7 @@ private[spark] class MapOutputTrackerMaster(
}
}

def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Boolean = {
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1773,7 +1773,7 @@ abstract class RDD[T: ClassTag](
/**
* Return whether this RDD is reliably checkpointed and materialized.
*/
private[rdd] def isReliablyCheckpointed: Boolean = {
private[spark] def isReliablyCheckpointed: Boolean = {
checkpointData match {
case Some(reliable: ReliableRDDCheckpointData[_]) if reliable.isCheckpointed => true
case _ => false
Expand Down
101 changes: 73 additions & 28 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1551,29 +1551,46 @@ private[spark] class DAGScheduler(
// The operation here can make sure for the partially completed intermediate stage,
// `findMissingPartitions()` returns all partitions every time.
stage match {
case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
// already executed at least once
if (sms.getNextAttemptId > 0) {
// While we previously validated possible rollbacks during the handling of a FetchFailure,
// where we were fetching from an indeterminate source map stages, this later check
// covers additional cases like recalculating an indeterminate stage after an executor
// loss. Moreover, because this check occurs later in the process, if a result stage task
// has successfully completed, we can detect this and abort the job, as rolling back a
// result stage is not possible.
val stagesToRollback = collectSucceedingStages(sms)
abortStageWithInvalidRollBack(stagesToRollback)
// stages which cannot be rolled back were aborted which leads to removing the
// the dependant job(s) from the active jobs set
val numActiveJobsWithStageAfterRollback =
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
if (numActiveJobsWithStageAfterRollback == 0) {
logInfo(log"All jobs depending on the indeterminate stage " +
log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.")
return
case sms: ShuffleMapStage if !sms.isAvailable =>
val needFullStageRetry = if (sms.shuffleDep.checksumMismatchFullRetryEnabled) {
Copy link
Contributor

@mridulm mridulm Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching up on PR's I missed out on reviewing.

This negatively interacts if there is push based shuffle enabled.
The condition should be sms.shuffleDep.checksumMismatchFullRetryEnabled && !pushBasedShuffleEnabled

+CC @ivoson

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mridulm can you please explain more about the issue with push based shuffle? Thanks.

Copy link
Contributor

@mridulm mridulm Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With push based shuffle enabled - a mappers output would also be pushed to mergers to create a reducer oriented view (all mappers write to a single merger for a given reducer).
If a subset of mapper tasks are now getting reexecuted - the merged output would get impacted as they have already been finalized when the previous attempt completed : causing a disconnect between the mapper output from the new attempt, and merged output from previous attempt.

Essentially, for indeterminate stages, the entire reducer oriented view is unusable - and needs to be recomputed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mridulm to recompute the indeterminate stages, we'll clean up all the shuffle outputs and shuffle merge state for push-based shuffle. Would that resolve your concern regarding to push-based shuffle?

mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()

mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()

// When the parents of this stage are indeterminate (e.g., some parents are not
// checkpointed and checksum mismatches are detected), the output data of the parents
// may have changed due to task retries. For correctness reason, we need to
// retry all tasks of the current stage. The legacy way of using current stage's
// deterministic level to trigger full stage retry is not accurate.
stage.isParentIndeterminate
} else {
if (stage.isIndeterminate) {
// already executed at least once
if (sms.getNextAttemptId > 0) {
// While we previously validated possible rollbacks during the handling of a FetchFailure,
// where we were fetching from an indeterminate source map stages, this later check
// covers additional cases like recalculating an indeterminate stage after an executor
// loss. Moreover, because this check occurs later in the process, if a result stage task
// has successfully completed, we can detect this and abort the job, as rolling back a
// result stage is not possible.
val stagesToRollback = collectSucceedingStages(sms)
abortStageWithInvalidRollBack(stagesToRollback)
Comment on lines +1572 to +1573
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We could have delegated this to abortUnrollbackableStages

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mridulm

I am working on another PR to capture more scenarios, can you please also take a look? Thanks

Will do some code refactor in the PR: #53274

// stages which cannot be rolled back were aborted which leads to removing the
// the dependant job(s) from the active jobs set
val numActiveJobsWithStageAfterRollback =
activeJobs.count(job => stagesToRollback.contains(job.finalStage))
if (numActiveJobsWithStageAfterRollback == 0) {
logInfo(log"All jobs depending on the indeterminate stage " +
log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.")
return
}
}
true
} else {
false
}
}
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()

if (needFullStageRetry) {
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()
}
case _ =>
}

Expand Down Expand Up @@ -1886,6 +1903,20 @@ private[spark] class DAGScheduler(
}
}

/**
* If a map stage is non-deterministic, the map tasks of the stage may return different result
* when re-try. To make sure data correctness, we need to re-try all the tasks of its succeeding
* stages, as the input data may be changed after the map tasks are re-tried. For stages where
* rollback and retry all tasks are not possible, we will need to abort the stages.
*/
private[scheduler] def abortUnrollbackableStages(mapStage: ShuffleMapStage): Unit = {
val stagesToRollback = collectSucceedingStages(mapStage)
val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback)
logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output " +
log"was failed, we will roll back and rerun below stages which include itself and all its " +
log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
}

/**
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
Expand Down Expand Up @@ -2022,8 +2053,26 @@ private[spark] class DAGScheduler(
// The epoch of the task is acceptable (i.e., the task was launched after the most
// recent failure we're aware of for the executor), so mark the task's output as
// available.
mapOutputTracker.registerMapOutput(
val isChecksumMismatched = mapOutputTracker.registerMapOutput(
shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
if (isChecksumMismatched) {
shuffleStage.isChecksumMismatched = isChecksumMismatched
Copy link
Contributor

@mridulm mridulm Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never reset back to false when the stage attempt is retried and succeeds - what am I missing ?
This would mean the app will always fail, right ?

Not sure what I am missing here.
+CC @ivoson , @cloud-fan , @attilapiros

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mridulm , this is not set back to false. Would expect all the succeeding stages do fully retry once there is checksum mismatch happening for the stage, as we don't know the successful tasks consumed which version shuffle output.

This won't fail the app, the impact is that the succeeding stages would have a fully-retry.

The code logic has changed a little bit in PR: #53274

Pls take a look once you get a change. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On retry - when we throw away the entire mapper output and recompute it -> at which point, we can go back to setting it to false ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, it's not setting back to false. We'll only recompute once any new shuffle checksum mismatch is detected. Maybe we can remove the flag to avoid the confusion here.

// There could be multiple checksum mismatches detected for a single stage attempt.
// We check for stage abortion once and only once when we first detect checksum
// mismatch for each stage attempt. For example, assume that we have
// stage1 -> stage2, and we encounter checksum mismatch during the retry of stage1.
// In this case, we need to call abortUnrollbackableStages() for the succeeding
// stages. Assume that when stage2 is retried, some tasks finish and some tasks
// failed again with FetchFailed. In case that we encounter checksum mismatch again
// during the retry of stage1, we need to call abortUnrollbackableStages() again.
if (shuffleStage.maxChecksumMismatchedId < smt.stageAttemptId) {
shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId
if (shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled
&& shuffleStage.isStageIndeterminate) {
abortUnrollbackableStages(shuffleStage)
}
}
}
}
} else {
logInfo(log"Ignoring ${MDC(TASK_NAME, smt)} completion from an older attempt of indeterminate stage")
Expand Down Expand Up @@ -2148,12 +2197,8 @@ private[spark] class DAGScheduler(
// Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is
// guaranteed to be determinate, so the input data of the reducers will not change
// even if the map tasks are re-tried.
if (mapStage.isIndeterminate) {
val stagesToRollback = collectSucceedingStages(mapStage)
val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback)
logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " +
log"we will roll back and rerun below stages which include itself and all its " +
log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
if (mapStage.isIndeterminate && !mapStage.shuffleDep.checksumMismatchFullRetryEnabled) {
abortUnrollbackableStages(mapStage)
}

// We expect one executor failure to trigger many FetchFailures in rapid succession,
Expand Down
22 changes: 22 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Stage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ private[scheduler] abstract class Stage(
private var nextAttemptId: Int = 0
private[scheduler] def getNextAttemptId: Int = nextAttemptId

/**
* Whether checksum mismatches have been detected across different attempt of the stage, where
* checksum mismatches typically indicates that different stage attempts have produced different
* data.
*/
private[scheduler] var isChecksumMismatched: Boolean = false

/**
* The maximum of task attempt id where checksum mismatches are detected.
*/
private[scheduler] var maxChecksumMismatchedId: Int = nextAttemptId

val name: String = callSite.shortForm
val details: String = callSite.longForm

Expand Down Expand Up @@ -131,4 +143,14 @@ private[scheduler] abstract class Stage(
def isIndeterminate: Boolean = {
rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE
}

// Returns true if any parents of this stage are indeterminate.
def isParentIndeterminate: Boolean = {
parents.exists(_.isStageIndeterminate)
}

// Returns true if the stage itself is indeterminate.
def isStageIndeterminate: Boolean = {
!rdd.isReliablyCheckpointed && isChecksumMismatched
}
}
Loading