diff --git a/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java
new file mode 100644
index 0000000000000..b0aed4d08d387
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.api.shuffle;
+
+import org.apache.spark.annotation.Experimental;
+
+import java.io.Serializable;
+
+/**
+ * Represents metadata about where shuffle blocks were written in a single map task.
+ *
+ * This is optionally returned by shuffle writers. The inner shuffle locations may
+ * be accessed by shuffle readers. Shuffle locations are only necessary when the
+ * location of shuffle blocks needs to be managed by the driver; shuffle plugins
+ * may choose to use an external database or other metadata management systems to
+ * track the locations of shuffle blocks instead.
+ */
+@Experimental
+public interface MapShuffleLocations extends Serializable {
+
+ /**
+ * Get the location for a given shuffle block written by this map task.
+ */
+ ShuffleLocation getLocationForBlock(int reduceId);
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java
new file mode 100644
index 0000000000000..87eb497098e0c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+/**
+ * Marker interface representing a location of a shuffle block. Implementations of shuffle readers
+ * and writers are expected to cast this down to an implementation-specific representation.
+ */
+public interface ShuffleLocation {
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java
index 5119e34803a85..181701175d351 100644
--- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java
@@ -20,6 +20,7 @@
import java.io.IOException;
import org.apache.spark.annotation.Experimental;
+import org.apache.spark.api.java.Optional;
/**
* :: Experimental ::
@@ -31,7 +32,7 @@
public interface ShuffleMapOutputWriter {
ShufflePartitionWriter getNextPartitionWriter() throws IOException;
- void commitAllPartitions() throws IOException;
+ Optional commitAllPartitions() throws IOException;
void abort(Throwable error) throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index aef133fe7d46a..434286175e415 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -25,6 +25,8 @@
import java.nio.channels.WritableByteChannel;
import javax.annotation.Nullable;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.api.shuffle.MapShuffleLocations;
import scala.None$;
import scala.Option;
import scala.Product2;
@@ -134,8 +136,11 @@ public void write(Iterator> records) throws IOException {
try {
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
- mapOutputWriter.commitAllPartitions();
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ Optional blockLocs = mapOutputWriter.commitAllPartitions();
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ blockLocs.orNull(),
+ partitionLengths);
return;
}
final SerializerInstance serInstance = serializer.newInstance();
@@ -168,8 +173,11 @@ public void write(Iterator> records) throws IOException {
}
partitionLengths = writePartitionedData(mapOutputWriter);
- mapOutputWriter.commitAllPartitions();
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ Optional mapLocations = mapOutputWriter.commitAllPartitions();
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ mapLocations.orNull(),
+ partitionLengths);
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
@@ -178,6 +186,10 @@ public void write(Iterator> records) throws IOException {
}
throw e;
}
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ DefaultMapShuffleLocations.get(blockManager.shuffleServerId()),
+ partitionLengths);
}
@VisibleForTesting
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java
new file mode 100644
index 0000000000000..ffd97c0f26605
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+
+import org.apache.spark.api.shuffle.MapShuffleLocations;
+import org.apache.spark.api.shuffle.ShuffleLocation;
+import org.apache.spark.storage.BlockManagerId;
+
+import java.util.Objects;
+
+public class DefaultMapShuffleLocations implements MapShuffleLocations, ShuffleLocation {
+
+ /**
+ * We borrow the cache size from the BlockManagerId's cache - around 1MB, which should be
+ * feasible.
+ */
+ private static final LoadingCache
+ DEFAULT_SHUFFLE_LOCATIONS_CACHE =
+ CacheBuilder.newBuilder()
+ .maximumSize(BlockManagerId.blockManagerIdCacheSize())
+ .build(new CacheLoader() {
+ @Override
+ public DefaultMapShuffleLocations load(BlockManagerId blockManagerId) {
+ return new DefaultMapShuffleLocations(blockManagerId);
+ }
+ });
+
+ private final BlockManagerId location;
+
+ public DefaultMapShuffleLocations(BlockManagerId blockManagerId) {
+ this.location = blockManagerId;
+ }
+
+ public static DefaultMapShuffleLocations get(BlockManagerId blockManagerId) {
+ return DEFAULT_SHUFFLE_LOCATIONS_CACHE.getUnchecked(blockManagerId);
+ }
+
+ @Override
+ public ShuffleLocation getLocationForBlock(int reduceId) {
+ return this;
+ }
+
+ public BlockManagerId getBlockManagerId() {
+ return location;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ return other instanceof DefaultMapShuffleLocations
+ && Objects.equals(((DefaultMapShuffleLocations) other).location, location);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(location);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index d7a6d6450ebc0..232b361313124 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -23,6 +23,8 @@
import java.nio.channels.WritableByteChannel;
import java.util.Iterator;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.api.shuffle.MapShuffleLocations;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
@@ -221,6 +223,7 @@ void closeAndWriteOutput() throws IOException {
final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport
.createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
final long[] partitionLengths;
+ Optional mapLocations;
try {
try {
partitionLengths = mergeSpills(spills, mapWriter);
@@ -231,7 +234,7 @@ void closeAndWriteOutput() throws IOException {
}
}
}
- mapWriter.commitAllPartitions();
+ mapLocations = mapWriter.commitAllPartitions();
} catch (Exception e) {
try {
mapWriter.abort(e);
@@ -240,7 +243,10 @@ void closeAndWriteOutput() throws IOException {
}
throw e;
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ mapLocations.orNull(),
+ partitionLengths);
}
@VisibleForTesting
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java
index 76e87a6740259..f7ec202ef4b9d 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java
@@ -46,6 +46,6 @@ public ShuffleWriteSupport writes() {
throw new IllegalStateException(
"Executor components must be initialized before getting writers.");
}
- return new DefaultShuffleWriteSupport(sparkConf, blockResolver);
+ return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId());
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
index c84158e1891d7..7eb0d56776de9 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
@@ -24,6 +24,10 @@
import java.io.OutputStream;
import java.nio.channels.FileChannel;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.api.shuffle.MapShuffleLocations;
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations;
+import org.apache.spark.storage.BlockManagerId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -49,6 +53,7 @@ public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter {
private final int bufferSize;
private int currPartitionId = 0;
private long currChannelPosition;
+ private final BlockManagerId shuffleServerId;
private final File outputFile;
private File outputTempFile;
@@ -61,11 +66,13 @@ public DefaultShuffleMapOutputWriter(
int shuffleId,
int mapId,
int numPartitions,
+ BlockManagerId shuffleServerId,
ShuffleWriteMetricsReporter metrics,
IndexShuffleBlockResolver blockResolver,
SparkConf sparkConf) {
this.shuffleId = shuffleId;
this.mapId = mapId;
+ this.shuffleServerId = shuffleServerId;
this.metrics = metrics;
this.blockResolver = blockResolver;
this.bufferSize =
@@ -90,10 +97,11 @@ public ShufflePartitionWriter getNextPartitionWriter() throws IOException {
}
@Override
- public void commitAllPartitions() throws IOException {
+ public Optional commitAllPartitions() throws IOException {
cleanUp();
File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
+ return Optional.of(DefaultMapShuffleLocations.get(shuffleServerId));
}
@Override
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java
index f8fadd0ecfa63..86f1583495689 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java
@@ -22,17 +22,21 @@
import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
import org.apache.spark.api.shuffle.ShuffleWriteSupport;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.storage.BlockManagerId;
public class DefaultShuffleWriteSupport implements ShuffleWriteSupport {
private final SparkConf sparkConf;
private final IndexShuffleBlockResolver blockResolver;
+ private final BlockManagerId shuffleServerId;
public DefaultShuffleWriteSupport(
SparkConf sparkConf,
- IndexShuffleBlockResolver blockResolver) {
+ IndexShuffleBlockResolver blockResolver,
+ BlockManagerId shuffleServerId) {
this.sparkConf = sparkConf;
this.blockResolver = blockResolver;
+ this.shuffleServerId = shuffleServerId;
}
@Override
@@ -41,7 +45,7 @@ public ShuffleMapOutputWriter createMapOutputWriter(
int mapId,
int numPartitions) {
return new DefaultShuffleMapOutputWriter(
- shuffleId, mapId, numPartitions,
+ shuffleId, mapId, numPartitions, shuffleServerId,
TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf);
}
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 1d4b1ef9c9a1c..74975019e7480 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -28,6 +28,7 @@ import scala.concurrent.duration.Duration
import scala.reflect.ClassTag
import scala.util.control.NonFatal
+import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation}
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
@@ -281,9 +282,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}
// For testing
- def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
- getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
+ def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int)
+ : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = {
+ getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1)
}
/**
@@ -295,8 +296,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
* and the second item is a sequence of (shuffle block id, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
*/
- def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])]
+ def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
+ : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])]
/**
* Deletes map output status information for the specified shuffle stage.
@@ -645,8 +646,8 @@ private[spark] class MapOutputTrackerMaster(
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
// This method is only called in local-mode.
- def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
+ : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
@@ -682,12 +683,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
private val fetching = new HashSet[Int]
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
- override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
- : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int)
+ : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
- MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
+ MapOutputTracker.convertMapStatuses(
+ shuffleId, startPartition, endPartition, statuses)
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
@@ -871,9 +873,9 @@ private[spark] object MapOutputTracker extends Logging {
shuffleId: Int,
startPartition: Int,
endPartition: Int,
- statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
+ statuses: Array[MapStatus]): Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = {
assert (statuses != null)
- val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]]
+ val splitsByAddress = new HashMap[ShuffleLocation, ListBuffer[(BlockId, Long)]]
for ((status, mapId) <- statuses.iterator.zipWithIndex) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
@@ -883,7 +885,8 @@ private[spark] object MapOutputTracker extends Logging {
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
if (size != 0) {
- splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
+ val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part)
+ splitsByAddress.getOrElseUpdate(shuffleLoc, ListBuffer()) +=
((ShuffleBlockId(shuffleId, mapId, part), size))
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 64f0a060a247c..a61f9bd14ef2f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -24,7 +24,9 @@ import scala.collection.mutable
import org.roaringbitmap.RoaringBitmap
import org.apache.spark.SparkEnv
+import org.apache.spark.api.shuffle.MapShuffleLocations
import org.apache.spark.internal.config
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
@@ -33,7 +35,16 @@ import org.apache.spark.util.Utils
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
*/
private[spark] sealed trait MapStatus {
- /** Location where this task was run. */
+
+ /**
+ * Locations where this task stored shuffle blocks.
+ *
+ * May be null if the MapOutputTracker is not tracking the location of shuffle blocks, leaving it
+ * up to the implementation of shuffle plugins to do so.
+ */
+ def mapShuffleLocations: MapShuffleLocations
+
+ /** Location where the task was run. */
def location: BlockManagerId
/**
@@ -56,11 +67,31 @@ private[spark] object MapStatus {
.map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS))
.getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get)
+ // A temporary concession to the fact that we only expect implementations of shuffle provided by
+ // Spark to be storing shuffle locations in the driver, meaning we want to introduce as little
+ // serialization overhead as possible in such default cases.
+ //
+ // If more similar cases arise, consider adding a serialization API for these shuffle locations.
+ private val DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 0
+ private val NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 1
+
+ /**
+ * Visible for testing.
+ */
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
+ apply(loc, DefaultMapShuffleLocations.get(loc), uncompressedSizes)
+ }
+
+ def apply(
+ loc: BlockManagerId,
+ mapShuffleLocs: MapShuffleLocations,
+ uncompressedSizes: Array[Long]): MapStatus = {
if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
- HighlyCompressedMapStatus(loc, uncompressedSizes)
+ HighlyCompressedMapStatus(
+ loc, mapShuffleLocs, uncompressedSizes)
} else {
- new CompressedMapStatus(loc, uncompressedSizes)
+ new CompressedMapStatus(
+ loc, mapShuffleLocs, uncompressedSizes)
}
}
@@ -91,41 +122,89 @@ private[spark] object MapStatus {
math.pow(LOG_BASE, compressedSize & 0xFF).toLong
}
}
-}
+ def writeLocations(
+ loc: BlockManagerId,
+ mapShuffleLocs: MapShuffleLocations,
+ out: ObjectOutput): Unit = {
+ if (mapShuffleLocs != null) {
+ out.writeBoolean(true)
+ if (mapShuffleLocs.isInstanceOf[DefaultMapShuffleLocations]
+ && mapShuffleLocs.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId == loc) {
+ out.writeByte(MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID)
+ } else {
+ out.writeByte(MapStatus.NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID)
+ out.writeObject(mapShuffleLocs)
+ }
+ } else {
+ out.writeBoolean(false)
+ }
+ loc.writeExternal(out)
+ }
+
+ def readLocations(in: ObjectInput): (BlockManagerId, MapShuffleLocations) = {
+ if (in.readBoolean()) {
+ val locId = in.readByte()
+ if (locId == MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) {
+ val blockManagerId = BlockManagerId(in)
+ (blockManagerId, DefaultMapShuffleLocations.get(blockManagerId))
+ } else {
+ val mapShuffleLocations = in.readObject().asInstanceOf[MapShuffleLocations]
+ val blockManagerId = BlockManagerId(in)
+ (blockManagerId, mapShuffleLocations)
+ }
+ } else {
+ val blockManagerId = BlockManagerId(in)
+ (blockManagerId, null)
+ }
+ }
+}
/**
* A [[MapStatus]] implementation that tracks the size of each block. Size for each block is
* represented using a single byte.
*
- * @param loc location where the task is being executed.
+ * @param loc Location were the task is being executed.
+ * @param mapShuffleLocs locations where the task stored its shuffle blocks - may be null.
* @param compressedSizes size of the blocks, indexed by reduce partition id.
*/
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
+ private[this] var mapShuffleLocs: MapShuffleLocations,
private[this] var compressedSizes: Array[Byte])
extends MapStatus with Externalizable {
- protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only
+ // For deserialization only
+ protected def this() = this(null, null, null.asInstanceOf[Array[Byte]])
- def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
- this(loc, uncompressedSizes.map(MapStatus.compressSize))
+ def this(
+ loc: BlockManagerId,
+ mapShuffleLocations: MapShuffleLocations,
+ uncompressedSizes: Array[Long]) {
+ this(
+ loc,
+ mapShuffleLocations,
+ uncompressedSizes.map(MapStatus.compressSize))
}
override def location: BlockManagerId = loc
+ override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs
+
override def getSizeForBlock(reduceId: Int): Long = {
MapStatus.decompressSize(compressedSizes(reduceId))
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- loc.writeExternal(out)
+ MapStatus.writeLocations(loc, mapShuffleLocs, out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
- loc = BlockManagerId(in)
+ val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in)
+ loc = deserializedLoc
+ mapShuffleLocs = deserializedMapShuffleLocs
val len = in.readInt()
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
@@ -138,6 +217,7 @@ private[spark] class CompressedMapStatus(
* plus a bitmap for tracking which blocks are empty.
*
* @param loc location where the task is being executed
+ * @param mapShuffleLocs location where the task stored shuffle blocks - may be null
* @param numNonEmptyBlocks the number of non-empty blocks
* @param emptyBlocks a bitmap tracking which blocks are empty
* @param avgSize average size of the non-empty and non-huge blocks
@@ -145,6 +225,7 @@ private[spark] class CompressedMapStatus(
*/
private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
+ private[this] var mapShuffleLocs: MapShuffleLocations,
private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
@@ -155,10 +236,12 @@ private[spark] class HighlyCompressedMapStatus private (
require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1, null) // For deserialization only
+ protected def this() = this(null, null, -1, null, -1, null) // For deserialization only
override def location: BlockManagerId = loc
+ override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs
+
override def getSizeForBlock(reduceId: Int): Long = {
assert(hugeBlockSizes != null)
if (emptyBlocks.contains(reduceId)) {
@@ -172,7 +255,7 @@ private[spark] class HighlyCompressedMapStatus private (
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
- loc.writeExternal(out)
+ MapStatus.writeLocations(loc, mapShuffleLocs, out)
emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
out.writeInt(hugeBlockSizes.size)
@@ -183,7 +266,9 @@ private[spark] class HighlyCompressedMapStatus private (
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
- loc = BlockManagerId(in)
+ val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in)
+ loc = deserializedLoc
+ mapShuffleLocs = deserializedMapShuffleLocs
emptyBlocks = new RoaringBitmap()
emptyBlocks.readExternal(in)
avgSize = in.readLong()
@@ -199,7 +284,10 @@ private[spark] class HighlyCompressedMapStatus private (
}
private[spark] object HighlyCompressedMapStatus {
- def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ def apply(
+ loc: BlockManagerId,
+ mapShuffleLocs: MapShuffleLocations,
+ uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a zero-sized
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
@@ -239,7 +327,12 @@ private[spark] object HighlyCompressedMapStatus {
}
emptyBlocks.trim()
emptyBlocks.runOptimize()
- new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
- hugeBlockSizes)
+ new HighlyCompressedMapStatus(
+ loc,
+ mapShuffleLocs,
+ numNonEmptyBlocks,
+ emptyBlocks,
+ avgSize,
+ hugeBlockSizes)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 2df133dd2b13a..ba8c92518f019 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -31,7 +31,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSe
import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput}
import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool}
-import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
+import com.esotericsoftware.kryo.serializers.{ExternalizableSerializer, JavaSerializer => KryoJavaSerializer}
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.roaringbitmap.RoaringBitmap
@@ -152,6 +152,8 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer())
kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer())
kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
+ kryo.register(classOf[CompressedMapStatus], new ExternalizableSerializer())
+ kryo.register(classOf[HighlyCompressedMapStatus], new ExternalizableSerializer())
kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas))
kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas))
@@ -485,8 +487,6 @@ private[serializer] object KryoSerializer {
private val toRegister: Seq[Class[_]] = Seq(
ByteBuffer.allocate(1).getClass,
classOf[StorageLevel],
- classOf[CompressedMapStatus],
- classOf[HighlyCompressedMapStatus],
classOf[CompactBuffer[_]],
classOf[BlockManagerId],
classOf[Array[Boolean]],
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index c5eefc7c5c049..d6f63e71f113c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -20,7 +20,8 @@ package org.apache.spark.shuffle
import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer.SerializerManager
-import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations
+import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -47,7 +48,14 @@ private[spark] class BlockStoreShuffleReader[K, C](
context,
blockManager.shuffleClient,
blockManager,
- mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
+ mapOutputTracker.getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition)
+ .map {
+ case (loc: DefaultMapShuffleLocations, blocks: Seq[(BlockId, Long)]) =>
+ (loc.getBlockManagerId, blocks)
+ case _ =>
+ throw new UnsupportedOperationException("Not allowed to using non-default map shuffle" +
+ " locations yet.")
+ },
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 62316f384b642..1fcae684b0052 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -67,8 +67,11 @@ private[spark] class SortShuffleWriter[K, V, C](
val mapOutputWriter = writeSupport.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
- mapOutputWriter.commitAllPartitions()
- mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
+ val mapLocations = mapOutputWriter.commitAllPartitions()
+ mapStatus = MapStatus(
+ blockManager.shuffleServerId,
+ mapLocations.orNull(),
+ partitionLengths)
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index d4a59c33b974c..d72bd6f9af6bc 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -133,12 +133,14 @@ private[spark] object BlockManagerId {
getCachedBlockManagerId(obj)
}
+ val blockManagerIdCacheSize = 10000
+
/**
* The max cache size is hardcoded to 10000, since the size of a BlockManagerId
* object is about 48B, the total memory cost should be below 1MB which is feasible.
*/
val blockManagerIdCache = CacheBuilder.newBuilder()
- .maximumSize(10000)
+ .maximumSize(blockManagerIdCacheSize)
.build(new CacheLoader[BlockManagerId, BlockManagerId]() {
override def load(id: BlockManagerId) = id
})
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 012dc5d21bce4..5f0de31bd25e3 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -172,6 +172,8 @@ public void setUp() throws IOException {
when(shuffleDep.serializer()).thenReturn(serializer);
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
+ when(blockManager.shuffleServerId()).thenReturn(BlockManagerId.apply(
+ "0", "localhost", 9099, Option.empty()));
TaskContext$.MODULE$.setTaskContext(taskContext);
}
@@ -188,8 +190,7 @@ private UnsafeShuffleWriter