diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java index 1edf044225ccf..4fc20bad9938b 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleExecutorComponents.java @@ -27,7 +27,7 @@ */ @Experimental public interface ShuffleExecutorComponents { - void intitializeExecutor(String appId, String execId); + void initializeExecutor(String appId, String execId); ShuffleWriteSupport writes(); } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java index c043a6b3a4995..6a53803e5d117 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java @@ -17,6 +17,7 @@ package org.apache.spark.api.shuffle; +import java.io.Closeable; import java.io.IOException; import java.io.OutputStream; import java.nio.channels.Channels; @@ -26,17 +27,48 @@ /** * :: Experimental :: - * An interface for giving streams / channels for shuffle writes + * An interface for giving streams / channels for shuffle writes. * * @since 3.0.0 */ @Experimental -public interface ShufflePartitionWriter { - OutputStream openStream() throws IOException; +public interface ShufflePartitionWriter extends Closeable { - long getLength(); + /** + * Returns an underlying {@link OutputStream} that can write bytes to the underlying data store. + *

+ * Note that this stream itself is not closed by the caller; close the stream in the + * implementation of this interface's {@link #close()}. + */ + OutputStream toStream() throws IOException; - default WritableByteChannel openChannel() throws IOException { - return Channels.newChannel(openStream()); + /** + * Returns an underlying {@link WritableByteChannel} that can write bytes to the underlying data + * store. + *

+ * Note that this channel itself is not closed by the caller; close the channel in the + * implementation of this interface's {@link #close()}. + */ + default WritableByteChannel toChannel() throws IOException { + return Channels.newChannel(toStream()); } + + /** + * Get the number of bytes written by this writer's stream returned by {@link #toStream()} or + * the channel returned by {@link #toChannel()}. + */ + long getNumBytesWritten(); + + /** + * Close all resources created by this ShufflePartitionWriter, via calls to {@link #toStream()} + * or {@link #toChannel()}. + *

+ * This must always close any stream returned by {@link #toStream()}. + *

+ * Note that the default version of {@link #toChannel()} returns a {@link WritableByteChannel} + * that does not itself need to be closed up front; only the underlying output stream given by + * {@link #toStream()} must be closed. + */ + @Override + void close() throws IOException; } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java index 5ba5564bb46d0..6c69d5db9fd06 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java @@ -30,7 +30,6 @@ @Experimental public interface ShuffleWriteSupport { ShuffleMapOutputWriter createMapOutputWriter( - String appId, int shuffleId, int mapId, int numPartitions) throws IOException; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 32b446785a9f0..aef133fe7d46a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -19,8 +19,10 @@ import java.io.File; import java.io.FileInputStream; -import java.io.FileOutputStream; import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; import javax.annotation.Nullable; import scala.None$; @@ -34,6 +36,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShufflePartitionWriter; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; import org.apache.spark.internal.config.package$; import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; @@ -82,6 +87,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int shuffleId; private final int mapId; private final Serializer serializer; + private final ShuffleWriteSupport shuffleWriteSupport; private final IndexShuffleBlockResolver shuffleBlockResolver; /** Array of file writers, one for each partition */ @@ -103,7 +109,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { BypassMergeSortShuffleHandle handle, int mapId, SparkConf conf, - ShuffleWriteMetricsReporter writeMetrics) { + ShuffleWriteMetricsReporter writeMetrics, + ShuffleWriteSupport shuffleWriteSupport) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); @@ -116,57 +123,61 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); this.shuffleBlockResolver = shuffleBlockResolver; + this.shuffleWriteSupport = shuffleWriteSupport; } @Override public void write(Iterator> records) throws IOException { assert (partitionWriters == null); - if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); - return; - } - final SerializerInstance serInstance = serializer.newInstance(); - final long openStartTime = System.nanoTime(); - partitionWriters = new DiskBlockObjectWriter[numPartitions]; - partitionWriterSegments = new FileSegment[numPartitions]; - for (int i = 0; i < numPartitions; i++) { - final Tuple2 tempShuffleBlockIdPlusFile = - blockManager.diskBlockManager().createTempShuffleBlock(); - final File file = tempShuffleBlockIdPlusFile._2(); - final BlockId blockId = tempShuffleBlockIdPlusFile._1(); - partitionWriters[i] = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); - } - // Creating the file to write to and creating a disk writer both involve interacting with - // the disk, and can take a long time in aggregate when we open many files, so should be - // included in the shuffle write time. - writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - - while (records.hasNext()) { - final Product2 record = records.next(); - final K key = record._1(); - partitionWriters[partitioner.getPartition(key)].write(key, record._2()); - } + ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport + .createMapOutputWriter(shuffleId, mapId, numPartitions); + try { + if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new DiskBlockObjectWriter[numPartitions]; + partitionWriterSegments = new FileSegment[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2 tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - for (int i = 0; i < numPartitions; i++) { - try (DiskBlockObjectWriter writer = partitionWriters[i]) { - partitionWriterSegments[i] = writer.commitAndGet(); + while (records.hasNext()) { + final Product2 record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); } - } - File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); - File tmp = Utils.tempFileWith(output); - try { - partitionLengths = writePartitionedFile(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); - } finally { - if (tmp.exists() && !tmp.delete()) { - logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + for (int i = 0; i < numPartitions; i++) { + try (DiskBlockObjectWriter writer = partitionWriters[i]) { + partitionWriterSegments[i] = writer.commitAndGet(); + } } + + partitionLengths = writePartitionedData(mapOutputWriter); + mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } catch (Exception e) { + try { + mapOutputWriter.abort(e); + } catch (Exception e2) { + logger.error("Failed to abort the writer after failing to write map output.", e2); + } + throw e; } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting @@ -179,37 +190,54 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] writePartitionedFile(File outputFile) throws IOException { + private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator return lengths; } - - final FileOutputStream out = new FileOutputStream(outputFile, true); final long writeStartTime = System.nanoTime(); - boolean threwException = true; try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); - if (file.exists()) { - final FileInputStream in = new FileInputStream(file); - boolean copyThrewException = true; - try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + boolean copyThrewException = true; + ShufflePartitionWriter writer = null; + try { + writer = mapOutputWriter.getNextPartitionWriter(); + if (!file.exists()) { copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - } - if (!file.delete()) { - logger.error("Unable to delete file for partition {}", i); + } else { + if (transferToEnabled) { + WritableByteChannel outputChannel = writer.toChannel(); + FileInputStream in = new FileInputStream(file); + try (FileChannel inputChannel = in.getChannel()) { + Utils.copyFileStreamNIO(inputChannel, outputChannel, 0, inputChannel.size()); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + } else { + OutputStream tempOutputStream = writer.toStream(); + FileInputStream in = new FileInputStream(file); + try { + Utils.copyStream(in, tempOutputStream, false, false); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + } + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); + } } + } finally { + Closeables.close(writer, copyThrewException); } + + lengths[i] = writer.getNumBytesWritten(); } - threwException = false; } finally { - Closeables.close(out, threwException); writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java new file mode 100644 index 0000000000000..906600c0f15fc --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleDataIO.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.shuffle.ShuffleExecutorComponents; +import org.apache.spark.api.shuffle.ShuffleDataIO; + +public class DefaultShuffleDataIO implements ShuffleDataIO { + + private final SparkConf sparkConf; + + public DefaultShuffleDataIO(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @Override + public ShuffleExecutorComponents executor() { + return new DefaultShuffleExecutorComponents(sparkConf); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java new file mode 100644 index 0000000000000..76e87a6740259 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleExecutorComponents.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.api.shuffle.ShuffleExecutorComponents; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.storage.BlockManager; + +public class DefaultShuffleExecutorComponents implements ShuffleExecutorComponents { + + private final SparkConf sparkConf; + private BlockManager blockManager; + private IndexShuffleBlockResolver blockResolver; + + public DefaultShuffleExecutorComponents(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @Override + public void initializeExecutor(String appId, String execId) { + blockManager = SparkEnv.get().blockManager(); + blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + } + + @Override + public ShuffleWriteSupport writes() { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + return new DefaultShuffleWriteSupport(sparkConf, blockResolver); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java new file mode 100644 index 0000000000000..0f7e5ed66bb76 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.FileChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShufflePartitionWriter; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.internal.config.package$; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.storage.TimeTrackingOutputStream; +import org.apache.spark.util.Utils; + +public class DefaultShuffleMapOutputWriter implements ShuffleMapOutputWriter { + + private static final Logger log = + LoggerFactory.getLogger(DefaultShuffleMapOutputWriter.class); + + private final int shuffleId; + private final int mapId; + private final ShuffleWriteMetricsReporter metrics; + private final IndexShuffleBlockResolver blockResolver; + private final long[] partitionLengths; + private final int bufferSize; + private int currPartitionId = 0; + private long currChannelPosition; + + private final File outputFile; + private File outputTempFile; + private FileOutputStream outputFileStream; + private FileChannel outputFileChannel; + private TimeTrackingOutputStream ts; + private BufferedOutputStream outputBufferedFileStream; + + public DefaultShuffleMapOutputWriter( + int shuffleId, + int mapId, + int numPartitions, + ShuffleWriteMetricsReporter metrics, + IndexShuffleBlockResolver blockResolver, + SparkConf sparkConf) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.metrics = metrics; + this.blockResolver = blockResolver; + this.bufferSize = + (int) (long) sparkConf.get( + package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; + this.partitionLengths = new long[numPartitions]; + this.outputFile = blockResolver.getDataFile(shuffleId, mapId); + this.outputTempFile = null; + } + + @Override + public ShufflePartitionWriter getNextPartitionWriter() throws IOException { + if (outputTempFile == null) { + outputTempFile = Utils.tempFileWith(outputFile); + } + if (outputFileChannel != null) { + currChannelPosition = outputFileChannel.position(); + } else { + currChannelPosition = 0L; + } + return new DefaultShufflePartitionWriter(currPartitionId++); + } + + @Override + public void commitAllPartitions() throws IOException { + cleanUp(); + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, outputTempFile); + } + + @Override + public void abort(Throwable error) { + try { + cleanUp(); + } catch (Exception e) { + log.error("Unable to close appropriate underlying file stream", e); + } + if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) { + log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath()); + } + } + + private void cleanUp() throws IOException { + if (outputBufferedFileStream != null) { + outputBufferedFileStream.close(); + } + + if (outputFileChannel != null) { + outputFileChannel.close(); + } + + if (outputFileStream != null) { + outputFileStream.close(); + } + } + + private void initStream() throws IOException { + if (outputFileStream == null) { + outputFileStream = new FileOutputStream(outputTempFile, true); + ts = new TimeTrackingOutputStream(metrics, outputFileStream); + } + if (outputBufferedFileStream == null) { + outputBufferedFileStream = new BufferedOutputStream(ts, bufferSize); + } + } + + private void initChannel() throws IOException { + if (outputFileStream == null) { + outputFileStream = new FileOutputStream(outputTempFile, true); + } + if (outputFileChannel == null) { + outputFileChannel = outputFileStream.getChannel(); + } + } + + private class DefaultShufflePartitionWriter implements ShufflePartitionWriter { + + private final int partitionId; + private PartitionWriterStream stream = null; + + private DefaultShufflePartitionWriter(int partitionId) { + this.partitionId = partitionId; + } + + @Override + public OutputStream toStream() throws IOException { + if (outputFileChannel != null) { + throw new IllegalStateException("Requested an output channel for a previous write but" + + " now an output stream has been requested. Should not be using both channels" + + " and streams to write."); + } + initStream(); + stream = new PartitionWriterStream(); + return stream; + } + + @Override + public FileChannel toChannel() throws IOException { + if (stream != null) { + throw new IllegalStateException("Requested an output stream for a previous write but" + + " now an output channel has been requested. Should not be using both channels" + + " and streams to write."); + } + initChannel(); + return outputFileChannel; + } + + @Override + public long getNumBytesWritten() { + if (outputFileChannel != null && stream == null) { + try { + long newPosition = outputFileChannel.position(); + return newPosition - currChannelPosition; + } catch (Exception e) { + log.error("The partition which failed is: {}", partitionId, e); + throw new IllegalStateException("Failed to calculate position of file channel", e); + } + } else if (stream != null) { + return stream.getCount(); + } else { + // Assume an empty partition if stream and channel are never created + return 0; + } + } + + @Override + public void close() throws IOException { + if (stream != null) { + stream.close(); + } + partitionLengths[partitionId] = getNumBytesWritten(); + } + } + + private class PartitionWriterStream extends OutputStream { + private int count = 0; + private boolean isClosed = false; + + public int getCount() { + return count; + } + + @Override + public void write(int b) throws IOException { + verifyNotClosed(); + outputBufferedFileStream.write(b); + count++; + } + + @Override + public void write(byte[] buf, int pos, int length) throws IOException { + verifyNotClosed(); + outputBufferedFileStream.write(buf, pos, length); + count += length; + } + + @Override + public void close() throws IOException { + flush(); + isClosed = true; + } + + @Override + public void flush() throws IOException { + if (!isClosed) { + outputBufferedFileStream.flush(); + } + } + + private void verifyNotClosed() { + if (isClosed) { + throw new IllegalStateException("Attempting to write to a closed block output stream."); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java new file mode 100644 index 0000000000000..f8fadd0ecfa63 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleWriteSupport.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; + +public class DefaultShuffleWriteSupport implements ShuffleWriteSupport { + + private final SparkConf sparkConf; + private final IndexShuffleBlockResolver blockResolver; + + public DefaultShuffleWriteSupport( + SparkConf sparkConf, + IndexShuffleBlockResolver blockResolver) { + this.sparkConf = sparkConf; + this.blockResolver = blockResolver; + } + + @Override + public ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + int numPartitions) { + return new DefaultShuffleMapOutputWriter( + shuffleId, mapId, numPartitions, + TaskContext.get().taskMetrics().shuffleWriteMetrics(), blockResolver, sparkConf); + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index fe3b5d98969e1..a852a06be9125 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -22,6 +22,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode} +import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils @@ -768,6 +769,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val SHUFFLE_IO_PLUGIN_CLASS = + ConfigBuilder("spark.shuffle.io.plugin.class") + .doc("Name of the class to use for shuffle IO.") + .stringConf + .createWithDefault(classOf[DefaultShuffleDataIO].getName) + private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b59fa8e8a3ccd..5da7b5cb35e6d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -20,8 +20,10 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap import org.apache.spark._ -import org.apache.spark.internal.Logging +import org.apache.spark.api.shuffle.{ShuffleDataIO, ShuffleExecutorComponents} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ +import org.apache.spark.util.Utils /** * In sort-based shuffle, incoming records are sorted according to their target partition ids, then @@ -68,6 +70,8 @@ import org.apache.spark.shuffle._ */ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + import SortShuffleManager._ + if (!conf.getBoolean("spark.shuffle.spill", true)) { logWarning( "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + @@ -79,6 +83,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager */ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** @@ -148,7 +154,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager bypassMergeSortHandle, mapId, env.conf, - metrics) + metrics, + shuffleExecutorComponents.writes()) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } @@ -205,6 +212,16 @@ private[spark] object SortShuffleManager extends Logging { true } } + + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val configuredPluginClasses = conf.get(config.SHUFFLE_IO_PLUGIN_CLASS) + val maybeIO = Utils.loadExtensions( + classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) + require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") + val executorComponents = maybeIO.head.executor() + executorComponents.initializeExecutor(conf.getAppId, SparkEnv.get.executorId) + executorComponents + } } /** diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 37b21e2eddba6..a9f3c73e90ab8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer -import java.nio.channels.{Channels, FileChannel} +import java.nio.channels.{Channels, FileChannel, WritableByteChannel} import java.nio.charset.StandardCharsets import java.nio.file.Files import java.security.SecureRandom @@ -337,10 +337,14 @@ private[spark] object Utils extends Logging { def copyFileStreamNIO( input: FileChannel, - output: FileChannel, + output: WritableByteChannel, startPosition: Long, bytesToCopy: Long): Unit = { - val initialPos = output.position() + val outputInitialState = output match { + case outputFileChannel: FileChannel => + Some((outputFileChannel.position(), outputFileChannel)) + case _ => None + } var count = 0L // In case transferTo method transferred less data than we have required. while (count < bytesToCopy) { @@ -355,15 +359,17 @@ private[spark] object Utils extends Logging { // kernel version 2.6.32, this issue can be seen in // https://bugs.openjdk.java.net/browse/JDK-7052359 // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). - val finalPos = output.position() - val expectedPos = initialPos + bytesToCopy - assert(finalPos == expectedPos, - s""" - |Current position $finalPos do not equal to expected position $expectedPos - |after transferTo, please check your kernel version to see if it is 2.6.32, - |this is a kernel bug which will lead to unexpected behavior when using transferTo. - |You can set spark.file.transferTo = false to disable this NIO feature. - """.stripMargin) + outputInitialState.foreach { case (initialPos, outputFileChannel) => + val finalPos = outputFileChannel.position() + val expectedPos = initialPos + bytesToCopy + assert(finalPos == expectedPos, + s""" + |Current position $finalPos do not equal to expected position $expectedPos + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) + } } /** diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 8b1084a8edc76..90c790cefcca2 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListene import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId} -import org.apache.spark.util.{MutablePair, Utils} +import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { @@ -368,7 +368,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem) val writer1 = manager.getWriter[Int, Int]( shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics) - val data1 = (1 to 10).map { x => x -> x} + val data1 = (1 to 10).map { x => x -> x } // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently @@ -383,13 +383,17 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // simultaneously, and everything is still OK def writeAndClose( - writer: ShuffleWriter[Int, Int])( + writer: ShuffleWriter[Int, Int], + taskContext: TaskContext)( iter: Iterator[(Int, Int)]): Option[MapStatus] = { + TaskContext.setTaskContext(taskContext) val files = writer.write(iter) - writer.stop(true) + val status = writer.stop(true) + TaskContext.unset + status } val interleaver = new InterleaveIterators( - data1, writeAndClose(writer1), data2, writeAndClose(writer2)) + data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2)) val (mapOutput1, mapOutput2) = interleaver.run() // check that we can read the map output and it has the right data diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index 7f67affe56364..7eb867fc29fd2 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark -import org.apache.spark.util.Utils +import org.apache.spark.shuffle.sort.io.{DefaultShuffleWriteSupport} /** * Benchmark to measure performance for aggregate primitives. @@ -46,6 +46,7 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase def getWriter(transferTo: Boolean): BypassMergeSortShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) + val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) conf.set("spark.file.transferTo", String.valueOf(transferTo)) conf.set("spark.shuffle.file.buffer", "32k") @@ -55,7 +56,8 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase shuffleHandle, 0, conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleWriteSupport ) shuffleWriter diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 7f956c26d0ff0..23f03957ae11a 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -18,23 +18,27 @@ package org.apache.spark.shuffle.sort import java.io.File -import java.util.UUID +import java.util.{Properties, UUID} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt} +import org.mockito.ArgumentMatchers.{any, anyInt, anyString} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach +import scala.util.Random import org.apache.spark._ +import org.apache.spark.api.shuffle.ShuffleWriteSupport import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -49,7 +53,9 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ private var outputFile: File = _ + private var writeSupport: ShuffleWriteSupport = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) + .set("spark.app.id", "sampleApp") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ @@ -118,9 +124,27 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get } }) + + val memoryManager = new TestMemoryManager(conf) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + stageAttemptNumber = 0, + partitionId = 0, + taskAttemptId = Random.nextInt(10000), + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + localProperties = new Properties, + metricsSystem = null, + taskMetrics = taskMetrics)) + + writeSupport = new DefaultShuffleWriteSupport(conf, blockResolver) } override def afterEach(): Unit = { + TaskContext.unset() try { Utils.deleteRecursively(tempDir) blockIdToFileMap.clear() @@ -137,7 +161,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) writer.write(Iterator.empty) writer.stop( /* success = */ true) @@ -153,6 +178,33 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte } test("write with some empty partitions") { + val transferConf = conf.clone.set("spark.file.transferTo", "false") + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + blockResolver, + shuffleHandle, + 0, // MapId + transferConf, + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport + ) + writer.write(records) + writer.stop( /* success = */ true) + assert(temporaryFilesCreated.nonEmpty) + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } + + // TODO(ifilonenko): MAKE THIS PASS + test("write with some empty partitions with transferTo") { def records: Iterator[(Int, Int)] = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( @@ -161,7 +213,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) writer.write(records) writer.stop( /* success = */ true) @@ -196,7 +249,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) intercept[SparkException] { @@ -218,7 +272,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte shuffleHandle, 0, // MapId conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + writeSupport ) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { @@ -232,5 +287,4 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } - } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala index 62cc13fa107f0..ce1abde421fca 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterBenchmark.scala @@ -22,7 +22,7 @@ import org.mockito.Mockito.when import org.apache.spark.{Aggregator, SparkEnv} import org.apache.spark.benchmark.Benchmark import org.apache.spark.shuffle.BaseShuffleHandle -import org.apache.spark.util.Utils +import org.apache.spark.shuffle.sort.io.DefaultShuffleDataIO /** * Benchmark to measure performance for aggregate primitives. diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala new file mode 100644 index 0000000000000..22d52924a7c72 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io + +import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} +import java.math.BigInteger +import java.nio.ByteBuffer + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} +import org.mockito.Mock +import org.mockito.Mockito.{doAnswer, doNothing, when} +import org.mockito.MockitoAnnotations +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.network.util.LimitedInputStream +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.util.Utils + +class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private var shuffleWriteMetrics: ShuffleWriteMetrics = _ + + private val NUM_PARTITIONS = 4 + private val D_LEN = 10 + private val data: Array[Array[Int]] = (0 until NUM_PARTITIONS).map { + p => (1 to D_LEN).map(_ + p).toArray }.toArray + + private var tempFile: File = _ + private var mergedOutputFile: File = _ + private var tempDir: File = _ + private var partitionSizesInMergedFile: Array[Long] = _ + private var conf: SparkConf = _ + private var mapOutputWriter: DefaultShuffleMapOutputWriter = _ + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + override def beforeEach(): Unit = { + MockitoAnnotations.initMocks(this) + tempDir = Utils.createTempDir(null, "test") + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) + tempFile = File.createTempFile("tempfile", "", tempDir) + partitionSizesInMergedFile = null + conf = new SparkConf() + .set("spark.app.id", "example.spark.app") + .set("spark.shuffle.unsafe.file.output.buffer", "16k") + when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) + + doNothing().when(shuffleWriteMetrics).incWriteTime(anyLong) + + doAnswer(new Answer[Void] { + def answer(invocationOnMock: InvocationOnMock): Void = { + partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + mergedOutputFile.delete + tmp.renameTo(mergedOutputFile) + } + null + } + }).when(blockResolver) + .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) + mapOutputWriter = new DefaultShuffleMapOutputWriter( + 0, 0, NUM_PARTITIONS, shuffleWriteMetrics, blockResolver, conf) + } + + private def readRecordsFromFile(fromByte: Boolean): Array[Array[Int]] = { + var startOffset = 0L + val result = new Array[Array[Int]](NUM_PARTITIONS) + (0 until NUM_PARTITIONS).foreach { p => + val partitionSize = partitionSizesInMergedFile(p).toInt + lazy val inner = new Array[Int](partitionSize) + lazy val innerBytebuffer = ByteBuffer.allocate(partitionSize) + if (partitionSize > 0) { + val in = new FileInputStream(mergedOutputFile) + in.getChannel.position(startOffset) + val lin = new LimitedInputStream(in, partitionSize) + var nonEmpty = true + var count = 0 + while (nonEmpty) { + try { + val readBit = lin.read() + if (fromByte) { + innerBytebuffer.put(readBit.toByte) + } else { + inner(count) = readBit + } + count += 1 + } catch { + case _: Exception => + nonEmpty = false + } + } + in.close() + } + if (fromByte) { + result(p) = innerBytebuffer.array().sliding(4, 4).map { b => + new BigInteger(b).intValue() + }.toArray + } else { + result(p) = inner + } + startOffset += partitionSize + } + result + } + + test("writing to an outputstream") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getNextPartitionWriter + val stream = writer.toStream() + data(p).foreach { i => stream.write(i)} + stream.close() + intercept[IllegalStateException] { + stream.write(p) + } + assert(writer.getNumBytesWritten() == D_LEN) + writer.close + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(false)) + } + + test("writing to a channel") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getNextPartitionWriter + val channel = writer.toChannel() + val byteBuffer = ByteBuffer.allocate(D_LEN * 4) + val intBuffer = byteBuffer.asIntBuffer() + intBuffer.put(data(p)) + assert(channel.isOpen) + channel.write(byteBuffer) + // Bytes require * 4 + assert(writer.getNumBytesWritten == D_LEN * 4) + writer.close + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(true)) + } + + test("copyStreams with an outputstream") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getNextPartitionWriter + val stream = writer.toStream() + val byteBuffer = ByteBuffer.allocate(D_LEN * 4) + val intBuffer = byteBuffer.asIntBuffer() + intBuffer.put(data(p)) + val in = new ByteArrayInputStream(byteBuffer.array()) + Utils.copyStream(in, stream, false, false) + in.close() + stream.close() + assert(writer.getNumBytesWritten == D_LEN * 4) + writer.close + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(true)) + } + + test("copyStreamsWithNIO with a channel") { + (0 until NUM_PARTITIONS).foreach{ p => + val writer = mapOutputWriter.getNextPartitionWriter + val channel = writer.toChannel() + val byteBuffer = ByteBuffer.allocate(D_LEN * 4) + val intBuffer = byteBuffer.asIntBuffer() + intBuffer.put(data(p)) + val out = new FileOutputStream(tempFile) + out.write(byteBuffer.array()) + out.close() + val in = new FileInputStream(tempFile) + Utils.copyFileStreamNIO(in.getChannel, channel, 0, D_LEN * 4) + in.close() + assert(writer.getNumBytesWritten == D_LEN * 4) + writer.close + } + mapOutputWriter.commitAllPartitions() + val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile(true)) + } +}