@@ -20,6 +20,7 @@ package org.apache.spark.storage
2020import java .util .concurrent .LinkedBlockingQueue
2121
2222import scala .collection .mutable .{ArrayBuffer , HashSet , Queue }
23+ import scala .util .{Failure , Success , Try }
2324
2425import org .apache .spark .{Logging , TaskContext }
2526import org .apache .spark .network .{BlockFetchingListener , BlockTransferService }
@@ -54,7 +55,7 @@ final class ShuffleBlockFetcherIterator(
5455 blocksByAddress : Seq [(BlockManagerId , Seq [(BlockId , Long )])],
5556 serializer : Serializer ,
5657 maxBytesInFlight : Long )
57- extends Iterator [(BlockId , Option [Iterator [Any ]])] with Logging {
58+ extends Iterator [(BlockId , Try [Iterator [Any ]])] with Logging {
5859
5960 import ShuffleBlockFetcherIterator ._
6061
@@ -117,16 +118,18 @@ final class ShuffleBlockFetcherIterator(
117118 private [this ] def cleanup () {
118119 isZombie = true
119120 // Release the current buffer if necessary
120- if (currentResult != null && ! currentResult.failed) {
121- currentResult.buf.release()
121+ currentResult match {
122+ case SuccessFetchResult (_, _, buf) => buf.release()
123+ case _ =>
122124 }
123125
124126 // Release buffers in the results queue
125127 val iter = results.iterator()
126128 while (iter.hasNext) {
127129 val result = iter.next()
128- if (! result.failed) {
129- result.buf.release()
130+ result match {
131+ case SuccessFetchResult (_, _, buf) => buf.release()
132+ case _ =>
130133 }
131134 }
132135 }
@@ -149,7 +152,7 @@ final class ShuffleBlockFetcherIterator(
149152 // Increment the ref count because we need to pass this to a different thread.
150153 // This needs to be released after use.
151154 buf.retain()
152- results.put(new FetchResult (BlockId (blockId), sizeMap(blockId), buf))
155+ results.put(new SuccessFetchResult (BlockId (blockId), sizeMap(blockId), buf))
153156 shuffleMetrics.remoteBytesRead += buf.size
154157 shuffleMetrics.remoteBlocksFetched += 1
155158 }
@@ -158,7 +161,7 @@ final class ShuffleBlockFetcherIterator(
158161
159162 override def onBlockFetchFailure (blockId : String , e : Throwable ): Unit = {
160163 logError(s " Failed to get block(s) from ${req.address.host}: ${req.address.port}" , e)
161- results.put(new FetchResult (BlockId (blockId), - 1 , null ))
164+ results.put(new FailureFetchResult (BlockId (blockId), e ))
162165 }
163166 }
164167 )
@@ -229,12 +232,12 @@ final class ShuffleBlockFetcherIterator(
229232 val buf = blockManager.getBlockData(blockId)
230233 shuffleMetrics.localBlocksFetched += 1
231234 buf.retain()
232- results.put(new FetchResult (blockId, 0 , buf))
235+ results.put(new SuccessFetchResult (blockId, 0 , buf))
233236 } catch {
234237 case e : Exception =>
235238 // If we see an exception, stop immediately.
236239 logError(s " Error occurred while fetching local blocks " , e)
237- results.put(new FetchResult (blockId, - 1 , null ))
240+ results.put(new FailureFetchResult (blockId, e ))
238241 return
239242 }
240243 }
@@ -265,36 +268,39 @@ final class ShuffleBlockFetcherIterator(
265268
266269 override def hasNext : Boolean = numBlocksProcessed < numBlocksToFetch
267270
268- override def next (): (BlockId , Option [Iterator [Any ]]) = {
271+ override def next (): (BlockId , Try [Iterator [Any ]]) = {
269272 numBlocksProcessed += 1
270273 val startFetchWait = System .currentTimeMillis()
271274 currentResult = results.take()
272275 val result = currentResult
273276 val stopFetchWait = System .currentTimeMillis()
274277 shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
275- if (! result.failed) {
276- bytesInFlight -= result.size
278+
279+ result match {
280+ case SuccessFetchResult (_, size, _) => bytesInFlight -= size
281+ case _ =>
277282 }
278283 // Send fetch requests up to maxBytesInFlight
279284 while (fetchRequests.nonEmpty &&
280285 (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
281286 sendRequest(fetchRequests.dequeue())
282287 }
283288
284- val iteratorOpt : Option [Iterator [Any ]] = if (result.failed) {
285- None
286- } else {
287- val is = blockManager.wrapForCompression(result.blockId, result.buf.createInputStream())
288- val iter = serializer.newInstance().deserializeStream(is).asIterator
289- Some (CompletionIterator [Any , Iterator [Any ]](iter, {
290- // Once the iterator is exhausted, release the buffer and set currentResult to null
291- // so we don't release it again in cleanup.
292- currentResult = null
293- result.buf.release()
294- }))
289+ val iteratorTry : Try [Iterator [Any ]] = result match {
290+ case FailureFetchResult (_, e) => Failure (e)
291+ case SuccessFetchResult (blockId, _, buf) => {
292+ val is = blockManager.wrapForCompression(blockId, buf.createInputStream())
293+ val iter = serializer.newInstance().deserializeStream(is).asIterator
294+ Success (CompletionIterator [Any , Iterator [Any ]](iter, {
295+ // Once the iterator is exhausted, release the buffer and set currentResult to null
296+ // so we don't release it again in cleanup.
297+ currentResult = null
298+ buf.release()
299+ }))
300+ }
295301 }
296302
297- (result.blockId, iteratorOpt )
303+ (result.blockId, iteratorTry )
298304 }
299305}
300306
@@ -313,14 +319,30 @@ object ShuffleBlockFetcherIterator {
313319 }
314320
315321 /**
316- * Result of a fetch from a remote block. A failure is represented as size == -1.
322+ * Result of a fetch from a remote block.
323+ */
324+ trait FetchResult {
325+ val blockId : BlockId
326+ }
327+
328+ /**
329+ * Result of a fetch from a remote block successfully.
317330 * @param blockId block id
318331 * @param size estimated size of the block, used to calculate bytesInFlight.
319- * Note that this is NOT the exact bytes. -1 if failure is present.
320- * @param buf [[ManagedBuffer ]] for the content. null is error.
332+ * Note that this is NOT the exact bytes.
333+ * @param buf [[ManagedBuffer ]] for the content.
334+ */
335+ case class SuccessFetchResult (blockId : BlockId , size : Long , buf : ManagedBuffer )
336+ extends FetchResult {
337+ require(buf != null )
338+ require(size >= 0 )
339+ }
340+
341+ /**
342+ * Result of a fetch from a remote block unsuccessfully.
343+ * @param blockId block id
344+ * @param e the failure exception
321345 */
322- case class FetchResult (blockId : BlockId , size : Long , buf : ManagedBuffer ) {
323- def failed : Boolean = size == - 1
324- if (failed) assert(buf == null ) else assert(buf != null )
346+ case class FailureFetchResult (blockId : BlockId , e : Throwable ) extends FetchResult {
325347 }
326348}
0 commit comments