Skip to content
Closed
7 changes: 7 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener}


Expand Down Expand Up @@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable {
*/
private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit

/**
* Record that this task has failed due to a fetch failure from a remote host. This allows
* fetch-failure handling to get triggered by the driver, regardless of intervening user-code.
*/
private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit

}
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._

private[spark] class TaskContextImpl(
Expand Down Expand Up @@ -56,6 +57,10 @@ private[spark] class TaskContextImpl(
// Whether the task has failed.
@volatile private var failed: Boolean = false

// If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
// hide the exception. See SPARK-19276
@volatile private var _fetchFailedException: Option[FetchFailedException] = None

override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
Expand Down Expand Up @@ -126,4 +131,10 @@ private[spark] class TaskContextImpl(
taskMetrics.registerAccumulator(a)
}

private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
this._fetchFailedException = Option(fetchFailed)
}

private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException

}
33 changes: 28 additions & 5 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.executor

import java.io.{File, NotSerializableException}
import java.lang.Thread.UncaughtExceptionHandler
import java.lang.management.ManagementFactory
import java.net.{URI, URL}
import java.nio.ByteBuffer
Expand Down Expand Up @@ -52,7 +53,8 @@ private[spark] class Executor(
executorHostname: String,
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
isLocal: Boolean = false)
isLocal: Boolean = false,
uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler)
extends Logging {

logInfo(s"Starting executor ID $executorId on host $executorHostname")
Expand All @@ -78,7 +80,7 @@ private[spark] class Executor(
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler)
Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler)
}

// Start worker thread pool
Expand Down Expand Up @@ -342,6 +344,14 @@ private[spark] class Executor(
}
}
}
task.context.fetchFailed.foreach { fetchFailure =>
// uh-oh. it appears the user code has caught the fetch-failure without throwing any
// other exceptions. Its *possible* this is what the user meant to do (though highly
// unlikely). So we will log an error and keep going.
logError(s"TID ${taskId} completed successfully though internally it encountered " +
s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
}
val taskFinish = System.currentTimeMillis()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
Expand Down Expand Up @@ -402,8 +412,17 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

} catch {
case ffe: FetchFailedException =>
val reason = ffe.toTaskFailedReason
case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit: but does it make sense to store the taskFailedReason (rather than the actual exception) in the task context?

if (!t.isInstanceOf[FetchFailedException]) {
// there was a fetch failure in the task, but some user code wrapped that exception
// and threw something else. Regardless, we treat it as a fetch failure.
val fetchFailedCls = classOf[FetchFailedException].getName
logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
s"failed, but the ${fetchFailedCls} was hidden by another " +
s"exception. Spark is handling this like a fetch failure and ignoring the " +
s"other exception: $t")
}
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Probably log a similar message as above ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean the msg I added about "TID ${taskId} completed successfully though internally it encountered unrecoverable fetch failures!"? I wouldn't think we'd want to log anything special here. I'm trying to make this a "normal" code path. The user is allowed to allowed to do this. (sparksql already does.)

we could log a warning, but then this change should be accompanied by auditing the code and making sure we never do this ourselves.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, something along those lines ...
And I agree, we should not be doing this ourselves as well.


Expand Down Expand Up @@ -455,13 +474,17 @@ private[spark] class Executor(
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
SparkUncaughtExceptionHandler.uncaughtException(t)
uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
}

} finally {
runningTasks.remove(taskId)
}
}

private def hasFetchFailure: Boolean = {
task != null && task.context != null && task.context.fetchFailed.isDefined
}
}

/**
Expand Down
9 changes: 3 additions & 6 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@

package org.apache.spark.scheduler

import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import java.util.Properties

import scala.collection.mutable
import scala.collection.mutable.HashMap

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util._

/**
Expand Down Expand Up @@ -137,6 +132,8 @@ private[spark] abstract class Task[T](
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still queried
// directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
}
}
Expand All @@ -156,7 +153,7 @@ private[spark] abstract class Task[T](
var epoch: Long = -1

// Task context, to be initialized in run().
@transient protected var context: TaskContextImpl = _
@transient var context: TaskContextImpl = _

// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.shuffle

import org.apache.spark.{FetchFailed, TaskFailedReason}
import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils

Expand All @@ -26,6 +26,11 @@ import org.apache.spark.util.Utils
* back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage.
*
* Note that bmAddress can be null.
*
* To prevent user code from hiding this fetch failure, in the constructor we call
* [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately
* after creating it -- you cannot create it, check some condition, and then decide to ignore it
* (or risk triggering any other exceptions). See SPARK-19276.
*/
private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
Expand All @@ -45,6 +50,12 @@ private[spark] class FetchFailedException(
this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
}

// SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code
// which intercepts this exception (possibly wrapping it), the Executor can still tell there was
// a fetch failure, and send the correct error msg back to the driver. We wrap with an Option
// because the TaskContext is not defined in some test cases.
Option(TaskContext.get()).map(_.setFetchFailed(this))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since creation of an Exception does not necessarily mean it should get thrown - we must explicitly add this expectation to the documentation/contract of FetchFailedException constructor - indicating that we expect it to be created only for it to be thrown immediately.
This should be fine since FetchFailedException is private[spark] right now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good point. I added to the docs, does it look OK?

I also considered making the call to TaskContext.setFetchFailed live outside of the constructor, so at each site it was created, it would have to be called -- but I thought that seemed more dangerous.


def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
Utils.exceptionString(this))
}
Expand Down
139 changes: 134 additions & 5 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.executor

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.lang.Thread.UncaughtExceptionHandler
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{CountDownLatch, TimeUnit}
Expand All @@ -27,7 +28,7 @@ import scala.concurrent.duration._

import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, when}
import org.mockito.Mockito.{inOrder, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Eventually
Expand All @@ -37,9 +38,12 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.memory.MemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.{FakeTask, TaskDescription}
import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId

class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {

Expand Down Expand Up @@ -123,6 +127,75 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
}

test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") {
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
sc = new SparkContext(conf)
val serializer = SparkEnv.get.closureSerializer.newInstance()
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size

// Submit a job where a fetch failure is thrown, but user code has a try/catch which hides
// the fetch failure. The executor should still tell the driver that the task failed due to a
// fetch failure, not a generic exception from user code.
val inputRDD = new FetchFailureThrowingRDD(sc)
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
stageId = 1,
stageAttemptId = 0,
taskBinary = taskBinary,
partition = secondRDD.partitions(0),
locs = Seq(),
outputId = 0,
localProperties = new Properties(),
serializedTaskMetrics = serializedTaskMetrics
)

val serTask = serializer.serialize(task)
val taskDescription = createFakeTaskDescription(serTask)

val failReason = runTaskAndGetFailReason(taskDescription)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment about what's going on here? I think the FFE gets thrown because the shuffle map data was never generated? And then you're checking that it's correctly accounted for, even though the user RDD code wrapped the exception in something else?

assert(failReason.isInstanceOf[FetchFailed])
}

test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
// may be a false positive. And we should call the uncaught exception handler.
val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
sc = new SparkContext(conf)
val serializer = SparkEnv.get.closureSerializer.newInstance()
val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size

// Submit a job where a fetch failure is thrown, but user code has a try/catch which hides
// the fetch failure. The executor should still tell the driver that the task failed due to a
// fetch failure, not a generic exception from user code.
val inputRDD = new FetchFailureThrowingRDD(sc)
val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
val task = new ResultTask(
stageId = 1,
stageAttemptId = 0,
taskBinary = taskBinary,
partition = secondRDD.partitions(0),
locs = Seq(),
outputId = 0,
localProperties = new Properties(),
serializedTaskMetrics = serializedTaskMetrics
)

val serTask = serializer.serialize(task)
val taskDescription = createFakeTaskDescription(serTask)

val (failReason, uncaughtExceptionHandler) =
runTaskGetFailReasonAndExceptionHandler(taskDescription)
assert(failReason.isInstanceOf[ExceptionFailure])
val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
assert(exceptionCaptor.getAllValues.size === 1)
assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
}

test("Gracefully handle error in task deserialization") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this test related to this PR? (seems useful but like it should be in its own PR?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mridulm pointed out this bug in an earlier version of this pr, so I fixed the bug and added a test case. But in any case, I've separated this out into #16930 / https://issues.apache.org/jira/browse/SPARK-19597

val conf = new SparkConf
val serializer = new JavaSerializer(conf)
Expand Down Expand Up @@ -169,13 +242,20 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}

private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
}

private def runTaskGetFailReasonAndExceptionHandler(
taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = {
val mockBackend = mock[ExecutorBackend]
val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
var executor: Executor = null
try {
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
uncaughtExceptionHandler = mockUncaughtExceptionHandler)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
eventually(timeout(5 seconds), interval(10 milliseconds)) {
eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}
} finally {
Expand All @@ -193,7 +273,56 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
val failureData = statusCaptor.getAllValues.get(1)
SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
val failReason =
SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
(failReason, mockUncaughtExceptionHandler)
}
}

class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
new Iterator[Int] {
override def hasNext: Boolean = true
override def next(): Int = {
throw new FetchFailedException(
bmAddress = BlockManagerId("1", "hostA", 1234),
shuffleId = 0,
mapId = 0,
reduceId = 0,
message = "fake fetch failure"
)
}
}
}
override protected def getPartitions: Array[Partition] = {
Array(new SimplePartition)
}
}

class SimplePartition extends Partition {
override def index: Int = 0
}

class FetchFailureHidingRDD(
sc: SparkContext,
val input: FetchFailureThrowingRDD,
throwOOM: Boolean) extends RDD[Int](input) {
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
val inItr = input.compute(split, context)
try {
Iterator(inItr.size)
} catch {
case t: Throwable =>
if (throwOOM) {
throw new OutOfMemoryError("OOM while handling another exception")
} else {
throw new RuntimeException("User Exception that hides the original exception", t)
}
}
}

override protected def getPartitions: Array[Partition] = {
Array(new SimplePartition)
}
}

Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ object MimaExcludes {
// [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this"),

// [SPARK-19267] Fetch Failure handling robust to user error handling
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed"),

// [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"),
Expand Down