Skip to content

Commit 6550329

Browse files
daviesmateiz
authored andcommitted
[SPARK-3762] clear reference of SparkEnv after stop
SparkEnv is cached in ThreadLocal object, so after stop and create a new SparkContext, old SparkEnv is still used by some threads, it will trigger many problems, for example, pyspark will have problem after restart SparkContext, because py4j use thread pool for RPC. This patch will clear all the references after stop a SparkEnv. cc mateiz tdas pwendell Author: Davies Liu <[email protected]> Closes #2624 from davies/env and squashes the following commits: a69f30c [Davies Liu] deprecate getThreadLocal ba77ca4 [Davies Liu] remove getThreadLocal(), update docs ee62bb7 [Davies Liu] cleanup ThreadLocal of SparnENV 4d0ea8b [Davies Liu] clear reference of SparkEnv after stop
1 parent 12e2551 commit 6550329

9 files changed

Lines changed: 8 additions & 21 deletions

File tree

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ import org.apache.spark.util.{AkkaUtils, Utils}
4343
* :: DeveloperApi ::
4444
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
4545
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
46-
* Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
47-
* objects needs to have the right SparkEnv set. You can get the current environment with
48-
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
46+
* Spark code finds the SparkEnv through a global variable, so all the threads can access the same
47+
* SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext).
4948
*
5049
* NOTE: This is not intended for external use. This is exposed for Shark and may be made private
5150
* in a future release.
@@ -119,30 +118,28 @@ class SparkEnv (
119118
}
120119

121120
object SparkEnv extends Logging {
122-
private val env = new ThreadLocal[SparkEnv]
123-
@volatile private var lastSetSparkEnv : SparkEnv = _
121+
@volatile private var env: SparkEnv = _
124122

125123
private[spark] val driverActorSystemName = "sparkDriver"
126124
private[spark] val executorActorSystemName = "sparkExecutor"
127125

128126
def set(e: SparkEnv) {
129-
lastSetSparkEnv = e
130-
env.set(e)
127+
env = e
131128
}
132129

133130
/**
134-
* Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv
135-
* previously set in any thread.
131+
* Returns the SparkEnv.
136132
*/
137133
def get: SparkEnv = {
138-
Option(env.get()).getOrElse(lastSetSparkEnv)
134+
env
139135
}
140136

141137
/**
142138
* Returns the ThreadLocal SparkEnv.
143139
*/
140+
@deprecated("Use SparkEnv.get instead", "1.2")
144141
def getThreadLocal: SparkEnv = {
145-
env.get()
142+
env
146143
}
147144

148145
private[spark] def create(

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ private[spark] class PythonRDD(
196196

197197
override def run(): Unit = Utils.logUncaughtExceptions {
198198
try {
199-
SparkEnv.set(env)
200199
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
201200
val dataOut = new DataOutputStream(stream)
202201
// Partition index

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ private[spark] class Executor(
148148

149149
override def run() {
150150
val startTime = System.currentTimeMillis()
151-
SparkEnv.set(env)
152151
Thread.currentThread.setContextClassLoader(replClassLoader)
153152
val ser = SparkEnv.get.closureSerializer.newInstance()
154153
logInfo(s"Running $taskName (TID $taskId)")
@@ -158,7 +157,6 @@ private[spark] class Executor(
158157
val startGCTime = gcTime
159158

160159
try {
161-
SparkEnv.set(env)
162160
Accumulators.clear()
163161
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
164162
updateDependencies(taskFiles, taskJars)

core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ private[spark] class PipedRDD[T: ClassTag](
131131
// Start a thread to feed the process input from our parent's iterator
132132
new Thread("stdin writer for " + command) {
133133
override def run() {
134-
SparkEnv.set(env)
135134
val out = new PrintWriter(proc.getOutputStream)
136135

137136
// input the pipe context firstly

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,6 @@ class DAGScheduler(
630630
protected def runLocallyWithinThread(job: ActiveJob) {
631631
var jobResult: JobResult = JobSucceeded
632632
try {
633-
SparkEnv.set(env)
634633
val rdd = job.finalStage.rdd
635634
val split = rdd.partitions(job.partitions(0))
636635
val taskContext =

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,6 @@ private[spark] class TaskSchedulerImpl(
216216
* that tasks are balanced across the cluster.
217217
*/
218218
def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
219-
SparkEnv.set(sc.env)
220-
221219
// Mark each slave as alive and remember its hostname
222220
// Also track if new executor is added
223221
var newExecAvail = false

streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
217217

218218
/** Generate jobs and perform checkpoint for the given `time`. */
219219
private def generateJobs(time: Time) {
220-
SparkEnv.set(ssc.env)
221220
Try(graph.generateJobs(time)) match {
222221
case Success(jobs) =>
223222
val receivedBlockInfo = graph.getReceiverInputStreams.map { stream =>

streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
138138
}
139139
jobSet.handleJobStart(job)
140140
logInfo("Starting job " + job.id + " from job set of time " + jobSet.time)
141-
SparkEnv.set(ssc.env)
142141
}
143142

144143
private def handleJobCompletion(job: Job) {

streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
202202
@transient val thread = new Thread() {
203203
override def run() {
204204
try {
205-
SparkEnv.set(env)
206205
startReceivers()
207206
} catch {
208207
case ie: InterruptedException => logInfo("ReceiverLauncher interrupted")

0 commit comments

Comments
 (0)