Skip to content

Commit c79de07

Browse files
committed
address comments
1 parent fb3b31e commit c79de07

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,11 +422,14 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
422422
for (int partition = 0; partition < numPartitions; partition++) {
423423
for (int i = 0; i < spills.length; i++) {
424424
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
425-
long bytesToTransfer = partitionLengthInSpill;
426425
final FileChannel spillInputChannel = spillInputChannels[i];
427426
final long writeStartTime = System.nanoTime();
428-
Utils.copyFileStreamNIO(spillInputChannel, mergedFileOutputChannel, bytesToTransfer);
429-
spillInputChannelPositions[i] += bytesToTransfer;
427+
Utils.copyFileStreamNIO(
428+
spillInputChannel,
429+
mergedFileOutputChannel,
430+
spillInputChannelPositions[i],
431+
partitionLengthInSpill);
432+
spillInputChannelPositions[i] += partitionLengthInSpill;
430433
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
431434
bytesWrittenToMergedFile += partitionLengthInSpill;
432435
partitionLengths[partition] += partitionLengthInSpill;

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ private[spark] object Utils extends Logging {
330330
val inChannel = in.asInstanceOf[FileInputStream].getChannel()
331331
val outChannel = out.asInstanceOf[FileOutputStream].getChannel()
332332
val size = inChannel.size()
333-
copyFileStreamNIO(inChannel, outChannel, size)
333+
copyFileStreamNIO(inChannel, outChannel, 0, size)
334334
size
335335
} else {
336336
var count = 0L
@@ -356,13 +356,19 @@ private[spark] object Utils extends Logging {
356356
}
357357
}
358358

359-
def copyFileStreamNIO(input: FileChannel, output: FileChannel, bytesToCopy: Long): Unit = {
359+
def copyFileStreamNIO(
360+
input: FileChannel,
361+
output: FileChannel,
362+
startPosition: Long,
363+
bytesToCopy: Long): Unit = {
360364
val initialPos = output.position()
361365
var count = 0L
362366
// In case transferTo method transferred less data than we have required.
363367
while (count < bytesToCopy) {
364-
count += input.transferTo(count, bytesToCopy - count, output)
368+
count += input.transferTo(count + startPosition, bytesToCopy - count, output)
365369
}
370+
assert(count == bytesToCopy,
371+
s"request to copy $bytesToCopy bytes, but actually copied $count bytes.")
366372

367373
// Check the position after transferTo loop to see if it is in the right position and
368374
// give user information if not.

0 commit comments

Comments
 (0)