-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-45036][SQL] SPJ: Simplify the logic to handle partially clustered distribution #42757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par | |
| import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} | ||
| import org.apache.spark.sql.connector.catalog.Table | ||
| import org.apache.spark.sql.connector.read._ | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
||
| /** | ||
| * Physical plan node for scanning a batch of data from a data source v2. | ||
|
|
@@ -101,7 +100,7 @@ case class BatchScanExec( | |
| "partition values that are not present in the original partitioning.") | ||
| } | ||
|
|
||
| groupPartitions(newPartitions).get.map(_._2) | ||
| groupPartitions(newPartitions).get.groupedParts.map(_.parts) | ||
|
|
||
| case _ => | ||
| // no validation is needed as the data source did not report any specific partitioning | ||
|
|
@@ -137,81 +136,63 @@ case class BatchScanExec( | |
|
|
||
| outputPartitioning match { | ||
| case p: KeyGroupedPartitioning => | ||
| if (conf.v2BucketingPushPartValuesEnabled && | ||
| conf.v2BucketingPartiallyClusteredDistributionEnabled) { | ||
| assert(filteredPartitions.forall(_.size == 1), | ||
| "Expect partitions to be not grouped when " + | ||
| s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + | ||
| "is enabled") | ||
|
|
||
| val groupedPartitions = groupPartitions(finalPartitions.map(_.head), | ||
| groupSplits = true).get | ||
|
|
||
| // This means the input partitions are not grouped by partition values. We'll need to | ||
| // check `groupByPartitionValues` and decide whether to group and replicate splits | ||
| // within a partition. | ||
| if (spjParams.commonPartitionValues.isDefined && | ||
| spjParams.applyPartialClustering) { | ||
| // A mapping from the common partition values to how many splits the partition | ||
| // should contain. | ||
| val commonPartValuesMap = spjParams.commonPartitionValues | ||
| val groupedPartitions = filteredPartitions.map(splits => { | ||
| assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) | ||
| (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) | ||
| }) | ||
|
|
||
| // This means the input partitions are not grouped by partition values. We'll need to | ||
|
||
| // check `groupByPartitionValues` and decide whether to group and replicate splits | ||
| // within a partition. | ||
| if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { | ||
| // A mapping from the common partition values to how many splits the partition | ||
| // should contain. | ||
| val commonPartValuesMap = spjParams.commonPartitionValues | ||
| .get | ||
| .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) | ||
| .toMap | ||
| val nestGroupedPartitions = groupedPartitions.map { | ||
| case (partValue, splits) => | ||
| // `commonPartValuesMap` should contain the part value since it's the super set. | ||
| val numSplits = commonPartValuesMap | ||
| .get(InternalRowComparableWrapper(partValue, p.expressions)) | ||
| assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + | ||
| "common partition values from Spark plan") | ||
|
|
||
| val newSplits = if (spjParams.replicatePartitions) { | ||
| // We need to also replicate partitions according to the other side of join | ||
| Seq.fill(numSplits.get)(splits) | ||
| } else { | ||
| // Not grouping by partition values: this could be the side with partially | ||
| // clustered distribution. Because of dynamic filtering, we'll need to check if | ||
| // the final number of splits of a partition is smaller than the original | ||
| // number, and fill with empty splits if so. This is necessary so that both | ||
| // sides of a join will have the same number of partitions & splits. | ||
| splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) | ||
| } | ||
| (InternalRowComparableWrapper(partValue, p.expressions), newSplits) | ||
| val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => | ||
| // `commonPartValuesMap` should contain the part value since it's the super set. | ||
| val numSplits = commonPartValuesMap | ||
| .get(InternalRowComparableWrapper(partValue, p.expressions)) | ||
| assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + | ||
| "common partition values from Spark plan") | ||
|
|
||
| val newSplits = if (spjParams.replicatePartitions) { | ||
| // We need to also replicate partitions according to the other side of join | ||
| Seq.fill(numSplits.get)(splits) | ||
| } else { | ||
| // Not grouping by partition values: this could be the side with partially | ||
| // clustered distribution. Because of dynamic filtering, we'll need to check if | ||
| // the final number of splits of a partition is smaller than the original | ||
| // number, and fill with empty splits if so. This is necessary so that both | ||
| // sides of a join will have the same number of partitions & splits. | ||
| splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) | ||
| } | ||
| (InternalRowComparableWrapper(partValue, p.expressions), newSplits) | ||
| } | ||
|
|
||
| // Now fill missing partition keys with empty partitions | ||
| val partitionMapping = nestGroupedPartitions.toMap | ||
| finalPartitions = spjParams.commonPartitionValues.get.flatMap { | ||
| case (partValue, numSplits) => | ||
| // Use empty partition for those partition values that are not present. | ||
| partitionMapping.getOrElse( | ||
| InternalRowComparableWrapper(partValue, p.expressions), | ||
| Seq.fill(numSplits)(Seq.empty)) | ||
| } | ||
| } else { | ||
| // either `commonPartitionValues` is not defined, or it is defined but | ||
| // `applyPartialClustering` is false. | ||
| val partitionMapping = groupedPartitions.map { case (row, parts) => | ||
| InternalRowComparableWrapper(row, p.expressions) -> parts | ||
| }.toMap | ||
|
|
||
| // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there | ||
| // could exist duplicated partition values, as partition grouping is not done | ||
| // at the beginning and postponed to this method. It is important to use unique | ||
| // partition values here so that grouped partitions won't get duplicated. | ||
| finalPartitions = p.uniquePartitionValues.map { partValue => | ||
| // Use empty partition for those partition values that are not present | ||
| // Now fill missing partition keys with empty partitions | ||
| val partitionMapping = nestGroupedPartitions.toMap | ||
| finalPartitions = spjParams.commonPartitionValues.get.flatMap { | ||
| case (partValue, numSplits) => | ||
| // Use empty partition for those partition values that are not present. | ||
| partitionMapping.getOrElse( | ||
| InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) | ||
| } | ||
| InternalRowComparableWrapper(partValue, p.expressions), | ||
| Seq.fill(numSplits)(Seq.empty)) | ||
| } | ||
| } else { | ||
| val partitionMapping = finalPartitions.map { parts => | ||
| val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey() | ||
| InternalRowComparableWrapper(row, p.expressions) -> parts | ||
| // either `commonPartitionValues` is not defined, or it is defined but | ||
| // `applyPartialClustering` is false. | ||
| val partitionMapping = groupedPartitions.map { case (partValue, splits) => | ||
| InternalRowComparableWrapper(partValue, p.expressions) -> splits | ||
| }.toMap | ||
| finalPartitions = p.partitionValues.map { partValue => | ||
|
|
||
| // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there | ||
| // could exist duplicated partition values, as partition grouping is not done | ||
| // at the beginning and postponed to this method. It is important to use unique | ||
| // partition values here so that grouped partitions won't get duplicated. | ||
| finalPartitions = p.uniquePartitionValues.map { partValue => | ||
| // Use empty partition for those partition values that are not present | ||
| partitionMapping.getOrElse( | ||
| InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,8 +62,9 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { | |
| redact(result) | ||
| } | ||
|
|
||
| def partitions: Seq[Seq[InputPartition]] = | ||
| groupedPartitions.map(_.map(_._2)).getOrElse(inputPartitions.map(Seq(_))) | ||
| def partitions: Seq[Seq[InputPartition]] = { | ||
| groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_))) | ||
| } | ||
|
|
||
| /** | ||
| * Shorthand for calling redact() without specifying redacting rules | ||
|
|
@@ -94,16 +95,18 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { | |
| keyGroupedPartitioning match { | ||
| case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) => | ||
| groupedPartitions | ||
| .map { partitionValues => | ||
| KeyGroupedPartitioning(exprs, partitionValues.size, partitionValues.map(_._1)) | ||
| .map { keyGroupedPartsInfo => | ||
| val keyGroupedParts = keyGroupedPartsInfo.groupedParts | ||
| KeyGroupedPartitioning(exprs, keyGroupedParts.size, keyGroupedParts.map(_.value), | ||
| keyGroupedPartsInfo.originalParts.map(_.partitionKey())) | ||
| } | ||
| .getOrElse(super.outputPartitioning) | ||
| case _ => | ||
| super.outputPartitioning | ||
| } | ||
| } | ||
|
|
||
| @transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = { | ||
| @transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = { | ||
| // Early check if we actually need to materialize the input partitions. | ||
| keyGroupedPartitioning match { | ||
| case Some(_) => groupPartitions(inputPartitions) | ||
|
|
@@ -117,24 +120,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { | |
| * - all input partitions implement [[HasPartitionKey]] | ||
| * - `keyGroupedPartitioning` is set | ||
| * | ||
| * The result, if defined, is a list of tuples where the first element is a partition value, | ||
| * and the second element is a list of input partitions that share the same partition value. | ||
| * The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a list of | ||
| * [[KeyGroupedPartition]], as well as a list of partition values from the original input splits, | ||
| * sorted according to the partition keys in ascending order. | ||
| * | ||
| * A non-empty result means each partition is clustered on a single key and therefore eligible | ||
| * for further optimizations to eliminate shuffling in some operations such as join and aggregate. | ||
| */ | ||
| def groupPartitions( | ||
| inputPartitions: Seq[InputPartition], | ||
| groupSplits: Boolean = !conf.v2BucketingPushPartValuesEnabled || | ||
| !conf.v2BucketingPartiallyClusteredDistributionEnabled): | ||
| Option[Seq[(InternalRow, Seq[InputPartition])]] = { | ||
|
|
||
| def groupPartitions(inputPartitions: Seq[InputPartition]): Option[KeyGroupedPartitionInfo] = { | ||
| if (!SQLConf.get.v2BucketingEnabled) return None | ||
|
|
||
| keyGroupedPartitioning.flatMap { expressions => | ||
| val results = inputPartitions.takeWhile { | ||
| case _: HasPartitionKey => true | ||
| case _ => false | ||
| }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p)) | ||
| }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p.asInstanceOf[HasPartitionKey])) | ||
|
|
||
| if (results.length != inputPartitions.length || inputPartitions.isEmpty) { | ||
| // Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here. | ||
|
|
@@ -143,32 +143,25 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { | |
| // also sort the input partitions according to their partition key order. This ensures | ||
| // a canonical order from both sides of a bucketed join, for example. | ||
| val partitionDataTypes = expressions.map(_.dataType) | ||
| val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = { | ||
| val partitionOrdering: Ordering[(InternalRow, InputPartition)] = { | ||
| RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1) | ||
| } | ||
|
|
||
| val partitions = if (groupSplits) { | ||
| // Group the splits by their partition value | ||
| results | ||
| val sortedKeyToPartitions = results.sorted(partitionOrdering) | ||
| val groupedPartitions = sortedKeyToPartitions | ||
| .map(t => (InternalRowComparableWrapper(t._1, expressions), t._2)) | ||
| .groupBy(_._1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem likely comes from this groupBy, as there are some differences between Scala 2.12 and Scala 2.13. For example:
We can see that when using Scala 2.13.8, the order of the results has changed. The possible fix maybe:
Perhaps there are other better ways to fix it?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @LuciferYang for the findings! Yes it's a bug as I was assuming the order will be preserved in the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opened #42839 to fix this. |
||
| .toSeq | ||
| .map { | ||
| case (key, s) => (key.row, s.map(_._2)) | ||
| } | ||
| } else { | ||
| // No splits grouping, each split will become a separate Spark partition | ||
| results.map(t => (t._1, Seq(t._2))) | ||
| } | ||
| .map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) } | ||
|
|
||
| Some(partitions.sorted(partitionOrdering)) | ||
| Some(KeyGroupedPartitionInfo(groupedPartitions, sortedKeyToPartitions.map(_._2))) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def outputOrdering: Seq[SortOrder] = { | ||
| // when multiple partitions are grouped together, ordering inside partitions is not preserved | ||
| val partitioningPreservesOrdering = groupedPartitions.forall(_.forall(_._2.length <= 1)) | ||
| val partitioningPreservesOrdering = groupedPartitions | ||
| .forall(_.groupedParts.forall(_.parts.length <= 1)) | ||
| ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering) | ||
| } | ||
|
|
||
|
|
@@ -217,3 +210,19 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { | |
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A key-grouped Spark partition, which could consist of multiple input splits | ||
| * | ||
| * @param value the partition value shared by all the input splits | ||
| * @param parts the input splits that are grouped into a single Spark partition | ||
| */ | ||
| private[v2] case class KeyGroupedPartition(value: InternalRow, parts: Seq[InputPartition]) | ||
|
|
||
| /** | ||
| * Information about key-grouped partitions, which contains a list of grouped partitions as well | ||
| * as the original input partitions before the grouping. | ||
| */ | ||
| private[v2] case class KeyGroupedPartitionInfo( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like it would refer to info about one KeyGroupedPartition. How about KeyGroupedPartitionInfos ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought about it but is |
||
| groupedParts: Seq[KeyGroupedPartition], | ||
| originalParts: Seq[HasPartitionKey]) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think to add 'final cluster keys' to the javadoc , to make it even more clear?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure