-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19820] [core] Allow reason to be specified for task kill #17166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
91b8aef
a58d391
02d81b5
170fa34
614ad3c
72b28cb
fda712d
8f7ffb3
348e97a
f5069f7
884a3ad
203a900
6e8593b
5707715
a37c09b
71b41b3
3ec3633
145c78a
8c4381f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,8 +59,8 @@ private[spark] class TaskContextImpl( | |
| /** List of callback functions to execute when the task fails. */ | ||
| @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] | ||
|
|
||
| // Whether the corresponding task has been killed. | ||
| @volatile private var interrupted: Boolean = false | ||
| // If defined, the corresponding task has been killed for the contained reason. | ||
| @volatile private var maybeKillReason: Option[String] = None | ||
|
|
||
| // Whether the task has completed. | ||
| private var completed: Boolean = false | ||
|
|
@@ -140,16 +140,22 @@ private[spark] class TaskContextImpl( | |
| } | ||
|
|
||
| /** Marks the task for interruption, i.e. cancellation. */ | ||
| private[spark] def markInterrupted(): Unit = { | ||
| interrupted = true | ||
| private[spark] def markInterrupted(reason: String): Unit = { | ||
| maybeKillReason = Some(reason) | ||
| } | ||
|
|
||
| private[spark] override def killTaskIfInterrupted(): Unit = { | ||
| if (maybeKillReason.isDefined) { | ||
| throw new TaskKilledException(maybeKillReason.get) | ||
|
||
| } | ||
| } | ||
|
|
||
| @GuardedBy("this") | ||
| override def isCompleted(): Boolean = synchronized(completed) | ||
|
|
||
| override def isRunningLocally(): Boolean = false | ||
|
|
||
| override def isInterrupted(): Boolean = interrupted | ||
| override def isInterrupted(): Boolean = maybeKillReason.isDefined | ||
|
|
||
| override def getLocalProperty(key: String): String = localProperties.getProperty(key) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -212,8 +212,8 @@ case object TaskResultLost extends TaskFailedReason { | |
| * Task was killed intentionally and needs to be rescheduled. | ||
| */ | ||
| @DeveloperApi | ||
| case object TaskKilled extends TaskFailedReason { | ||
| override def toErrorString: String = "TaskKilled (killed intentionally)" | ||
| case class TaskKilled(reason: String) extends TaskFailedReason { | ||
| override def toErrorString: String = s"TaskKilled ($reason)" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this was part of DeveloperApi, what is the impact of making this change ? If it does introduce backward incompatible changes, is there a way to mitigate this ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unfortunately not backwards compatible. I've looked into this, and the issue seems to be that we are converting a case object into a case class. If
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is unfortunate, but looks like it cant be helped if we need this feature. Thx for clarifying. |
||
| override def countTowardsTaskFailures: Boolean = false | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -215,7 +215,8 @@ private[spark] class PythonRunner( | |
|
|
||
| case e: Exception if context.isInterrupted => | ||
| logDebug("Exception thrown after task interruption", e) | ||
| throw new TaskKilledException | ||
| context.killTaskIfInterrupted() | ||
| null // not reached | ||
|
||
|
|
||
| case e: Exception if env.isStopped => | ||
| logDebug("Exception thrown after context is stopped", e) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -158,7 +158,7 @@ private[spark] class Executor( | |
| threadPool.execute(tr) | ||
| } | ||
|
|
||
| def killTask(taskId: Long, interruptThread: Boolean): Unit = { | ||
| def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = { | ||
| val taskRunner = runningTasks.get(taskId) | ||
| if (taskRunner != null) { | ||
| if (taskReaperEnabled) { | ||
|
|
@@ -168,7 +168,8 @@ private[spark] class Executor( | |
| case Some(existingReaper) => interruptThread && !existingReaper.interruptThread | ||
| } | ||
| if (shouldCreateReaper) { | ||
| val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread) | ||
| val taskReaper = new TaskReaper( | ||
| taskRunner, interruptThread = interruptThread, reason = reason) | ||
|
||
| taskReaperForTask(taskId) = taskReaper | ||
| Some(taskReaper) | ||
| } else { | ||
|
|
@@ -178,7 +179,7 @@ private[spark] class Executor( | |
| // Execute the TaskReaper from outside of the synchronized block. | ||
| maybeNewTaskReaper.foreach(taskReaperPool.execute) | ||
| } else { | ||
| taskRunner.kill(interruptThread = interruptThread) | ||
| taskRunner.kill(interruptThread = interruptThread, reason = reason) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -189,8 +190,9 @@ private[spark] class Executor( | |
| * tasks instead of taking the JVM down. | ||
| * @param interruptThread whether to interrupt the task thread | ||
| */ | ||
| def killAllTasks(interruptThread: Boolean) : Unit = { | ||
| runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread)) | ||
| def killAllTasks(interruptThread: Boolean, reason: String) : Unit = { | ||
| runningTasks.keys().asScala.foreach(t => | ||
| killTask(t, interruptThread = interruptThread, reason = reason)) | ||
| } | ||
|
|
||
| def stop(): Unit = { | ||
|
|
@@ -217,8 +219,8 @@ private[spark] class Executor( | |
| val threadName = s"Executor task launch worker for task $taskId" | ||
| private val taskName = taskDescription.name | ||
|
|
||
| /** Whether this task has been killed. */ | ||
| @volatile private var killed = false | ||
| /** If specified, this task has been killed and this option contains the reason. */ | ||
| @volatile private var maybeKillReason: Option[String] = None | ||
|
|
||
| @volatile private var threadId: Long = -1 | ||
|
|
||
|
|
@@ -239,13 +241,13 @@ private[spark] class Executor( | |
| */ | ||
| @volatile var task: Task[Any] = _ | ||
|
|
||
| def kill(interruptThread: Boolean): Unit = { | ||
| logInfo(s"Executor is trying to kill $taskName (TID $taskId)") | ||
| killed = true | ||
| def kill(interruptThread: Boolean, reason: String): Unit = { | ||
| logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason") | ||
| maybeKillReason = Some(reason) | ||
| if (task != null) { | ||
| synchronized { | ||
| if (!finished) { | ||
| task.kill(interruptThread) | ||
| task.kill(interruptThread, reason) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -296,12 +298,12 @@ private[spark] class Executor( | |
|
|
||
| // If this task has been killed before we deserialized it, let's quit now. Otherwise, | ||
| // continue executing the task. | ||
| if (killed) { | ||
| if (maybeKillReason.isDefined) { | ||
| // Throw an exception rather than returning, because returning within a try{} block | ||
| // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl | ||
| // exception will be caught by the catch block, leading to an incorrect ExceptionFailure | ||
| // for the task. | ||
| throw new TaskKilledException | ||
| throw new TaskKilledException(maybeKillReason.get) | ||
|
||
| } | ||
|
|
||
| logDebug("Task " + taskId + "'s epoch is " + task.epoch) | ||
|
|
@@ -358,9 +360,7 @@ private[spark] class Executor( | |
| } else 0L | ||
|
|
||
| // If the task has been killed, let's fail it. | ||
| if (task.killed) { | ||
| throw new TaskKilledException | ||
| } | ||
| task.context.killTaskIfInterrupted() | ||
|
|
||
| val resultSer = env.serializer.newInstance() | ||
| val beforeSerialization = System.currentTimeMillis() | ||
|
|
@@ -426,15 +426,18 @@ private[spark] class Executor( | |
| setTaskFinishedAndClearInterruptStatus() | ||
| execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) | ||
|
|
||
| case _: TaskKilledException => | ||
| logInfo(s"Executor killed $taskName (TID $taskId)") | ||
| case t: TaskKilledException => | ||
| logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") | ||
| setTaskFinishedAndClearInterruptStatus() | ||
| execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) | ||
| execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) | ||
|
|
||
| case _: InterruptedException if task.killed => | ||
| logInfo(s"Executor interrupted and killed $taskName (TID $taskId)") | ||
| logInfo( | ||
| s"Executor interrupted and killed $taskName (TID $taskId)," + | ||
| s" reason: ${task.maybeKillReason.get}") | ||
| setTaskFinishedAndClearInterruptStatus() | ||
| execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) | ||
| execBackend.statusUpdate( | ||
| taskId, TaskState.KILLED, ser.serialize(TaskKilled(task.maybeKillReason.get))) | ||
|
|
||
| case CausedBy(cDE: CommitDeniedException) => | ||
| val reason = cDE.toTaskFailedReason | ||
|
|
@@ -512,7 +515,8 @@ private[spark] class Executor( | |
| */ | ||
| private class TaskReaper( | ||
| taskRunner: TaskRunner, | ||
| val interruptThread: Boolean) | ||
| val interruptThread: Boolean, | ||
| val reason: String) | ||
| extends Runnable { | ||
|
|
||
| private[this] val taskId: Long = taskRunner.taskId | ||
|
|
@@ -533,7 +537,7 @@ private[spark] class Executor( | |
| // Only attempt to kill the task once. If interruptThread = false then a second kill | ||
| // attempt would be a no-op and if interruptThread = true then it may not be safe or | ||
| // effective to interrupt multiple times: | ||
| taskRunner.kill(interruptThread = interruptThread) | ||
| taskRunner.kill(interruptThread = interruptThread, reason = reason) | ||
| // Monitor the killed task until it exits. The synchronization logic here is complicated | ||
| // because we don't want to synchronize on the taskRunner while possibly taking a thread | ||
| // dump, but we also need to be careful to avoid races between checking whether the task | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,8 +89,8 @@ private[spark] abstract class Task[T]( | |
| TaskContext.setTaskContext(context) | ||
| taskThread = Thread.currentThread() | ||
|
|
||
| if (_killed) { | ||
| kill(interruptThread = false) | ||
| if (_maybeKillReason != null) { | ||
| kill(interruptThread = false, _maybeKillReason) | ||
| } | ||
|
|
||
| new CallerContext( | ||
|
|
@@ -160,15 +160,20 @@ private[spark] abstract class Task[T]( | |
|
|
||
| // A flag to indicate whether the task is killed. This is used in case context is not yet | ||
| // initialized when kill() is invoked. | ||
| @volatile @transient private var _killed = false | ||
| @volatile @transient private var _maybeKillReason: String = null | ||
|
||
|
|
||
| protected var _executorDeserializeTime: Long = 0 | ||
| protected var _executorDeserializeCpuTime: Long = 0 | ||
|
|
||
| /** | ||
| * Whether the task has been killed. | ||
| */ | ||
| def killed: Boolean = _killed | ||
| def killed: Boolean = _maybeKillReason != null | ||
|
|
||
| /** | ||
| * If this task has been killed, contains the reason for the kill. | ||
|
||
| */ | ||
| def maybeKillReason: Option[String] = Option(_maybeKillReason) | ||
|
|
||
| /** | ||
| * Returns the amount of time spent deserializing the RDD and function to be run. | ||
|
|
@@ -201,10 +206,11 @@ private[spark] abstract class Task[T]( | |
| * be called multiple times. | ||
| * If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread. | ||
| */ | ||
| def kill(interruptThread: Boolean) { | ||
| _killed = true | ||
| def kill(interruptThread: Boolean, reason: String) { | ||
| require(reason != null) | ||
| _maybeKillReason = reason | ||
| if (context != null) { | ||
| context.markInterrupted() | ||
| context.markInterrupted(reason) | ||
| } | ||
| if (interruptThread && taskThread != null) { | ||
| taskThread.interrupt() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Overloading
maybeKillReasonto indicateinterruptedstatus smells a bit; but might be ok for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the reason here is to allow this to be set atomically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about calling this
reasonIfKilled, here and elsewhere? (if you strongly prefer the existing name find to leave as-is -- I just slightly prefer making it somewhat more obvious that this and the fact that the task has been killed are tightly intertwined).In any case, can you expand the comment a bit to one you used below: "If specified, this task has been killed and this option contains the reason."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done