diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 3242ac21ab32..c242320635a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -85,11 +85,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec]) .map(_.outputPartitioning.numPartitions) val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) { - // Here we pick the max number of partitions among these non-shuffle children as the - // expected number of shuffle partitions. However, if it's smaller than - // `conf.numShufflePartitions`, we pick `conf.numShufflePartitions` as the - // expected number of shuffle partitions. - math.max(nonShuffleChildrenNumPartitions.max, conf.defaultNumShufflePartitions) + if (nonShuffleChildrenNumPartitions.length == childrenIndexes.length) { + // Here we pick the max number of partitions among these non-shuffle children. + nonShuffleChildrenNumPartitions.max + } else { + // Here we pick the max number of partitions among these non-shuffle children as the + // expected number of shuffle partitions. However, if it's smaller than + // `conf.numShufflePartitions`, we pick `conf.numShufflePartitions` as the + // expected number of shuffle partitions. + math.max(nonShuffleChildrenNumPartitions.max, conf.defaultNumShufflePartitions) + } } else { childrenNumPartitions.max } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 14ba008e8996..558bfd7f4675 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -843,4 +843,32 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { } } } + + test("SPARK-32767 Bucket join should work if SHUFFLE_PARTITIONS larger than bucket number") { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "9", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10") { + + val testSpec1 = BucketedTableTestSpec( + Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), + numPartitions = 1, + expectedShuffle = false, + expectedSort = false) + val testSpec2 = BucketedTableTestSpec( + Some(BucketSpec(6, Seq("i", "j"), Seq("i", "j"))), + numPartitions = 1, + expectedShuffle = true, + expectedSort = true) + Seq(false, true).foreach { enableAdaptive => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> s"$enableAdaptive") { + Seq((testSpec1, testSpec2), (testSpec2, testSpec1)).foreach { specs => + testBucketing( + bucketedTableTestSpecLeft = specs._1, + bucketedTableTestSpecRight = specs._2, + joinCondition = joinCondition(Seq("i", "j"))) + } + } + } + } + } }