@@ -373,9 +373,10 @@ private[spark] class ExternalSorter[K, V, C](
373373
374374 // Create our file writers if we haven't done so yet
375375 if (partitionWriters == null ) {
376+ curWriteMetrics = new ShuffleWriteMetrics ()
376377 partitionWriters = Array .fill(numPartitions) {
377378 val (blockId, file) = diskBlockManager.createTempBlock()
378- blockManager.getDiskWriter(blockId, file, ser, fileBufferSize).open()
379+ blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics ).open()
379380 }
380381 }
381382
@@ -734,10 +735,6 @@ private[spark] class ExternalSorter[K, V, C](
734735 val offsets = new Array [Long ](numPartitions + 1 )
735736 val lengths = new Array [Long ](numPartitions)
736737
737- // Statistics
738- var totalBytes = 0L
739- var totalTime = 0L
740-
741738 if (bypassMergeSort && partitionWriters != null ) {
742739 // We decided to write separate files for each partition, so just concatenate them. To keep
743740 // this simple we spill out the current in-memory collection so that everything is in files.
@@ -769,27 +766,22 @@ private[spark] class ExternalSorter[K, V, C](
769766 // partition and just write everything directly.
770767 for ((id, elements) <- this .partitionedIterator) {
771768 if (elements.hasNext) {
772- val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
769+ val writer = blockManager.getDiskWriter(
770+ blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get)
773771 for (elem <- elements) {
774772 writer.write(elem)
775773 }
776774 writer.commitAndClose()
777775 val segment = writer.fileSegment()
778776 offsets(id + 1 ) = segment.offset + segment.length
779777 lengths(id) = segment.length
780- totalTime += writer.timeWriting()
781- totalBytes += segment.length
782778 } else {
783779 // The partition is empty; don't create a new writer to avoid writing headers, etc
784780 offsets(id + 1 ) = offsets(id)
785781 }
786782 }
787783 }
788784
789- val shuffleMetrics = new ShuffleWriteMetrics
790- shuffleMetrics.shuffleBytesWritten = totalBytes
791- shuffleMetrics.shuffleWriteTime = totalTime
792- context.taskMetrics.shuffleWriteMetrics = Some (shuffleMetrics)
793785 context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
794786 context.taskMetrics.diskBytesSpilled += diskBytesSpilled
795787
0 commit comments