-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-54663][CORE] Computes RowBasedChecksum in ShuffleWriters #50230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
4cd559f
53d11af
c7675b1
7b89c44
64dd36b
422e370
89901ca
a1c50fa
db59634
04e08eb
c9c28e6
d82bad2
2575d52
3b99edb
74266a5
22c79c8
df48158
cf28940
4cfaac8
786fdd3
602729c
137f254
dde16d4
f7d9dfa
5aabe70
2fd0a94
bbe26bf
1a8e9f7
97af717
5e01c52
ce29311
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,125 @@ | ||||||
| /* | ||||||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||||||
| * contributor license agreements. See the NOTICE file distributed with | ||||||
| * this work for additional information regarding copyright ownership. | ||||||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||||||
| * (the "License"); you may not use this file except in compliance with | ||||||
| * the License. You may obtain a copy of the License at | ||||||
| * | ||||||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| * | ||||||
| * Unless required by applicable law or agreed to in writing, software | ||||||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| * See the License for the specific language governing permissions and | ||||||
| * limitations under the License. | ||||||
| */ | ||||||
|
|
||||||
| package org.apache.spark.shuffle.checksum | ||||||
|
|
||||||
| import java.io.{ByteArrayOutputStream, ObjectOutputStream} | ||||||
| import java.util.zip.Checksum | ||||||
|
|
||||||
| import scala.util.control.NonFatal | ||||||
|
|
||||||
| import org.apache.spark.internal.Logging | ||||||
| import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper | ||||||
|
|
||||||
| /** | ||||||
| * A class for computing checksum for input (key, value) pairs. The checksum is independent of | ||||||
| * the order of the input (key, value) pairs. It is done by computing a checksum for each row | ||||||
| * first, and then computing the XOR for all the row checksums. | ||||||
| */ | ||||||
| abstract class RowBasedChecksum() extends Serializable with Logging { | ||||||
| private var hasError: Boolean = false | ||||||
| private var checksumValue: Long = 0 | ||||||
| /** Returns the checksum value computed. Tt returns the default checksum value (0) if there | ||||||
| * are any errors encountered during the checksum computation. | ||||||
| */ | ||||||
| def getValue: Long = { | ||||||
| if (!hasError) checksumValue else 0 | ||||||
| } | ||||||
|
|
||||||
| /** Updates the row-based checksum with the given (key, value) pair */ | ||||||
| def update(key: Any, value: Any): Unit = { | ||||||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| if (!hasError) { | ||||||
| try { | ||||||
| val rowChecksumValue = calculateRowChecksum(key, value) | ||||||
| checksumValue = checksumValue ^ rowChecksumValue | ||||||
|
||||||
| } catch { | ||||||
| case NonFatal(e) => | ||||||
| logInfo("Checksum computation encountered error: ", e) | ||||||
|
||||||
| logInfo("Checksum computation encountered error: ", e) | |
| logError("Checksum computation encountered error: ", e) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a question: I don't see this used anywhere except in tests, why not have it in core/src/test instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved to test package.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have other case of MyByteArrayOutputStream for this purpose.
Refactor to reuse it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
mridulm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: a comment to say it is for testing only or better would be move to a helper class used in the tests only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe into the ShuffleChecksumTestHelper? But its name suggest it is only for shuffle checksum. So what about an extra rename to ChecksumTestHelper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I see the comment above:
Note that this checksum computation is very expensive, and it is used only in tests
in the core component. A much cheaper implementation of RowBasedChecksum is in
UnsafeRowChecksum.
And I can see your comment:
I can't use UnsafeRowChecksum.scala in the test because the test is in core, while the usaferow is in sql. So I added OutputStreamRowBasedChecksum for the tests in core.
But you can move this class and object to the test code of the core module, is not it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you cannot move the whole object so let's just move the method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the method to ShuffleChecksumTestHelper, didn't rename the class as currently all the new added classes/components were in shuffle package.
mridulm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
mridulm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -53,6 +53,7 @@ | |||||||
| import org.apache.spark.scheduler.MapStatus$; | ||||||||
| import org.apache.spark.serializer.Serializer; | ||||||||
| import org.apache.spark.serializer.SerializerInstance; | ||||||||
| import org.apache.spark.shuffle.checksum.RowBasedChecksum; | ||||||||
| import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; | ||||||||
| import org.apache.spark.shuffle.ShuffleWriter; | ||||||||
| import org.apache.spark.storage.*; | ||||||||
|
|
@@ -104,6 +105,14 @@ final class BypassMergeSortShuffleWriter<K, V> | |||||||
| private long[] partitionLengths; | ||||||||
| /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */ | ||||||||
| private final Checksum[] partitionChecksums; | ||||||||
| /** | ||||||||
| * Checksum calculator for each partition. Different from the above Checksum, | ||||||||
| * RowBasedChecksum is independent of the input row order, which is used to | ||||||||
| * detect whether different task attempts of the same partition produce different | ||||||||
| * output data or not. | ||||||||
| */ | ||||||||
| private final RowBasedChecksum[] rowBasedChecksums; | ||||||||
| private final SparkConf conf; | ||||||||
|
||||||||
|
|
||||||||
| /** | ||||||||
| * Are we in the process of stopping? Because map tasks can call stop() with success = true | ||||||||
|
|
@@ -132,6 +141,8 @@ final class BypassMergeSortShuffleWriter<K, V> | |||||||
| this.serializer = dep.serializer(); | ||||||||
| this.shuffleExecutorComponents = shuffleExecutorComponents; | ||||||||
| this.partitionChecksums = createPartitionChecksums(numPartitions, conf); | ||||||||
| this.rowBasedChecksums = dep.rowBasedChecksums(); | ||||||||
| this.conf = conf; | ||||||||
|
||||||||
| } | ||||||||
|
|
||||||||
| @Override | ||||||||
|
|
@@ -144,7 +155,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException { | |||||||
| partitionLengths = mapOutputWriter.commitAllPartitions( | ||||||||
| ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths(); | ||||||||
| mapStatus = MapStatus$.MODULE$.apply( | ||||||||
| blockManager.shuffleServerId(), partitionLengths, mapId); | ||||||||
| blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); | ||||||||
| return; | ||||||||
| } | ||||||||
| final SerializerInstance serInstance = serializer.newInstance(); | ||||||||
|
|
@@ -171,7 +182,11 @@ public void write(Iterator<Product2<K, V>> records) throws IOException { | |||||||
| while (records.hasNext()) { | ||||||||
| final Product2<K, V> record = records.next(); | ||||||||
| final K key = record._1(); | ||||||||
| partitionWriters[partitioner.getPartition(key)].write(key, record._2()); | ||||||||
| final int partitionId = partitioner.getPartition(key); | ||||||||
| partitionWriters[partitionId].write(key, record._2()); | ||||||||
| if (rowBasedChecksums.length > 0) { | ||||||||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| rowBasedChecksums[partitionId].update(key, record._2()); | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| for (int i = 0; i < numPartitions; i++) { | ||||||||
|
|
@@ -182,7 +197,7 @@ public void write(Iterator<Product2<K, V>> records) throws IOException { | |||||||
|
|
||||||||
| partitionLengths = writePartitionedData(mapOutputWriter); | ||||||||
| mapStatus = MapStatus$.MODULE$.apply( | ||||||||
| blockManager.shuffleServerId(), partitionLengths, mapId); | ||||||||
| blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); | ||||||||
| } catch (Exception e) { | ||||||||
| try { | ||||||||
| mapOutputWriter.abort(e); | ||||||||
|
|
@@ -199,6 +214,15 @@ public long[] getPartitionLengths() { | |||||||
| return partitionLengths; | ||||||||
| } | ||||||||
|
|
||||||||
| public RowBasedChecksum[] getRowBasedChecksums() { | ||||||||
|
||||||||
| return rowBasedChecksums; | ||||||||
| } | ||||||||
|
|
||||||||
| public long getAggregatedChecksumValue() { | ||||||||
| final long checksum = RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); | ||||||||
| return checksum; | ||||||||
|
||||||||
| final long checksum = RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); | |
| return checksum; | |
| return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,6 +60,7 @@ | |
| import org.apache.spark.shuffle.api.ShufflePartitionWriter; | ||
| import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter; | ||
| import org.apache.spark.shuffle.api.WritableByteChannelWrapper; | ||
| import org.apache.spark.shuffle.checksum.RowBasedChecksum; | ||
| import org.apache.spark.storage.BlockManager; | ||
| import org.apache.spark.storage.TimeTrackingOutputStream; | ||
| import org.apache.spark.unsafe.Platform; | ||
|
|
@@ -103,6 +104,13 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream | |
| private MyByteArrayOutputStream serBuffer; | ||
| private SerializationStream serOutputStream; | ||
|
|
||
| /** | ||
| * RowBasedChecksum calculator for each partition. RowBasedChecksum is independent | ||
| * of the input row order, which is used to detect whether different task attempts | ||
| * of the same partition produce different output data or not. | ||
| */ | ||
| private final RowBasedChecksum[] rowBasedChecksums; | ||
|
|
||
| /** | ||
| * Are we in the process of stopping? Because map tasks can call stop() with success = true | ||
| * and then call stop() with success = false if they get an exception, we want to make sure | ||
|
|
@@ -142,6 +150,7 @@ public UnsafeShuffleWriter( | |
| (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); | ||
| this.mergeBufferSizeInBytes = | ||
| (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_MERGE_BUFFER_SIZE()) * 1024; | ||
| this.rowBasedChecksums = dep.rowBasedChecksums(); | ||
| open(); | ||
| } | ||
|
|
||
|
|
@@ -163,6 +172,13 @@ public long getPeakMemoryUsedBytes() { | |
| return peakMemoryUsedBytes; | ||
| } | ||
|
|
||
| public RowBasedChecksum[] getRowBasedChecksums() { | ||
| return rowBasedChecksums; | ||
| } | ||
| public long getAggregatedChecksumValue() { | ||
|
||
| return RowBasedChecksum.getAggregatedChecksumValue(rowBasedChecksums); | ||
| } | ||
|
|
||
| /** | ||
| * This convenience method should only be called in test code. | ||
| */ | ||
|
|
@@ -234,7 +250,7 @@ void closeAndWriteOutput() throws IOException { | |
| } | ||
| } | ||
| mapStatus = MapStatus$.MODULE$.apply( | ||
| blockManager.shuffleServerId(), partitionLengths, mapId); | ||
| blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); | ||
| } | ||
|
|
||
| @VisibleForTesting | ||
|
|
@@ -252,6 +268,9 @@ void insertRecordIntoSorter(Product2<K, V> record) throws IOException { | |
|
|
||
| sorter.insertRecord( | ||
| serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); | ||
| if (rowBasedChecksums.length > 0) { | ||
| rowBasedChecksums[partitionId].update(key, record._2()); | ||
| } | ||
| } | ||
|
|
||
| @VisibleForTesting | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -29,6 +29,7 @@ import org.apache.spark.internal.LogKeys._ | |||||
| import org.apache.spark.rdd.RDD | ||||||
| import org.apache.spark.serializer.Serializer | ||||||
| import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} | ||||||
| import org.apache.spark.shuffle.checksum.RowBasedChecksum | ||||||
| import org.apache.spark.storage.BlockManagerId | ||||||
| import org.apache.spark.util.Utils | ||||||
|
|
||||||
|
|
@@ -74,6 +75,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { | |||||
| * @param aggregator map/reduce-side aggregator for RDD's shuffle | ||||||
| * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine) | ||||||
| * @param shuffleWriterProcessor the processor to control the write behavior in ShuffleMapTask | ||||||
| * @param rowBasedChecksums the row-based checksums for each shuffle partition | ||||||
| */ | ||||||
| @DeveloperApi | ||||||
| class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( | ||||||
|
|
@@ -83,9 +85,30 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( | |||||
| val keyOrdering: Option[Ordering[K]] = None, | ||||||
| val aggregator: Option[Aggregator[K, V, C]] = None, | ||||||
| val mapSideCombine: Boolean = false, | ||||||
| val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor) | ||||||
| val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, | ||||||
| val rowBasedChecksums: Array[RowBasedChecksum] = Array.empty) | ||||||
|
||||||
| extends Dependency[Product2[K, V]] with Logging { | ||||||
|
|
||||||
| def this( | ||||||
| rdd: RDD[_ <: Product2[K, V]], | ||||||
| partitioner: Partitioner, | ||||||
| serializer: Serializer, | ||||||
| keyOrdering: Option[Ordering[K]], | ||||||
| aggregator: Option[Aggregator[K, V, C]], | ||||||
| mapSideCombine: Boolean, | ||||||
| shuffleWriterProcessor: ShuffleWriteProcessor) = { | ||||||
| this( | ||||||
| rdd, | ||||||
| partitioner, | ||||||
| serializer, | ||||||
| keyOrdering, | ||||||
| aggregator, | ||||||
| mapSideCombine, | ||||||
| shuffleWriterProcessor, | ||||||
| Array.empty | ||||||
|
||||||
| Array.empty | |
| EMPTY_ROW_BASED_CHECKSUMS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE | |
| import java.util.concurrent.locks.ReentrantReadWriteLock | ||
|
|
||
| import scala.collection | ||
| import scala.collection.mutable.{HashMap, ListBuffer, Map} | ||
| import scala.collection.mutable.{HashMap, ListBuffer, Map, Set} | ||
| import scala.concurrent.{ExecutionContext, Future} | ||
| import scala.concurrent.duration.Duration | ||
| import scala.jdk.CollectionConverters._ | ||
|
|
@@ -99,6 +99,11 @@ private class ShuffleStatus( | |
| */ | ||
| val mapStatusesDeleted = new Array[MapStatus](numPartitions) | ||
|
|
||
| /** | ||
| * Keep the indices of the Map tasks whose checksums are different across retries. | ||
| */ | ||
| private[this] val checksumMismatchIndices : Set[Int] = Set() | ||
|
|
||
| /** | ||
| * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the | ||
| * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for | ||
|
|
@@ -169,6 +174,12 @@ private class ShuffleStatus( | |
| } else { | ||
| mapIdToMapIndex.remove(currentMapStatus.mapId) | ||
| } | ||
|
|
||
| val preStatus = | ||
| if (mapStatuses(mapIndex) != null) mapStatuses(mapIndex) else mapStatusesDeleted(mapIndex) | ||
| if (preStatus != null && preStatus.checksumValue != status.checksumValue) { | ||
| checksumMismatchIndices.add(mapIndex) | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are three main cases here:
For the latter two, we dont need to track it in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, for case 1, we need to track the mismatches. The usage of checksumMismatchIndices is that (in the next PR) we will rollback the downstream stages, if we detect checksum mismatches for its upstream stages. For case 2, if downstream stages have not consumed output, which means they have not started. In this case, the rollback is a no-op, and it doesn't hurt to record the mismatches here. For case 3, I think we need to record the mismatches. Assuming a situation where all partitions of a stage have finished, while some speculative tasks are still running. As all outputs have been produced, the downstream stage can start and read from the data. Later, some speculative tasks finish, and new mapStatus will override the old mapStatus with new data location. For the downstream stage, the not yet started tasks or retried tasks would read from the new data, while the finished and running tasks would read from the old data, resulting in inconsistency.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It is unclear how
That is fair, this is indeed possible. |
||
| mapStatuses(mapIndex) = status | ||
| mapIdToMapIndex(status.mapId) = mapIndex | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's update the classdoc. We now also leverage the sum to handle duplicated values better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated