diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index be597edecba98..60a60377d8a3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.Closeable import java.util.concurrent.TimeUnit._ -import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag @@ -49,7 +49,6 @@ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.{CallSite, Utils} - /** * The entry point to programming Spark with the Dataset and DataFrame API. * @@ -940,15 +939,7 @@ object SparkSession extends Logging { options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } setDefaultSession(session) setActiveSession(session) - - // Register a successfully instantiated context to the singleton. This should be at the - // end of the class definition so that the singleton is updated only if there is no - // exception in the construction of the instance. - sparkContext.addSparkListener(new SparkListener { - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - defaultSession.set(null) - } - }) + registerContextListener(sparkContext) } return session @@ -1064,6 +1055,20 @@ object SparkSession extends Logging { // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// + private val listenerRegistered: AtomicBoolean = new AtomicBoolean(false) + + /** Register the AppEnd listener onto the Context */ + private def registerContextListener(sparkContext: SparkContext): Unit = { + if (!listenerRegistered.get()) { + sparkContext.addSparkListener(new SparkListener { + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { + defaultSession.set(null) + } + }) + listenerRegistered.set(true) + } + } + /** The active SparkSession for the current thread. */ private val activeThreadSession = new InheritableThreadLocal[SparkSession] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala index 7b76d0702d835..0a522fdbdeed8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -169,6 +169,31 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach { assert(session.sessionState.conf.getConf(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31234") } + test("SPARK-31354: SparkContext only register one SparkSession ApplicationEnd listener") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test-app-SPARK-31354-1") + val context = new SparkContext(conf) + SparkSession + .builder() + .sparkContext(context) + .master("local") + .getOrCreate() + val postFirstCreation = context.listenerBus.listeners.size() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + + SparkSession + .builder() + .sparkContext(context) + .master("local") + .getOrCreate() + val postSecondCreation = context.listenerBus.listeners.size() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + assert(postFirstCreation == postSecondCreation) + } + test("SPARK-31532: should not propagate static sql configs to the existing" + " active/default SparkSession") { val session = SparkSession.builder()