diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a456f91d4c96..07a71ebed085 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1847,9 +1847,9 @@ private[spark] class DAGScheduler( case Success => // An earlier attempt of a stage (which is zombie) may still have running tasks. If these // tasks complete, they still count and we can mark the corresponding partitions as - // finished. Here we notify the task scheduler to skip running tasks for the same partition, - // to save resource. - if (task.stageAttemptId < stage.latestInfo.attemptNumber()) { + // finished if the stage is determinate. Here we notify the task scheduler to skip running + // tasks for the same partition to save resource. + if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()) { taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 7bb8f49e6bff..7691b98f620b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -3169,13 +3169,16 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti makeMapStatus("hostB", 2))) - // The second task of the shuffle map stage 1 from 1st attempt succeeds + // The second task of the shuffle map stage 1 from 1st attempt succeeds runEvent(makeCompletionEvent( taskSets(1).tasks(1), Success, makeMapStatus("hostC", 2))) + // Above task completion should not mark the partition 1 complete from 2nd attempt + assert(!tasksMarkedAsCompleted.contains(taskSets(3).tasks(1))) + // This task completion should get ignored and partition 1 should be missing // for shuffle map stage 1 assert(mapOutputTracker.findMissingPartitions(shuffleId2) == Some(Seq(1)))