-
Notifications
You must be signed in to change notification settings - Fork 51
[SPARK-25299] shuffle reader API #523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
864d1cd
c88751c
9af216f
a35b826
14c47ae
5bb4c32
584e6c8
0292fe2
71c2cc7
43c377c
9fc6a60
45172a5
4e5652b
a35d8fe
1a09ebe
c149d24
672d473
e0a3289
1e89b3f
495c7bd
88a03cb
741deed
76c0381
c7c52b0
897c0bf
34eaaf6
0548800
0637e70
f069dc1
0bba677
a82a725
ac392a1
53dd94b
4c0c791
84f7931
b59efb5
aba8a94
5ef59b6
49a1901
8c6c09c
6370b41
c442b63
2c1272a
2758a5c
bd349ca
653f67c
9f53839
94275fd
26e97c1
91db776
f0fa7b8
50c8fc3
4aa4b6e
7d23f47
363d4ab
bb7fa4c
711109b
04a135c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.api.shuffle; | ||
|
|
||
| public final class ShuffleBlockInfo { | ||
| private final int shuffleId; | ||
| private final int mapId; | ||
| private final int reduceId; | ||
| private final long length; | ||
|
|
||
| public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length) { | ||
| this.shuffleId = shuffleId; | ||
| this.mapId = mapId; | ||
| this.reduceId = reduceId; | ||
| this.length = length; | ||
| } | ||
|
|
||
| public int getShuffleId() { | ||
| return shuffleId; | ||
| } | ||
|
|
||
| public int getMapId() { | ||
| return mapId; | ||
| } | ||
|
|
||
| public int getReduceId() { | ||
| return reduceId; | ||
| } | ||
|
|
||
| public long getLength() { | ||
| return length; | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.api.shuffle; | ||
|
|
||
| import java.io.IOException; | ||
| import java.io.InputStream; | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * An interface for reading shuffle records | ||
| */ | ||
yifeih marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| public interface ShuffleReadSupport { | ||
| Iterable<InputStream> getPartitionReaders(Iterable<ShuffleBlockInfo> blockMetadata) throws IOException; | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,10 +17,14 @@ | |
|
|
||
| package org.apache.spark.shuffle | ||
|
|
||
| import java.util.Optional | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
||
| import org.apache.spark._ | ||
| import org.apache.spark.internal.{config, Logging} | ||
| import org.apache.spark.serializer.SerializerManager | ||
| import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} | ||
| import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.storage.ShuffleBlockId | ||
| import org.apache.spark.util.CompletionIterator | ||
| import org.apache.spark.util.collection.ExternalSorter | ||
|
|
||
|
|
@@ -34,33 +38,31 @@ private[spark] class BlockStoreShuffleReader[K, C]( | |
| endPartition: Int, | ||
| context: TaskContext, | ||
| readMetrics: ShuffleReadMetricsReporter, | ||
| serializerManager: SerializerManager = SparkEnv.get.serializerManager, | ||
| blockManager: BlockManager = SparkEnv.get.blockManager, | ||
| shuffleReadSupport: ShuffleReadSupport, | ||
| mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) | ||
| extends ShuffleReader[K, C] with Logging { | ||
|
|
||
| private val dep = handle.dependency | ||
|
|
||
| /** Read the combined key-values for this reduce task */ | ||
| val blocksIterator = | ||
| mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition) | ||
| .flatMap(blockManagerIdInfo => { | ||
| blockManagerIdInfo._2.map( | ||
| blockInfo => { | ||
| val block = blockInfo._1.asInstanceOf[ShuffleBlockId] | ||
| new ShuffleBlockInfo(block.shuffleId, block.mapId, block.reduceId, blockInfo._2) | ||
| } | ||
| ) | ||
| }) | ||
| override def read(): Iterator[Product2[K, C]] = { | ||
| val wrappedStreams = new ShuffleBlockFetcherIterator( | ||
| context, | ||
| blockManager.shuffleClient, | ||
| blockManager, | ||
| mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), | ||
| serializerManager.wrapStream, | ||
| // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility | ||
| SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, | ||
| SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), | ||
| SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), | ||
| SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), | ||
| SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), | ||
| readMetrics).toCompletionIterator | ||
| val wrappedStreams = | ||
|
||
| shuffleReadSupport.getPartitionReaders(blocksIterator.toIterable.asJava).asScala | ||
|
|
||
| val serializerInstance = dep.serializer.newInstance() | ||
|
|
||
| // Create a key/value iterator for each stream | ||
| val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => | ||
| val recordIter = wrappedStreams.flatMap { case wrappedStream => | ||
| // Note: the asKeyValueIterator below wraps a key/value iterator inside of a | ||
| // NextIterator. The NextIterator makes sure that close() is called on the | ||
| // underlying InputStream when all records have been read. | ||
|
|
@@ -72,7 +74,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( | |
| recordIter.map { record => | ||
| readMetrics.incRecordsRead(1) | ||
| record | ||
| }, | ||
| }.toIterator, | ||
| context.taskMetrics().mergeShuffleReadMetrics()) | ||
|
|
||
| // An interruptible iterator must be used here in order to support task cancellation | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.shuffle.io | ||
|
|
||
| import java.io.InputStream | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
||
| import org.apache.spark.{MapOutputTracker, SparkEnv, TaskContext} | ||
| import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport} | ||
| import org.apache.spark.internal.config | ||
| import org.apache.spark.serializer.SerializerManager | ||
| import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator, ShuffleBlockId} | ||
|
|
||
| class DefaultShuffleReadSupport( | ||
| blockManager: BlockManager, | ||
| serializerManager: SerializerManager, | ||
| mapOutputTracker: MapOutputTracker) extends ShuffleReadSupport { | ||
|
|
||
| val maxBytesInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024 | ||
|
||
| val maxReqsInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT) | ||
| val maxBlocksInFlightPerAddress = | ||
| SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS) | ||
| val maxReqSizeShuffleToMem = SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) | ||
| val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT) | ||
|
|
||
| override def getPartitionReaders( | ||
| blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = { | ||
|
|
||
| val minReduceId = blockMetadata.asScala.map(block => block.getReduceId).min | ||
| val maxReduceId = blockMetadata.asScala.map(block => block.getReduceId).max | ||
| val shuffleId = blockMetadata.asScala.head.getShuffleId | ||
|
|
||
| val shuffleBlockFetchIterator = new ShuffleBlockFetcherIterator( | ||
| TaskContext.get(), | ||
| blockManager.shuffleClient, | ||
| blockManager, | ||
| mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1), | ||
|
||
| serializerManager.wrapStream, | ||
| maxBytesInFlight, | ||
| maxReqsInFlight, | ||
| maxBlocksInFlightPerAddress, | ||
| maxReqSizeShuffleToMem, | ||
| detectCorrupt, | ||
| shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics() | ||
| ).toCompletionIterator | ||
|
|
||
| new ShuffleBlockInputStreamIterator(shuffleBlockFetchIterator).toIterable.asJava | ||
| } | ||
|
|
||
| private class ShuffleBlockInputStreamIterator( | ||
|
||
| blockFetchIterator: Iterator[(BlockId, InputStream)]) | ||
| extends Iterator[InputStream] { | ||
| override def hasNext: Boolean = blockFetchIterator.hasNext | ||
|
|
||
| override def next(): InputStream = { | ||
| blockFetchIterator.next()._2 | ||
| } | ||
| } | ||
|
|
||
| private[spark] object DefaultShuffleReadSupport { | ||
| def toShuffleBlockInfo(blockId: BlockId, length: Long): ShuffleBlockInfo = { | ||
| assert(blockId.isInstanceOf[ShuffleBlockId]) | ||
| val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] | ||
| new ShuffleBlockInfo( | ||
| shuffleBlockId.shuffleId, | ||
| shuffleBlockId.mapId, | ||
| shuffleBlockId.reduceId, | ||
| length) | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,13 +20,16 @@ package org.apache.spark.shuffle | |
| import java.io.{ByteArrayOutputStream, InputStream} | ||
| import java.nio.ByteBuffer | ||
|
|
||
| import org.mockito.Mockito.{mock, when} | ||
| import org.mockito.Mockito.{doReturn, mock, when} | ||
| import org.mockito.invocation.InvocationOnMock | ||
| import org.mockito.stubbing.{Answer, Stubber} | ||
|
|
||
| import org.apache.spark._ | ||
| import org.apache.spark.internal.config | ||
| import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} | ||
| import org.apache.spark.serializer.{JavaSerializer, SerializerManager} | ||
| import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} | ||
| import org.apache.spark.shuffle.io.DefaultShuffleReadSupport | ||
| import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockId} | ||
|
|
||
| /** | ||
| * Wrapper for a managed buffer that keeps track of how many times retain and release are called. | ||
|
|
@@ -101,16 +104,19 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext | |
|
|
||
| // Make a mocked MapOutputTracker for the shuffle reader to use to determine what | ||
| // shuffle data to read. | ||
| val mapOutputTracker = mock(classOf[MapOutputTracker]) | ||
| when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { | ||
| // Test a scenario where all data is local, to avoid creating a bunch of additional mocks | ||
| // for the code to read data over the network. | ||
| val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => | ||
| val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) | ||
| (shuffleBlockId, byteOutputStream.size().toLong) | ||
| } | ||
| Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator | ||
| val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => | ||
| val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) | ||
| (shuffleBlockId, byteOutputStream.size().toLong) | ||
| } | ||
| val blocksToRetrieve = Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) | ||
| val mapOutputTracker = mock(classOf[MapOutputTracker]) | ||
| when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)) | ||
| .thenAnswer(new Answer[Iterator[(BlockManagerId, Seq[(BlockId, Long)])]] { | ||
| def answer(invocationOnMock: InvocationOnMock): | ||
| Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { | ||
|
||
| blocksToRetrieve.iterator | ||
| } | ||
| }) | ||
|
|
||
| // Create a mocked shuffle handle to pass into HashShuffleReader. | ||
| val shuffleHandle = { | ||
|
|
@@ -128,15 +134,18 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext | |
| .set(config.SHUFFLE_SPILL_COMPRESS, false)) | ||
|
|
||
| val taskContext = TaskContext.empty() | ||
| TaskContext.setTaskContext(taskContext) | ||
mccheah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() | ||
|
|
||
| val shuffleReadSupport = | ||
| new DefaultShuffleReadSupport(blockManager, serializerManager, mapOutputTracker) | ||
| val shuffleReader = new BlockStoreShuffleReader( | ||
| shuffleHandle, | ||
| reduceId, | ||
| reduceId + 1, | ||
| taskContext, | ||
| metrics, | ||
| serializerManager, | ||
| blockManager, | ||
| shuffleReadSupport, | ||
| mapOutputTracker) | ||
|
|
||
| assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Spark doesn't usually put
finalmodifiers.