diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1366251d0618..427fb7590a4e 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -675,7 +675,9 @@ private[spark] class ExecutorAllocationManager( val taskIndex = taskEnd.taskInfo.index val stageId = taskEnd.stageId allocationManager.synchronized { - numRunningTasks -= 1 + if (stageIdToNumTasks.contains(stageId)) { + numRunningTasks -= 1 + } // If the executor is no longer running any scheduled tasks, mark it as idle if (executorIdToTaskIds.contains(executorId)) { executorIdToTaskIds(executorId) -= taskId diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index ec409712b953..c62736284048 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -938,6 +938,33 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager) === Map.empty) } + test("SPARK-18981: maxNumExecutorsNeeded should properly handle speculated tasks") { + sc = createSparkContext() + val manager = sc.executorAllocationManager.get + assert(maxNumExecutorsNeeded(manager) === 0) + + val stageInfo = createStageInfo(0, 1) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo)) + assert(maxNumExecutorsNeeded(manager) === 1) + + val taskInfo = createTaskInfo(0, 0, "executor-1") + val speculatedTaskInfo = createTaskInfo(1, 0, "executor-1") + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + assert(maxNumExecutorsNeeded(manager) === 1) + + sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, speculatedTaskInfo)) + assert(maxNumExecutorsNeeded(manager) === 2) + + sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, Success, taskInfo, null)) + assert(maxNumExecutorsNeeded(manager) === 1) + + sc.listenerBus.postToAll(SparkListenerStageCompleted(stageInfo)) + assert(maxNumExecutorsNeeded(manager) === 0) + + sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, TaskKilled, speculatedTaskInfo, null)) + assert(maxNumExecutorsNeeded(manager) === 0) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5,