Skip to content

Commit 1436c5a

Browse files
committed
[SPARK-44647][SQL] Support SPJ where join keys are less than cluster keys
### What changes were proposed in this pull request? - Add new conf spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled - Change key compatibility checks in EnsureRequirements. Remove checks where all partition keys must be in join keys to allow isKeyCompatible = true in this case (if this flag is enabled) - "Project" partitions by join keys in KeyGroupedPartitioning/KeyGroupedShuffleSpec - Add join key grouping to the partition grouping in BatchScanExec ### Why are the changes needed? - Support Storage Partition Join in cases where the join condition does not contain all the partition keys, but just some of them ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? -Added tests in KeyGroupedPartitioningSuite -Because of apache#37886 we have to select all join keys to trigger SPJ in this case, otherwise DSV2 scan does not report KeyGroupedPartitioning and SPJ does not get triggered. Need to see how to relax this in separate PR.
1 parent fd424ca commit 1436c5a

File tree

5 files changed

+375
-32
lines changed

5 files changed

+375
-32
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,14 @@ case class KeyGroupedPartitioning(
355355
} else {
356356
// We'll need to find leaf attributes from the partition expressions first.
357357
val attributes = expressions.flatMap(_.collectLeaves())
358-
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
358+
359+
if (SQLConf.get.getConf(
360+
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)) {
361+
requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) &&
362+
expressions.forall(_.collectLeaves().size == 1)
363+
} else {
364+
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
365+
}
359366
}
360367

361368
case _ =>
@@ -364,8 +371,21 @@ case class KeyGroupedPartitioning(
364371
}
365372
}
366373

367-
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
368-
KeyGroupedShuffleSpec(this, distribution)
374+
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
375+
var result = KeyGroupedShuffleSpec(this, distribution)
376+
if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
377+
// If allowing join keys to be subset of clustering keys, we should create a new
378+
// `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as
379+
// the returned shuffle spec.
380+
val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
381+
val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions,
382+
partitionValues, originalPartitionValues)
383+
result = result.copy(partitioning = projectedPartitioning, joinKeyPositions =
384+
Some(joinKeyPositions))
385+
}
386+
387+
result
388+
}
369389

370390
lazy val uniquePartitionValues: Seq[InternalRow] = {
371391
partitionValues
@@ -378,8 +398,25 @@ case class KeyGroupedPartitioning(
378398
object KeyGroupedPartitioning {
379399
def apply(
380400
expressions: Seq[Expression],
381-
partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
382-
KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues)
401+
projectionPositions: Seq[Int],
402+
partitionValues: Seq[InternalRow],
403+
originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
404+
val projectedExpressions = projectionPositions.map(expressions(_))
405+
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
406+
val projectedOriginalPartitionValues =
407+
originalPartitionValues.map(project(expressions, projectionPositions, _))
408+
409+
KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length,
410+
projectedPartitionValues, projectedOriginalPartitionValues)
411+
}
412+
413+
def project(
414+
expressions: Seq[Expression],
415+
positions: Seq[Int],
416+
input: InternalRow): InternalRow = {
417+
val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType))
418+
.toArray
419+
new GenericInternalRow(projectedValues)
383420
}
384421

385422
def supportsExpressions(expressions: Seq[Expression]): Boolean = {
@@ -674,7 +711,8 @@ case class HashShuffleSpec(
674711

675712
case class KeyGroupedShuffleSpec(
676713
partitioning: KeyGroupedPartitioning,
677-
distribution: ClusteredDistribution) extends ShuffleSpec {
714+
distribution: ClusteredDistribution,
715+
joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec {
678716

679717
/**
680718
* A sequence where each element is a set of positions of the partition expression to the cluster
@@ -709,7 +747,7 @@ case class KeyGroupedShuffleSpec(
709747
// 3.3 each pair of partition expressions at the same index must share compatible
710748
// transform functions.
711749
// 4. the partition values from both sides are following the same order.
712-
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) =>
750+
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) =>
713751
distribution.clustering.length == otherDistribution.clustering.length &&
714752
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
715753
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,6 +1530,18 @@ object SQLConf {
15301530
.booleanConf
15311531
.createWithDefault(false)
15321532

1533+
val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS =
1534+
buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled")
1535+
.doc("Whether to allow storage-partition join in the case where join keys are" +
1536+
"a subset of the partition keys of the source tables. At planning time, " +
1537+
"Spark will group the partitions by only those keys that are in the join keys." +
1538+
"This is currently enabled only if spark.sql.sources.v2.bucketing.pushPartValues.enabled " +
1539+
"is also enabled."
1540+
)
1541+
.version("4.0.0")
1542+
.booleanConf
1543+
.createWithDefault(false)
1544+
15331545
val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
15341546
.doc("The maximum number of buckets allowed.")
15351547
.version("2.4.0")
@@ -4936,6 +4948,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
49364948
def v2BucketingShuffleEnabled: Boolean =
49374949
getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)
49384950

4951+
def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
4952+
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)
4953+
49394954
def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
49404955
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
49414956

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,12 @@ case class BatchScanExec(
120120
val newPartValues = spjParams.commonPartitionValues.get.flatMap {
121121
case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
122122
}
123-
k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues)
123+
val expressions = spjParams.joinKeyPositions match {
124+
case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i))
125+
case _ => k.expressions
126+
}
127+
k.copy(expressions = expressions, numPartitions = newPartValues.length,
128+
partitionValues = newPartValues)
124129
case p => p
125130
}
126131
}
@@ -132,14 +137,29 @@ case class BatchScanExec(
132137
// return an empty RDD with 1 partition if dynamic filtering removed the only split
133138
sparkContext.parallelize(Array.empty[InternalRow], 1)
134139
} else {
135-
var finalPartitions = filteredPartitions
136-
137-
outputPartitioning match {
140+
val finalPartitions = outputPartitioning match {
138141
case p: KeyGroupedPartitioning =>
139-
val groupedPartitions = filteredPartitions.map(splits => {
140-
assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
141-
(splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
142-
})
142+
assert(spjParams.keyGroupedPartitioning.isDefined)
143+
val expressions = spjParams.keyGroupedPartitioning.get
144+
145+
// Re-group the input partitions if we are projecting on a subset of join keys
146+
val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match {
147+
case Some(projectPositions) =>
148+
val projectedExpressions = projectPositions.map(i => expressions(i))
149+
val parts = filteredPartitions.flatten.groupBy(part => {
150+
val row = part.asInstanceOf[HasPartitionKey].partitionKey()
151+
val projectedRow = KeyGroupedPartitioning.project(
152+
expressions, projectPositions, row)
153+
InternalRowComparableWrapper(projectedRow, projectedExpressions)
154+
}).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq
155+
(parts, projectedExpressions)
156+
case _ =>
157+
val groupedParts = filteredPartitions.map(splits => {
158+
assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
159+
(splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
160+
})
161+
(groupedParts, expressions)
162+
}
143163

144164
// When partially clustered, the input partitions are not grouped by partition
145165
// values. Here we'll need to check `commonPartitionValues` and decide how to group
@@ -149,12 +169,12 @@ case class BatchScanExec(
149169
// should contain.
150170
val commonPartValuesMap = spjParams.commonPartitionValues
151171
.get
152-
.map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
172+
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
153173
.toMap
154174
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
155175
// `commonPartValuesMap` should contain the part value since it's the super set.
156176
val numSplits = commonPartValuesMap
157-
.get(InternalRowComparableWrapper(partValue, p.expressions))
177+
.get(InternalRowComparableWrapper(partValue, partExpressions))
158178
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
159179
"common partition values from Spark plan")
160180

@@ -169,37 +189,37 @@ case class BatchScanExec(
169189
// sides of a join will have the same number of partitions & splits.
170190
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
171191
}
172-
(InternalRowComparableWrapper(partValue, p.expressions), newSplits)
192+
(InternalRowComparableWrapper(partValue, partExpressions), newSplits)
173193
}
174194

175195
// Now fill missing partition keys with empty partitions
176196
val partitionMapping = nestGroupedPartitions.toMap
177-
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
197+
spjParams.commonPartitionValues.get.flatMap {
178198
case (partValue, numSplits) =>
179199
// Use empty partition for those partition values that are not present.
180200
partitionMapping.getOrElse(
181-
InternalRowComparableWrapper(partValue, p.expressions),
201+
InternalRowComparableWrapper(partValue, partExpressions),
182202
Seq.fill(numSplits)(Seq.empty))
183203
}
184204
} else {
185205
// either `commonPartitionValues` is not defined, or it is defined but
186206
// `applyPartialClustering` is false.
187207
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
188-
InternalRowComparableWrapper(partValue, p.expressions) -> splits
208+
InternalRowComparableWrapper(partValue, partExpressions) -> splits
189209
}.toMap
190210

191211
// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
192212
// could exist duplicated partition values, as partition grouping is not done
193213
// at the beginning and postponed to this method. It is important to use unique
194214
// partition values here so that grouped partitions won't get duplicated.
195-
finalPartitions = p.uniquePartitionValues.map { partValue =>
215+
p.uniquePartitionValues.map { partValue =>
196216
// Use empty partition for those partition values that are not present
197217
partitionMapping.getOrElse(
198-
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
218+
InternalRowComparableWrapper(partValue, partExpressions), Seq.empty)
199219
}
200220
}
201221

202-
case _ =>
222+
case _ => filteredPartitions
203223
}
204224

205225
new DataSourceRDD(
@@ -234,6 +254,7 @@ case class BatchScanExec(
234254

235255
case class StoragePartitionJoinParams(
236256
keyGroupedPartitioning: Option[Seq[Expression]] = None,
257+
joinKeyPositions: Option[Seq[Int]] = None,
237258
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
238259
applyPartialClustering: Boolean = false,
239260
replicatePartitions: Boolean = false) {
@@ -247,6 +268,7 @@ case class StoragePartitionJoinParams(
247268
}
248269

249270
override def hashCode(): Int = Objects.hashCode(
271+
joinKeyPositions: Option[Seq[Int]],
250272
commonPartitionValues: Option[Seq[(InternalRow, Int)]],
251273
applyPartialClustering: java.lang.Boolean,
252274
replicatePartitions: java.lang.Boolean)

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ case class EnsureRequirements(
380380
val rightSpec = specs(1)
381381

382382
var isCompatible = false
383-
if (!conf.v2BucketingPushPartValuesEnabled) {
383+
if (!conf.v2BucketingPushPartValuesEnabled &&
384+
!conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
384385
isCompatible = leftSpec.isCompatibleWith(rightSpec)
385386
} else {
386387
logInfo("Pushing common partition values for storage-partitioned join")
@@ -505,10 +506,10 @@ case class EnsureRequirements(
505506
}
506507

507508
// Now we need to push-down the common partition key to the scan in each child
508-
newLeft = populatePartitionValues(
509-
left, mergedPartValues, applyPartialClustering, replicateLeftSide)
510-
newRight = populatePartitionValues(
511-
right, mergedPartValues, applyPartialClustering, replicateRightSide)
509+
newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions,
510+
applyPartialClustering, replicateLeftSide)
511+
newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions,
512+
applyPartialClustering, replicateRightSide)
512513
}
513514
}
514515

@@ -530,19 +531,21 @@ case class EnsureRequirements(
530531
private def populatePartitionValues(
531532
plan: SparkPlan,
532533
values: Seq[(InternalRow, Int)],
534+
joinKeyPositions: Option[Seq[Int]],
533535
applyPartialClustering: Boolean,
534536
replicatePartitions: Boolean): SparkPlan = plan match {
535537
case scan: BatchScanExec =>
536538
scan.copy(
537539
spjParams = scan.spjParams.copy(
538540
commonPartitionValues = Some(values),
541+
joinKeyPositions = joinKeyPositions,
539542
applyPartialClustering = applyPartialClustering,
540543
replicatePartitions = replicatePartitions
541544
)
542545
)
543546
case node =>
544547
node.mapChildren(child => populatePartitionValues(
545-
child, values, applyPartialClustering, replicatePartitions))
548+
child, values, joinKeyPositions, applyPartialClustering, replicatePartitions))
546549
}
547550

548551
/**

0 commit comments

Comments
 (0)