Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ package org.apache.spark.sql.hive
import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.Analyzer
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlanner
import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState}
Expand Down Expand Up @@ -88,6 +89,20 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
customCheckRules
}

/**
* Logical query plan optimizer that takes into account Hive.
*/
override lazy val optimizer: Optimizer =
new SparkOptimizer(catalog, experimentalMethods) {
override def postHocOptimizationBatches: Seq[Batch] = Seq(
Batch("Prune Hive Table Partitions", Once,
new PruneHiveTablePartitions(session))
)

override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] =
super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules
}

/**
* Planner that takes into account Hive-specific strategies.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, ScriptTransformation}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils}
import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
import org.apache.spark.sql.hive.client.HiveClientImpl
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.hive.orc.OrcFileFormat
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}


/**
* Determine the database, serde/format and schema of the Hive serde table, according to the storage
* properties.
Expand Down Expand Up @@ -138,6 +138,54 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] {
}
}

case class PruneHiveTablePartitions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a todo that we should merge this rule with PruneFileSourcePartitions, after we completely make hive a data source.

session: SparkSession) extends Rule[LogicalPlan] with PredicateHelper {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case filter @ Filter(condition, relation: HiveTableRelation)
if DDLUtils.isHiveTable(relation.tableMeta) && relation.isPartitioned =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's only for hive table? what about data source table?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that PruneFileSourcePartitions can be for data source table now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DDLUtils.isHiveTable(relation.tableMeta) is no longer needed

val predicates = splitConjunctivePredicates(condition)
val normalizedFilters = predicates.map { e =>
e transform {
case a: AttributeReference =>
a.withName(relation.output.find(_.semanticEquals(a)).get.name)
}
}
val partitionSet = AttributeSet(relation.partitionCols)
val pruningPredicates = normalizedFilters.filter { predicate =>
!predicate.references.isEmpty &&
predicate.references.subsetOf(partitionSet)
}
if (pruningPredicates.nonEmpty && session.sessionState.conf.fallBackToHdfsForStatsEnabled &&
session.sessionState.conf.metastorePartitionPruning) {
val prunedPartitions = session.sharedState.externalCatalog.listPartitionsByFilter(
relation.tableMeta.database,
relation.tableMeta.identifier.table,
pruningPredicates,
session.sessionState.conf.sessionLocalTimeZone)
val hiveTable = HiveClientImpl.toHiveTable(relation.tableMeta)
val partitions = prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveTable))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to do this? All we need is partition data location, and we can get it by CatalogTablePartition.storage.locationUri

val sizeInBytes = try {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we already have partition level statistics at hive metastore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we already have partition level statistics, But we cannot know total number of partition, so it cannot compute the statistics for pruned partitions.

val hadoopConf = session.sessionState.newHadoopConf()
partitions.map { partition =>
val fs: FileSystem = partition.getDataLocation.getFileSystem(hadoopConf)
fs.getContentSummary(partition.getDataLocation).getLength
}.sum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there are too many partitions, it will be very slow.
can you add a check that whether the sum is larger than threshold, if true then break.

} catch {
case e: IOException =>
logWarning("Failed to get table size from hdfs.", e)
session.sessionState.conf.defaultSizeInBytes
}
val withStats = relation.tableMeta.copy(
stats = Some(CatalogStatistics(sizeInBytes = BigInt(sizeInBytes))))
val prunedCatalogRelation = relation.copy(tableMeta = withStats)
val filterExpression = predicates.reduceLeft(And)
Filter(filterExpression, prunedCatalogRelation)
} else {
filter
}
}
}

/**
* Replaces generic operations with specific variants that are designed to work with Hive.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1269,4 +1269,42 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
}

}

test("auto converts to broadcast join by size estimate of scanned partitions " +
"for partitioned table") {
withTempView("tempTbl", "largeTbl") {
withTable("partTbl") {
spark.range(0, 1000, 1, 2).selectExpr("id as col1", "id as col2")
.createOrReplaceTempView("tempTbl")
spark.range(0, 100000, 1, 2).selectExpr("id as col1", "id as col2")
.createOrReplaceTempView("largeTbl")
sql("CREATE TABLE partTbl (col1 INT, col2 STRING) " +
"PARTITIONED BY (part1 STRING, part2 INT) STORED AS textfile")
for (part1 <- Seq("a", "b", "c", "d"); part2 <- Seq(1, 2)) {
sql(
s"""
|INSERT OVERWRITE TABLE partTbl PARTITION (part1='$part1',part2='$part2')
|select col1, col2 from tempTbl
""".stripMargin)
}
val query = "select * from largeTbl join partTbl on (largeTbl.col1 = partTbl.col1 " +
"and partTbl.part1 = 'a' and partTbl.part2 = 1)"
withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "8001") {

withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> "true") {
val broadcastJoins =
sql(query).queryExecution.sparkPlan.collect { case j: BroadcastHashJoinExec => j }
assert(broadcastJoins.nonEmpty)
}

withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> "false") {
val broadcastJoins =
sql(query).queryExecution.sparkPlan.collect { case j: BroadcastHashJoinExec => j }
assert(broadcastJoins.isEmpty)
}
}
}
}
}
}