Skip to content

Commit 3244707

Browse files
squitoMarcelo Vanzin
authored andcommitted
[SPARK-24309][CORE] AsyncEventQueue should stop on interrupt.
EventListeners can interrupt the event queue thread. In particular, when the EventLoggingListener writes to hdfs, hdfs can interrupt the thread. When there is an interrupt, the queue should be removed and stop accepting any more events. Before this change, the queue would continue to take more events (till it was full), and then would not stop when the application was complete because the PoisonPill couldn't be added. Added a unit test which failed before this change. Author: Imran Rashid <irashid@cloudera.com> Closes apache#21356 from squito/SPARK-24309.
1 parent b550b2a commit 3244707

4 files changed

Lines changed: 98 additions & 17 deletions

File tree

core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ import org.apache.spark.util.Utils
3434
* Delivery will only begin when the `start()` method is called. The `stop()` method should be
3535
* called when no more events need to be delivered.
3636
*/
37-
private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics)
37+
private class AsyncEventQueue(
38+
val name: String,
39+
conf: SparkConf,
40+
metrics: LiveListenerBusMetrics,
41+
bus: LiveListenerBus)
3842
extends SparkListenerBus
3943
with Logging {
4044

@@ -81,23 +85,18 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
8185
}
8286

8387
private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) {
84-
try {
85-
var next: SparkListenerEvent = eventQueue.take()
86-
while (next != POISON_PILL) {
87-
val ctx = processingTime.time()
88-
try {
89-
super.postToAll(next)
90-
} finally {
91-
ctx.stop()
92-
}
93-
eventCount.decrementAndGet()
94-
next = eventQueue.take()
88+
var next: SparkListenerEvent = eventQueue.take()
89+
while (next != POISON_PILL) {
90+
val ctx = processingTime.time()
91+
try {
92+
super.postToAll(next)
93+
} finally {
94+
ctx.stop()
9595
}
9696
eventCount.decrementAndGet()
97-
} catch {
98-
case ie: InterruptedException =>
99-
logInfo(s"Stopping listener queue $name.", ie)
97+
next = eventQueue.take()
10098
}
99+
eventCount.decrementAndGet()
101100
}
102101

103102
override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = {
@@ -130,7 +129,11 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
130129
eventCount.incrementAndGet()
131130
eventQueue.put(POISON_PILL)
132131
}
133-
dispatchThread.join()
132+
// this thread might be trying to stop itself as part of error handling -- we can't join
133+
// in that case.
134+
if (Thread.currentThread() != dispatchThread) {
135+
dispatchThread.join()
136+
}
134137
}
135138

136139
def post(event: SparkListenerEvent): Unit = {
@@ -187,6 +190,12 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
187190
true
188191
}
189192

193+
override def removeListenerOnError(listener: SparkListenerInterface): Unit = {
194+
// the listener failed in an unrecoverably way, we want to remove it from the entire
195+
// LiveListenerBus (potentially stopping a queue if it is empty)
196+
bus.removeListener(listener)
197+
}
198+
190199
}
191200

192201
private object AsyncEventQueue {

core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) {
102102
queue.addListener(listener)
103103

104104
case None =>
105-
val newQueue = new AsyncEventQueue(queue, conf, metrics)
105+
val newQueue = new AsyncEventQueue(queue, conf, metrics, this)
106106
newQueue.addListener(listener)
107107
if (started.get()) {
108108
newQueue.start(sparkContext)

core/src/main/scala/org/apache/spark/util/ListenerBus.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
6060
}
6161
}
6262

63+
/**
64+
* This can be overriden by subclasses if there is any extra cleanup to do when removing a
65+
* listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus.
66+
*/
67+
def removeListenerOnError(listener: L): Unit = {
68+
removeListener(listener)
69+
}
70+
71+
6372
/**
6473
* Post the event to all registered listeners. The `postToAll` caller should guarantee calling
6574
* `postToAll` in the same thread for all events.
@@ -80,7 +89,16 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
8089
}
8190
try {
8291
doPostEvent(listener, event)
92+
if (Thread.interrupted()) {
93+
// We want to throw the InterruptedException right away so we can associate the interrupt
94+
// with this listener, as opposed to waiting for a queue.take() etc. to detect it.
95+
throw new InterruptedException()
96+
}
8397
} catch {
98+
case ie: InterruptedException =>
99+
logError(s"Interrupted while posting to ${Utils.getFormattedClassName(listener)}. " +
100+
s"Removing that listener.", ie)
101+
removeListenerOnError(listener)
84102
case NonFatal(e) if !isIgnorableException(e) =>
85103
logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e)
86104
} finally {

core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,48 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
489489
assert(bus.findListenersByClass[BasicJobCounter]().isEmpty)
490490
}
491491

492+
Seq(true, false).foreach { throwInterruptedException =>
493+
val suffix = if (throwInterruptedException) "throw interrupt" else "set Thread interrupted"
494+
test(s"interrupt within listener is handled correctly: $suffix") {
495+
val conf = new SparkConf(false)
496+
.set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5)
497+
val bus = new LiveListenerBus(conf)
498+
val counter1 = new BasicJobCounter()
499+
val counter2 = new BasicJobCounter()
500+
val interruptingListener1 = new InterruptingListener(throwInterruptedException)
501+
val interruptingListener2 = new InterruptingListener(throwInterruptedException)
502+
bus.addToSharedQueue(counter1)
503+
bus.addToSharedQueue(interruptingListener1)
504+
bus.addToStatusQueue(counter2)
505+
bus.addToEventLogQueue(interruptingListener2)
506+
assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE))
507+
assert(bus.findListenersByClass[BasicJobCounter]().size === 2)
508+
assert(bus.findListenersByClass[InterruptingListener]().size === 2)
509+
510+
bus.start(mockSparkContext, mockMetricsSystem)
511+
512+
// after we post one event, both interrupting listeners should get removed, and the
513+
// event log queue should be removed
514+
bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded))
515+
bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
516+
assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE))
517+
assert(bus.findListenersByClass[BasicJobCounter]().size === 2)
518+
assert(bus.findListenersByClass[InterruptingListener]().size === 0)
519+
assert(counter1.count === 1)
520+
assert(counter2.count === 1)
521+
522+
// posting more events should be fine, they'll just get processed from the OK queue.
523+
(0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) }
524+
bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
525+
assert(counter1.count === 6)
526+
assert(counter2.count === 6)
527+
528+
// Make sure stopping works -- this requires putting a poison pill in all active queues, which
529+
// would fail if our interrupted queue was still active, as its queue would be full.
530+
bus.stop()
531+
}
532+
}
533+
492534
/**
493535
* Assert that the given list of numbers has an average that is greater than zero.
494536
*/
@@ -547,6 +589,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
547589
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception }
548590
}
549591

592+
/**
593+
* A simple listener that interrupts on job end.
594+
*/
595+
private class InterruptingListener(val throwInterruptedException: Boolean) extends SparkListener {
596+
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
597+
if (throwInterruptedException) {
598+
throw new InterruptedException("got interrupted")
599+
} else {
600+
Thread.currentThread().interrupt()
601+
}
602+
}
603+
}
550604
}
551605

552606
// These classes can't be declared inside of the SparkListenerSuite class because we don't want

0 commit comments

Comments
 (0)