diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 0b5bf3f48b593..d8b88d64273be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.Closeable import java.util.concurrent.TimeUnit._ -import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -711,12 +711,15 @@ class SparkSession private( // scalastyle:on /** - * Stop the underlying `SparkContext`. + * Stop the underlying `SparkContext` if there are no active sessions remaining. * * @since 2.0.0 */ def stop(): Unit = { - sparkContext.stop() + SparkSession.clearActiveSession() + if (SparkSession.numActiveSessions.get() == 0) { + sparkContext.stop() + } } /** @@ -776,6 +779,8 @@ class SparkSession private( @Stable object SparkSession extends Logging { + private[spark] val numActiveSessions: AtomicInteger = new AtomicInteger(0) + /** * Builder for [[SparkSession]]. */ @@ -958,6 +963,8 @@ object SparkSession extends Logging { sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { defaultSession.set(null) + // Should remove listener after this event fires + sparkContext.removeSparkListener(this) } }) } @@ -981,17 +988,26 @@ object SparkSession extends Logging { * @since 2.0.0 */ def setActiveSession(session: SparkSession): Unit = { - activeThreadSession.set(session) + if (session != getActiveSession.get && getActiveSession.isDefined) { + numActiveSessions.getAndIncrement + activeThreadSession.set(session) + } else if (session == null) { + this.clearActiveSession() + } } /** - * Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will - * return the first created context instead of a thread-local override. + * Clears the active SparkSession for current thread assuming it is defined. + * Subsequent calls to getOrCreate will return the first created context + * instead of a thread-local override. * * @since 2.0.0 */ def clearActiveSession(): Unit = { - activeThreadSession.remove() + if (getActiveSession.isDefined) { + activeThreadSession.remove() + numActiveSessions.decrementAndGet() + } } /** @@ -1004,12 +1020,14 @@ object SparkSession extends Logging { } /** - * Clears the default SparkSession that is returned by the builder. - * + * Clears the default SparkSession that is returned by the builder + * if it is not null. * @since 2.0.0 */ def clearDefaultSession(): Unit = { - defaultSession.set(null) + if (getDefaultSession.isDefined) { + defaultSession.set(null) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 10b17571d2aaa..615d32226db13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -152,4 +152,16 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { session.sparkContext.hadoopConfiguration.unset(mySpecialKey) } } + + test("SPARK-27958: SparkContext stopped when last SparkSession is stopped ") { + val conf = new SparkConf().setAppName("test").setMaster("local").set("key1", "value1") + val newSC = new SparkContext(conf) + val session1 = SparkSession.builder().sparkContext(newSC).master("local").getOrCreate() + assert(!session1.sparkContext.isStopped) + val session2 = SparkSession.builder().sparkContext(newSC).master("local").getOrCreate() + session1.stop() + assert(!session1.sparkContext.isStopped) + session2.stop() + assert(session1.sparkContext.isStopped) + } }