Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.avro.reflect.Nullable;

import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -291,8 +290,8 @@ public void loadNext() {
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
if (taskContext != null) {
taskContext.killTaskIfInterrupted();
}
// This pointer points to a 4-byte record length, followed by the record's bytes
final long recordPointer = array.get(offset + position);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
Expand Down Expand Up @@ -102,8 +101,8 @@ public void loadNext() throws IOException {
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
if (taskContext != null) {
taskContext.killTaskIfInterrupted();
}
recordLength = din.readInt();
keyPrefix = din.readLong();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
if (context.isInterrupted) {
throw new TaskKilledException
} else {
delegate.hasNext
}
context.killTaskIfInterrupted()
delegate.hasNext
}

def next(): T = delegate.next()
Expand Down
18 changes: 18 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2249,6 +2249,24 @@ class SparkContext(config: SparkConf) extends Logging {
dagScheduler.cancelStage(stageId, None)
}

/**
* Kill and reschedule the given task attempt. Task ids can be obtained from the Spark UI
* or through SparkListener.onTaskStart.
*
* @param taskId the task ID to kill. This id uniquely identifies the task attempt.
* @param interruptThread whether to interrupt the thread running the task.
* @param reason the reason for killing the task, which should be a short string. If a task
* is killed multiple times with different reasons, only one reason will be reported.
*
* @return Whether the task was successfully killed.
*/
def killTaskAttempt(
taskId: Long,
interruptThread: Boolean = true,
reason: String = "killed via SparkContext.killTaskAttempt"): Boolean = {
dagScheduler.killTaskAttempt(taskId, interruptThread, reason)
}

/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ abstract class TaskContext extends Serializable {
@DeveloperApi
def getMetricsSources(sourceName: String): Seq[Source]

/**
* If the task is interrupted, throws TaskKilledException with the reason for the interrupt.
*/
private[spark] def killTaskIfInterrupted(): Unit

/**
* If the task is interrupted, the reason this task was killed, otherwise None.
*/
private[spark] def getKillReason(): Option[String]

/**
* Returns the manager for this task's managed memory.
*/
Expand Down
21 changes: 16 additions & 5 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ private[spark] class TaskContextImpl(
/** List of callback functions to execute when the task fails. */
@transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]

// Whether the corresponding task has been killed.
@volatile private var interrupted: Boolean = false
// If defined, the corresponding task has been killed and this option contains the reason.
@volatile private var reasonIfKilled: Option[String] = None

// Whether the task has completed.
private var completed: Boolean = false
Expand Down Expand Up @@ -140,16 +140,27 @@ private[spark] class TaskContextImpl(
}

/** Marks the task for interruption, i.e. cancellation. */
private[spark] def markInterrupted(): Unit = {
interrupted = true
private[spark] def markInterrupted(reason: String): Unit = {
reasonIfKilled = Some(reason)
}

private[spark] override def killTaskIfInterrupted(): Unit = {
val reason = reasonIfKilled
if (reason.isDefined) {
throw new TaskKilledException(reason.get)
}
}

private[spark] override def getKillReason(): Option[String] = {
reasonIfKilled
}

@GuardedBy("this")
override def isCompleted(): Boolean = synchronized(completed)

override def isRunningLocally(): Boolean = false

override def isInterrupted(): Boolean = interrupted
override def isInterrupted(): Boolean = reasonIfKilled.isDefined

override def getLocalProperty(key: String): String = localProperties.getProperty(key)

Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/TaskEndReason.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ case object TaskResultLost extends TaskFailedReason {
* Task was killed intentionally and needs to be rescheduled.
*/
@DeveloperApi
case object TaskKilled extends TaskFailedReason {
override def toErrorString: String = "TaskKilled (killed intentionally)"
case class TaskKilled(reason: String) extends TaskFailedReason {
override def toErrorString: String = s"TaskKilled ($reason)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this was part of DeveloperApi, what is the impact of making this change ?
In JsonProtocol, in mocked user code ?

If it does introduce backward incompatible changes, is there a way to mitigate this ?
Perhaps make TaskKilled a class with apply/unapply/serde/toString (essentially all that case class provides) and case object with apply with default reason = null (and logged when used as deprecated) ?

Copy link
Contributor Author

@ericl ericl Mar 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unfortunately not backwards compatible. I've looked into this, and the issue seems to be that we are converting a case object into a case class. If TaskKilled was a case class to start with, compatibility might have been possible (e.g. implement equals() ignoring reason), but as is you would break scala match statements and possibly other things depending on the user code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is unfortunate, but looks like it cant be helped if we need this feature.
Probably something to keep in mind with future use of case objects !

Thx for clarifying.

override def countTowardsTaskFailures: Boolean = false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ import org.apache.spark.annotation.DeveloperApi
* Exception thrown when a task is explicitly killed (i.e., task failure is expected).
*/
@DeveloperApi
class TaskKilledException extends RuntimeException
class TaskKilledException(val reason: String) extends RuntimeException {
def this() = this("unknown reason")
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ private[spark] class PythonRunner(

case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need the getOrElse here? (since isInterrupted is true, shouldn't this always be defined?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mridulm pointed out that should the kill reason get reset to None by a concurrent thread, this would crash. However, it is true that this can't happen in the current implementation.

If you think it's clearer, we could throw an AssertionError in this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm ok if Mridul wants this then fine to leave as-is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kayousterhout I actually had not considered this, but the use of maybeKillReason in Executor/other places; this was a nice catch by @ericl


case e: Exception if env.isStopped =>
logDebug("Exception thrown after context is stopped", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.launchTask(this, taskDesc)
}

case KillTask(taskId, _, interruptThread) =>
case KillTask(taskId, _, interruptThread, reason) =>
if (executor == null) {
exitExecutor(1, "Received KillTask command but executor was null")
} else {
executor.killTask(taskId, interruptThread)
executor.killTask(taskId, interruptThread, reason)
}

case StopExecutor =>
Expand Down
52 changes: 28 additions & 24 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ private[spark] class Executor(
threadPool.execute(tr)
}

def killTask(taskId: Long, interruptThread: Boolean): Unit = {
def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = {
val taskRunner = runningTasks.get(taskId)
if (taskRunner != null) {
if (taskReaperEnabled) {
Expand All @@ -168,7 +168,8 @@ private[spark] class Executor(
case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
}
if (shouldCreateReaper) {
val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread)
val taskReaper = new TaskReaper(
taskRunner, interruptThread = interruptThread, reason = reason)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

showCreateReaper special cases thread interruption (and missing ofcourse) - specifying "reason" is trying to piggy back on top of this.
What is the developer expectation when killTask is invoked with a message - should the 'reason' override the reason propagated by the TaskReaper ? If it might not, we need to document this - the reason might not make it even if kill happens after api invocation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's reasonable to show one (arbitrary) reason since this should be a rare situation. Also updated the killTaskAttempt doc comment to reflect this.

taskReaperForTask(taskId) = taskReaper
Some(taskReaper)
} else {
Expand All @@ -178,7 +179,7 @@ private[spark] class Executor(
// Execute the TaskReaper from outside of the synchronized block.
maybeNewTaskReaper.foreach(taskReaperPool.execute)
} else {
taskRunner.kill(interruptThread = interruptThread)
taskRunner.kill(interruptThread = interruptThread, reason = reason)
}
}
}
Expand All @@ -189,8 +190,9 @@ private[spark] class Executor(
* tasks instead of taking the JVM down.
* @param interruptThread whether to interrupt the task thread
*/
def killAllTasks(interruptThread: Boolean) : Unit = {
runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread))
def killAllTasks(interruptThread: Boolean, reason: String) : Unit = {
runningTasks.keys().asScala.foreach(t =>
killTask(t, interruptThread = interruptThread, reason = reason))
}

def stop(): Unit = {
Expand All @@ -217,8 +219,8 @@ private[spark] class Executor(
val threadName = s"Executor task launch worker for task $taskId"
private val taskName = taskDescription.name

/** Whether this task has been killed. */
@volatile private var killed = false
/** If specified, this task has been killed and this option contains the reason. */
@volatile private var reasonIfKilled: Option[String] = None

@volatile private var threadId: Long = -1

Expand All @@ -239,13 +241,13 @@ private[spark] class Executor(
*/
@volatile var task: Task[Any] = _

def kill(interruptThread: Boolean): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
killed = true
def kill(interruptThread: Boolean, reason: String): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason")
reasonIfKilled = Some(reason)
if (task != null) {
synchronized {
if (!finished) {
task.kill(interruptThread)
task.kill(interruptThread, reason)
}
}
}
Expand Down Expand Up @@ -296,12 +298,13 @@ private[spark] class Executor(

// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
val killReason = reasonIfKilled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why re-name the variable here (instead of just using reasonIfKilled below)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we assign to a temporary, then there is no risk of seeing concurrent mutations of the value as we access it below (though, this cannot currently happen).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh in retrospect I think TaskContext should have just clearly documented that an invariant of reasonIfKilled is that, once set, it won't be un-set, and then we'd avoid all of these corner cases. But not worth changing now.

if (killReason.isDefined) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException
throw new TaskKilledException(killReason.get)
}

logDebug("Task " + taskId + "'s epoch is " + task.epoch)
Expand Down Expand Up @@ -358,9 +361,7 @@ private[spark] class Executor(
} else 0L

// If the task has been killed, let's fail it.
if (task.killed) {
throw new TaskKilledException
}
task.context.killTaskIfInterrupted()

val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
Expand Down Expand Up @@ -426,15 +427,17 @@ private[spark] class Executor(
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

case _: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId)")
case t: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))

case _: InterruptedException if task.killed =>
logInfo(s"Executor interrupted and killed $taskName (TID $taskId)")
case _: InterruptedException if task.reasonIfKilled.isDefined =>
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
execBackend.statusUpdate(
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))

case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskFailedReason
Expand Down Expand Up @@ -512,7 +515,8 @@ private[spark] class Executor(
*/
private class TaskReaper(
taskRunner: TaskRunner,
val interruptThread: Boolean)
val interruptThread: Boolean,
val reason: String)
extends Runnable {

private[this] val taskId: Long = taskRunner.taskId
Expand All @@ -533,7 +537,7 @@ private[spark] class Executor(
// Only attempt to kill the task once. If interruptThread = false then a second kill
// attempt would be a no-op and if interruptThread = true then it may not be safe or
// effective to interrupt multiple times:
taskRunner.kill(interruptThread = interruptThread)
taskRunner.kill(interruptThread = interruptThread, reason = reason)
// Monitor the killed task until it exits. The synchronization logic here is complicated
// because we don't want to synchronize on the taskRunner while possibly taking a thread
// dump, but we also need to be careful to avoid races between checking whether the task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,15 @@ class DAGScheduler(
eventProcessLoop.post(StageCancelled(stageId, reason))
}

/**
* Kill a given task. It will be retried.
*
* @return Whether the task was successfully killed.
*/
def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = {
taskScheduler.killTaskAttempt(taskId, interruptThread, reason)
}

/**
* Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
* the last fetch failure.
Expand Down Expand Up @@ -1353,7 +1362,7 @@ class DAGScheduler(
case TaskResultLost =>
// Do nothing here; the TaskScheduler handles these failures and resubmits the task.

case _: ExecutorLostFailure | TaskKilled | UnknownReason =>
case _: ExecutorLostFailure | _: TaskKilled | UnknownReason =>
// Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
// will abort the job.
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,21 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int

def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit =
/**
* Requests that an executor kills a running task.
*
* @param taskId Id of the task.
* @param executorId Id of the executor the task is running on.
* @param interruptThread Whether the executor should interrupt the task thread.
* @param reason The reason for the task kill.
*/
def killTask(
taskId: Long,
executorId: String,
interruptThread: Boolean,
reason: String): Unit =
throw new UnsupportedOperationException

def isReady(): Boolean = true

/**
Expand Down
Loading