Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ private[spark] class PipedRDD[T: ClassTag](
val childThreadException = new AtomicReference[Throwable](null)

// Start a thread to print the process's stderr to ours
new Thread(s"stderr reader for $command") {
val stderrReaderThread = new Thread(s"${PipedRDD.STDERR_READER_THREAD_PREFIX} $command") {
override def run(): Unit = {
val err = proc.getErrorStream
try {
Expand All @@ -128,10 +128,11 @@ private[spark] class PipedRDD[T: ClassTag](
err.close()
}
}
}.start()
}
stderrReaderThread.start()

// Start a thread to feed the process input from our parent's iterator
new Thread(s"stdin writer for $command") {
val stdinWriterThread = new Thread(s"${PipedRDD.STDIN_WRITER_THREAD_PREFIX} $command") {
override def run(): Unit = {
TaskContext.setTaskContext(context)
val out = new PrintWriter(new BufferedWriter(
Expand All @@ -156,7 +157,34 @@ private[spark] class PipedRDD[T: ClassTag](
out.close()
}
}
}.start()
}
stdinWriterThread.start()

def cleanUpIOThreads(): Unit = {
if (proc.isAlive) {
proc.destroy()
}
if (stdinWriterThread.isAlive) {
stdinWriterThread.stop()
}

if (stderrReaderThread.isAlive) {
stderrReaderThread.stop()
}
}

// stops stdin writer and stderr read threads when the corresponding task is finished as a safe
// belt. Otherwise, these threads could outlive the task's lifetime. For example:
// val pipeRDD = sc.range(1, 100).pipe(Seq("cat"))
// val abnormalRDD = pipeRDD.mapPartitions(_ => Iterator.empty)
// the iterator generated by PipedRDD is never involved. If the parent RDD's iterator is time
// consuming to generate(ShuffledRDD's shuffle operation for example), the outlived stdin writer
// thread will consume significant memory and cpu time. Also, there's race condition for
// ShuffledRDD + PipedRDD if the subprocess command is failed. The task will be marked as failed
// , ShuffleBlockFetcherIterator will be cleaned up at task completion, which may hangs
// ShuffleBlockFetcherIterator.next() call. The failed tasks' stdin writer never exits and leaks
// significant memory held in ShufflerReader.
context.addTaskCompletionListener[Unit](_ => cleanUpIOThreads())

// Return an iterator that read lines from the process's stdout
val lines = Source.fromInputStream(proc.getInputStream)(encoding).getLines
Expand Down Expand Up @@ -185,6 +213,10 @@ private[spark] class PipedRDD[T: ClassTag](
}

private def cleanup(): Unit = {
// interrupt the stdin writer thread, so the stdin writer thread for ShuffledRDD could be
// exited with InterruptedException if waiting at results.take
stdinWriterThread.interrupt()

// cleanup task working directory if used
if (workInTaskDirectory) {
scala.util.control.Exception.ignoring(classOf[IOException]) {
Expand Down Expand Up @@ -219,4 +251,7 @@ private object PipedRDD {
}
buf
}

val STDIN_WRITER_THREAD_PREFIX = "stdin writer for"
val STDERR_READER_THREAD_PREFIX = "stderr reader for"
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ final class ShuffleBlockFetcherIterator(

/**
* Whether the iterator is still active. If isZombie is true, the callback interface will no
* longer place fetched blocks into [[results]].
* longer place fetched blocks into [[results]] and the iterator is marked as fully consumed.
*/
@GuardedBy("this")
private[this] var isZombie = false
Expand Down Expand Up @@ -372,7 +372,7 @@ final class ShuffleBlockFetcherIterator(
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}

override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
override def hasNext: Boolean = !isZombie && (numBlocksProcessed < numBlocksToFetch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ShuffleBlockFetcherIterator's change related to PipedRDD?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of. The root cause of OOM I described in Spark-26713 is ShuffledRDD + PipedRDD.

I found PipedRDD's stderr writer hangs at ShuffleBlockFetcherIterator.next() and leaks memory. I believe this change of ShuffleBlockFetcherIterator's could reduce the possibility of race condition and It seems right to mark iterator as fully consumed if is already cleaned up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the change at PipedRDD by this PR, won't the stdin writer thread (you wrote stderr writer but I think it is typo) be interrupted? If so, it stops consuming ShuffleBlockFetcherIterator anymore. Isn't it enough to solve that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(you wrote stderr writer but I think it is typo)

Sorry for the typo.

If so, it stops consuming ShuffleBlockFetcherIterator anymore. Isn't it enough to solve that?

Yes, for the ShuffledRDD + PipedRRDD case , the cleanup logic in PipedRDD is enough to solve the potential leak.
However I am thinking that ShuffledRDD could be transformed with any operations, there may be other cases that ShuffledBlockFetcherIterator is cleaned up but still being consumed. So, making the ShuffledBlockFetcherIterator defensive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it sounds good for me. You can make the related comment general and move it to ShuffledBlockFetcherIterator.


/**
* Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers
Expand All @@ -395,7 +395,7 @@ final class ShuffleBlockFetcherIterator(
// then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
// is also corrupt, so the previous stage could be retried.
// For local shuffle block, throw FailureFetchResult for the first IOException.
while (result == null) {
while (!isZombie && result == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible that hasNext returns true and next throws NoSuchElementException? isZombie may get changed by other threads?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that can happen. Right now I think it's 'worse' in that the iterator might be cleaned up and yet next() will keep querying the iterator that's being drained by cleanup().

To really tighten it up I think more or all of next() and cleanup() would have to be synchronized (?) and I'm not sure what the implications are of that.

We could follow this up with small things like making hasNext() synchronized at least, as isZombie is marked GuardedBy("this"). That still doesn't prevent this from happening but is a little tighter.

@advancedxy what do you think? I think the argument is merely that this fixes the potential issue in 99% of cases, not 100%.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible that hasNext returns true and next throws NoSuchElementException? isZombie may get changed by other threads?

@cloud-fan Yeah, it can happen. But I agree with @srowen. The isZombie flag indicates the whole task is finished, there's no point for the consumer of the iterator to be still active. This changes the semantics of Iterator at rare chances, but I think it is acceptable.

We could follow this up with small things like making hasNext() synchronized at least, as isZombie is marked GuardedBy("this"). That still doesn't prevent this from happening but is a little tighter.

Maybe. But I would leave it as it's if It's up to me. Like you said, this doesn't prevent the semantics changing but a little tighter.

val startFetchWait = System.currentTimeMillis()
result = results.take()
val stopFetchWait = System.currentTimeMillis()
Expand Down Expand Up @@ -489,8 +489,12 @@ final class ShuffleBlockFetcherIterator(
fetchUpToMaxBytes()
}

currentResult = result.asInstanceOf[SuccessFetchResult]
(currentResult.blockId, new BufferReleasingInputStream(input, this))
if (result != null) {
currentResult = result.asInstanceOf[SuccessFetchResult]
(currentResult.blockId, new BufferReleasingInputStream(input, this))
} else { // the iterator has already be closed
throw new NoSuchElementException
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny nit: new NoSuchElementException() You could also ...

if (result == null) {
  throw ..
}

but doesn't really matter, maybe just slightly cleaner to follow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.

L387 in this file is also throw new NoSuchElementException. Shall I modify that too?

}
}

private def fetchUpToMaxBytes(): Unit = {
Expand Down
24 changes: 24 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.rdd

import java.io.File

import scala.collection.JavaConverters._
import scala.collection.Map
import scala.io.Codec

Expand Down Expand Up @@ -83,6 +84,29 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext {
}
}

test("stdin writer thread should be exited when task is finished") {
assume(TestUtils.testCommandAvailable("cat"))
val nums = sc.makeRDD(Array(1, 2, 3, 4), 1).map { x =>
val obj = new Object
obj.synchronized {
obj.wait() // make the thread waits here.
}
x
}

val piped = nums.pipe(Seq("cat"))

val result = piped.mapPartitions(_ => Array.emptyIntArray.iterator)

assert(result.collect().length === 0)

// collect stderr writer threads
val stderrWriterThreads = Thread.getAllStackTraces.keySet().asScala
.filter { _.getName.startsWith(PipedRDD.STDIN_WRITER_THREAD_PREFIX) }

assert(stderrWriterThreads.size === 0)
}

test("advanced pipe") {
assume(TestUtils.testCommandAvailable("cat"))
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,65 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
}

test("iterator is all consumed if task completes early") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId

// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())

// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)

val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
// Return the first two blocks, and wait till task completion before returning the 3rd one
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
sem.acquire()
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
}
}
})

val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator

val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
blockManager,
blocksByAddress,
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
true,
taskContext.taskMetrics.createTempShuffleReadMetrics())


assert(iterator.hasNext)
iterator.next()

taskContext.markTaskCompleted(None)
sem.release()
assert(iterator.hasNext === false)
}

test("fail all blocks if any of the remote request fails") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
Expand Down