Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.apache.avro.reflect.Nullable;

import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -291,8 +290,8 @@ public void loadNext() {
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
if (taskContext != null) {
taskContext.killTaskIfInterrupted();
}
// This pointer points to a 4-byte record length, followed by the record's bytes
final long recordPointer = array.get(offset + position);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskKilledException;
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
Expand Down Expand Up @@ -102,8 +101,8 @@ public void loadNext() throws IOException {
// to avoid performance overhead. This check is added here in `loadNext()` instead of in
// `hasNext()` because it's technically possible for the caller to be relying on
// `getNumRecords()` instead of `hasNext()` to know when to stop.
if (taskContext != null && taskContext.isInterrupted()) {
throw new TaskKilledException();
if (taskContext != null) {
taskContext.killTaskIfInterrupted();
}
recordLength = din.readInt();
keyPrefix = din.readLong();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
if (context.isInterrupted) {
throw new TaskKilledException
} else {
delegate.hasNext
}
context.killTaskIfInterrupted()
delegate.hasNext
}

def next(): T = delegate.next()
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2249,6 +2249,22 @@ class SparkContext(config: SparkConf) extends Logging {
dagScheduler.cancelStage(stageId, None)
}

/**
* Kill and reschedule the given task attempt. Task ids can be obtained from the Spark UI
* or through SparkListener.onTaskStart.
*
* @param taskId the task ID to kill. This id uniquely identifies the task attempt.
* @param interruptThread whether to interrupt the thread running the task.
* @param reason the reason for killing the task, which should be a short string. If a task
* is killed multiple times with different reasons, only one reason will be reported.
*/
def killTaskAttempt(
taskId: Long,
interruptThread: Boolean = true,
reason: String = "killed via SparkContext.killTaskAttempt"): Unit = {
dagScheduler.killTaskAttempt(taskId, interruptThread, reason)
}

/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
Expand Down
5 changes: 5 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ abstract class TaskContext extends Serializable {
@DeveloperApi
def getMetricsSources(sourceName: String): Seq[Source]

/**
* If the task is interrupted, throws TaskKilledException with the reason for the interrupt.
*/
private[spark] def killTaskIfInterrupted(): Unit

/**
* Returns the manager for this task's managed memory.
*/
Expand Down
16 changes: 11 additions & 5 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ private[spark] class TaskContextImpl(
/** List of callback functions to execute when the task fails. */
@transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]

// Whether the corresponding task has been killed.
@volatile private var interrupted: Boolean = false
// If defined, the corresponding task has been killed for the contained reason.
@volatile private var maybeKillReason: Option[String] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Overloading maybeKillReason to indicate interrupted status smells a bit; but might be ok for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the reason here is to allow this to be set atomically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about calling this reasonIfKilled, here and elsewhere? (if you strongly prefer the existing name find to leave as-is -- I just slightly prefer making it somewhat more obvious that this and the fact that the task has been killed are tightly intertwined).

In any case, can you expand the comment a bit to one you used below: "If specified, this task has been killed and this option contains the reason."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// Whether the task has completed.
private var completed: Boolean = false
Expand Down Expand Up @@ -140,16 +140,22 @@ private[spark] class TaskContextImpl(
}

/** Marks the task for interruption, i.e. cancellation. */
private[spark] def markInterrupted(): Unit = {
interrupted = true
private[spark] def markInterrupted(reason: String): Unit = {
maybeKillReason = Some(reason)
}

private[spark] override def killTaskIfInterrupted(): Unit = {
if (maybeKillReason.isDefined) {
throw new TaskKilledException(maybeKillReason.get)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not thread safe - while technically we do not allow kill reason to be reset to None right now and might be fine, it can lead to future issues.

Either make all access/updates to kill reason synchronized; or capture maybeKillReason to a local variable and use that in the if and throw

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}

@GuardedBy("this")
override def isCompleted(): Boolean = synchronized(completed)

override def isRunningLocally(): Boolean = false

override def isInterrupted(): Boolean = interrupted
override def isInterrupted(): Boolean = maybeKillReason.isDefined

override def getLocalProperty(key: String): String = localProperties.getProperty(key)

Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/TaskEndReason.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ case object TaskResultLost extends TaskFailedReason {
* Task was killed intentionally and needs to be rescheduled.
*/
@DeveloperApi
case object TaskKilled extends TaskFailedReason {
override def toErrorString: String = "TaskKilled (killed intentionally)"
case class TaskKilled(reason: String) extends TaskFailedReason {
override def toErrorString: String = s"TaskKilled ($reason)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this was part of DeveloperApi, what is the impact of making this change ?
In JsonProtocol, in mocked user code ?

If it does introduce backward incompatible changes, is there a way to mitigate this ?
Perhaps make TaskKilled a class with apply/unapply/serde/toString (essentially all that case class provides) and case object with apply with default reason = null (and logged when used as deprecated) ?

Copy link
Contributor Author

@ericl ericl Mar 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unfortunately not backwards compatible. I've looked into this, and the issue seems to be that we are converting a case object into a case class. If TaskKilled was a case class to start with, compatibility might have been possible (e.g. implement equals() ignoring reason), but as is you would break scala match statements and possibly other things depending on the user code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is unfortunate, but looks like it cant be helped if we need this feature.
Probably something to keep in mind with future use of case objects !

Thx for clarifying.

override def countTowardsTaskFailures: Boolean = false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ import org.apache.spark.annotation.DeveloperApi
* Exception thrown when a task is explicitly killed (i.e., task failure is expected).
*/
@DeveloperApi
class TaskKilledException extends RuntimeException
class TaskKilledException(val reason: String) extends RuntimeException
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ private[spark] class PythonRunner(

case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
context.killTaskIfInterrupted()
null // not reached
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be good if we could directly throw the exception here - instead of relying on killTaskIfInterrupted to do the right thing (it is interrupted already according to the case check)
Not only will it not remove the unreachable null, but also ensure future changes to killTaskIfInterrupted or interrupt reset, etc does not break this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


case e: Exception if env.isStopped =>
logDebug("Exception thrown after context is stopped", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.launchTask(this, taskDesc)
}

case KillTask(taskId, _, interruptThread) =>
case KillTask(taskId, _, interruptThread, reason) =>
if (executor == null) {
exitExecutor(1, "Received KillTask command but executor was null")
} else {
executor.killTask(taskId, interruptThread)
executor.killTask(taskId, interruptThread, reason)
}

case StopExecutor =>
Expand Down
50 changes: 27 additions & 23 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ private[spark] class Executor(
threadPool.execute(tr)
}

def killTask(taskId: Long, interruptThread: Boolean): Unit = {
def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = {
val taskRunner = runningTasks.get(taskId)
if (taskRunner != null) {
if (taskReaperEnabled) {
Expand All @@ -168,7 +168,8 @@ private[spark] class Executor(
case Some(existingReaper) => interruptThread && !existingReaper.interruptThread
}
if (shouldCreateReaper) {
val taskReaper = new TaskReaper(taskRunner, interruptThread = interruptThread)
val taskReaper = new TaskReaper(
taskRunner, interruptThread = interruptThread, reason = reason)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

showCreateReaper special cases thread interruption (and missing ofcourse) - specifying "reason" is trying to piggy back on top of this.
What is the developer expectation when killTask is invoked with a message - should the 'reason' override the reason propagated by the TaskReaper ? If it might not, we need to document this - the reason might not make it even if kill happens after api invocation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's reasonable to show one (arbitrary) reason since this should be a rare situation. Also updated the killTaskAttempt doc comment to reflect this.

taskReaperForTask(taskId) = taskReaper
Some(taskReaper)
} else {
Expand All @@ -178,7 +179,7 @@ private[spark] class Executor(
// Execute the TaskReaper from outside of the synchronized block.
maybeNewTaskReaper.foreach(taskReaperPool.execute)
} else {
taskRunner.kill(interruptThread = interruptThread)
taskRunner.kill(interruptThread = interruptThread, reason = reason)
}
}
}
Expand All @@ -189,8 +190,9 @@ private[spark] class Executor(
* tasks instead of taking the JVM down.
* @param interruptThread whether to interrupt the task thread
*/
def killAllTasks(interruptThread: Boolean) : Unit = {
runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = interruptThread))
def killAllTasks(interruptThread: Boolean, reason: String) : Unit = {
runningTasks.keys().asScala.foreach(t =>
killTask(t, interruptThread = interruptThread, reason = reason))
}

def stop(): Unit = {
Expand All @@ -217,8 +219,8 @@ private[spark] class Executor(
val threadName = s"Executor task launch worker for task $taskId"
private val taskName = taskDescription.name

/** Whether this task has been killed. */
@volatile private var killed = false
/** If specified, this task has been killed and this option contains the reason. */
@volatile private var maybeKillReason: Option[String] = None

@volatile private var threadId: Long = -1

Expand All @@ -239,13 +241,13 @@ private[spark] class Executor(
*/
@volatile var task: Task[Any] = _

def kill(interruptThread: Boolean): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
killed = true
def kill(interruptThread: Boolean, reason: String): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason")
maybeKillReason = Some(reason)
if (task != null) {
synchronized {
if (!finished) {
task.kill(interruptThread)
task.kill(interruptThread, reason)
}
}
}
Expand Down Expand Up @@ -296,12 +298,12 @@ private[spark] class Executor(

// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
if (maybeKillReason.isDefined) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException
throw new TaskKilledException(maybeKillReason.get)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above here - atomic use of maybeKillReason required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

}

logDebug("Task " + taskId + "'s epoch is " + task.epoch)
Expand Down Expand Up @@ -358,9 +360,7 @@ private[spark] class Executor(
} else 0L

// If the task has been killed, let's fail it.
if (task.killed) {
throw new TaskKilledException
}
task.context.killTaskIfInterrupted()

val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
Expand Down Expand Up @@ -426,15 +426,18 @@ private[spark] class Executor(
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

case _: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId)")
case t: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))

case _: InterruptedException if task.killed =>
logInfo(s"Executor interrupted and killed $taskName (TID $taskId)")
logInfo(
s"Executor interrupted and killed $taskName (TID $taskId)," +
s" reason: ${task.maybeKillReason.get}")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
execBackend.statusUpdate(
taskId, TaskState.KILLED, ser.serialize(TaskKilled(task.maybeKillReason.get)))

case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskFailedReason
Expand Down Expand Up @@ -512,7 +515,8 @@ private[spark] class Executor(
*/
private class TaskReaper(
taskRunner: TaskRunner,
val interruptThread: Boolean)
val interruptThread: Boolean,
val reason: String)
extends Runnable {

private[this] val taskId: Long = taskRunner.taskId
Expand All @@ -533,7 +537,7 @@ private[spark] class Executor(
// Only attempt to kill the task once. If interruptThread = false then a second kill
// attempt would be a no-op and if interruptThread = true then it may not be safe or
// effective to interrupt multiple times:
taskRunner.kill(interruptThread = interruptThread)
taskRunner.kill(interruptThread = interruptThread, reason = reason)
// Monitor the killed task until it exits. The synchronization logic here is complicated
// because we don't want to synchronize on the taskRunner while possibly taking a thread
// dump, but we also need to be careful to avoid races between checking whether the task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,13 @@ class DAGScheduler(
eventProcessLoop.post(StageCancelled(stageId, reason))
}

/**
* Kill a given task. It will be retried.
*/
def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Unit = {
taskScheduler.killTaskAttempt(taskId, interruptThread, reason)
}

/**
* Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
* the last fetch failure.
Expand Down Expand Up @@ -1345,7 +1352,7 @@ class DAGScheduler(
case TaskResultLost =>
// Do nothing here; the TaskScheduler handles these failures and resubmits the task.

case _: ExecutorLostFailure | TaskKilled | UnknownReason =>
case _: ExecutorLostFailure | _: TaskKilled | UnknownReason =>
// Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
// will abort the job.
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,21 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int

def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit =
/**
* Requests that an executor kills a running task.
*
* @param taskId Id of the task.
* @param executorId Id of the executor the task is running on.
* @param interruptThread Whether the executor should interrupt the task thread.
* @param reason The reason for the task kill.
*/
def killTask(
taskId: Long,
executorId: String,
interruptThread: Boolean,
reason: String): Unit =
throw new UnsupportedOperationException

def isReady(): Boolean = true

/**
Expand Down
20 changes: 13 additions & 7 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ private[spark] abstract class Task[T](
TaskContext.setTaskContext(context)
taskThread = Thread.currentThread()

if (_killed) {
kill(interruptThread = false)
if (_maybeKillReason != null) {
kill(interruptThread = false, _maybeKillReason)
}

new CallerContext(
Expand Down Expand Up @@ -160,15 +160,20 @@ private[spark] abstract class Task[T](

// A flag to indicate whether the task is killed. This is used in case context is not yet
// initialized when kill() is invoked.
@volatile @transient private var _killed = false
@volatile @transient private var _maybeKillReason: String = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to make this a String and not Option[String] - like other places it is defined/used ?

Copy link
Contributor Author

@ericl ericl Mar 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one gets deserialized to null sometimes (@ transient), so it seemed cleaner to use a bare string rather than have an option that can be null.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


protected var _executorDeserializeTime: Long = 0
protected var _executorDeserializeCpuTime: Long = 0

/**
* Whether the task has been killed.
*/
def killed: Boolean = _killed
def killed: Boolean = _maybeKillReason != null

/**
* If this task has been killed, contains the reason for the kill.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, can you make the comment "If specified, this task has been killed and this option contains the reason." (assuming that you get rid of the killed variable)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

*/
def maybeKillReason: Option[String] = Option(_maybeKillReason)

/**
* Returns the amount of time spent deserializing the RDD and function to be run.
Expand Down Expand Up @@ -201,10 +206,11 @@ private[spark] abstract class Task[T](
* be called multiple times.
* If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread.
*/
def kill(interruptThread: Boolean) {
_killed = true
def kill(interruptThread: Boolean, reason: String) {
require(reason != null)
_maybeKillReason = reason
if (context != null) {
context.markInterrupted()
context.markInterrupted(reason)
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()
Expand Down
Loading