diff --git a/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java new file mode 100644 index 0000000000000..b0aed4d08d387 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/MapShuffleLocations.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.api.shuffle; + +import org.apache.spark.annotation.Experimental; + +import java.io.Serializable; + +/** + * Represents metadata about where shuffle blocks were written in a single map task. + *

+ * This is optionally returned by shuffle writers. The inner shuffle locations may + * be accessed by shuffle readers. Shuffle locations are only necessary when the + * location of shuffle blocks needs to be managed by the driver; shuffle plugins + * may choose to use an external database or other metadata management systems to + * track the locations of shuffle blocks instead. + */ +@Experimental +public interface MapShuffleLocations extends Serializable { + + /** + * Get the location for a given shuffle block written by this map task. + */ + ShuffleLocation getLocationForBlock(int reduceId); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java new file mode 100644 index 0000000000000..87eb497098e0c --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleLocation.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +/** + * Marker interface representing a location of a shuffle block. Implementations of shuffle readers + * and writers are expected to cast this down to an implementation-specific representation. + */ +public interface ShuffleLocation { +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java index 5119e34803a85..181701175d351 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleMapOutputWriter.java @@ -20,6 +20,7 @@ import java.io.IOException; import org.apache.spark.annotation.Experimental; +import org.apache.spark.api.java.Optional; /** * :: Experimental :: @@ -31,7 +32,7 @@ public interface ShuffleMapOutputWriter { ShufflePartitionWriter getNextPartitionWriter() throws IOException; - void commitAllPartitions() throws IOException; + Optional commitAllPartitions() throws IOException; void abort(Throwable error) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index aef133fe7d46a..434286175e415 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -25,6 +25,8 @@ import java.nio.channels.WritableByteChannel; import javax.annotation.Nullable; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; import scala.None$; import scala.Option; import scala.Product2; @@ -134,8 +136,11 @@ public void write(Iterator> records) throws IOException { try { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; - mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + Optional blockLocs = mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + blockLocs.orNull(), + partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -168,8 +173,11 @@ public void write(Iterator> records) throws IOException { } partitionLengths = writePartitionedData(mapOutputWriter); - mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + Optional mapLocations = mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + mapLocations.orNull(), + partitionLengths); } catch (Exception e) { try { mapOutputWriter.abort(e); @@ -178,6 +186,10 @@ public void write(Iterator> records) throws IOException { } throw e; } + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + DefaultMapShuffleLocations.get(blockManager.shuffleServerId()), + partitionLengths); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java new file mode 100644 index 0000000000000..ffd97c0f26605 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultMapShuffleLocations.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; + +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.api.shuffle.ShuffleLocation; +import org.apache.spark.storage.BlockManagerId; + +import java.util.Objects; + +public class DefaultMapShuffleLocations implements MapShuffleLocations, ShuffleLocation { + + /** + * We borrow the cache size from the BlockManagerId's cache - around 1MB, which should be + * feasible. + */ + private static final LoadingCache + DEFAULT_SHUFFLE_LOCATIONS_CACHE = + CacheBuilder.newBuilder() + .maximumSize(BlockManagerId.blockManagerIdCacheSize()) + .build(new CacheLoader() { + @Override + public DefaultMapShuffleLocations load(BlockManagerId blockManagerId) { + return new DefaultMapShuffleLocations(blockManagerId); + } + }); + + private final BlockManagerId location; + + public DefaultMapShuffleLocations(BlockManagerId blockManagerId) { + this.location = blockManagerId; + } + + public static DefaultMapShuffleLocations get(BlockManagerId blockManagerId) { + return DEFAULT_SHUFFLE_LOCATIONS_CACHE.getUnchecked(blockManagerId); + } + + @Override + public ShuffleLocation getLocationForBlock(int reduceId) { + return this; + } + + public BlockManagerId getBlockManagerId() { + return location; + } + + @Override + public boolean equals(Object other) { + return other instanceof DefaultMapShuffleLocations + && Objects.equals(((DefaultMapShuffleLocations) other).location, location); + } + + @Override + public int hashCode() { + return Objects.hashCode(location); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index d7a6d6450ebc0..232b361313124 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -23,6 +23,8 @@ import java.nio.channels.WritableByteChannel; import java.util.Iterator; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -221,6 +223,7 @@ void closeAndWriteOutput() throws IOException { final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); final long[] partitionLengths; + Optional mapLocations; try { try { partitionLengths = mergeSpills(spills, mapWriter); @@ -231,7 +234,7 @@ void closeAndWriteOutput() throws IOException { } } } - mapWriter.commitAllPartitions(); + mapLocations = mapWriter.commitAllPartitions(); } catch (Exception e) { try { mapWriter.abort(e); @@ -240,7 +243,10 @@ void closeAndWriteOutput() throws IOException { } throw e; } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + mapLocations.orNull(), + partitionLengths); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java index 76e87a6740259..f7ec202ef4b9d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -46,6 +46,6 @@ public ShuffleWriteSupport writes() { throw new IllegalStateException( "Executor components must be initialized before getting writers."); } - return new DefaultShuffleWriteSupport(sparkConf, blockResolver); + return new DefaultShuffleWriteSupport(sparkConf, blockResolver, blockManager.shuffleServerId()); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java index c84158e1891d7..7eb0d56776de9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -24,6 +24,10 @@ import java.io.OutputStream; import java.nio.channels.FileChannel; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations; +import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,6 +53,7 @@ public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final int bufferSize; private int currPartitionId = 0; private long currChannelPosition; + private final BlockManagerId shuffleServerId; private final File outputFile; private File outputTempFile; @@ -61,11 +66,13 @@ public DefaultShuffleMapOutputWriter( int shuffleId, int mapId, int numPartitions, + BlockManagerId shuffleServerId, ShuffleWriteMetricsReporter metrics, IndexShuffleBlockResolver blockResolver, SparkConf sparkConf) { this.shuffleId = shuffleId; this.mapId = mapId; + this.shuffleServerId = shuffleServerId; this.metrics = metrics; this.blockResolver = blockResolver; this.bufferSize = @@ -90,10 +97,11 @@ public ShufflePartitionWriter getNextPartitionWriter() throws IOException { } @Override - public void commitAllPartitions() throws IOException { + public Optional commitAllPartitions() throws IOException { cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + return Optional.of(DefaultMapShuffleLocations.get(shuffleServerId)); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java index f8fadd0ecfa63..86f1583495689 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java @@ -22,17 +22,21 @@ import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; import org.apache.spark.api.shuffle.ShuffleWriteSupport; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.storage.BlockManagerId; public class DefaultShuffleWriteSupport implements ShuffleWriteSupport { private final SparkConf sparkConf; private final IndexShuffleBlockResolver blockResolver; + private final BlockManagerId shuffleServerId; public DefaultShuffleWriteSupport( SparkConf sparkConf, - IndexShuffleBlockResolver blockResolver) { + IndexShuffleBlockResolver blockResolver, + BlockManagerId shuffleServerId) { this.sparkConf = sparkConf; this.blockResolver = blockResolver; + this.shuffleServerId = shuffleServerId; } @Override @@ -41,7 +45,7 @@ public ShuffleMapOutputWriter createMapOutputWriter( int mapId, int numPartitions) { return new DefaultShuffleMapOutputWriter( - shuffleId, mapId, numPartitions, + shuffleId, mapId, numPartitions, shuffleServerId, TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf); } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1d4b1ef9c9a1c..74975019e7480 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -28,6 +28,7 @@ import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal +import org.apache.spark.api.shuffle.{MapShuffleLocations, ShuffleLocation} import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ @@ -281,9 +282,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } // For testing - def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { - getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) + def getMapSizesByShuffleLocation(shuffleId: Int, reduceId: Int) + : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { + getMapSizesByShuffleLocation(shuffleId, reduceId, reduceId + 1) } /** @@ -295,8 +296,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * and the second item is a sequence of (shuffle block id, shuffle block size) tuples * describing the shuffle blocks that are stored at that block manager. */ - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] + def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) + : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] /** * Deletes map output status information for the specified shuffle stage. @@ -645,8 +646,8 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) + : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -682,12 +683,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private val fetching = new HashSet[Int] // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. - override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + override def getMapSizesByShuffleLocation(shuffleId: Int, startPartition: Int, endPartition: Int) + : Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses( + shuffleId, startPartition, endPartition, statuses) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: @@ -871,9 +873,9 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus]): Iterator[(ShuffleLocation, Seq[(BlockId, Long)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] + val splitsByAddress = new HashMap[ShuffleLocation, ListBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" @@ -883,7 +885,8 @@ private[spark] object MapOutputTracker extends Logging { for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { - splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + val shuffleLoc = status.mapShuffleLocations.getLocationForBlock(part) + splitsByAddress.getOrElseUpdate(shuffleLoc, ListBuffer()) += ((ShuffleBlockId(shuffleId, mapId, part), size)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 64f0a060a247c..a61f9bd14ef2f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -24,7 +24,9 @@ import scala.collection.mutable import org.roaringbitmap.RoaringBitmap import org.apache.spark.SparkEnv +import org.apache.spark.api.shuffle.MapShuffleLocations import org.apache.spark.internal.config +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -33,7 +35,16 @@ import org.apache.spark.util.Utils * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. */ private[spark] sealed trait MapStatus { - /** Location where this task was run. */ + + /** + * Locations where this task stored shuffle blocks. + * + * May be null if the MapOutputTracker is not tracking the location of shuffle blocks, leaving it + * up to the implementation of shuffle plugins to do so. + */ + def mapShuffleLocations: MapShuffleLocations + + /** Location where the task was run. */ def location: BlockManagerId /** @@ -56,11 +67,31 @@ private[spark] object MapStatus { .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) + // A temporary concession to the fact that we only expect implementations of shuffle provided by + // Spark to be storing shuffle locations in the driver, meaning we want to introduce as little + // serialization overhead as possible in such default cases. + // + // If more similar cases arise, consider adding a serialization API for these shuffle locations. + private val DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 0 + private val NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID: Byte = 1 + + /** + * Visible for testing. + */ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + apply(loc, DefaultMapShuffleLocations.get(loc), uncompressedSizes) + } + + def apply( + loc: BlockManagerId, + mapShuffleLocs: MapShuffleLocations, + uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus( + loc, mapShuffleLocs, uncompressedSizes) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus( + loc, mapShuffleLocs, uncompressedSizes) } } @@ -91,41 +122,89 @@ private[spark] object MapStatus { math.pow(LOG_BASE, compressedSize & 0xFF).toLong } } -} + def writeLocations( + loc: BlockManagerId, + mapShuffleLocs: MapShuffleLocations, + out: ObjectOutput): Unit = { + if (mapShuffleLocs != null) { + out.writeBoolean(true) + if (mapShuffleLocs.isInstanceOf[DefaultMapShuffleLocations] + && mapShuffleLocs.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId == loc) { + out.writeByte(MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) + } else { + out.writeByte(MapStatus.NON_DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) + out.writeObject(mapShuffleLocs) + } + } else { + out.writeBoolean(false) + } + loc.writeExternal(out) + } + + def readLocations(in: ObjectInput): (BlockManagerId, MapShuffleLocations) = { + if (in.readBoolean()) { + val locId = in.readByte() + if (locId == MapStatus.DEFAULT_MAP_SHUFFLE_LOCATIONS_ID) { + val blockManagerId = BlockManagerId(in) + (blockManagerId, DefaultMapShuffleLocations.get(blockManagerId)) + } else { + val mapShuffleLocations = in.readObject().asInstanceOf[MapShuffleLocations] + val blockManagerId = BlockManagerId(in) + (blockManagerId, mapShuffleLocations) + } + } else { + val blockManagerId = BlockManagerId(in) + (blockManagerId, null) + } + } +} /** * A [[MapStatus]] implementation that tracks the size of each block. Size for each block is * represented using a single byte. * - * @param loc location where the task is being executed. + * @param loc Location were the task is being executed. + * @param mapShuffleLocs locations where the task stored its shuffle blocks - may be null. * @param compressedSizes size of the blocks, indexed by reduce partition id. */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, + private[this] var mapShuffleLocs: MapShuffleLocations, private[this] var compressedSizes: Array[Byte]) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + // For deserialization only + protected def this() = this(null, null, null.asInstanceOf[Array[Byte]]) - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this( + loc: BlockManagerId, + mapShuffleLocations: MapShuffleLocations, + uncompressedSizes: Array[Long]) { + this( + loc, + mapShuffleLocations, + uncompressedSizes.map(MapStatus.compressSize)) } override def location: BlockManagerId = loc + override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - loc.writeExternal(out) + MapStatus.writeLocations(loc, mapShuffleLocs, out) out.writeInt(compressedSizes.length) out.write(compressedSizes) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - loc = BlockManagerId(in) + val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in) + loc = deserializedLoc + mapShuffleLocs = deserializedMapShuffleLocs val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) @@ -138,6 +217,7 @@ private[spark] class CompressedMapStatus( * plus a bitmap for tracking which blocks are empty. * * @param loc location where the task is being executed + * @param mapShuffleLocs location where the task stored shuffle blocks - may be null * @param numNonEmptyBlocks the number of non-empty blocks * @param emptyBlocks a bitmap tracking which blocks are empty * @param avgSize average size of the non-empty and non-huge blocks @@ -145,6 +225,7 @@ private[spark] class CompressedMapStatus( */ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, + private[this] var mapShuffleLocs: MapShuffleLocations, private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, @@ -155,10 +236,12 @@ private[spark] class HighlyCompressedMapStatus private ( require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc + override def mapShuffleLocations: MapShuffleLocations = mapShuffleLocs + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -172,7 +255,7 @@ private[spark] class HighlyCompressedMapStatus private ( } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { - loc.writeExternal(out) + MapStatus.writeLocations(loc, mapShuffleLocs, out) emptyBlocks.writeExternal(out) out.writeLong(avgSize) out.writeInt(hugeBlockSizes.size) @@ -183,7 +266,9 @@ private[spark] class HighlyCompressedMapStatus private ( } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { - loc = BlockManagerId(in) + val (deserializedLoc, deserializedMapShuffleLocs) = MapStatus.readLocations(in) + loc = deserializedLoc + mapShuffleLocs = deserializedMapShuffleLocs emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() @@ -199,7 +284,10 @@ private[spark] class HighlyCompressedMapStatus private ( } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + mapShuffleLocs: MapShuffleLocations, + uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -239,7 +327,12 @@ private[spark] object HighlyCompressedMapStatus { } emptyBlocks.trim() emptyBlocks.runOptimize() - new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes) + new HighlyCompressedMapStatus( + loc, + mapShuffleLocs, + numNonEmptyBlocks, + emptyBlocks, + avgSize, + hugeBlockSizes) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 2df133dd2b13a..ba8c92518f019 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -31,7 +31,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSe import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.io.{UnsafeInput => KryoUnsafeInput, UnsafeOutput => KryoUnsafeOutput} import com.esotericsoftware.kryo.pool.{KryoCallback, KryoFactory, KryoPool} -import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} +import com.esotericsoftware.kryo.serializers.{ExternalizableSerializer, JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.RoaringBitmap @@ -152,6 +152,8 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[CompressedMapStatus], new ExternalizableSerializer()) + kryo.register(classOf[HighlyCompressedMapStatus], new ExternalizableSerializer()) kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) @@ -485,8 +487,6 @@ private[serializer] object KryoSerializer { private val toRegister: Seq[Class[_]] = Seq( ByteBuffer.allocate(1).getClass, classOf[StorageLevel], - classOf[CompressedMapStatus], - classOf[HighlyCompressedMapStatus], classOf[CompactBuffer[_]], classOf[BlockManagerId], classOf[Array[Boolean]], diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c5eefc7c5c049..d6f63e71f113c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -20,7 +20,8 @@ package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -47,7 +48,14 @@ private[spark] class BlockStoreShuffleReader[K, C]( context, blockManager.shuffleClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + mapOutputTracker.getMapSizesByShuffleLocation(handle.shuffleId, startPartition, endPartition) + .map { + case (loc: DefaultMapShuffleLocations, blocks: Seq[(BlockId, Long)]) => + (loc.getBlockManagerId, blocks) + case _ => + throw new UnsupportedOperationException("Not allowed to using non-default map shuffle" + + " locations yet.") + }, serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 62316f384b642..1fcae684b0052 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -67,8 +67,11 @@ private[spark] class SortShuffleWriter[K, V, C]( val mapOutputWriter = writeSupport.createMapOutputWriter( dep.shuffleId, mapId, dep.partitioner.numPartitions) val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) - mapOutputWriter.commitAllPartitions() - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + val mapLocations = mapOutputWriter.commitAllPartitions() + mapStatus = MapStatus( + blockManager.shuffleServerId, + mapLocations.orNull(), + partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index d4a59c33b974c..d72bd6f9af6bc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -133,12 +133,14 @@ private[spark] object BlockManagerId { getCachedBlockManagerId(obj) } + val blockManagerIdCacheSize = 10000 + /** * The max cache size is hardcoded to 10000, since the size of a BlockManagerId * object is about 48B, the total memory cost should be below 1MB which is feasible. */ val blockManagerIdCache = CacheBuilder.newBuilder() - .maximumSize(10000) + .maximumSize(blockManagerIdCacheSize) .build(new CacheLoader[BlockManagerId, BlockManagerId]() { override def load(id: BlockManagerId) = id }) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 012dc5d21bce4..5f0de31bd25e3 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -172,6 +172,8 @@ public void setUp() throws IOException { when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); + when(blockManager.shuffleServerId()).thenReturn(BlockManagerId.apply( + "0", "localhost", 9099, Option.empty())); TaskContext$.MODULE$.setTaskContext(taskContext); } @@ -188,8 +190,7 @@ private UnsafeShuffleWriter createWriter( taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleWriteSupport(conf, shuffleBlockResolver) - ); + new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId())); } private void assertSpillFilesWereCleanedUp() { @@ -550,7 +551,7 @@ public void testPeakMemoryUsed() throws Exception { taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), - new DefaultShuffleWriteSupport(conf, shuffleBlockResolver)); + new DefaultShuffleWriteSupport(conf, shuffleBlockResolver, blockManager.shuffleServerId())); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index d86975964b558..0a77c4f6d5838 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MA import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { @@ -67,10 +68,13 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L))) - val statuses = tracker.getMapSizesByExecutorId(10, 0) + val statuses = tracker.getMapSizesByShuffleLocation(10, 0) assert(statuses.toSet === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), - (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) + Seq( + (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), + (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), + ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) .toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() @@ -90,11 +94,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) + assert(tracker.getMapSizesByShuffleLocation(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) + assert(tracker.getMapSizesByShuffleLocation(10, 0).isEmpty) tracker.stop() rpcEnv.shutdown() @@ -121,7 +125,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) } + intercept[FetchFailedException] { tracker.getMapSizesByShuffleLocation(10, 1) } tracker.stop() rpcEnv.shutdown() @@ -143,24 +147,26 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) slaveTracker.updateEpoch(masterTracker.getEpoch) // This is expected to fail because no outputs have been registered for the shuffle. - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) + assert(slaveTracker.getMapSizesByShuffleLocation(10, 0).toSeq === + Seq( + (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput) slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) } // failure should be cached - intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByShuffleLocation(10, 0) } assert(0 == masterTracker.getNumCachedSerializedBroadcast) masterTracker.stop() @@ -261,8 +267,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // being sent. masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => + val bmId = BlockManagerId("999", "mps", 1000) masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + bmId, + DefaultMapShuffleLocations.get(bmId), + Array.fill[Long](4000000)(0))) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -315,11 +324,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + assert(tracker.getMapSizesByShuffleLocation(10, 0, 4).toSeq === Seq( - (BlockManagerId("a", "hostA", 1000), + (DefaultMapShuffleLocations.get(BlockManagerId("a", "hostA", 1000)), Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), - (BlockManagerId("b", "hostB", 1000), + (DefaultMapShuffleLocations.get(BlockManagerId("b", "hostB", 1000)), Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) ) ) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 90c790cefcca2..83026c002f1b2 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -73,7 +73,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // All blocks must have non-zero size (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id) assert(statuses.forall(_._2.forall(blockIdSizePair => blockIdSizePair._2 > 0))) } } @@ -112,7 +112,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id) statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -137,7 +137,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, id) statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index e17d264cced9f..14b93957734e4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -29,12 +29,14 @@ import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.api.shuffle.MapShuffleLocations import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} @@ -700,8 +702,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -727,8 +729,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) // we can see both result blocks now - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === - HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 0) + .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -766,11 +770,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(ExecutorLost("exec-hostA", event)) if (expectFileLoss) { intercept[MetadataFetchFailedException] { - mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0) + mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0) } } else { - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) } } } @@ -1063,8 +1067,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === - HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 0) + .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( @@ -1193,10 +1199,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === - HashSet("hostA", "hostB")) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet === - HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 0) + .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) + assert(mapOutputTracker + .getMapSizesByShuffleLocation(shuffleId, 1) + .map(_._1.asInstanceOf[DefaultMapShuffleLocations].getBlockManagerId.host) + .toSet === HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( @@ -1386,8 +1396,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi Success, makeMapStatus("hostA", reduceRdd.partitions.size))) assert(shuffleStage.numAvailableOutputs === 2) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostB"), makeShuffleLocation("hostA"))) // finish the next stage normally, which completes the job complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -1541,7 +1551,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi reduceIdx <- reduceIdxs } { // this would throw an exception if the map status hadn't been registered - val statuses = mapOutputTracker.getMapSizesByExecutorId(stage, reduceIdx) + val statuses = mapOutputTracker.getMapSizesByShuffleLocation(stage, reduceIdx) // really we should have already thrown an exception rather than fail either of these // asserts, but just to be extra defensive let's double check the statuses are OK assert(statuses != null) @@ -1593,7 +1603,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // check that we have all the map output for stage 0 (0 until reduceRdd.partitions.length).foreach { reduceIdx => - val statuses = mapOutputTracker.getMapSizesByExecutorId(0, reduceIdx) + val statuses = mapOutputTracker.getMapSizesByShuffleLocation(0, reduceIdx) // really we should have already thrown an exception rather than fail either of these // asserts, but just to be extra defensive let's double check the statuses are OK assert(statuses != null) @@ -1792,8 +1802,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) // Make sure that the reduce stage was now submitted. assert(taskSets.size === 3) @@ -2055,8 +2065,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) @@ -2101,8 +2111,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"))) // Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep val reduceTaskSet = taskSets(1) @@ -2265,8 +2275,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting the second stage, show a fetch failure @@ -2281,8 +2291,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) + assert(listener2.results.size === 0) // Second stage listener should still not have a result // Stage 1 should now be running as task set 3; make its first task succeed @@ -2290,8 +2301,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(3), Seq( (Success, makeMapStatus("hostB", rdd2.partitions.length)), (Success, makeMapStatus("hostD", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep2.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostD"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep2.shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostB"), makeShuffleLocation("hostD"))) assert(listener2.results.size === 1) // Finally, the reduce job should be running as task set 4; make it see a fetch failure, @@ -2329,8 +2340,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", rdd1.partitions.length)), (Success, makeMapStatus("hostB", rdd1.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + HashSet(makeShuffleLocation("hostA"), makeShuffleLocation("hostB"))) assert(listener1.results.size === 1) // When attempting stage1, trigger a fetch failure. @@ -2355,8 +2366,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(2).stageId === 0) complete(taskSets(2), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)))) - assert(mapOutputTracker.getMapSizesByExecutorId(dep1.shuffleId, 0).map(_._1).toSet === - Set(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByShuffleLocation(dep1.shuffleId, 0).map(_._1).toSet === + Set(makeShuffleLocation("hostC"), makeShuffleLocation("hostB"))) // After stage0 is finished, stage1 will be submitted and found there is no missing // partitions in it. Then listener got triggered. @@ -2908,6 +2919,10 @@ object DAGSchedulerSuite { def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) + + def makeShuffleLocation(host: String): MapShuffleLocations = { + DefaultMapShuffleLocations.get(makeBlockManagerId(host)) + } } object FailThisAttempt { diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index c1e7fb9a1db16..3c786c0927bc6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.LocalSparkContext._ import org.apache.spark.internal.config import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { @@ -61,7 +62,11 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes) + val bmId = BlockManagerId("a", "b", 10) + val status = MapStatus( + bmId, + DefaultMapShuffleLocations.get(bmId), + sizes) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -75,7 +80,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes) + val status = MapStatus(null, null, sizes) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -86,11 +91,13 @@ class MapStatusSuite extends SparkFunSuite { test("HighlyCompressedMapStatus: estimated size should be the average non-empty block size") { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) - val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val bmId = BlockManagerId("a", "b", 10) + val loc = DefaultMapShuffleLocations.get(bmId) + val status = MapStatus(bmId, loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) - assert(status1.location == loc) + assert(status1.location == loc.getBlockManagerId) + assert(status1.mapShuffleLocations == loc) for (i <- 0 until 3000) { val estimate = status1.getSizeForBlock(i) if (sizes(i) > 0) { @@ -108,11 +115,13 @@ class MapStatusSuite extends SparkFunSuite { val sizes = (0L to 3000L).toArray val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length - val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val bmId = BlockManagerId("a", "b", 10) + val loc = DefaultMapShuffleLocations.get(bmId) + val status = MapStatus(bmId, loc, sizes) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) - assert(status1.location == loc) + assert(status1.location === bmId) + assert(status1.mapShuffleLocations === loc) for (i <- 0 until threshold) { val estimate = status1.getSizeForBlock(i) if (sizes(i) > 0) { @@ -165,7 +174,8 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val bmId = BlockManagerId("exec-0", "host-0", 100) + val status1 = MapStatus(bmId, DefaultMapShuffleLocations.get(bmId), sizes) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index aa6db8d0423a3..83305a96e6794 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -192,7 +192,8 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa shuffleId <- shuffleIds reduceIdx <- (0 until nParts) } { - val statuses = taskScheduler.mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceIdx) + val statuses = taskScheduler.mapOutputTracker.getMapSizesByShuffleLocation( + shuffleId, reduceIdx) // really we should have already thrown an exception rather than fail either of these // asserts, but just to be extra defensive let's double check the statuses are OK assert(statuses != null) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 16eec7e0bea1f..c523d0cb9ce80 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -36,8 +36,9 @@ import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Kryo._ import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.ThreadUtils class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") @@ -350,8 +351,10 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val ser = new KryoSerializer(conf).newInstance() val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) + val bmId = BlockManagerId("exec-1", "host", 1234) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize(HighlyCompressedMapStatus( + bmId, DefaultMapShuffleLocations.get(bmId), blockSizes)) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 6d2ef17a7a790..b3073addb7ccc 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} /** @@ -102,14 +103,17 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { + when(mapOutputTracker.getMapSizesByShuffleLocation( + shuffleId, reduceId, reduceId + 1)).thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) (shuffleBlockId, byteOutputStream.size().toLong) } - Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator + Seq( + (DefaultMapShuffleLocations.get(localBlockManagerId), shuffleBlockIdsAndSizes)) + .toIterator } // Create a mocked shuffle handle to pass into HashShuffleReader. diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala index 2690f1a515fcc..b39e37c1e3842 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BlockStoreShuffleReaderBenchmark.scala @@ -193,13 +193,13 @@ object BlockStoreShuffleReaderBenchmark extends BenchmarkBase { dataBlockId = remoteBlockManagerId } - when(mapOutputTracker.getMapSizesByExecutorId(0, 0, 1)) + when(mapOutputTracker.getMapSizesByShuffleLocation(0, 0, 1)) .thenReturn { val shuffleBlockIdsAndSizes = (0 until NUM_MAPS).map { mapId => val shuffleBlockId = ShuffleBlockId(0, mapId, 0) (shuffleBlockId, dataFileLength) } - Seq((dataBlockId, shuffleBlockIdsAndSizes)).toIterator + Seq((DefaultMapShuffleLocations.get(dataBlockId), shuffleBlockIdsAndSizes)).toIterator } when(dependency.serializer).thenReturn(serializer) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index 69fe03e75606f..0b3394e88d9f1 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.storage.BlockManagerId /** * Benchmark to measure performance for aggregate primitives. @@ -46,9 +47,10 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase def getWriter(transferTo: Boolean): BypassMergeSortShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) + val shuffleWriteSupport = new DefaultShuffleWriteSupport( + conf, blockResolver, BlockManagerId("0", "localhost", 7090)) conf.set("spark.file.transferTo", String.valueOf(transferTo)) conf.set("spark.shuffle.file.buffer", "32k") - val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( blockManager, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 23f03957ae11a..538672e4bc738 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt, anyString} +import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer @@ -140,7 +140,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte metricsSystem = null, taskMetrics = taskMetrics)) - writeSupport = new DefaultShuffleWriteSupport(conf, blockResolver) + writeSupport = new DefaultShuffleWriteSupport( + conf, blockResolver, BlockManagerId("0", "localhost", 7090)) } override def afterEach(): Unit = { @@ -203,7 +204,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(taskMetrics.memoryBytesSpilled === 0) } - // TODO(ifilonenko): MAKE THIS PASS test("write with some empty partitions with transferTo") { def records: Iterator[(Int, Int)] = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index 32257b0cc4b56..b0ff15cb1f790 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -23,6 +23,7 @@ import org.apache.spark.{Aggregator, SparkEnv, TaskContext} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.BaseShuffleHandle import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.storage.BlockManagerId /** * Benchmark to measure performance for aggregate primitives. @@ -77,7 +78,10 @@ object SortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) TaskContext.setTaskContext(taskContext) - val writeSupport = new DefaultShuffleWriteSupport(defaultConf, blockResolver) + val writeSupport = new DefaultShuffleWriteSupport( + defaultConf, + blockResolver, + BlockManagerId("0", "localhost", 9099)) val shuffleWriter = new SortShuffleWriter[String, String, String]( blockResolver, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index 20bf3eac95d84..0e659ff7cc5f3 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport +import org.apache.spark.storage.BlockManagerId /** * Benchmark to measure performance for aggregate primitives. @@ -43,7 +44,8 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) - val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) + val shuffleWriteSupport = new DefaultShuffleWriteSupport( + conf, blockResolver, BlockManagerId("0", "localhost", 9099)) TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala index 22d52924a7c72..d704f72015ceb 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -87,7 +88,13 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft }).when(blockResolver) .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) mapOutputWriter = new DefaultShuffleMapOutputWriter( - 0, 0, NUM_PARTITIONS, shuffleWriteMetrics, blockResolver, conf) + 0, + 0, + NUM_PARTITIONS, + BlockManagerId("0", "localhost", 9099), + shuffleWriteMetrics, + blockResolver, + conf) } private def readRecordsFromFile(fromByte: Boolean): Array[Array[Int]] = {