Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.sources.{BaseRelation, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.BitSet

trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
val relation: BaseRelation
Expand Down Expand Up @@ -151,6 +152,7 @@ case class RowDataSourceScanExec(
* @param output Output attributes of the scan, including data attributes and partition attributes.
* @param requiredSchema Required schema of the underlying relation, excluding partition columns.
* @param partitionFilters Predicates to use for partition pruning.
* @param optionalBucketSet Bucket ids for bucket pruning
* @param dataFilters Filters on non-partition columns.
* @param tableIdentifier identifier for the table in the metastore.
*/
Expand All @@ -159,6 +161,7 @@ case class FileSourceScanExec(
output: Seq[Attribute],
requiredSchema: StructType,
partitionFilters: Seq[Expression],
optionalBucketSet: Option[BitSet],
dataFilters: Seq[Expression],
override val tableIdentifier: Option[TableIdentifier])
extends DataSourceScanExec with ColumnarBatchScan {
Expand Down Expand Up @@ -286,7 +289,20 @@ case class FileSourceScanExec(
} getOrElse {
metadata
}
withOptPartitionCount

val withSelectedBucketsCount = relation.bucketSpec.map { spec =>
val numSelectedBuckets = optionalBucketSet.map { b =>
b.cardinality()
} getOrElse {
spec.numBuckets
}
withOptPartitionCount + ("SelectedBucketsCount" ->
s"$numSelectedBuckets out of ${spec.numBuckets}")
} getOrElse {
withOptPartitionCount
}

withSelectedBucketsCount
}

private lazy val inputRDD: RDD[InternalRow] = {
Expand Down Expand Up @@ -371,14 +387,27 @@ case class FileSourceScanExec(
val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen)
PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen, hosts)
}
}.groupBy { f =>
BucketingUtils
.getBucketId(new Path(f.filePath).getName)
.getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
}

val prunedBucketed = if (optionalBucketSet.isDefined) {
val bucketSet = optionalBucketSet.get
bucketed.filter {
f => bucketSet.get(
BucketingUtils.getBucketId(new Path(f.filePath).getName)
.getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")))
}
} else {
bucketed
}

val filesGroupedToBuckets = prunedBucketed.groupBy { f =>
BucketingUtils
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we avoid calculating bucket id from file name twice?

.getBucketId(new Path(f.filePath).getName)
.getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
}

val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil))
FilePartition(bucketId, filesGroupedToBuckets.getOrElse(bucketId, Nil))
}

new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions)
Expand Down Expand Up @@ -503,6 +532,7 @@ case class FileSourceScanExec(
output.map(QueryPlan.normalizeExprId(_, output)),
requiredSchema,
QueryPlan.normalizePredicates(partitionFilters, output),
optionalBucketSet,
QueryPlan.normalizePredicates(dataFilters, output),
None)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning

object BucketingUtils {
// The file name of bucketed data should have 3 parts:
// 1. some other information in the head of file name
Expand All @@ -35,5 +38,16 @@ object BucketingUtils {
case other => None
}

// Given bucketColumn, numBuckets and value, returns the corresponding bucketId
def getBucketIdFromValue(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
val mutableInternalRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
mutableInternalRow.update(0, value)

val bucketIdGenerator = UnsafeProjection.create(
HashPartitioning(Seq(bucketColumn), numBuckets).partitionIdExpression :: Nil,
bucketColumn :: Nil)
bucketIdGenerator(mutableInternalRow).getInt(0)
}

def bucketIdToString(id: Int): String = f"_$id%05d"
}
Original file line number Diff line number Diff line change
Expand Up @@ -312,18 +312,6 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
case _ => Nil
}

// Get the bucket ID based on the bucketing values.
// Restriction: Bucket pruning works iff the bucketing column has one and only one column.
def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = {
val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType))
mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null)
val bucketIdGeneration = UnsafeProjection.create(
HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil,
bucketColumn :: Nil)

bucketIdGeneration(mutableRow).getInt(0)
}

// Based on Public API.
private def pruneFilterProject(
relation: LogicalRelation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.datasources

import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.util.collection.BitSet

/**
* A strategy for planning scans over collections of files that might be partitioned or bucketed
Expand All @@ -50,6 +51,85 @@ import org.apache.spark.sql.execution.SparkPlan
* and add it. Proceed to the next file.
*/
object FileSourceStrategy extends Strategy with Logging {

// should prune buckets iff num buckets is greater than 1 and there is only one bucket column
private def shouldPruneBuckets(bucketSpec: Option[BucketSpec]): Boolean = {
bucketSpec match {
case Some(spec) => spec.bucketColumnNames.length == 1 && spec.numBuckets > 1
case None => false
}
}

private def getExpressionBuckets(expr: Expression,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: code style

def xxx(
    para1: A,
    para2: B): XXX

bucketColumnName: String,
numBuckets: Int): BitSet = {

def getMatchedBucketBitSet(attr: Attribute, v: Any): BitSet = {
val matchedBuckets = new BitSet(numBuckets)
matchedBuckets.set(BucketingUtils.getBucketIdFromValue(attr, numBuckets, v))
matchedBuckets
}

expr match {
case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use Equality to match both EqualTo and EqualNullSafe

getMatchedBucketBitSet(a, v)
case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
getMatchedBucketBitSet(a, v)
case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName =>
getMatchedBucketBitSet(a, v)
case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName =>
getMatchedBucketBitSet(a, v)
case expressions.In(a: Attribute, list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should catch InSet as well

if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName =>
val valuesSet = list.map(e => e.eval(EmptyRow))
valuesSet
.map(v => getMatchedBucketBitSet(a, v))
.fold(new BitSet(numBuckets))(_ | _)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we create one bit set for all the matched buckets, instead of creating many bit sets and merge them?

case expressions.IsNull(a: Attribute) if a.name == bucketColumnName =>
getMatchedBucketBitSet(a, null)
case expressions.And(left, right) =>
getExpressionBuckets(left, bucketColumnName, numBuckets) |
Copy link

@pwoody pwoody Mar 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this logic work given that unsupported filters are an empty bitset? Perhaps we can have unmatched filters go to a filled bitset (all buckets might match), remove cardinality checks for the ALL match, and swap this to &.

Copy link
Author

@sabanas sabanas Mar 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in an AND condition i'd like to get the union of both RHS and LHS BitSets. for example, if table is bucketed on column j:
j == 0 && i == 0 -> will return a bit set matching j's bucket.
i == 0 && k == 0 -> will return union of 2 empty bit sets, meaning an empty bit set (which is later on an indicator for reading all buckets)
j == 0 && j == 1 (which is effectively a false condition)-> will return bit set matching both j values. i haven't implemented an optimization for this case (can be done with & condition in case both RHS & LHS are not empty).
can you think of an input that will not work here?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, yeah was mostly curious of the why not have the BitSet modeled as "this bucket might have data" and then doing the natural & and | operations?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, makes sense. i changed it accordingly.

getExpressionBuckets(right, bucketColumnName, numBuckets)
case expressions.Or(left, right) =>
val leftBuckets = getExpressionBuckets(left, bucketColumnName, numBuckets)
val rightBuckets = getExpressionBuckets(right, bucketColumnName, numBuckets)

// if some expression in OR condition requires all buckets, return an empty BitSet
if (leftBuckets.cardinality() == 0 || rightBuckets.cardinality() == 0) {
new BitSet(numBuckets)
} else {
// return a BitSet that includes all required buckets
leftBuckets | rightBuckets
}
case _ => new BitSet(numBuckets)
}
}

private def getBuckets(normalizedFilters: Seq[Expression],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: code style

bucketSpec: BucketSpec): Option[BitSet] = {

val bucketColumnName = bucketSpec.bucketColumnNames.head
val numBuckets = bucketSpec.numBuckets

val matchedBuckets = normalizedFilters
Copy link

@pwoody pwoody Mar 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we swap this to something like .reduce(And).map(getExpressionBuckets)? To consolidate the And logic in the other method.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

.map(f => getExpressionBuckets(f, bucketColumnName, numBuckets))
.fold(new BitSet(numBuckets))(_ | _)

val numBucketsSelected = if (matchedBuckets.cardinality() != 0) {
matchedBuckets.cardinality()
}
else {
numBuckets
}

logInfo {
s"Pruned ${numBuckets - numBucketsSelected} out of $numBuckets buckets."
}

// None means all the buckets need to be scanned
if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets)
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projects, filters,
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
Expand Down Expand Up @@ -79,6 +159,13 @@ object FileSourceStrategy extends Strategy with Logging {
ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet)))
logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}")

val bucketSpec: Option[BucketSpec] = fsRelation.bucketSpec
val bucketSet = if (shouldPruneBuckets(bucketSpec)) {
getBuckets(normalizedFilters, bucketSpec.get)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe better to call it genBucketSet instead of getBuckets

} else {
None
}

val dataColumns =
l.resolve(fsRelation.dataSchema, fsRelation.sparkSession.sessionState.analyzer.resolver)

Expand Down Expand Up @@ -108,6 +195,7 @@ object FileSourceStrategy extends Strategy with Logging {
outputAttributes,
outputSchema,
partitionKeyFilters.toSeq,
bucketSet,
dataFilters,
table.map(_.identifier))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.datasources.BucketingUtils
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -83,39 +83,43 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
// To verify if the bucket pruning works, this function checks two conditions:
// 1) Check if the pruned buckets (before filtering) are empty.
// 2) Verify the final result is the same as the expected one
private def checkPrunedAnswers(
bucketSpec: BucketSpec,
bucketValues: Seq[Integer],
filterCondition: Column,
originalDataFrame: DataFrame): Unit = {
private def checkPrunedAnswers(bucketSpec: BucketSpec,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: code style

bucketValues: Seq[Integer],
filterCondition: Column,
originalDataFrame: DataFrame): Unit = {
// This test verifies parts of the plan. Disable whole stage codegen.
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val strategy = DataSourceStrategy(spark.sessionState.conf)
val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k")
val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec
// Limit: bucket pruning only works when the bucket column has one and only one column
assert(bucketColumnNames.length == 1)
val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head)
val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex)
val matchedBuckets = new BitSet(numBuckets)
bucketValues.foreach { value =>
matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value))
}

// Filter could hide the bug in bucket pruning. Thus, skipping all the filters
val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan
val rdd = plan.find(_.isInstanceOf[DataSourceScanExec])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not introduced by this PR but this name is wrong, we should probably call it scanPlan

assert(rdd.isDefined, plan)

val checkedResult = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
if (matchedBuckets.get(index % numBuckets) && iter.nonEmpty) Iterator(index) else Iterator()
// if nothing should be pruned, skip the pruning test
if (bucketValues.nonEmpty) {
val matchedBuckets = new BitSet(numBuckets)
bucketValues.foreach { value =>
matchedBuckets.set(BucketingUtils.getBucketIdFromValue(bucketColumn, numBuckets, value))
}
val invalidBuckets = rdd.get.execute().mapPartitionsWithIndex { case (index, iter) =>
// return indexes of partitions that should have been pruned and are not empty
if (!matchedBuckets.get(index % numBuckets) && iter.nonEmpty) {
Iterator(index)
} else {
Iterator()
}
}.collect()

if (invalidBuckets.nonEmpty) {
fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan")
}
}
// TODO: These tests are not testing the right columns.
// // checking if all the pruned buckets are empty
// val invalidBuckets = checkedResult.collect().toList
// if (invalidBuckets.nonEmpty) {
// fail(s"Buckets $invalidBuckets should have been pruned from:\n$plan")
// }

checkAnswer(
bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"),
Expand Down Expand Up @@ -229,6 +233,27 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
bucketValues = j :: Nil,
filterCondition = $"j" === j && $"i" > j % 5,
df)

// check multiple bucket values OR condition
checkPrunedAnswers(
bucketSpec,
bucketValues = Seq(j, j + 1),
filterCondition = $"j" === j || $"j" === (j + 1),
df)

// check bucket value and none bucket value OR condition
checkPrunedAnswers(
bucketSpec,
bucketValues = Nil,
filterCondition = $"j" === j || $"i" === 0,
df)

// check AND condition in complex expression
checkPrunedAnswers(
bucketSpec,
bucketValues = Seq(j),
filterCondition = ($"i" === 0 || $"k" > $"j") && $"j" === j,
df)
}
}
}
Expand Down