Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,20 @@ private[spark] class BlockStoreShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

// Wrap the streams for compression and encryption based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
serializerManager.wrapStream(blockId, inputStream)
}

val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@

package org.apache.spark.storage

import java.io.InputStream
import java.io.{InputStream, IOException}
import java.nio.ByteBuffer
import java.util.concurrent.LinkedBlockingQueue
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import scala.util.control.NonFatal

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.Utils
import org.apache.spark.util.io.{ChunkedByteBufferInputStream, ChunkedByteBufferOutputStream}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems ChunkedByteBufferInputStream is not used here.


/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
Expand All @@ -56,6 +59,7 @@ final class ShuffleBlockFetcherIterator(
shuffleClient: ShuffleClient,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the Scaladoc to document the two new parameters here? I understand what streamWrapper means from context but it might be useful for new readers of this code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int)
extends Iterator[(BlockId, InputStream)] with Logging {
Expand Down Expand Up @@ -108,6 +112,9 @@ final class ShuffleBlockFetcherIterator(
/** Current number of requests in flight */
private[this] var reqsInFlight = 0

/** The blocks that can't be decompressed successfully */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about add more explanation, for example:

 /** The blocks that can't be decompressed successfully. 
  ** It is used to guarantee that we retry at most once for those corrupted blocks. 
  **/

private[this] val corruptedBlocks = mutable.HashSet[String]()

private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()

/**
Expand Down Expand Up @@ -305,35 +312,82 @@ final class ShuffleBlockFetcherIterator(
*/
override def next(): (BlockId, InputStream) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)

result match {
case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
bytesInFlight -= size
if (isNetworkReqDone) {
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
}
case _ =>
var result: FetchResult = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add documentation explaining what's going on here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw is there a way to refactor this function so it is testable? i do worry some of the logic here won't be tested at all.

var input: InputStream = null
while (result == null) {
val startFetchWait = System.currentTimeMillis()
result = results.take()
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)

result match {
case SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
bytesInFlight -= size
if (isNetworkReqDone) {
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
}

val in = try {
buf.createInputStream()
} catch {
// The exception could only be throwed by local shuffle block
case e: IOException if buf.isInstanceOf[FileSegmentManagedBuffer] =>
logError("Failed to create input stream from local block", e)
buf.release()
result = FailureFetchResult(blockId, address, e)
null
}
if (in != null) {
input = streamWrapper(blockId, in)
// Only copy the stream if it's wrapped by compression or encryption, also the size of
// block is small (the decompressed block is smaller than maxBytesInFlight)
if (!input.eq(in) && size < maxBytesInFlight / 3) {
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
try {
// Decompress the whole block at once to detect any corruption, which could increase
// the memory usage tne potential increase the chance of OOM.
// TODO: manage the memory used here, and spill it into disk in case of OOM.
Utils.copyStream(input, out)
input = out.toChunkedByteBuffer.toInputStream(true)
} catch {
case e: IOException =>
buf.release()
if (buf.isInstanceOf[FileSegmentManagedBuffer]
|| corruptedBlocks.contains(blockId.toString)) {
result = FailureFetchResult(blockId, address, e)
} else {
logWarning(s"got an corrupted block $blockId from $address, fetch again")
fetchRequests += FetchRequest(address, Array((blockId, size)))
result = null
}
} finally {
// TODO: release the buf here (earlier)
in.close()
}
}
}

case _ =>
}

// Send fetch requests up to maxBytesInFlight
fetchUpToMaxBytes()
}
// Send fetch requests up to maxBytesInFlight
fetchUpToMaxBytes()
currentResult = result

result match {
case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)

case SuccessFetchResult(blockId, address, _, buf, _) =>
try {
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
(result.blockId, new BufferReleasingInputStream(input, this))
} catch {
case NonFatal(t) =>
throwFetchFailedException(blockId, address, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
* @param dispose if true, [[ChunkedByteBuffer.dispose()]] will be called at the end of the stream
* in order to close any memory-mapped files which back the buffer.
*/
private class ChunkedByteBufferInputStream(
private[spark] class ChunkedByteBufferInputStream(
var chunkedByteBuffer: ChunkedByteBuffer,
dispose: Boolean)
extends InputStream {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
transfer,
blockManager,
blocksByAddress,
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue)

Expand Down Expand Up @@ -172,6 +173,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
transfer,
blockManager,
blocksByAddress,
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue)

Expand Down Expand Up @@ -235,6 +237,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
transfer,
blockManager,
blocksByAddress,
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue)

Expand Down