Skip to content

Commit e5ef33a

Browse files
committed
adding test and cleanup
1 parent e5563a7 commit e5ef33a

2 files changed

Lines changed: 38 additions & 17 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -939,19 +939,12 @@ object SparkSession extends Logging {
939939
options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
940940
setDefaultSession(session)
941941
setActiveSession(session)
942-
registerSessionListenerOnContext(sparkContext)
942+
registerContextListener(sparkContext)
943943
}
944944

945945
return session
946946
}
947947

948-
private def registerSessionListenerOnContext(sparkContext: SparkContext): Unit = {
949-
if (!SparkSession.sessionListenerRegistered.get()) {
950-
sparkContext.addSparkListener(_sessionListener)
951-
SparkSession.sessionListenerRegistered.set(true)
952-
}
953-
}
954-
955948
private def applyModifiableSettings(session: SparkSession): Unit = {
956949
val (staticConfs, otherConfs) =
957950
options.partition(kv => SQLConf.staticConfKeys.contains(kv._1))
@@ -1061,16 +1054,21 @@ object SparkSession extends Logging {
10611054
////////////////////////////////////////////////////////////////////////////////////////
10621055
// Private methods from now on
10631056
////////////////////////////////////////////////////////////////////////////////////////
1064-
/** Default listener on SparkContext */
1065-
private val _sessionListener: SparkListener = new SparkListener {
1066-
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
1067-
defaultSession.set(null)
1057+
1058+
private val listenerRegistered: AtomicBoolean = new AtomicBoolean(false)
1059+
1060+
/** Register the AppEnd listener onto the Context */
1061+
private def registerContextListener(sparkContext: SparkContext): Unit = {
1062+
if (!SparkSession.listenerRegistered.get()) {
1063+
sparkContext.addSparkListener(new SparkListener {
1064+
override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
1065+
defaultSession.set(null)
1066+
}
1067+
})
1068+
SparkSession.listenerRegistered.set(true)
10681069
}
10691070
}
10701071

1071-
/** Whether the app end listener has been registered on the context */
1072-
private val sessionListenerRegistered: AtomicBoolean = new AtomicBoolean(false)
1073-
10741072
/** The active SparkSession for the current thread. */
10751073
private val activeThreadSession = new InheritableThreadLocal[SparkSession]
10761074

@@ -1087,8 +1085,6 @@ object SparkSession extends Logging {
10871085
}
10881086
}
10891087

1090-
private[spark] def isSessionListenerRegistered: Boolean = sessionListenerRegistered.get
1091-
10921088
private def assertOnDriver(): Unit = {
10931089
if (Utils.isTesting && TaskContext.get != null) {
10941090
// we're accessing it during task execution, fail.

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,31 @@ class SparkSessionBuilderSuite extends SparkFunSuite with BeforeAndAfterEach {
169169
assert(session.sessionState.conf.getConf(GLOBAL_TEMP_DATABASE) === "globaltempdb-spark-31234")
170170
}
171171

172+
test("SPARK-31354: SparkContext only register one SparkSession ApplicationEnd listener") {
173+
val conf = new SparkConf()
174+
.setMaster("local")
175+
.setAppName("test-app-SPARK-31354-1")
176+
val context = new SparkContext(conf)
177+
SparkSession
178+
.builder()
179+
.sparkContext(context)
180+
.master("local")
181+
.getOrCreate()
182+
val postFirstCreation = context.listenerBus.listeners.size()
183+
SparkSession.clearActiveSession()
184+
SparkSession.clearDefaultSession()
185+
186+
SparkSession
187+
.builder()
188+
.sparkContext(context)
189+
.master("local")
190+
.getOrCreate()
191+
val postSecondCreation = context.listenerBus.listeners.size()
192+
SparkSession.clearActiveSession()
193+
SparkSession.clearDefaultSession()
194+
assert(postFirstCreation == postSecondCreation)
195+
}
196+
172197
test("SPARK-31532: should not propagate static sql configs to the existing" +
173198
" active/default SparkSession") {
174199
val session = SparkSession.builder()

0 commit comments

Comments
 (0)