Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ private[spark] class Executor(
// Whether to monitor killed / interrupted tasks
private val taskReaperEnabled = conf.get(TASK_REAPER_ENABLED)

private val killOnNestedFatalError = conf.get(EXECUTOR_KILL_ON_NESTED_FATAL_ERROR)

// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
Expand Down Expand Up @@ -648,7 +650,7 @@ private[spark] class Executor(
plugins.foreach(_.onTaskFailed(reason))
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))

case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
case t: Throwable if hasFetchFailure && !Executor.isFatalError(t, killOnNestedFatalError) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
if (!t.isInstanceOf[FetchFailedException]) {
// there was a fetch failure in the task, but some user code wrapped that exception
Expand Down Expand Up @@ -711,7 +713,7 @@ private[spark] class Executor(

// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (!t.isInstanceOf[SparkOutOfMemoryError] && Utils.isFatalError(t)) {
if (Executor.isFatalError(t, killOnNestedFatalError)) {
uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
}
} finally {
Expand Down Expand Up @@ -997,4 +999,27 @@ private[spark] object Executor {

// Used to store executorSource, for local mode only
var executorSourceLocalModeOnly: ExecutorSource = null

/**
* Whether a `Throwable` thrown from a task is a fatal error. We use this to decide whether to
* kill the executor.
*
* @param shouldDetectNestedFatalError whether to go through the exception chain to check whether
* exists a fatal error.
* @param depth the current depth of the recursive call. Return `false` when it's greater than 5.
* This is to avoid `StackOverflowError` when hitting a cycle in the exception chain.
*/
def isFatalError(t: Throwable, shouldDetectNestedFatalError: Boolean, depth: Int = 0): Boolean = {
if (depth <= 5) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pick up 5 which should be enough to cover most of cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, just create a config with that default value instead of the bool config spark.executor.killOnNestedFatalError and this magic number?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

t match {
case _: SparkOutOfMemoryError => false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in case, we are sure that OOM cannot be caused by a fatal error, and it cannot present somewhere in the chain?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an existing behavior. #20014 added SparkOutOfMemoryError to avoid killing the executor when it's not thrown by JVM.

case e if Utils.isFatalError(e) => true
case e if e.getCause != null && shouldDetectNestedFatalError =>
isFatalError(e.getCause, shouldDetectNestedFatalError, depth + 1)
case _ => false
}
} else {
false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,13 @@ package object config {
.booleanConf
.createWithDefault(false)

private[spark] val EXECUTOR_KILL_ON_NESTED_FATAL_ERROR =
ConfigBuilder("spark.executor.killOnNestedFatalError")
.doc("Whether to kill the executor when a nested fatal error is thrown from a task.")
.internal()
.booleanConf
.createWithDefault(true)

private[spark] val PUSH_BASED_SHUFFLE_ENABLED =
ConfigBuilder("spark.shuffle.push.enabled")
.doc("Set to 'true' to enable push-based shuffle on the client side and this works in " +
Expand Down
72 changes: 70 additions & 2 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, Map}
import scala.concurrent.duration._

import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.{any, eq => meq}
import org.mockito.Mockito.{inOrder, verify, when}
Expand All @@ -43,7 +44,7 @@ import org.apache.spark.TaskState.TaskState
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.config._
import org.apache.spark.internal.config.UI._
import org.apache.spark.memory.TestMemoryManager
import org.apache.spark.memory.{SparkOutOfMemoryError, TestMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rdd.RDD
import org.apache.spark.resource.ResourceInformation
Expand All @@ -52,7 +53,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task,
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockManager, BlockManagerId}
import org.apache.spark.util.{LongAccumulator, UninterruptibleThread}
import org.apache.spark.util.{LongAccumulator, ThreadUtils, UninterruptibleThread}

class ExecutorSuite extends SparkFunSuite
with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester {
Expand Down Expand Up @@ -402,6 +403,73 @@ class ExecutorSuite extends SparkFunSuite
assert(taskMetrics.getMetricValue("JVMHeapMemory") > 0)
}

test("SPARK-33587: isFatalError") {
def errorInThreadPool(e: => Throwable): Throwable = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to make this test cover the cases I mentioned in the description.

intercept[Throwable] {
val taskPool = ThreadUtils.newDaemonFixedThreadPool(1, "test")
try {
val f = taskPool.submit(new java.util.concurrent.Callable[String] {
override def call(): String = throw e
})
f.get()
} finally {
taskPool.shutdown()
}
}
}

def errorInGuavaCache(e: => Throwable): Throwable = {
val cache = CacheBuilder.newBuilder()
.build(new CacheLoader[String, String] {
override def load(key: String): String = throw e
})
intercept[Throwable] {
cache.get("test")
}
}

def testThrowable(
e: => Throwable,
shouldDetectNestedFatalError: Boolean,
isFatal: Boolean): Unit = {
import Executor.isFatalError
assert(isFatalError(e, shouldDetectNestedFatalError) == isFatal)
// Now check nested exceptions. We get `true` only if we need to check nested exceptions
// (`shouldDetectNestedFatalError` is `true`) and `e` is fatal.
val expected = shouldDetectNestedFatalError && isFatal
assert(isFatalError(errorInThreadPool(e), shouldDetectNestedFatalError) == expected)
assert(isFatalError(errorInGuavaCache(e), shouldDetectNestedFatalError) == expected)
assert(isFatalError(
errorInThreadPool(errorInGuavaCache(e)),
shouldDetectNestedFatalError) == expected)
assert(isFatalError(
errorInGuavaCache(errorInThreadPool(e)),
shouldDetectNestedFatalError) == expected)
assert(isFatalError(
new SparkException("Task failed while writing rows.", e),
shouldDetectNestedFatalError) == expected)
}

for (shouldDetectNestedFatalError <- true :: false :: Nil) {
testThrowable(new OutOfMemoryError(), shouldDetectNestedFatalError, isFatal = true)
testThrowable(new InterruptedException(), shouldDetectNestedFatalError, isFatal = false)
testThrowable(new RuntimeException("test"), shouldDetectNestedFatalError, isFatal = false)
testThrowable(
new SparkOutOfMemoryError("test"),
shouldDetectNestedFatalError,
isFatal = false)
}

val e1 = new Exception("test1")
val e2 = new Exception("test2")
e1.initCause(e2)
e2.initCause(e1)
for (shouldDetectNestedFatalError <- true :: false :: Nil) {
testThrowable(e1, shouldDetectNestedFatalError, isFatal = false)
testThrowable(e2, shouldDetectNestedFatalError, isFatal = false)
}
}

private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
Expand Down