Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,14 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
* @param numPartitions the number of partitions
* @param partitionValues the values for the cluster keys of the distribution, must be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think to add 'final cluster keys' to the javadoc , to make it even more clear?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

* in ascending order.
* @param originalPartitionValues the original input partition values before any grouping has been
* applied, must be in ascending order.
*/
case class KeyGroupedPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {
partitionValues: Seq[InternalRow] = Seq.empty,
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {

override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
Expand Down Expand Up @@ -368,7 +371,7 @@ object KeyGroupedPartitioning {
def apply(
expressions: Seq[Expression],
partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues)
KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues)
}

def supportsExpressions(expressions: Seq[Expression]): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par
import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.internal.SQLConf

/**
* Physical plan node for scanning a batch of data from a data source v2.
Expand Down Expand Up @@ -101,7 +100,7 @@ case class BatchScanExec(
"partition values that are not present in the original partitioning.")
}

groupPartitions(newPartitions).get.map(_._2)
groupPartitions(newPartitions).get.groupedParts.map(_.parts)

case _ =>
// no validation is needed as the data source did not report any specific partitioning
Expand Down Expand Up @@ -137,81 +136,63 @@ case class BatchScanExec(

outputPartitioning match {
case p: KeyGroupedPartitioning =>
if (conf.v2BucketingPushPartValuesEnabled &&
conf.v2BucketingPartiallyClusteredDistributionEnabled) {
assert(filteredPartitions.forall(_.size == 1),
"Expect partitions to be not grouped when " +
s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
"is enabled")

val groupedPartitions = groupPartitions(finalPartitions.map(_.head),
groupSplits = true).get

// This means the input partitions are not grouped by partition values. We'll need to
// check `groupByPartitionValues` and decide whether to group and replicate splits
// within a partition.
if (spjParams.commonPartitionValues.isDefined &&
spjParams.applyPartialClustering) {
// A mapping from the common partition values to how many splits the partition
// should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
val groupedPartitions = filteredPartitions.map(splits => {
assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
(splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
})

// This means the input partitions are not grouped by partition values. We'll need to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can we clarify 'this' When partially-clustered, input partitions are not grouped by partition values

Nit: groupByPartitionValues seems never actually defined, can we fix it? Does it refer to groupedPartitions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let me improve this comments too.

// check `groupByPartitionValues` and decide whether to group and replicate splits
// within a partition.
if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) {
// A mapping from the common partition values to how many splits the partition
// should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
.get
.map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
.toMap
val nestGroupedPartitions = groupedPartitions.map {
case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, p.expressions))
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

val newSplits = if (spjParams.replicatePartitions) {
// We need to also replicate partitions according to the other side of join
Seq.fill(numSplits.get)(splits)
} else {
// Not grouping by partition values: this could be the side with partially
// clustered distribution. Because of dynamic filtering, we'll need to check if
// the final number of splits of a partition is smaller than the original
// number, and fill with empty splits if so. This is necessary so that both
// sides of a join will have the same number of partitions & splits.
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
}
(InternalRowComparableWrapper(partValue, p.expressions), newSplits)
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, p.expressions))
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

val newSplits = if (spjParams.replicatePartitions) {
// We need to also replicate partitions according to the other side of join
Seq.fill(numSplits.get)(splits)
} else {
// Not grouping by partition values: this could be the side with partially
// clustered distribution. Because of dynamic filtering, we'll need to check if
// the final number of splits of a partition is smaller than the original
// number, and fill with empty splits if so. This is necessary so that both
// sides of a join will have the same number of partitions & splits.
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
}
(InternalRowComparableWrapper(partValue, p.expressions), newSplits)
}

// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions),
Seq.fill(numSplits)(Seq.empty))
}
} else {
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (row, parts) =>
InternalRowComparableWrapper(row, p.expressions) -> parts
}.toMap

// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
// could exist duplicated partition values, as partition grouping is not done
// at the beginning and postponed to this method. It is important to use unique
// partition values here so that grouped partitions won't get duplicated.
finalPartitions = p.uniquePartitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
}
InternalRowComparableWrapper(partValue, p.expressions),
Seq.fill(numSplits)(Seq.empty))
}
} else {
val partitionMapping = finalPartitions.map { parts =>
val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey()
InternalRowComparableWrapper(row, p.expressions) -> parts
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
InternalRowComparableWrapper(partValue, p.expressions) -> splits
}.toMap
finalPartitions = p.partitionValues.map { partValue =>

// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
// could exist duplicated partition values, as partition grouping is not done
// at the beginning and postponed to this method. It is important to use unique
// partition values here so that grouped partitions won't get duplicated.
finalPartitions = p.uniquePartitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
redact(result)
}

def partitions: Seq[Seq[InputPartition]] =
groupedPartitions.map(_.map(_._2)).getOrElse(inputPartitions.map(Seq(_)))
def partitions: Seq[Seq[InputPartition]] = {
groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_)))
}

/**
* Shorthand for calling redact() without specifying redacting rules
Expand Down Expand Up @@ -94,16 +95,18 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
keyGroupedPartitioning match {
case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) =>
groupedPartitions
.map { partitionValues =>
KeyGroupedPartitioning(exprs, partitionValues.size, partitionValues.map(_._1))
.map { keyGroupedPartsInfo =>
val keyGroupedParts = keyGroupedPartsInfo.groupedParts
KeyGroupedPartitioning(exprs, keyGroupedParts.size, keyGroupedParts.map(_.value),
keyGroupedPartsInfo.originalParts.map(_.partitionKey()))
}
.getOrElse(super.outputPartitioning)
case _ =>
super.outputPartitioning
}
}

@transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = {
@transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = {
// Early check if we actually need to materialize the input partitions.
keyGroupedPartitioning match {
case Some(_) => groupPartitions(inputPartitions)
Expand All @@ -117,24 +120,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
* - all input partitions implement [[HasPartitionKey]]
* - `keyGroupedPartitioning` is set
*
* The result, if defined, is a list of tuples where the first element is a partition value,
* and the second element is a list of input partitions that share the same partition value.
* The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a list of
* [[KeyGroupedPartition]], as well as a list of partition values from the original input splits,
* sorted according to the partition keys in ascending order.
*
* A non-empty result means each partition is clustered on a single key and therefore eligible
* for further optimizations to eliminate shuffling in some operations such as join and aggregate.
*/
def groupPartitions(
inputPartitions: Seq[InputPartition],
groupSplits: Boolean = !conf.v2BucketingPushPartValuesEnabled ||
!conf.v2BucketingPartiallyClusteredDistributionEnabled):
Option[Seq[(InternalRow, Seq[InputPartition])]] = {

def groupPartitions(inputPartitions: Seq[InputPartition]): Option[KeyGroupedPartitionInfo] = {
if (!SQLConf.get.v2BucketingEnabled) return None

keyGroupedPartitioning.flatMap { expressions =>
val results = inputPartitions.takeWhile {
case _: HasPartitionKey => true
case _ => false
}.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p))
}.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p.asInstanceOf[HasPartitionKey]))

if (results.length != inputPartitions.length || inputPartitions.isEmpty) {
// Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here.
Expand All @@ -143,32 +143,25 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
// also sort the input partitions according to their partition key order. This ensures
// a canonical order from both sides of a bucketed join, for example.
val partitionDataTypes = expressions.map(_.dataType)
val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = {
val partitionOrdering: Ordering[(InternalRow, InputPartition)] = {
RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1)
}

val partitions = if (groupSplits) {
// Group the splits by their partition value
results
val sortedKeyToPartitions = results.sorted(partitionOrdering)
val groupedPartitions = sortedKeyToPartitions
.map(t => (InternalRowComparableWrapper(t._1, expressions), t._2))
.groupBy(_._1)
Copy link
Contributor

@LuciferYang LuciferYang Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem likely comes from this groupBy, as there are some differences between Scala 2.12 and Scala 2.13.

For example:

  • Scala 2.12.18
Welcome to Scala 2.12.18 (OpenJDK 64-Bit Server VM, Java 1.8.0_382).
Type in expressions for evaluation. Or try :help.

scala> val input = Seq((50,50),(51,51),(52,52))
input: Seq[(Int, Int)] = List((50,50), (51,51), (52,52))

scala> input.groupBy(_._1).toSeq
res0: Seq[(Int, Seq[(Int, Int)])] = Vector((50,List((50,50))), (51,List((51,51))), (52,List((52,52))))
  • Scala 2.13.8
Welcome to Scala 2.13.8 (OpenJDK 64-Bit Server VM, Java 1.8.0_382).
Type in expressions for evaluation. Or try :help.

scala> val input = Seq((50,50),(51,51),(52,52))
val input: Seq[(Int, Int)] = List((50,50), (51,51), (52,52))

scala> input.groupBy(_._1).toSeq
val res0: Seq[(Int, Seq[(Int, Int)])] = List((52,List((52,52))), (50,List((50,50))), (51,List((51,51))))

We can see that when using Scala 2.13.8, the order of the results has changed.

The possible fix maybe:

  1. Using another function to replace groupBy to maintain the output order, such as foldLeft with LinkedHashMap ?
  2. Re-sorting the groupedPartitions ?

Perhaps there are other better ways to fix it?

Copy link
Member Author

@sunchao sunchao Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @LuciferYang for the findings! Yes it's a bug as I was assuming the order will be preserved in the groupBy. Let me open a follow-up PR to fix this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened #42839 to fix this.

.toSeq
.map {
case (key, s) => (key.row, s.map(_._2))
}
} else {
// No splits grouping, each split will become a separate Spark partition
results.map(t => (t._1, Seq(t._2)))
}
.map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) }

Some(partitions.sorted(partitionOrdering))
Some(KeyGroupedPartitionInfo(groupedPartitions, sortedKeyToPartitions.map(_._2)))
}
}
}

override def outputOrdering: Seq[SortOrder] = {
// when multiple partitions are grouped together, ordering inside partitions is not preserved
val partitioningPreservesOrdering = groupedPartitions.forall(_.forall(_._2.length <= 1))
val partitioningPreservesOrdering = groupedPartitions
.forall(_.groupedParts.forall(_.parts.length <= 1))
ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering)
}

Expand Down Expand Up @@ -217,3 +210,19 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
}
}
}

/**
* A key-grouped Spark partition, which could consist of multiple input splits
*
* @param value the partition value shared by all the input splits
* @param parts the input splits that are grouped into a single Spark partition
*/
private[v2] case class KeyGroupedPartition(value: InternalRow, parts: Seq[InputPartition])

/**
* Information about key-grouped partitions, which contains a list of grouped partitions as well
* as the original input partitions before the grouping.
*/
private[v2] case class KeyGroupedPartitionInfo(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it would refer to info about one KeyGroupedPartition. How about KeyGroupedPartitionInfos ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it but is Infos a proper plural noun?

groupedParts: Seq[KeyGroupedPartition],
originalParts: Seq[HasPartitionKey])
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ case class EnsureRequirements(
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(KeyGroupedPartitioning(clustering, _, _)), _) =>
case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(KeyGroupedPartitioning(clustering, _, _))) =>
case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys)
.orElse(reorderJoinKeysRecursively(
Expand Down Expand Up @@ -483,7 +483,10 @@ case class EnsureRequirements(
s"'$joinType'. Skipping partially clustered distribution.")
replicateRightSide = false
} else {
val partValues = if (replicateLeftSide) rightPartValues else leftPartValues
// In partially clustered distribution, we should use un-grouped partition values
val spec = if (replicateLeftSide) rightSpec else leftSpec
val partValues = spec.partitioning.originalPartitionValues

val numExpectedPartitions = partValues
.map(InternalRowComparableWrapper(_, partitionExprs))
.groupBy(identity)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ object ShuffleExchangeExec {
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
case k @ KeyGroupedPartitioning(expressions, n, _) =>
case k @ KeyGroupedPartitioning(expressions, n, _, _) =>
val valueMap = k.uniquePartitionValues.zipWithIndex.map {
case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index)
}.toMap
Expand Down Expand Up @@ -332,7 +332,7 @@ object ShuffleExchangeExec {
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case KeyGroupedPartitioning(expressions, _, _) =>
case KeyGroupedPartitioning(expressions, _, _, _) =>
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ abstract class DistributionAndOrderingSuiteBase
plan: QueryPlan[T]): Partitioning = partitioning match {
case HashPartitioning(exprs, numPartitions) =>
HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions)
case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) =>
KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions,
partitionValues)
case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) =>
KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues,
originalPartValues)
case PartitioningCollection(partitionings) =>
PartitioningCollection(partitionings.map(resolvePartitioning(_, plan)))
case RangePartitioning(ordering, numPartitions) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
// Has exactly one partition.
val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v)))
checkQueryPlan(df, distribution,
physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues))
physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues))
}

test("non-clustered distribution: no V2 catalog") {
Expand Down
Loading