Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private[spark] trait ExecutorAllocationClient {
* Default implementation delegates to kill, scheduler must override
* if it supports graceful decommissioning.
*
* @param executorsAndDecominfo identifiers of executors & decom info.
* @param executorsAndDecomInfo identifiers of executors & decom info.
* @param adjustTargetNumExecutors whether the target number of executors will be adjusted down
* after these executors have been decommissioned.
* @return the ids of the executors acknowledged by the cluster manager to be removed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1825,7 +1825,7 @@ private[spark] class DAGScheduler(
if (bmAddress != null) {
val externalShuffleServiceEnabled = env.blockManager.externalShuffleServiceEnabled
val isHostDecommissioned = taskScheduler
.getExecutorDecommissionInfo(bmAddress.executorId)
.getExecutorDecommissionState(bmAddress.executorId)
.exists(_.isHostDecommissioned)

// Shuffle output of all executors on host `bmAddress.host` may be lost if:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@
package org.apache.spark.scheduler

/**
* Provides more detail when an executor is being decommissioned.
* Message providing more detail when an executor is being decommissioned.
* @param message Human readable reason for why the decommissioning is happening.
* @param isHostDecommissioned Whether the host (aka the `node` or `worker` in other places) is
* being decommissioned too. Used to infer if the shuffle data might
* be lost even if the external shuffle service is enabled.
*/
private[spark]
case class ExecutorDecommissionInfo(message: String, isHostDecommissioned: Boolean)

/**
* State related to decommissioning that is kept by the TaskSchedulerImpl. This state is derived
* from the info message above but it is kept distinct to allow the state to evolve independently
* from the message.
*/
case class ExecutorDecommissionState(
// Timestamp the decommissioning commenced as per the Driver's clock,
// to estimate when the executor might eventually be lost if EXECUTOR_DECOMMISSION_KILL_INTERVAL
// is configured.
startTime: Long,
isHostDecommissioned: Boolean)
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ private[spark] trait TaskScheduler {
/**
* If an executor is decommissioned, return its corresponding decommission info
*/
def getExecutorDecommissionInfo(executorId: String): Option[ExecutorDecommissionInfo]
def getExecutorDecommissionState(executorId: String): Option[ExecutorDecommissionState]

/**
* Process a lost executor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ private[spark] class TaskSchedulerImpl(

// We add executors here when we first get decommission notification for them. Executors can
// continue to run even after being asked to decommission, but they will eventually exit.
val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionState]

// When they exit and we know of that via heartbeat failure, we will add them to this cache.
// This cache is consulted to know if a fetch failure is because a source executor was
Expand All @@ -152,7 +152,7 @@ private[spark] class TaskSchedulerImpl(
.ticker(new Ticker{
override def read(): Long = TimeUnit.MILLISECONDS.toNanos(clock.getTimeMillis())
})
.build[String, ExecutorDecommissionInfo]()
.build[String, ExecutorDecommissionState]()
.asMap()

def runningTasksByExecutors: Map[String, Int] = synchronized {
Expand Down Expand Up @@ -293,7 +293,7 @@ private[spark] class TaskSchedulerImpl(
private[scheduler] def createTaskSetManager(
taskSet: TaskSet,
maxTaskFailures: Int): TaskSetManager = {
new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt)
new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt, clock)
}

override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
Expand Down Expand Up @@ -922,22 +922,36 @@ private[spark] class TaskSchedulerImpl(
synchronized {
// Don't bother noting decommissioning for executors that we don't know about
if (executorIdToHost.contains(executorId)) {
// The scheduler can get multiple decommission updates from multiple sources,
// and some of those can have isHostDecommissioned false. We merge them such that
// if we heard isHostDecommissioned ever true, then we keep that one since it is
// most likely coming from the cluster manager and thus authoritative
val oldDecomInfo = executorsPendingDecommission.get(executorId)
if (!oldDecomInfo.exists(_.isHostDecommissioned)) {
executorsPendingDecommission(executorId) = decommissionInfo
val oldDecomStateOpt = executorsPendingDecommission.get(executorId)
val newDecomState = if (oldDecomStateOpt.isEmpty) {
// This is the first time we are hearing of decommissioning this executor,
// so create a brand new state.
ExecutorDecommissionState(
clock.getTimeMillis(),
decommissionInfo.isHostDecommissioned)
} else {
val oldDecomState = oldDecomStateOpt.get
if (!oldDecomState.isHostDecommissioned && decommissionInfo.isHostDecommissioned) {
// Only the cluster manager is allowed to send decommission messages with
// isHostDecommissioned set. So the new decommissionInfo is from the cluster
// manager and is thus authoritative. Flip isHostDecommissioned to true but keep the old
// decommission start time.
ExecutorDecommissionState(
oldDecomState.startTime,
isHostDecommissioned = true)
} else {
oldDecomState
}
}
executorsPendingDecommission(executorId) = newDecomState
}
}
rootPool.executorDecommission(executorId)
backend.reviveOffers()
}

override def getExecutorDecommissionInfo(executorId: String)
: Option[ExecutorDecommissionInfo] = synchronized {
override def getExecutorDecommissionState(executorId: String)
: Option[ExecutorDecommissionState] = synchronized {
executorsPendingDecommission
.get(executorId)
.orElse(Option(decommissionedExecutorsRemoved.get(executorId)))
Expand All @@ -948,14 +962,14 @@ private[spark] class TaskSchedulerImpl(
val reason = givenReason match {
// Handle executor process loss due to decommissioning
case ExecutorProcessLost(message, origWorkerLost, origCausedByApp) =>
val executorDecommissionInfo = getExecutorDecommissionInfo(executorId)
val executorDecommissionState = getExecutorDecommissionState(executorId)
ExecutorProcessLost(
message,
// Also mark the worker lost if we know that the host was decommissioned
origWorkerLost || executorDecommissionInfo.exists(_.isHostDecommissioned),
origWorkerLost || executorDecommissionState.exists(_.isHostDecommissioned),
// Executor loss is certainly not caused by app if we knew that this executor is being
// decommissioned
causedByApp = executorDecommissionInfo.isEmpty && origCausedByApp)
causedByApp = executorDecommissionState.isEmpty && origCausedByApp)
case e => e
}

Expand Down Expand Up @@ -1047,8 +1061,8 @@ private[spark] class TaskSchedulerImpl(
}


val decomInfo = executorsPendingDecommission.remove(executorId)
decomInfo.foreach(decommissionedExecutorsRemoved.put(executorId, _))
val decomState = executorsPendingDecommission.remove(executorId)
decomState.foreach(decommissionedExecutorsRemoved.put(executorId, _))

if (reason != LossReasonPending) {
executorIdToHost -= executorId
Expand Down Expand Up @@ -1085,12 +1099,12 @@ private[spark] class TaskSchedulerImpl(

// exposed for test
protected final def isExecutorDecommissioned(execId: String): Boolean =
getExecutorDecommissionInfo(execId).nonEmpty
getExecutorDecommissionState(execId).isDefined

// exposed for test
protected final def isHostDecommissioned(host: String): Boolean = {
hostToExecutors.get(host).exists { executors =>
executors.exists(e => getExecutorDecommissionInfo(e).exists(_.isHostDecommissioned))
executors.exists(e => getExecutorDecommissionState(e).exists(_.isHostDecommissioned))
}
}

Expand Down
34 changes: 15 additions & 19 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ private[spark] class TaskSetManager(

// Task index, start and finish time for each task attempt (indexed by task ID)
private[scheduler] val taskInfos = new HashMap[Long, TaskInfo]
private[scheduler] val tidToExecutorKillTimeMapping = new HashMap[Long, Long]

// Use a MedianHeap to record durations of successful tasks so we know when to launch
// speculative tasks. This is only used when speculation is enabled, to avoid the overhead
Expand Down Expand Up @@ -943,7 +942,6 @@ private[spark] class TaskSetManager(

/** If the given task ID is in the set of running tasks, removes it. */
def removeRunningTask(tid: Long): Unit = {
tidToExecutorKillTimeMapping.remove(tid)
if (runningTasksSet.remove(tid) && parent != null) {
parent.decreaseRunningTasks(1)
}
Expand Down Expand Up @@ -1054,15 +1052,21 @@ private[spark] class TaskSetManager(
logDebug("Task length threshold for speculation: " + threshold)
for (tid <- runningTasksSet) {
var speculated = checkAndSubmitSpeculatableTask(tid, time, threshold)
if (!speculated && tidToExecutorKillTimeMapping.contains(tid)) {
// Check whether this task will finish before the exectorKillTime assuming
// it will take medianDuration overall. If this task cannot finish within
// executorKillInterval, then this task is a candidate for speculation
val taskEndTimeBasedOnMedianDuration = taskInfos(tid).launchTime + medianDuration
val canExceedDeadline = tidToExecutorKillTimeMapping(tid) <
taskEndTimeBasedOnMedianDuration
if (canExceedDeadline) {
speculated = checkAndSubmitSpeculatableTask(tid, time, 0)
if (!speculated && executorDecommissionKillInterval.isDefined) {
val taskInfo = taskInfos(tid)
val decomState = sched.getExecutorDecommissionState(taskInfo.executorId)
if (decomState.isDefined) {
// Check if this task might finish after this executor is decommissioned.
// We estimate the task's finish time by using the median task duration.
// Whereas the time when the executor might be decommissioned is estimated using the
// config executorDecommissionKillInterval. If the task is going to finish after
// decommissioning, then we will eagerly speculate the task.
val taskEndTimeBasedOnMedianDuration = taskInfos(tid).launchTime + medianDuration
val executorDecomTime = decomState.get.startTime + executorDecommissionKillInterval.get
val canExceedDeadline = executorDecomTime < taskEndTimeBasedOnMedianDuration
if (canExceedDeadline) {
speculated = checkAndSubmitSpeculatableTask(tid, time, 0)
}
}
}
foundTasks |= speculated
Expand Down Expand Up @@ -1123,14 +1127,6 @@ private[spark] class TaskSetManager(

def executorDecommission(execId: String): Unit = {
recomputeLocality()
if (speculationEnabled) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a test that this is being speculated?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was used as an efficiency improvement: To not do this book keeping in the driver if the speculation is not enabled. Save both some cpu cycles and memory.

Now this check is done in checkSpeculatableTasks, which is not even called if speculation is disabled. And thus automatically begets this efficiency improvement. This is a positive side effect of changing the book keeping by merging tidToExecutorKillTimeMapping into executorDecommissionState.

In the meanwhile I will hunt for a suitable test that adds some coverage here or consider adding one.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for keeping me honest here. I added a test and a tiny change to make it pass. The behavior is that when speculation is disabled that the tasks should only be rerun when the executor is actually lost.

executorDecommissionKillInterval.foreach { interval =>
val executorKillTime = clock.getTimeMillis() + interval
runningTasksSet.filter(taskInfos(_).executorId == execId).foreach { tid =>
tidToExecutorKillTimeMapping(tid) = executorKillTime
}
}
}
}

def recomputeLocality(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
override def executorDecommission(
executorId: String,
decommissionInfo: ExecutorDecommissionInfo): Unit = {}
override def getExecutorDecommissionInfo(
executorId: String): Option[ExecutorDecommissionInfo] = None
override def getExecutorDecommissionState(
executorId: String): Option[ExecutorDecommissionState] = None
}

/**
Expand Down Expand Up @@ -787,8 +787,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
override def executorDecommission(
executorId: String,
decommissionInfo: ExecutorDecommissionInfo): Unit = {}
override def getExecutorDecommissionInfo(
executorId: String): Option[ExecutorDecommissionInfo] = None
override def getExecutorDecommissionState(
executorId: String): Option[ExecutorDecommissionState] = None
}
val noKillScheduler = new DAGScheduler(
sc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,6 @@ private class DummyTaskScheduler extends TaskScheduler {
override def executorDecommission(
executorId: String,
decommissionInfo: ExecutorDecommissionInfo): Unit = {}
override def getExecutorDecommissionInfo(
executorId: String): Option[ExecutorDecommissionInfo] = None
override def getExecutorDecommissionState(
executorId: String): Option[ExecutorDecommissionState] = None
}
Loading