Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
864d1cd
initial API
yifeih Mar 20, 2019
c88751c
wip
yifeih Mar 21, 2019
9af216f
wip
yifeih Mar 22, 2019
a35b826
initial implementation of reader
yifeih Mar 26, 2019
14c47ae
fix based on comments
yifeih Mar 26, 2019
5bb4c32
fix java lang import and delete unneeded class
yifeih Mar 26, 2019
584e6c8
address initial comments
yifeih Mar 27, 2019
0292fe2
fix unit tests
yifeih Mar 27, 2019
71c2cc7
java checkstyle
yifeih Mar 27, 2019
43c377c
fix tests
yifeih Mar 27, 2019
9fc6a60
address some comments
yifeih Mar 28, 2019
45172a5
blah
yifeih Mar 28, 2019
4e5652b
address more comments
yifeih Mar 28, 2019
a35d8fe
Use decorators to customize how the read metrics reporter is instanti…
mccheah Apr 1, 2019
1a09ebe
blah
yifeih Apr 2, 2019
c149d24
initial tests
yifeih Apr 2, 2019
672d473
Revert "initial tests"
yifeih Apr 3, 2019
e0a3289
initial impl
yifeih Apr 3, 2019
1e89b3f
get shuffle reader tests to pass
yifeih Apr 3, 2019
495c7bd
update
yifeih Apr 3, 2019
88a03cb
tests
yifeih Apr 3, 2019
741deed
style
yifeih Apr 3, 2019
76c0381
Merge branch 'spark-25299' into yh/reader-api
yifeih Apr 3, 2019
c7c52b0
hook up executor components
yifeih Apr 4, 2019
897c0bf
fix compile
yifeih Apr 4, 2019
34eaaf6
remove unnecessary fields
yifeih Apr 4, 2019
0548800
remove unused
yifeih Apr 4, 2019
0637e70
refactor retrying iterator
yifeih Apr 4, 2019
f069dc1
remove unused import
yifeih Apr 4, 2019
0bba677
fix some comments
yifeih Apr 5, 2019
a82a725
null check
yifeih Apr 5, 2019
ac392a1
refactor interface
yifeih Apr 5, 2019
53dd94b
refactor API
yifeih Apr 5, 2019
4c0c791
shuffle iterator style
yifeih Apr 5, 2019
84f7931
add some javadocs for interfaces
yifeih Apr 5, 2019
b59efb5
attach apache headers
yifeih Apr 5, 2019
aba8a94
remove unused imports
yifeih Apr 5, 2019
5ef59b6
remove another import
yifeih Apr 5, 2019
49a1901
fix reader
yifeih Apr 5, 2019
8c6c09c
fix imports
yifeih Apr 5, 2019
6370b41
add exception comment for retry API
yifeih Apr 10, 2019
c442b63
address some comments
yifeih Apr 10, 2019
2c1272a
address comments
yifeih Apr 10, 2019
2758a5c
Merge branch 'spark-25299' into yh/reader-api
yifeih Apr 19, 2019
bd349ca
resolve conflicts
yifeih Apr 19, 2019
653f67c
style
yifeih Apr 19, 2019
9f53839
address some comments
yifeih Apr 19, 2019
94275fd
style
yifeih Apr 20, 2019
26e97c1
refactor API
yifeih Apr 20, 2019
91db776
cleanup
yifeih Apr 20, 2019
f0fa7b8
fix tests and style
yifeih Apr 22, 2019
50c8fc3
style
yifeih Apr 22, 2019
4aa4b6e
reorder result for test?
yifeih Apr 22, 2019
7d23f47
wip
yifeih Apr 26, 2019
363d4ab
address comments
yifeih Apr 29, 2019
bb7fa4c
style
yifeih Apr 29, 2019
711109b
cleanup tests
yifeih Apr 29, 2019
04a135c
Remove unused class
mccheah Apr 30, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.shuffle;

public final class ShuffleBlockInfo {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Spark doesn't usually put final modifiers.

private final int shuffleId;
private final int mapId;
private final int reduceId;
private final long length;

public ShuffleBlockInfo(int shuffleId, int mapId, int reduceId, long length) {
this.shuffleId = shuffleId;
this.mapId = mapId;
this.reduceId = reduceId;
this.length = length;
}

public int getShuffleId() {
return shuffleId;
}

public int getMapId() {
return mapId;
}

public int getReduceId() {
return reduceId;
}

public long getLength() {
return length;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.shuffle;

import java.io.IOException;
import java.io.InputStream;

/**
* :: Experimental ::
* An interface for reading shuffle records
*/
public interface ShuffleReadSupport {
Iterable<InputStream> getPartitionReaders(Iterable<ShuffleBlockInfo> blockMetadata) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

package org.apache.spark.shuffle

import java.util.Optional

import scala.collection.JavaConverters._

import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport}
import org.apache.spark.internal.Logging
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

Expand All @@ -34,33 +38,31 @@ private[spark] class BlockStoreShuffleReader[K, C](
endPartition: Int,
context: TaskContext,
readMetrics: ShuffleReadMetricsReporter,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
shuffleReadSupport: ShuffleReadSupport,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C] with Logging {

private val dep = handle.dependency

/** Read the combined key-values for this reduce task */
val blocksIterator =
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
.flatMap(blockManagerIdInfo => {
blockManagerIdInfo._2.map(
blockInfo => {
val block = blockInfo._1.asInstanceOf[ShuffleBlockId]
new ShuffleBlockInfo(block.shuffleId, block.mapId, block.reduceId, blockInfo._2)
}
)
})
override def read(): Iterator[Product2[K, C]] = {
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
readMetrics).toCompletionIterator
val wrappedStreams =
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the name is misleading now since these streams are not really wrapped (you moved the wrapStream call to later).

shuffleReadSupport.getPartitionReaders(blocksIterator.toIterable.asJava).asScala

val serializerInstance = dep.serializer.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
val recordIter = wrappedStreams.flatMap { case wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
Expand All @@ -72,7 +74,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
}.toIterator,
context.taskMetrics().mergeShuffleReadMetrics())

// An interruptible iterator must be used here in order to support task cancellation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.io

import java.io.InputStream

import scala.collection.JavaConverters._

import org.apache.spark.{MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.api.shuffle.{ShuffleBlockInfo, ShuffleReadSupport}
import org.apache.spark.internal.config
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator, ShuffleBlockId}

class DefaultShuffleReadSupport(
blockManager: BlockManager,
serializerManager: SerializerManager,
mapOutputTracker: MapOutputTracker) extends ShuffleReadSupport {

val maxBytesInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be private?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm yes xD

val maxReqsInFlight = SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT)
val maxBlocksInFlightPerAddress =
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
val maxReqSizeShuffleToMem = SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)
val detectCorrupt = SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT)

override def getPartitionReaders(
blockMetadata: java.lang.Iterable[ShuffleBlockInfo]): java.lang.Iterable[InputStream] = {

val minReduceId = blockMetadata.asScala.map(block => block.getReduceId).min
val maxReduceId = blockMetadata.asScala.map(block => block.getReduceId).max
val shuffleId = blockMetadata.asScala.head.getShuffleId

val shuffleBlockFetchIterator = new ShuffleBlockFetcherIterator(
TaskContext.get(),
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(shuffleId, minReduceId, maxReduceId + 1),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was already called in BlockStoreShuffleReader, you shouldn't need to call it again. Shouldn't the api just change so its getPartitionReaders(Iterator[(BlockManagerId, Seq[(BlockId, Long)])]), as returned by getMapSizesByExecutorId?

Also @mccheah -- I don't think this should be called ...byExecutorId anymore, it should be ...byShuffleStoreId or something along those lines, as the executorId which originally wrote the data is no longer meaningful. I think we discussed this on another PR, not sure whether that is still pending or I am forgetting something?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw I want to make sure this doesn't get lost in the discussion about failure handling -- the point here of organizing by shuffle store id is actually for something else, its for efficiency. Spark bundles requests to one host together to avoid having to make so many requests:

https://github.com/apache/spark/blob/9ed60c2c33737d4017ab8fb2628c40f8b14f3c5c/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala#L303-L331

But here the reader will be making requests to the shuffle store, not to the executor which generated the data -- so we want things to be grouped by shuffle store node id.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is the plugin implementation, so this is reading from executors or shuffle services directly, not from an arbitrary storage location. This should be a replication of what is currently done, right?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overridable plugin is called by BlockStoreShuffleReader. By the time we're in this code path, we know we're reading using the pre-existing code paths.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I mixed up two different ideas in this comment thread so now this discussion might be confused ... My second comment was a follow on to renaming getMapSizesByExecutorId to getMapSizesByShuffleStoreId. Yes that renaming is only useful for other cases, not this implementation -- but because its a shared api, it would be reflected here.

Copy link

@mccheah mccheah Apr 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The notion of locations is handled in #517. This code path without PR 517 just asks the map output tracker directly for the locations. Eventually the BlockStoreShuffleReader will be pushing down the locations to this layer instead.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, yeah I brought this up here: #517 (comment) and I guess you addressed that already.

OK, so the naming part is handled -- still the other part I mentioned in the first comment is still here:

this was already called in BlockStoreShuffleReader, you shouldn't need to call it again. Shouldn't the api just change so its getPartitionReaders(Iterator[(BlockManagerId, Seq[(BlockId, Long)])]), as returned by getMapSizesByExecutorId?

serializerManager.wrapStream,
maxBytesInFlight,
maxReqsInFlight,
maxBlocksInFlightPerAddress,
maxReqSizeShuffleToMem,
detectCorrupt,
shuffleMetrics = TaskContext.get().taskMetrics().createTempShuffleReadMetrics()
).toCompletionIterator

new ShuffleBlockInputStreamIterator(shuffleBlockFetchIterator).toIterable.asJava
}

private class ShuffleBlockInputStreamIterator(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the explicit iterator subclass? I'd think we can just call blockFetchIterator.map(_._2).asJava.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh scala magic that I am still learning 🙃

blockFetchIterator: Iterator[(BlockId, InputStream)])
extends Iterator[InputStream] {
override def hasNext: Boolean = blockFetchIterator.hasNext

override def next(): InputStream = {
blockFetchIterator.next()._2
}
}

private[spark] object DefaultShuffleReadSupport {
def toShuffleBlockInfo(blockId: BlockId, length: Long): ShuffleBlockInfo = {
assert(blockId.isInstanceOf[ShuffleBlockId])
val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
new ShuffleBlockInfo(
shuffleBlockId.shuffleId,
shuffleBlockId.mapId,
shuffleBlockId.reduceId,
length)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle._
import org.apache.spark.shuffle.io.DefaultShuffleReadSupport

/**
* In sort-based shuffle, incoming records are sorted according to their target partition ids, then
Expand Down Expand Up @@ -116,9 +117,14 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
// TODO: remove this from here once ShuffleExecutorComponents is introduced
val readSupport = new DefaultShuffleReadSupport(
blockManager = SparkEnv.get.blockManager,
serializerManager = SparkEnv.get.serializerManager,
mapOutputTracker = SparkEnv.get.mapOutputTracker)
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
startPartition, endPartition, context, metrics)
startPartition, endPartition, context, metrics, readSupport)
}

/** Get a writer for a given partition. Called on executors by map tasks. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ package org.apache.spark.shuffle
import java.io.{ByteArrayOutputStream, InputStream}
import java.nio.ByteBuffer

import org.mockito.Mockito.{mock, when}
import org.mockito.Mockito.{doReturn, mock, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.{Answer, Stubber}

import org.apache.spark._
import org.apache.spark.internal.config
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}
import org.apache.spark.shuffle.io.DefaultShuffleReadSupport
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockId}

/**
* Wrapper for a managed buffer that keeps track of how many times retain and release are called.
Expand Down Expand Up @@ -101,16 +104,19 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext

// Make a mocked MapOutputTracker for the shuffle reader to use to determine what
// shuffle data to read.
val mapOutputTracker = mock(classOf[MapOutputTracker])
when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn {
// Test a scenario where all data is local, to avoid creating a bunch of additional mocks
// for the code to read data over the network.
val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
(shuffleBlockId, byteOutputStream.size().toLong)
}
Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator
val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
(shuffleBlockId, byteOutputStream.size().toLong)
}
val blocksToRetrieve = Seq((localBlockManagerId, shuffleBlockIdsAndSizes))
val mapOutputTracker = mock(classOf[MapOutputTracker])
when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1))
.thenAnswer(new Answer[Iterator[(BlockManagerId, Seq[(BlockId, Long)])]] {
def answer(invocationOnMock: InvocationOnMock):
Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4-space indent this from def

blocksToRetrieve.iterator
}
})

// Create a mocked shuffle handle to pass into HashShuffleReader.
val shuffleHandle = {
Expand All @@ -128,15 +134,18 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
.set(config.SHUFFLE_SPILL_COMPRESS, false))

val taskContext = TaskContext.empty()
TaskContext.setTaskContext(taskContext)
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()

val shuffleReadSupport =
new DefaultShuffleReadSupport(blockManager, serializerManager, mapOutputTracker)
val shuffleReader = new BlockStoreShuffleReader(
shuffleHandle,
reduceId,
reduceId + 1,
taskContext,
metrics,
serializerManager,
blockManager,
shuffleReadSupport,
mapOutputTracker)

assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)
Expand Down