diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b9365f45a11a..17407f4ee21f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1142,6 +1142,12 @@ final class ShuffleBlockFetcherIterator( s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." logWarning(diagnosisResponse) diagnosisResponse + case shuffleBlockBatch: ShuffleBlockBatchId => + val diagnosisResponse = s"BlockBatch $shuffleBlockBatch is corrupted " + + s"but corruption diagnosis is skipped due to lack of shuffle checksum support for " + + s"ShuffleBlockBatchId" + logWarning(diagnosisResponse) + diagnosisResponse case unexpected: BlockId => throw SparkException.internalError( s"Unexpected type of BlockId, $unexpected", category = "STORAGE") diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index f2d5f27a66cc..a9902cb4ccb4 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1941,4 +1941,31 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(err2.getMessage.contains("corrupt at reset")) } + + test("SPARK-43242: Fix throw 'Unexpected type of BlockId' in shuffle corruption diagnose") { + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockBatchId(0, 0, 0, 3) -> createMockManagedBuffer()) + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + listener.onBlockFetchSuccess(ShuffleBlockBatchId(0, 0, 0, 3).toString, mockCorruptBuffer()) + } + + val logAppender = new LogAppender("diagnose corruption") + withLogAppender(logAppender) { + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + streamWrapperLimitSize = Some(100) + ) + intercept[FetchFailedException](iterator.next()) + verify(transfer, times(2)) + .fetchBlocks(any(), any(), any(), any(), any(), any()) + assert(logAppender.loggingEvents.count( + _.getMessage.getFormattedMessage.contains("Start corruption diagnosis")) === 1) + assert(logAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains("shuffle_0_0_0_3 is corrupted " + + "but corruption diagnosis is skipped due to lack of " + + "shuffle checksum support for ShuffleBlockBatchId"))) + } + } }