Skip to content

Commit 4d5965e

Browse files
Nikita Gorbachevskychoojoyq
authored andcommitted
[SPARK-28709][DSTREAMS] - Fix StreamingContext leak through StreamingJobProgressListener on stop
1 parent 247bebc commit 4d5965e

7 files changed

Lines changed: 63 additions & 30 deletions

File tree

core/src/main/scala/org/apache/spark/ui/SparkUI.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ private[spark] class SparkUI private (
138138
streamingJobProgressListener = Option(sparkListener)
139139
}
140140

141+
def clearStreamingJobProgressListener(): Unit = {
142+
streamingJobProgressListener = None
143+
}
141144
}
142145

143146
private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String)

core/src/main/scala/org/apache/spark/ui/WebUI.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ private[spark] abstract class WebUI(
9393
attachHandler(renderJsonHandler)
9494
val handlers = pageToHandlers.getOrElseUpdate(page, ArrayBuffer[ServletContextHandler]())
9595
handlers += renderHandler
96+
handlers += renderJsonHandler
9697
}
9798

9899
/** Attaches a handler to this UI. */

streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
3838
import org.apache.spark.deploy.SparkHadoopUtil
3939
import org.apache.spark.input.FixedLengthBinaryInputFormat
4040
import org.apache.spark.internal.Logging
41-
import org.apache.spark.internal.config.UI._
4241
import org.apache.spark.rdd.{RDD, RDDOperationScope}
4342
import org.apache.spark.scheduler.LiveListenerBus
4443
import org.apache.spark.serializer.SerializationDebugger
@@ -189,10 +188,9 @@ class StreamingContext private[streaming] (
189188
private[streaming] val progressListener = new StreamingJobProgressListener(this)
190189

191190
private[streaming] val uiTab: Option[StreamingTab] =
192-
if (conf.get(UI_ENABLED)) {
193-
Some(new StreamingTab(this))
194-
} else {
195-
None
191+
sparkContext.ui match {
192+
case Some(ui) => Some(new StreamingTab(this, ui))
193+
case None => None
196194
}
197195

198196
/* Initializing a streamingSource to register metrics */
@@ -511,6 +509,10 @@ class StreamingContext private[streaming] (
511509
scheduler.listenerBus.addListener(streamingListener)
512510
}
513511

512+
def removeStreamingListener(streamingListener: StreamingListener): Unit = {
513+
scheduler.listenerBus.removeListener(streamingListener)
514+
}
515+
514516
private def validate() {
515517
assert(graph != null, "Graph is null")
516518
graph.validate()
@@ -575,6 +577,8 @@ class StreamingContext private[streaming] (
575577
try {
576578
validate()
577579

580+
registerProgressListener()
581+
578582
// Start the streaming scheduler in a new thread, so that thread local properties
579583
// like call sites and job groups can be reset without affecting those of the
580584
// current thread.
@@ -690,6 +694,9 @@ class StreamingContext private[streaming] (
690694
Utils.tryLogNonFatalError {
691695
uiTab.foreach(_.detach())
692696
}
697+
Utils.tryLogNonFatalError {
698+
unregisterProgressListener()
699+
}
693700
StreamingContext.setActiveContext(null)
694701
Utils.tryLogNonFatalError {
695702
waiter.notifyStop()
@@ -716,6 +723,18 @@ class StreamingContext private[streaming] (
716723
// Do not stop SparkContext, let its own shutdown hook stop it
717724
stop(stopSparkContext = false, stopGracefully = stopGracefully)
718725
}
726+
727+
private def registerProgressListener(): Unit = {
728+
addStreamingListener(progressListener)
729+
sc.addSparkListener(progressListener)
730+
sc.ui.foreach(_.setStreamingJobProgressListener(progressListener))
731+
}
732+
733+
private def unregisterProgressListener(): Unit = {
734+
removeStreamingListener(progressListener)
735+
sc.removeSparkListener(progressListener)
736+
sc.ui.foreach(_.clearStreamingJobProgressListener())
737+
}
719738
}
720739

721740
/**

streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.streaming.ui
1919

20-
import org.apache.spark.SparkException
2120
import org.apache.spark.internal.Logging
2221
import org.apache.spark.streaming.StreamingContext
2322
import org.apache.spark.ui.{SparkUI, SparkUITab}
@@ -26,37 +25,24 @@ import org.apache.spark.ui.{SparkUI, SparkUITab}
2625
* Spark Web UI tab that shows statistics of a streaming job.
2726
* This assumes the given SparkContext has enabled its SparkUI.
2827
*/
29-
private[spark] class StreamingTab(val ssc: StreamingContext)
30-
extends SparkUITab(StreamingTab.getSparkUI(ssc), "streaming") with Logging {
31-
32-
import StreamingTab._
28+
private[spark] class StreamingTab(val ssc: StreamingContext, sparkUI: SparkUI)
29+
extends SparkUITab(sparkUI, "streaming") with Logging {
3330

3431
private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static"
3532

36-
val parent = getSparkUI(ssc)
33+
val parent = sparkUI
3734
val listener = ssc.progressListener
3835

39-
ssc.addStreamingListener(listener)
40-
ssc.sc.addSparkListener(listener)
41-
parent.setStreamingJobProgressListener(listener)
4236
attachPage(new StreamingPage(this))
4337
attachPage(new BatchPage(this))
4438

4539
def attach() {
46-
getSparkUI(ssc).attachTab(this)
47-
getSparkUI(ssc).addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming")
40+
parent.attachTab(this)
41+
parent.addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming")
4842
}
4943

5044
def detach() {
51-
getSparkUI(ssc).detachTab(this)
52-
getSparkUI(ssc).detachHandler("/static/streaming")
53-
}
54-
}
55-
56-
private object StreamingTab {
57-
def getSparkUI(ssc: StreamingContext): SparkUI = {
58-
ssc.sc.ui.getOrElse {
59-
throw new SparkException("Parent SparkUI to attach this tab to not found!")
60-
}
45+
parent.detachTab(this)
46+
parent.detachHandler("/static/streaming")
6147
}
6248
}

streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
5252

5353
// Set up the streaming context and input streams
5454
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
55-
ssc.addStreamingListener(ssc.progressListener)
56-
5755
val input = Seq(1, 2, 3, 4, 5)
5856
// Use "batchCount" to make sure we check the result after all batches finish
5957
val batchCounter = new BatchCounter(ssc)
@@ -106,8 +104,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
106104
testServer.start()
107105

108106
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
109-
ssc.addStreamingListener(ssc.progressListener)
110-
111107
val batchCounter = new BatchCounter(ssc)
112108
val networkStream = ssc.socketTextStream(
113109
"localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)

streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.scalatest.time.SpanSugar._
3434

3535
import org.apache.spark._
3636
import org.apache.spark.internal.Logging
37+
import org.apache.spark.internal.config.UI.UI_ENABLED
3738
import org.apache.spark.metrics.MetricsSystem
3839
import org.apache.spark.metrics.source.Source
3940
import org.apache.spark.storage.StorageLevel
@@ -392,6 +393,29 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL
392393
assert(!sourcesAfterStop.contains(streamingSourceAfterStop))
393394
}
394395

396+
test("SPARK-28709 registering and de-registering of progressListener") {
397+
val conf = new SparkConf().setMaster(master).setAppName(appName)
398+
conf.set(UI_ENABLED, true)
399+
400+
ssc = new StreamingContext(conf, batchDuration)
401+
402+
assert(ssc.sc.ui.isDefined, "Spark UI is not started!")
403+
val sparkUI = ssc.sc.ui.get
404+
405+
addInputStream(ssc).register()
406+
ssc.start()
407+
408+
assert(ssc.scheduler.listenerBus.listeners.contains(ssc.progressListener))
409+
assert(ssc.sc.listenerBus.listeners.contains(ssc.progressListener))
410+
assert(sparkUI.getStreamingJobProgressListener.get == ssc.progressListener)
411+
412+
ssc.stop()
413+
414+
assert(!ssc.scheduler.listenerBus.listeners.contains(ssc.progressListener))
415+
assert(!ssc.sc.listenerBus.listeners.contains(ssc.progressListener))
416+
assert(sparkUI.getStreamingJobProgressListener.isEmpty)
417+
}
418+
395419
test("awaitTermination") {
396420
ssc = new StreamingContext(master, appName, batchDuration)
397421
val inputStream = addInputStream(ssc)

streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ class UISeleniumSuite
9797

9898
val sparkUI = ssc.sparkContext.ui.get
9999

100+
sparkUI.getHandlers.count(_.getContextPath.contains("/streaming")) should be (5)
101+
100102
eventually(timeout(10.seconds), interval(50.milliseconds)) {
101103
go to (sparkUI.webUrl.stripSuffix("/"))
102104
find(cssSelector( """ul li a[href*="streaming"]""")) should not be (None)
@@ -196,6 +198,8 @@ class UISeleniumSuite
196198

197199
ssc.stop(false)
198200

201+
sparkUI.getHandlers.count(_.getContextPath.contains("/streaming")) should be (0)
202+
199203
eventually(timeout(10.seconds), interval(50.milliseconds)) {
200204
go to (sparkUI.webUrl.stripSuffix("/"))
201205
find(cssSelector( """ul li a[href*="streaming"]""")) should be(None)

0 commit comments

Comments
 (0)