@@ -84,10 +84,10 @@ private[spark] class TaskSetManager(
8484 val successful = new Array [Boolean ](numTasks)
8585 private val numFailures = new Array [Int ](numTasks)
8686
87- // Set the coresponding index of Boolean var when the task killed by other attempt tasks,
88- // this happened while we set the `spark.speculation` to true. The task killed by others
87+ // Add the tid of task into this HashSet when the task is killed by other attempt tasks.
88+ // This happened while we set the `spark.speculation` to true. The task killed by others
8989 // should not resubmit while executor lost.
90- private val killedByOtherAttempt : Array [ Boolean ] = new Array [ Boolean ](numTasks)
90+ private val killedByOtherAttempt = new HashSet [ Long ]
9191
9292 val taskAttempts = Array .fill[List [TaskInfo ]](numTasks)(Nil )
9393 private [scheduler] var tasksSuccessful = 0
@@ -735,7 +735,7 @@ private[spark] class TaskSetManager(
735735 logInfo(s " Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " +
736736 s " in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " +
737737 s " as the attempt ${info.attemptNumber} succeeded on ${info.host}" )
738- killedByOtherAttempt(index) = true
738+ killedByOtherAttempt += attemptInfo.taskId
739739 sched.backend.killTask(
740740 attemptInfo.taskId,
741741 attemptInfo.executorId,
@@ -758,7 +758,7 @@ private[spark] class TaskSetManager(
758758 }
759759 // There may be multiple tasksets for this stage -- we let all of them know that the partition
760760 // was completed. This may result in some of the tasksets getting completed.
761- sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId)
761+ sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info )
762762 // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
763763 // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
764764 // "deserialize" the value when holding a lock to avoid blocking other threads. So we call
@@ -769,9 +769,12 @@ private[spark] class TaskSetManager(
769769 maybeFinishTaskSet()
770770 }
771771
772- private [scheduler] def markPartitionCompleted (partitionId : Int ): Unit = {
772+ private [scheduler] def markPartitionCompleted (partitionId : Int , taskInfo : TaskInfo ): Unit = {
773773 partitionToIndex.get(partitionId).foreach { index =>
774774 if (! successful(index)) {
775+ if (speculationEnabled && ! isZombie) {
776+ successfulTaskDurations.insert(taskInfo.duration)
777+ }
775778 tasksSuccessful += 1
776779 successful(index) = true
777780 if (tasksSuccessful == numTasks) {
@@ -944,7 +947,7 @@ private[spark] class TaskSetManager(
944947 && ! isZombie) {
945948 for ((tid, info) <- taskInfos if info.executorId == execId) {
946949 val index = taskInfos(tid).index
947- if (successful(index) && ! killedByOtherAttempt(index )) {
950+ if (successful(index) && ! killedByOtherAttempt.contains(tid )) {
948951 successful(index) = false
949952 copiesRunning(index) -= 1
950953 tasksSuccessful -= 1
0 commit comments