Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ package org.apache.spark.sql.hive
import org.apache.spark.annotation.{Experimental, InterfaceStability}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.Analyzer
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlanner
import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState}
Expand Down Expand Up @@ -88,6 +89,20 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
customCheckRules
}

/**
* Logical query plan optimizer that takes into account Hive.
*/
override lazy val optimizer: Optimizer =
new SparkOptimizer(catalog, experimentalMethods) {
override def postHocOptimizationBatches: Seq[Batch] = Seq(
Batch("Prune Hive Table Partitions", Once,
new PruneHiveTablePartitions(session))
)

override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] =
super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules
}

/**
* Planner that takes into account Hive-specific strategies.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@ import java.io.IOException
import java.util.Locale

import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.common.StatsSetupConst

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan,
ScriptTransformation}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, InsertIntoDir, InsertIntoTable,
LogicalPlan, ScriptTransformation}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils}
import org.apache.spark.sql.execution.command.{CommandUtils, CreateTableCommand, DDLUtils}
import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
import org.apache.spark.sql.hive.execution._
Expand Down Expand Up @@ -139,6 +140,62 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] {
}
}

/**
*
* TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source.
*/
case class PruneHiveTablePartitions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a todo that we should merge this rule with PruneFileSourcePartitions, after we completely make hive a data source.

session: SparkSession) extends Rule[LogicalPlan] with PredicateHelper {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
case filter @ Filter(condition, relation: HiveTableRelation) if relation.isPartitioned =>
val predicates = splitConjunctivePredicates(condition)
val normalizedFilters = predicates.map { e =>
e transform {
case a: AttributeReference =>
a.withName(relation.output.find(_.semanticEquals(a)).get.name)
}
}
val partitionSet = AttributeSet(relation.partitionCols)
val pruningPredicates = normalizedFilters.filter { predicate =>
!predicate.references.isEmpty &&
predicate.references.subsetOf(partitionSet)
}
if (pruningPredicates.nonEmpty && session.sessionState.conf.fallBackToHdfsForStatsEnabled &&
session.sessionState.conf.metastorePartitionPruning) {
val prunedPartitions = session.sharedState.externalCatalog.listPartitionsByFilter(
relation.tableMeta.database,
relation.tableMeta.identifier.table,
pruningPredicates,
session.sessionState.conf.sessionLocalTimeZone)
val sizeInBytes = try {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we already have partition level statistics at hive metastore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we already have partition level statistics, But we cannot know total number of partition, so it cannot compute the statistics for pruned partitions.

prunedPartitions.map { part =>
val totalSize = part.parameters.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong)
val rawDataSize = part.parameters.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong)
if (totalSize.isDefined && totalSize.get > 0L) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should first use rawDataSize, because 1MB orc file is equal to 5MB textfile...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cenyuhai Yes,I think what you said is right.Thanks.

totalSize.get
} else if (rawDataSize.isDefined && rawDataSize.get > 0) {
rawDataSize.get
} else {
CommandUtils.calculateLocationSize(
session.sessionState, relation.tableMeta.identifier, part.storage.locationUri)
}
}.sum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there are too many partitions, it will be very slow.
can you add a check that whether the sum is larger than threshold, if true then break.

} catch {
case e: IOException =>
logWarning("Failed to get table size from hdfs.", e)
session.sessionState.conf.defaultSizeInBytes
}
val withStats = relation.tableMeta.copy(
stats = Some(CatalogStatistics(sizeInBytes = BigInt(sizeInBytes))))
val prunedCatalogRelation = relation.copy(tableMeta = withStats)
val filterExpression = predicates.reduceLeft(And)
Filter(filterExpression, prunedCatalogRelation)
} else {
filter
}
}
}

/**
* Replaces generic operations with specific variants that are designed to work with Hive.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1261,4 +1261,42 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
}

}

test("auto converts to broadcast join by size estimate of scanned partitions " +
"for partitioned table") {
withTempView("tempTbl", "largeTbl") {
withTable("partTbl") {
spark.range(0, 1000, 1, 2).selectExpr("id as col1", "id as col2")
.createOrReplaceTempView("tempTbl")
spark.range(0, 100000, 1, 2).selectExpr("id as col1", "id as col2")
.createOrReplaceTempView("largeTbl")
sql("CREATE TABLE partTbl (col1 INT, col2 STRING) " +
"PARTITIONED BY (part1 STRING, part2 INT) STORED AS textfile")
for (part1 <- Seq("a", "b", "c", "d"); part2 <- Seq(1, 2)) {
sql(
s"""
|INSERT OVERWRITE TABLE partTbl PARTITION (part1='$part1',part2='$part2')
|select col1, col2 from tempTbl
""".stripMargin)
}
val query = "select * from largeTbl join partTbl on (largeTbl.col1 = partTbl.col1 " +
"and partTbl.part1 = 'a' and partTbl.part2 = 1)"
withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "8001") {

withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> "true") {
val broadcastJoins =
sql(query).queryExecution.sparkPlan.collect { case j: BroadcastHashJoinExec => j }
assert(broadcastJoins.nonEmpty)
}

withSQLConf(SQLConf.HIVE_METASTORE_PARTITION_PRUNING.key -> "false") {
val broadcastJoins =
sql(query).queryExecution.sparkPlan.collect { case j: BroadcastHashJoinExec => j }
assert(broadcastJoins.isEmpty)
}
}
}
}
}
}