diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java
index 1edf044225ccf..4fc20bad9938b 100644
--- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java
@@ -27,7 +27,7 @@
*/
@Experimental
public interface ShuffleExecutorComponents {
- void intitializeExecutor(String appId, String execId);
+ void initializeExecutor(String appId, String execId);
ShuffleWriteSupport writes();
}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java
index c043a6b3a4995..6a53803e5d117 100644
--- a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java
@@ -17,6 +17,7 @@
package org.apache.spark.api.shuffle;
+import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
@@ -26,17 +27,48 @@
/**
* :: Experimental ::
- * An interface for giving streams / channels for shuffle writes
+ * An interface for giving streams / channels for shuffle writes.
*
* @since 3.0.0
*/
@Experimental
-public interface ShufflePartitionWriter {
- OutputStream openStream() throws IOException;
+public interface ShufflePartitionWriter extends Closeable {
- long getLength();
+ /**
+ * Returns an underlying {@link OutputStream} that can write bytes to the underlying data store.
+ *
+ * Note that this stream itself is not closed by the caller; close the stream in the
+ * implementation of this interface's {@link #close()}.
+ */
+ OutputStream toStream() throws IOException;
- default WritableByteChannel openChannel() throws IOException {
- return Channels.newChannel(openStream());
+ /**
+ * Returns an underlying {@link WritableByteChannel} that can write bytes to the underlying data
+ * store.
+ *
+ * Note that this channel itself is not closed by the caller; close the channel in the
+ * implementation of this interface's {@link #close()}.
+ */
+ default WritableByteChannel toChannel() throws IOException {
+ return Channels.newChannel(toStream());
}
+
+ /**
+ * Get the number of bytes written by this writer's stream returned by {@link #toStream()} or
+ * the channel returned by {@link #toChannel()}.
+ */
+ long getNumBytesWritten();
+
+ /**
+ * Close all resources created by this ShufflePartitionWriter, via calls to {@link #toStream()}
+ * or {@link #toChannel()}.
+ *
+ * This must always close any stream returned by {@link #toStream()}.
+ *
+ * Note that the default version of {@link #toChannel()} returns a {@link WritableByteChannel}
+ * that does not itself need to be closed up front; only the underlying output stream given by
+ * {@link #toStream()} must be closed.
+ */
+ @Override
+ void close() throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java
index 5ba5564bb46d0..6c69d5db9fd06 100644
--- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java
@@ -30,7 +30,6 @@
@Experimental
public interface ShuffleWriteSupport {
ShuffleMapOutputWriter createMapOutputWriter(
- String appId,
int shuffleId,
int mapId,
int numPartitions) throws IOException;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 32b446785a9f0..aef133fe7d46a 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -19,8 +19,10 @@
import java.io.File;
import java.io.FileInputStream;
-import java.io.FileOutputStream;
import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
import javax.annotation.Nullable;
import scala.None$;
@@ -34,6 +36,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
+import org.apache.spark.api.shuffle.ShufflePartitionWriter;
+import org.apache.spark.api.shuffle.ShuffleWriteSupport;
import org.apache.spark.internal.config.package$;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
@@ -82,6 +87,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
private final int shuffleId;
private final int mapId;
private final Serializer serializer;
+ private final ShuffleWriteSupport shuffleWriteSupport;
private final IndexShuffleBlockResolver shuffleBlockResolver;
/** Array of file writers, one for each partition */
@@ -103,7 +109,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
BypassMergeSortShuffleHandle handle,
int mapId,
SparkConf conf,
- ShuffleWriteMetricsReporter writeMetrics) {
+ ShuffleWriteMetricsReporter writeMetrics,
+ ShuffleWriteSupport shuffleWriteSupport) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
@@ -116,57 +123,61 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
this.shuffleBlockResolver = shuffleBlockResolver;
+ this.shuffleWriteSupport = shuffleWriteSupport;
}
@Override
public void write(Iterator> records) throws IOException {
assert (partitionWriters == null);
- if (!records.hasNext()) {
- partitionLengths = new long[numPartitions];
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
- return;
- }
- final SerializerInstance serInstance = serializer.newInstance();
- final long openStartTime = System.nanoTime();
- partitionWriters = new DiskBlockObjectWriter[numPartitions];
- partitionWriterSegments = new FileSegment[numPartitions];
- for (int i = 0; i < numPartitions; i++) {
- final Tuple2 tempShuffleBlockIdPlusFile =
- blockManager.diskBlockManager().createTempShuffleBlock();
- final File file = tempShuffleBlockIdPlusFile._2();
- final BlockId blockId = tempShuffleBlockIdPlusFile._1();
- partitionWriters[i] =
- blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
- }
- // Creating the file to write to and creating a disk writer both involve interacting with
- // the disk, and can take a long time in aggregate when we open many files, so should be
- // included in the shuffle write time.
- writeMetrics.incWriteTime(System.nanoTime() - openStartTime);
-
- while (records.hasNext()) {
- final Product2 record = records.next();
- final K key = record._1();
- partitionWriters[partitioner.getPartition(key)].write(key, record._2());
- }
+ ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport
+ .createMapOutputWriter(shuffleId, mapId, numPartitions);
+ try {
+ if (!records.hasNext()) {
+ partitionLengths = new long[numPartitions];
+ mapOutputWriter.commitAllPartitions();
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ return;
+ }
+ final SerializerInstance serInstance = serializer.newInstance();
+ final long openStartTime = System.nanoTime();
+ partitionWriters = new DiskBlockObjectWriter[numPartitions];
+ partitionWriterSegments = new FileSegment[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ final Tuple2 tempShuffleBlockIdPlusFile =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = tempShuffleBlockIdPlusFile._2();
+ final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+ partitionWriters[i] =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+ }
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ writeMetrics.incWriteTime(System.nanoTime() - openStartTime);
- for (int i = 0; i < numPartitions; i++) {
- try (DiskBlockObjectWriter writer = partitionWriters[i]) {
- partitionWriterSegments[i] = writer.commitAndGet();
+ while (records.hasNext()) {
+ final Product2 record = records.next();
+ final K key = record._1();
+ partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}
- }
- File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
- File tmp = Utils.tempFileWith(output);
- try {
- partitionLengths = writePartitionedFile(tmp);
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
- } finally {
- if (tmp.exists() && !tmp.delete()) {
- logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
+ for (int i = 0; i < numPartitions; i++) {
+ try (DiskBlockObjectWriter writer = partitionWriters[i]) {
+ partitionWriterSegments[i] = writer.commitAndGet();
+ }
}
+
+ partitionLengths = writePartitionedData(mapOutputWriter);
+ mapOutputWriter.commitAllPartitions();
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ } catch (Exception e) {
+ try {
+ mapOutputWriter.abort(e);
+ } catch (Exception e2) {
+ logger.error("Failed to abort the writer after failing to write map output.", e2);
+ }
+ throw e;
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
@VisibleForTesting
@@ -179,37 +190,54 @@ long[] getPartitionLengths() {
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
- private long[] writePartitionedFile(File outputFile) throws IOException {
+ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
return lengths;
}
-
- final FileOutputStream out = new FileOutputStream(outputFile, true);
final long writeStartTime = System.nanoTime();
- boolean threwException = true;
try {
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
- if (file.exists()) {
- final FileInputStream in = new FileInputStream(file);
- boolean copyThrewException = true;
- try {
- lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
+ boolean copyThrewException = true;
+ ShufflePartitionWriter writer = null;
+ try {
+ writer = mapOutputWriter.getNextPartitionWriter();
+ if (!file.exists()) {
copyThrewException = false;
- } finally {
- Closeables.close(in, copyThrewException);
- }
- if (!file.delete()) {
- logger.error("Unable to delete file for partition {}", i);
+ } else {
+ if (transferToEnabled) {
+ WritableByteChannel outputChannel = writer.toChannel();
+ FileInputStream in = new FileInputStream(file);
+ try (FileChannel inputChannel = in.getChannel()) {
+ Utils.copyFileStreamNIO(inputChannel, outputChannel, 0, inputChannel.size());
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ }
+ } else {
+ OutputStream tempOutputStream = writer.toStream();
+ FileInputStream in = new FileInputStream(file);
+ try {
+ Utils.copyStream(in, tempOutputStream, false, false);
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ }
+ }
+ if (!file.delete()) {
+ logger.error("Unable to delete file for partition {}", i);
+ }
}
+ } finally {
+ Closeables.close(writer, copyThrewException);
}
+
+ lengths[i] = writer.getNumBytesWritten();
}
- threwException = false;
} finally {
- Closeables.close(out, threwException);
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
partitionWriters = null;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java
new file mode 100644
index 0000000000000..906600c0f15fc
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.shuffle.ShuffleExecutorComponents;
+import org.apache.spark.api.shuffle.ShuffleDataIO;
+
+public class DefaultShuffleDataIO implements ShuffleDataIO {
+
+ private final SparkConf sparkConf;
+
+ public DefaultShuffleDataIO(SparkConf sparkConf) {
+ this.sparkConf = sparkConf;
+ }
+
+ @Override
+ public ShuffleExecutorComponents executor() {
+ return new DefaultShuffleExecutorComponents(sparkConf);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java
new file mode 100644
index 0000000000000..76e87a6740259
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.api.shuffle.ShuffleExecutorComponents;
+import org.apache.spark.api.shuffle.ShuffleWriteSupport;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.storage.BlockManager;
+
+public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents {
+
+ private final SparkConf sparkConf;
+ private BlockManager blockManager;
+ private IndexShuffleBlockResolver blockResolver;
+
+ public DefaultShuffleExecutorComponents(SparkConf sparkConf) {
+ this.sparkConf = sparkConf;
+ }
+
+ @Override
+ public void initializeExecutor(String appId, String execId) {
+ blockManager = SparkEnv.get().blockManager();
+ blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager);
+ }
+
+ @Override
+ public ShuffleWriteSupport writes() {
+ if (blockResolver == null) {
+ throw new IllegalStateException(
+ "Executor components must be initialized before getting writers.");
+ }
+ return new DefaultShuffleWriteSupport(sparkConf, blockResolver);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
new file mode 100644
index 0000000000000..0f7e5ed66bb76
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.FileChannel;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
+import org.apache.spark.api.shuffle.ShufflePartitionWriter;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.storage.TimeTrackingOutputStream;
+import org.apache.spark.util.Utils;
+
+public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter {
+
+ private static final Logger log =
+ LoggerFactory.getLogger(DefaultShuffleMapOutputWriter.class);
+
+ private final int shuffleId;
+ private final int mapId;
+ private final ShuffleWriteMetricsReporter metrics;
+ private final IndexShuffleBlockResolver blockResolver;
+ private final long[] partitionLengths;
+ private final int bufferSize;
+ private int currPartitionId = 0;
+ private long currChannelPosition;
+
+ private final File outputFile;
+ private File outputTempFile;
+ private FileOutputStream outputFileStream;
+ private FileChannel outputFileChannel;
+ private TimeTrackingOutputStream ts;
+ private BufferedOutputStream outputBufferedFileStream;
+
+ public DefaultShuffleMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ int numPartitions,
+ ShuffleWriteMetricsReporter metrics,
+ IndexShuffleBlockResolver blockResolver,
+ SparkConf sparkConf) {
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.metrics = metrics;
+ this.blockResolver = blockResolver;
+ this.bufferSize =
+ (int) (long) sparkConf.get(
+ package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024;
+ this.partitionLengths = new long[numPartitions];
+ this.outputFile = blockResolver.getDataFile(shuffleId, mapId);
+ this.outputTempFile = null;
+ }
+
+ @Override
+ public ShufflePartitionWriter getNextPartitionWriter() throws IOException {
+ if (outputTempFile == null) {
+ outputTempFile = Utils.tempFileWith(outputFile);
+ }
+ if (outputFileChannel != null) {
+ currChannelPosition = outputFileChannel.position();
+ } else {
+ currChannelPosition = 0L;
+ }
+ return new DefaultShufflePartitionWriter(currPartitionId++);
+ }
+
+ @Override
+ public void commitAllPartitions() throws IOException {
+ cleanUp();
+ blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, outputTempFile);
+ }
+
+ @Override
+ public void abort(Throwable error) {
+ try {
+ cleanUp();
+ } catch (Exception e) {
+ log.error("Unable to close appropriate underlying file stream", e);
+ }
+ if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) {
+ log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath());
+ }
+ }
+
+ private void cleanUp() throws IOException {
+ if (outputBufferedFileStream != null) {
+ outputBufferedFileStream.close();
+ }
+
+ if (outputFileChannel != null) {
+ outputFileChannel.close();
+ }
+
+ if (outputFileStream != null) {
+ outputFileStream.close();
+ }
+ }
+
+ private void initStream() throws IOException {
+ if (outputFileStream == null) {
+ outputFileStream = new FileOutputStream(outputTempFile, true);
+ ts = new TimeTrackingOutputStream(metrics, outputFileStream);
+ }
+ if (outputBufferedFileStream == null) {
+ outputBufferedFileStream = new BufferedOutputStream(ts, bufferSize);
+ }
+ }
+
+ private void initChannel() throws IOException {
+ if (outputFileStream == null) {
+ outputFileStream = new FileOutputStream(outputTempFile, true);
+ }
+ if (outputFileChannel == null) {
+ outputFileChannel = outputFileStream.getChannel();
+ }
+ }
+
+ private class DefaultShufflePartitionWriter implements ShufflePartitionWriter {
+
+ private final int partitionId;
+ private PartitionWriterStream stream = null;
+
+ private DefaultShufflePartitionWriter(int partitionId) {
+ this.partitionId = partitionId;
+ }
+
+ @Override
+ public OutputStream toStream() throws IOException {
+ if (outputFileChannel != null) {
+ throw new IllegalStateException("Requested an output channel for a previous write but" +
+ " now an output stream has been requested. Should not be using both channels" +
+ " and streams to write.");
+ }
+ initStream();
+ stream = new PartitionWriterStream();
+ return stream;
+ }
+
+ @Override
+ public FileChannel toChannel() throws IOException {
+ if (stream != null) {
+ throw new IllegalStateException("Requested an output stream for a previous write but" +
+ " now an output channel has been requested. Should not be using both channels" +
+ " and streams to write.");
+ }
+ initChannel();
+ return outputFileChannel;
+ }
+
+ @Override
+ public long getNumBytesWritten() {
+ if (outputFileChannel != null && stream == null) {
+ try {
+ long newPosition = outputFileChannel.position();
+ return newPosition - currChannelPosition;
+ } catch (Exception e) {
+ log.error("The partition which failed is: {}", partitionId, e);
+ throw new IllegalStateException("Failed to calculate position of file channel", e);
+ }
+ } else if (stream != null) {
+ return stream.getCount();
+ } else {
+ // Assume an empty partition if stream and channel are never created
+ return 0;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (stream != null) {
+ stream.close();
+ }
+ partitionLengths[partitionId] = getNumBytesWritten();
+ }
+ }
+
+ private class PartitionWriterStream extends OutputStream {
+ private int count = 0;
+ private boolean isClosed = false;
+
+ public int getCount() {
+ return count;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ verifyNotClosed();
+ outputBufferedFileStream.write(b);
+ count++;
+ }
+
+ @Override
+ public void write(byte[] buf, int pos, int length) throws IOException {
+ verifyNotClosed();
+ outputBufferedFileStream.write(buf, pos, length);
+ count += length;
+ }
+
+ @Override
+ public void close() throws IOException {
+ flush();
+ isClosed = true;
+ }
+
+ @Override
+ public void flush() throws IOException {
+ if (!isClosed) {
+ outputBufferedFileStream.flush();
+ }
+ }
+
+ private void verifyNotClosed() {
+ if (isClosed) {
+ throw new IllegalStateException("Attempting to write to a closed block output stream.");
+ }
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java
new file mode 100644
index 0000000000000..f8fadd0ecfa63
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
+import org.apache.spark.api.shuffle.ShuffleWriteSupport;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+
+public class DefaultShuffleWriteSupport implements ShuffleWriteSupport {
+
+ private final SparkConf sparkConf;
+ private final IndexShuffleBlockResolver blockResolver;
+
+ public DefaultShuffleWriteSupport(
+ SparkConf sparkConf,
+ IndexShuffleBlockResolver blockResolver) {
+ this.sparkConf = sparkConf;
+ this.blockResolver = blockResolver;
+ }
+
+ @Override
+ public ShuffleMapOutputWriter createMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ int numPartitions) {
+ return new DefaultShuffleMapOutputWriter(
+ shuffleId, mapId, numPartitions,
+ TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf);
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index fe3b5d98969e1..a852a06be9125 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode}
+import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO
import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy}
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.util.Utils
@@ -768,6 +769,12 @@ package object config {
.booleanConf
.createWithDefault(false)
+ private[spark] val SHUFFLE_IO_PLUGIN_CLASS =
+ ConfigBuilder("spark.shuffle.io.plugin.class")
+ .doc("Name of the class to use for shuffle IO.")
+ .stringConf
+ .createWithDefault(classOf[DefaultShuffleDataIO].getName)
+
private[spark] val SHUFFLE_FILE_BUFFER_SIZE =
ConfigBuilder("spark.shuffle.file.buffer")
.doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " +
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index b59fa8e8a3ccd..5da7b5cb35e6d 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -20,8 +20,10 @@ package org.apache.spark.shuffle.sort
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark._
-import org.apache.spark.internal.Logging
+import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents}
+import org.apache.spark.internal.{config, Logging}
import org.apache.spark.shuffle._
+import org.apache.spark.util.Utils
/**
* In sort-based shuffle, incoming records are sorted according to their target partition ids, then
@@ -68,6 +70,8 @@ import org.apache.spark.shuffle._
*/
private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+ import SortShuffleManager._
+
if (!conf.getBoolean("spark.shuffle.spill", true)) {
logWarning(
"spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." +
@@ -79,6 +83,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
*/
private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
+ private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
+
override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
/**
@@ -148,7 +154,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
bypassMergeSortHandle,
mapId,
env.conf,
- metrics)
+ metrics,
+ shuffleExecutorComponents.writes())
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
}
@@ -205,6 +212,16 @@ private[spark] object SortShuffleManager extends Logging {
true
}
}
+
+ private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
+ val configuredPluginClasses = conf.get(config.SHUFFLE_IO_PLUGIN_CLASS)
+ val maybeIO = Utils.loadExtensions(
+ classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
+ require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")
+ val executorComponents = maybeIO.head.executor()
+ executorComponents.initializeExecutor(conf.getAppId, SparkEnv.get.executorId)
+ executorComponents
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 37b21e2eddba6..a9f3c73e90ab8 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException
import java.math.{MathContext, RoundingMode}
import java.net._
import java.nio.ByteBuffer
-import java.nio.channels.{Channels, FileChannel}
+import java.nio.channels.{Channels, FileChannel, WritableByteChannel}
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.security.SecureRandom
@@ -337,10 +337,14 @@ private[spark] object Utils extends Logging {
def copyFileStreamNIO(
input: FileChannel,
- output: FileChannel,
+ output: WritableByteChannel,
startPosition: Long,
bytesToCopy: Long): Unit = {
- val initialPos = output.position()
+ val outputInitialState = output match {
+ case outputFileChannel: FileChannel =>
+ Some((outputFileChannel.position(), outputFileChannel))
+ case _ => None
+ }
var count = 0L
// In case transferTo method transferred less data than we have required.
while (count < bytesToCopy) {
@@ -355,15 +359,17 @@ private[spark] object Utils extends Logging {
// kernel version 2.6.32, this issue can be seen in
// https://bugs.openjdk.java.net/browse/JDK-7052359
// This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
- val finalPos = output.position()
- val expectedPos = initialPos + bytesToCopy
- assert(finalPos == expectedPos,
- s"""
- |Current position $finalPos do not equal to expected position $expectedPos
- |after transferTo, please check your kernel version to see if it is 2.6.32,
- |this is a kernel bug which will lead to unexpected behavior when using transferTo.
- |You can set spark.file.transferTo = false to disable this NIO feature.
- """.stripMargin)
+ outputInitialState.foreach { case (initialPos, outputFileChannel) =>
+ val finalPos = outputFileChannel.position()
+ val expectedPos = initialPos + bytesToCopy
+ assert(finalPos == expectedPos,
+ s"""
+ |Current position $finalPos do not equal to expected position $expectedPos
+ |after transferTo, please check your kernel version to see if it is 2.6.32,
+ |this is a kernel bug which will lead to unexpected behavior when using transferTo.
+ |You can set spark.file.transferTo = false to disable this NIO feature.
+ """.stripMargin)
+ }
}
/**
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 8b1084a8edc76..90c790cefcca2 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListene
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId}
-import org.apache.spark.util.{MutablePair, Utils}
+import org.apache.spark.util.MutablePair
abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext {
@@ -368,7 +368,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)
val writer1 = manager.getWriter[Int, Int](
shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics)
- val data1 = (1 to 10).map { x => x -> x}
+ val data1 = (1 to 10).map { x => x -> x }
// second attempt -- also successful. We'll write out different data,
// just to simulate the fact that the records may get written differently
@@ -383,13 +383,17 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// simultaneously, and everything is still OK
def writeAndClose(
- writer: ShuffleWriter[Int, Int])(
+ writer: ShuffleWriter[Int, Int],
+ taskContext: TaskContext)(
iter: Iterator[(Int, Int)]): Option[MapStatus] = {
+ TaskContext.setTaskContext(taskContext)
val files = writer.write(iter)
- writer.stop(true)
+ val status = writer.stop(true)
+ TaskContext.unset
+ status
}
val interleaver = new InterleaveIterators(
- data1, writeAndClose(writer1), data2, writeAndClose(writer2))
+ data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2))
val (mapOutput1, mapOutput2) = interleaver.run()
// check that we can read the map output and it has the right data
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala
index 7f67affe56364..7eb867fc29fd2 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala
@@ -19,7 +19,7 @@ package org.apache.spark.shuffle.sort
import org.apache.spark.SparkConf
import org.apache.spark.benchmark.Benchmark
-import org.apache.spark.util.Utils
+import org.apache.spark.shuffle.sort.io.{DefaultShuffleWriteSupport}
/**
* Benchmark to measure performance for aggregate primitives.
@@ -46,6 +46,7 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase
def getWriter(transferTo: Boolean): BypassMergeSortShuffleWriter[String, String] = {
val conf = new SparkConf(loadDefaults = false)
+ val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver)
conf.set("spark.file.transferTo", String.valueOf(transferTo))
conf.set("spark.shuffle.file.buffer", "32k")
@@ -55,7 +56,8 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase
shuffleHandle,
0,
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ shuffleWriteSupport
)
shuffleWriter
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 7f956c26d0ff0..23f03957ae11a 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -18,23 +18,27 @@
package org.apache.spark.shuffle.sort
import java.io.File
-import java.util.UUID
+import java.util.{Properties, UUID}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.mockito.{Mock, MockitoAnnotations}
import org.mockito.Answers.RETURNS_SMART_NULLS
-import org.mockito.ArgumentMatchers.{any, anyInt}
+import org.mockito.ArgumentMatchers.{any, anyInt, anyString}
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfterEach
+import scala.util.Random
import org.apache.spark._
+import org.apache.spark.api.shuffle.ShuffleWriteSupport
import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport
import org.apache.spark.storage._
import org.apache.spark.util.Utils
@@ -49,7 +53,9 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
private var taskMetrics: TaskMetrics = _
private var tempDir: File = _
private var outputFile: File = _
+ private var writeSupport: ShuffleWriteSupport = _
private val conf: SparkConf = new SparkConf(loadDefaults = false)
+ .set("spark.app.id", "sampleApp")
private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _
@@ -118,9 +124,27 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get
}
})
+
+ val memoryManager = new TestMemoryManager(conf)
+ val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
+ when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
+
+ TaskContext.setTaskContext(new TaskContextImpl(
+ stageId = 0,
+ stageAttemptNumber = 0,
+ partitionId = 0,
+ taskAttemptId = Random.nextInt(10000),
+ attemptNumber = 0,
+ taskMemoryManager = taskMemoryManager,
+ localProperties = new Properties,
+ metricsSystem = null,
+ taskMetrics = taskMetrics))
+
+ writeSupport = new DefaultShuffleWriteSupport(conf, blockResolver)
}
override def afterEach(): Unit = {
+ TaskContext.unset()
try {
Utils.deleteRecursively(tempDir)
blockIdToFileMap.clear()
@@ -137,7 +161,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
shuffleHandle,
0, // MapId
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ writeSupport
)
writer.write(Iterator.empty)
writer.stop( /* success = */ true)
@@ -153,6 +178,33 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
}
test("write with some empty partitions") {
+ val transferConf = conf.clone.set("spark.file.transferTo", "false")
+ def records: Iterator[(Int, Int)] =
+ Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ blockManager,
+ blockResolver,
+ shuffleHandle,
+ 0, // MapId
+ transferConf,
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ writeSupport
+ )
+ writer.write(records)
+ writer.stop( /* success = */ true)
+ assert(temporaryFilesCreated.nonEmpty)
+ assert(writer.getPartitionLengths.sum === outputFile.length())
+ assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files
+ assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
+ assert(shuffleWriteMetrics.bytesWritten === outputFile.length())
+ assert(shuffleWriteMetrics.recordsWritten === records.length)
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+ }
+
+ // TODO(ifilonenko): MAKE THIS PASS
+ test("write with some empty partitions with transferTo") {
def records: Iterator[(Int, Int)] =
Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
val writer = new BypassMergeSortShuffleWriter[Int, Int](
@@ -161,7 +213,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
shuffleHandle,
0, // MapId
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ writeSupport
)
writer.write(records)
writer.stop( /* success = */ true)
@@ -196,7 +249,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
shuffleHandle,
0, // MapId
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ writeSupport
)
intercept[SparkException] {
@@ -218,7 +272,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
shuffleHandle,
0, // MapId
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ writeSupport
)
intercept[SparkException] {
writer.write((0 until 100000).iterator.map(i => {
@@ -232,5 +287,4 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
writer.stop( /* success = */ false)
assert(temporaryFilesCreated.count(_.exists()) === 0)
}
-
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala
index 62cc13fa107f0..ce1abde421fca 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala
@@ -22,7 +22,7 @@ import org.mockito.Mockito.when
import org.apache.spark.{Aggregator, SparkEnv}
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.shuffle.BaseShuffleHandle
-import org.apache.spark.util.Utils
+import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO
/**
* Benchmark to measure performance for aggregate primitives.
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala
new file mode 100644
index 0000000000000..22d52924a7c72
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io
+
+import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
+import java.math.BigInteger
+import java.nio.ByteBuffer
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.{any, anyInt, anyLong}
+import org.mockito.Mock
+import org.mockito.Mockito.{doAnswer, doNothing, when}
+import org.mockito.MockitoAnnotations
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.network.util.LimitedInputStream
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.util.Utils
+
+class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var shuffleWriteMetrics: ShuffleWriteMetrics = _
+
+ private val NUM_PARTITIONS = 4
+ private val D_LEN = 10
+ private val data: Array[Array[Int]] = (0 until NUM_PARTITIONS).map {
+ p => (1 to D_LEN).map(_ + p).toArray }.toArray
+
+ private var tempFile: File = _
+ private var mergedOutputFile: File = _
+ private var tempDir: File = _
+ private var partitionSizesInMergedFile: Array[Long] = _
+ private var conf: SparkConf = _
+ private var mapOutputWriter: DefaultShuffleMapOutputWriter = _
+
+ override def afterEach(): Unit = {
+ try {
+ Utils.deleteRecursively(tempDir)
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ override def beforeEach(): Unit = {
+ MockitoAnnotations.initMocks(this)
+ tempDir = Utils.createTempDir(null, "test")
+ mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir)
+ tempFile = File.createTempFile("tempfile", "", tempDir)
+ partitionSizesInMergedFile = null
+ conf = new SparkConf()
+ .set("spark.app.id", "example.spark.app")
+ .set("spark.shuffle.unsafe.file.output.buffer", "16k")
+ when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile)
+
+ doNothing().when(shuffleWriteMetrics).incWriteTime(anyLong)
+
+ doAnswer(new Answer[Void] {
+ def answer(invocationOnMock: InvocationOnMock): Void = {
+ partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]]
+ val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ if (tmp != null) {
+ mergedOutputFile.delete
+ tmp.renameTo(mergedOutputFile)
+ }
+ null
+ }
+ }).when(blockResolver)
+ .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))
+ mapOutputWriter = new DefaultShuffleMapOutputWriter(
+ 0, 0, NUM_PARTITIONS, shuffleWriteMetrics, blockResolver, conf)
+ }
+
+ private def readRecordsFromFile(fromByte: Boolean): Array[Array[Int]] = {
+ var startOffset = 0L
+ val result = new Array[Array[Int]](NUM_PARTITIONS)
+ (0 until NUM_PARTITIONS).foreach { p =>
+ val partitionSize = partitionSizesInMergedFile(p).toInt
+ lazy val inner = new Array[Int](partitionSize)
+ lazy val innerBytebuffer = ByteBuffer.allocate(partitionSize)
+ if (partitionSize > 0) {
+ val in = new FileInputStream(mergedOutputFile)
+ in.getChannel.position(startOffset)
+ val lin = new LimitedInputStream(in, partitionSize)
+ var nonEmpty = true
+ var count = 0
+ while (nonEmpty) {
+ try {
+ val readBit = lin.read()
+ if (fromByte) {
+ innerBytebuffer.put(readBit.toByte)
+ } else {
+ inner(count) = readBit
+ }
+ count += 1
+ } catch {
+ case _: Exception =>
+ nonEmpty = false
+ }
+ }
+ in.close()
+ }
+ if (fromByte) {
+ result(p) = innerBytebuffer.array().sliding(4, 4).map { b =>
+ new BigInteger(b).intValue()
+ }.toArray
+ } else {
+ result(p) = inner
+ }
+ startOffset += partitionSize
+ }
+ result
+ }
+
+ test("writing to an outputstream") {
+ (0 until NUM_PARTITIONS).foreach{ p =>
+ val writer = mapOutputWriter.getNextPartitionWriter
+ val stream = writer.toStream()
+ data(p).foreach { i => stream.write(i)}
+ stream.close()
+ intercept[IllegalStateException] {
+ stream.write(p)
+ }
+ assert(writer.getNumBytesWritten() == D_LEN)
+ writer.close
+ }
+ mapOutputWriter.commitAllPartitions()
+ val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray
+ assert(partitionSizesInMergedFile === partitionLengths)
+ assert(mergedOutputFile.length() === partitionLengths.sum)
+ assert(data === readRecordsFromFile(false))
+ }
+
+ test("writing to a channel") {
+ (0 until NUM_PARTITIONS).foreach{ p =>
+ val writer = mapOutputWriter.getNextPartitionWriter
+ val channel = writer.toChannel()
+ val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
+ val intBuffer = byteBuffer.asIntBuffer()
+ intBuffer.put(data(p))
+ assert(channel.isOpen)
+ channel.write(byteBuffer)
+ // Bytes require * 4
+ assert(writer.getNumBytesWritten == D_LEN * 4)
+ writer.close
+ }
+ mapOutputWriter.commitAllPartitions()
+ val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
+ assert(partitionSizesInMergedFile === partitionLengths)
+ assert(mergedOutputFile.length() === partitionLengths.sum)
+ assert(data === readRecordsFromFile(true))
+ }
+
+ test("copyStreams with an outputstream") {
+ (0 until NUM_PARTITIONS).foreach{ p =>
+ val writer = mapOutputWriter.getNextPartitionWriter
+ val stream = writer.toStream()
+ val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
+ val intBuffer = byteBuffer.asIntBuffer()
+ intBuffer.put(data(p))
+ val in = new ByteArrayInputStream(byteBuffer.array())
+ Utils.copyStream(in, stream, false, false)
+ in.close()
+ stream.close()
+ assert(writer.getNumBytesWritten == D_LEN * 4)
+ writer.close
+ }
+ mapOutputWriter.commitAllPartitions()
+ val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
+ assert(partitionSizesInMergedFile === partitionLengths)
+ assert(mergedOutputFile.length() === partitionLengths.sum)
+ assert(data === readRecordsFromFile(true))
+ }
+
+ test("copyStreamsWithNIO with a channel") {
+ (0 until NUM_PARTITIONS).foreach{ p =>
+ val writer = mapOutputWriter.getNextPartitionWriter
+ val channel = writer.toChannel()
+ val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
+ val intBuffer = byteBuffer.asIntBuffer()
+ intBuffer.put(data(p))
+ val out = new FileOutputStream(tempFile)
+ out.write(byteBuffer.array())
+ out.close()
+ val in = new FileInputStream(tempFile)
+ Utils.copyFileStreamNIO(in.getChannel, channel, 0, D_LEN * 4)
+ in.close()
+ assert(writer.getNumBytesWritten == D_LEN * 4)
+ writer.close
+ }
+ mapOutputWriter.commitAllPartitions()
+ val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
+ assert(partitionSizesInMergedFile === partitionLengths)
+ assert(mergedOutputFile.length() === partitionLengths.sum)
+ assert(data === readRecordsFromFile(true))
+ }
+}