-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23803][SQL] Support bucket pruning #20915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
5be4dff
4ab1583
3bb7a2e
c45da4b
f0b84bd
cb36012
8f6bc28
0a4aeab
697784d
de9ecb6
f949128
72712c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources | |
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.catalyst.catalog.BucketSpec | ||
| import org.apache.spark.sql.catalyst.expressions | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.planning.PhysicalOperation | ||
| import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan | ||
| import org.apache.spark.sql.execution.FileSourceScanExec | ||
| import org.apache.spark.sql.execution.SparkPlan | ||
| import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} | ||
| import org.apache.spark.util.collection.BitSet | ||
|
|
||
| /** | ||
| * A strategy for planning scans over collections of files that might be partitioned or bucketed | ||
|
|
@@ -50,6 +51,85 @@ import org.apache.spark.sql.execution.SparkPlan | |
| * and add it. Proceed to the next file. | ||
| */ | ||
| object FileSourceStrategy extends Strategy with Logging { | ||
|
|
||
| // should prune buckets iff num buckets is greater than 1 and there is only one bucket column | ||
| private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = { | ||
| bucketSpec match { | ||
| case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1 | ||
| case None => false | ||
| } | ||
| } | ||
|
|
||
| private def getExpressionBuckets(expr: Expression, | ||
|
||
| bucketColumnName: String, | ||
| numBuckets: Int): BitSet = { | ||
|
|
||
| def getMatchedBucketBitSet(attr: Attribute, v: Any): BitSet = { | ||
| val matchedBuckets = new BitSet(numBuckets) | ||
| matchedBuckets.set(BucketingUtils.getBucketIdFromValue(attr, numBuckets, v)) | ||
| matchedBuckets | ||
| } | ||
|
|
||
| expr match { | ||
| case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => | ||
|
||
| getMatchedBucketBitSet(a, v) | ||
| case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName => | ||
| getMatchedBucketBitSet(a, v) | ||
| case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => | ||
| getMatchedBucketBitSet(a, v) | ||
| case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName => | ||
| getMatchedBucketBitSet(a, v) | ||
| case expressions.In(a: Attribute, list) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should catch |
||
| if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => | ||
| val valuesSet = list.map(e => e.eval(EmptyRow)) | ||
| valuesSet | ||
| .map(v => getMatchedBucketBitSet(a, v)) | ||
| .fold(new BitSet(numBuckets))(_ | _) | ||
|
||
| case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => | ||
| getMatchedBucketBitSet(a, null) | ||
| case expressions.And(left, right) => | ||
| getExpressionBuckets(left, bucketColumnName, numBuckets) | | ||
|
||
| getExpressionBuckets(right, bucketColumnName, numBuckets) | ||
| case expressions.Or(left, right) => | ||
| val leftBuckets = getExpressionBuckets(left, bucketColumnName, numBuckets) | ||
| val rightBuckets = getExpressionBuckets(right, bucketColumnName, numBuckets) | ||
|
|
||
| // if some expression in OR condition requires all buckets, return an empty BitSet | ||
| if (leftBuckets.cardinality() == 0 || rightBuckets.cardinality() == 0) { | ||
| new BitSet(numBuckets) | ||
| } else { | ||
| // return a BitSet that includes all required buckets | ||
| leftBuckets | rightBuckets | ||
| } | ||
| case _ => new BitSet(numBuckets) | ||
| } | ||
| } | ||
|
|
||
| private def getBuckets(normalizedFilters: Seq[Expression], | ||
|
||
| bucketSpec: BucketSpec): Option[BitSet] = { | ||
|
|
||
| val bucketColumnName = bucketSpec.bucketColumnNames.head | ||
| val numBuckets = bucketSpec.numBuckets | ||
|
|
||
| val matchedBuckets = normalizedFilters | ||
|
||
| .map(f => getExpressionBuckets(f, bucketColumnName, numBuckets)) | ||
| .fold(new BitSet(numBuckets))(_ | _) | ||
|
|
||
| val numBucketsSelected = if (matchedBuckets.cardinality() != 0) { | ||
| matchedBuckets.cardinality() | ||
| } | ||
| else { | ||
| numBuckets | ||
| } | ||
|
|
||
| logInfo { | ||
| s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets." | ||
| } | ||
|
|
||
| // None means all the buckets need to be scanned | ||
| if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) | ||
| } | ||
|
|
||
| def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { | ||
| case PhysicalOperation(projects, filters, | ||
| l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => | ||
|
|
@@ -79,6 +159,13 @@ object FileSourceStrategy extends Strategy with Logging { | |
| ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) | ||
| logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") | ||
|
|
||
| val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec | ||
| val bucketSet = if (shouldPruneBuckets(bucketSpec)) { | ||
| getBuckets(normalizedFilters, bucketSpec.get) | ||
|
||
| } else { | ||
| None | ||
| } | ||
|
|
||
| val dataColumns = | ||
| l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver) | ||
|
|
||
|
|
@@ -108,6 +195,7 @@ object FileSourceStrategy extends Strategy with Logging { | |
| outputAttributes, | ||
| outputSchema, | ||
| partitionKeyFilters.toSeq, | ||
| bucketSet, | ||
| dataFilters, | ||
| table.map(_.identifier)) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec | |
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning | ||
| import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} | ||
| import org.apache.spark.sql.execution.datasources.DataSourceStrategy | ||
| import org.apache.spark.sql.execution.datasources.BucketingUtils | ||
| import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec | ||
| import org.apache.spark.sql.execution.joins.SortMergeJoinExec | ||
| import org.apache.spark.sql.functions._ | ||
|
|
@@ -83,39 +83,43 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { | |
| // To verify if the bucket pruning works, this function checks two conditions: | ||
| // 1) Check if the pruned buckets (before filtering) are empty. | ||
| // 2) Verify the final result is the same as the expected one | ||
| private def checkPrunedAnswers( | ||
| bucketSpec: BucketSpec, | ||
| bucketValues: Seq[Integer], | ||
| filterCondition: Column, | ||
| originalDataFrame: DataFrame): Unit = { | ||
| private def checkPrunedAnswers(bucketSpec: BucketSpec, | ||
|
||
| bucketValues: Seq[Integer], | ||
| filterCondition: Column, | ||
| originalDataFrame: DataFrame): Unit = { | ||
| // This test verifies parts of the plan. Disable whole stage codegen. | ||
| withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { | ||
| val strategy = DataSourceStrategy(spark.sessionState.conf) | ||
| val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") | ||
| val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec | ||
| // Limit: bucket pruning only works when the bucket column has one and only one column | ||
| assert(bucketColumnNames.length == 1) | ||
| val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) | ||
| val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) | ||
| val matchedBuckets = new BitSet(numBuckets) | ||
| bucketValues.foreach { value => | ||
| matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) | ||
| } | ||
|
|
||
| // Filter could hide the bug in bucket pruning. Thus, skipping all the filters | ||
| val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan | ||
| val rdd = plan.find(_.isInstanceOf[DataSourceScanExec]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: not introduced by this PR but this name is wrong, we should probably call it |
||
| assert(rdd.isDefined, plan) | ||
|
|
||
| val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => | ||
| if (matchedBuckets.get(index % numBuckets) && iter.nonEmpty) Iterator(index) else Iterator() | ||
| // if nothing should be pruned, skip the pruning test | ||
| if (bucketValues.nonEmpty) { | ||
| val matchedBuckets = new BitSet(numBuckets) | ||
| bucketValues.foreach { value => | ||
| matchedBuckets.set(BucketingUtils.getBucketIdFromValue(bucketColumn, numBuckets, value)) | ||
| } | ||
| val invalidBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) => | ||
| // return indexes of partitions that should have been pruned and are not empty | ||
| if (!matchedBuckets.get(index % numBuckets) && iter.nonEmpty) { | ||
| Iterator(index) | ||
| } else { | ||
| Iterator() | ||
| } | ||
| }.collect() | ||
|
|
||
| if (invalidBuckets.nonEmpty) { | ||
| fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan") | ||
| } | ||
| } | ||
| // TODO: These tests are not testing the right columns. | ||
| // // checking if all the pruned buckets are empty | ||
| // val invalidBuckets = checkedResult.collect().toList | ||
| // if (invalidBuckets.nonEmpty) { | ||
| // fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan") | ||
| // } | ||
|
|
||
| checkAnswer( | ||
| bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), | ||
|
|
@@ -229,6 +233,27 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { | |
| bucketValues = j :: Nil, | ||
| filterCondition = $"j" === j && $"i" > j % 5, | ||
| df) | ||
|
|
||
| // check multiple bucket values OR condition | ||
| checkPrunedAnswers( | ||
| bucketSpec, | ||
| bucketValues = Seq(j, j + 1), | ||
| filterCondition = $"j" === j || $"j" === (j + 1), | ||
| df) | ||
|
|
||
| // check bucket value and none bucket value OR condition | ||
| checkPrunedAnswers( | ||
| bucketSpec, | ||
| bucketValues = Nil, | ||
| filterCondition = $"j" === j || $"i" === 0, | ||
| df) | ||
|
|
||
| // check AND condition in complex expression | ||
| checkPrunedAnswers( | ||
| bucketSpec, | ||
| bucketValues = Seq(j), | ||
| filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === j, | ||
| df) | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we avoid calculating bucket id from file name twice?