Skip to content

Commit 68ae0ed

Browse files
committed
Fix potential ThreadLocal leaks when using ForkJoinPool
1 parent ed8869e commit 68ae0ed

5 files changed

Lines changed: 17 additions & 31 deletions

File tree

core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import scala.concurrent.duration._
2424
import scala.util.control.NonFatal
2525

2626
import org.apache.spark.{SparkConf, SparkException}
27-
import org.apache.spark.util.Utils
27+
import org.apache.spark.util.{ThreadUtils, Utils}
2828

2929
/**
3030
* An exception thrown if RpcTimeout modifies a [[TimeoutException]].
@@ -77,9 +77,7 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S
7777
throw new SparkException("Exception thrown in awaitResult", t)
7878
}
7979
try {
80-
// scalastyle:off awaitresult
81-
Await.result(future, duration)
82-
// scalastyle:on awaitresult
80+
ThreadUtils.awaitResult(future, duration)
8381
} catch addMessageIfTimeout.orElse(wrapAndRethrow)
8482
}
8583
}

core/src/main/scala/org/apache/spark/util/ThreadUtils.scala

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.util
1919

2020
import java.util.concurrent._
2121

22-
import scala.concurrent.{Await, Awaitable, ExecutionContext, ExecutionContextExecutor}
22+
import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor}
2323
import scala.concurrent.duration.Duration
2424
import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread}
2525
import scala.util.control.NonFatal
@@ -180,31 +180,20 @@ private[spark] object ThreadUtils {
180180

181181
// scalastyle:off awaitresult
182182
/**
183-
* Preferred alternative to `Await.result()`. This method wraps and re-throws any exceptions
184-
* thrown by the underlying `Await` call, ensuring that this thread's stack trace appears in
185-
* logs.
186-
*/
187-
@throws(classOf[SparkException])
188-
def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
189-
try {
190-
Await.result(awaitable, atMost)
191-
// scalastyle:on awaitresult
192-
} catch {
193-
case NonFatal(t) =>
194-
throw new SparkException("Exception thrown in awaitResult: ", t)
195-
}
196-
}
197-
198-
/**
199-
* Calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s `BlockingContext`, wraps
200-
* and re-throws any exceptions with nice stack track.
183+
* Preferred alternative to `Await.result()`.
184+
*
185+
* This method wraps and re-throws any exceptions thrown by the underlying `Await` call, ensuring
186+
* that this thread's stack trace appears in logs.
201187
*
202-
* Codes running in the user's thread may be in a thread of Scala ForkJoinPool. As concurrent
203-
* executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this method
204-
* basically prevents ForkJoinPool from running other tasks in the current waiting thread.
188+
* In addition, it calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s
189+
* `BlockingContext`. Codes running in the user's thread may be in a thread of Scala ForkJoinPool.
190+
* As concurrent executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this
191+
* method basically prevents ForkJoinPool from running other tasks in the current waiting thread.
192+
* In general, we should use this method because many places in Spark use [[ThreadLocal]] and it's
193+
* hard to debug when [[ThreadLocal]]s leak to other tasks.
205194
*/
206195
@throws(classOf[SparkException])
207-
def awaitResultInForkJoinSafely[T](awaitable: Awaitable[T], atMost: Duration): T = {
196+
def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = {
208197
try {
209198
// `awaitPermission` is not actually used anywhere so it's safe to pass in null here.
210199
// See SPARK-13747.
@@ -215,4 +204,5 @@ private[spark] object ThreadUtils {
215204
throw new SparkException("Exception thrown in awaitResult: ", t)
216205
}
217206
}
207+
// scalastyle:on awaitresult
218208
}

scalastyle-config.xml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,6 @@ This file is divided into 3 sections:
200200
// scalastyle:off awaitresult
201201
Await.result(...)
202202
// scalastyle:on awaitresult
203-
If your codes use ThreadLocal and may run in threads created by the user, use ThreadUtils.awaitResultInForkJoinSafely instead.
204203
]]></customMessage>
205204
</check>
206205

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
578578
}
579579

580580
override def executeCollect(): Array[InternalRow] = {
581-
ThreadUtils.awaitResultInForkJoinSafely(relationFuture, Duration.Inf)
581+
ThreadUtils.awaitResult(relationFuture, Duration.Inf)
582582
}
583583
}
584584

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ case class BroadcastExchangeExec(
128128
}
129129

130130
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
131-
ThreadUtils.awaitResultInForkJoinSafely(relationFuture, timeout)
132-
.asInstanceOf[broadcast.Broadcast[T]]
131+
ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
133132
}
134133
}
135134

0 commit comments

Comments
 (0)