Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.util.Shell
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, BoundReference, Expression, InterpretedPredicate}

object ExternalCatalogUtils {
// This duplicates default value of Hive `ConfVars.DEFAULTPARTITIONNAME`, since catalyst doesn't
Expand Down Expand Up @@ -125,6 +126,38 @@ object ExternalCatalogUtils {
}
escapePathName(col) + "=" + partitionString
}

def prunePartitionsByFilter(
catalogTable: CatalogTable,
inputPartitions: Seq[CatalogTablePartition],
predicates: Seq[Expression],
defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
if (predicates.isEmpty) {
inputPartitions
} else {
val partitionSchema = catalogTable.partitionSchema
val partitionColumnNames = catalogTable.partitionColumnNames.toSet

val nonPartitionPruningPredicates = predicates.filterNot {
_.references.map(_.name).toSet.subsetOf(partitionColumnNames)
}
if (nonPartitionPruningPredicates.nonEmpty) {
throw new AnalysisException("Expected only partition pruning predicates: " +
nonPartitionPruningPredicates)
}

val boundPredicate =
InterpretedPredicate.create(predicates.reduce(And).transform {
case att: AttributeReference =>
val index = partitionSchema.indexWhere(_.name == att.name)
BoundReference(index, partitionSchema(index).dataType, nullable = true)
})

inputPartitions.filter { p =>
boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId))
}
}
}
}

object CatalogUtils {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -556,9 +556,9 @@ class InMemoryCatalog(
table: String,
predicates: Seq[Expression],
defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
// TODO: Provide an implementation
throw new UnsupportedOperationException(
"listPartitionsByFilter is not implemented for InMemoryCatalog")
val catalogTable = getTable(db, table)
val allPartitions = listPartitions(db, table)
prunePartitionsByFilter(catalogTable, allPartitions, predicates, defaultTimeZoneId)
}

// --------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.catalog

import java.net.URI
import java.util.TimeZone

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
Expand All @@ -28,6 +29,8 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -436,6 +439,44 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac
assert(catalog.listPartitions("db2", "tbl2", Some(Map("a" -> "unknown"))).isEmpty)
}

test("list partitions by filter") {
val tz = TimeZone.getDefault.getID
val catalog = newBasicCatalog()

def checkAnswer(
table: CatalogTable, filters: Seq[Expression], expected: Set[CatalogTablePartition])
: Unit = {

assertResult(expected.map(_.spec)) {
catalog.listPartitionsByFilter(table.database, table.identifier.identifier, filters, tz)
.map(_.spec).toSet
}
}

val tbl2 = catalog.getTable("db2", "tbl2")

checkAnswer(tbl2, Seq.empty, Set(part1, part2))
checkAnswer(tbl2, Seq('a.int <= 1), Set(part1))
checkAnswer(tbl2, Seq('a.int === 2), Set.empty)
checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2))
checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2))
checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1))
checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1))
checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty)
checkAnswer(tbl2, Seq('a.int === 1 || 'b.string === "x"), Set(part1))

intercept[AnalysisException] {
try {
checkAnswer(tbl2, Seq('a.int > 0 && 'col1.int > 0), Set.empty)
} catch {
// HiveExternalCatalog may be the first one to notice and throw an exception, which will
// then be caught and converted to a RuntimeException with a descriptive message.
case ex: RuntimeException if ex.getMessage.contains("MetaException") =>
throw new AnalysisException(ex.getMessage)
}
}
}

test("drop partitions") {
val catalog = newBasicCatalog()
assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,6 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
val table = r.tableMeta
val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table)
val cache = sparkSession.sessionState.catalog.tableRelationCache
val withHiveSupport =
sparkSession.sparkContext.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive"

val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() {
override def call(): LogicalPlan = {
Expand All @@ -233,8 +231,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
bucketSpec = table.bucketSpec,
className = table.provider.get,
options = table.storage.properties ++ pathOption,
// TODO: improve `InMemoryCatalog` and remove this limitation.
catalogTable = if (withHiveSupport) Some(table) else None)
catalogTable = Some(table))

LogicalRelation(
dataSource.resolveRelation(checkFilesExist = false),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
Expand Down Expand Up @@ -1039,37 +1039,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient {
val rawTable = getRawTable(db, table)
val catalogTable = restoreTableMetadata(rawTable)
val partitionColumnNames = catalogTable.partitionColumnNames.toSet
val nonPartitionPruningPredicates = predicates.filterNot {
_.references.map(_.name).toSet.subsetOf(partitionColumnNames)
}

if (nonPartitionPruningPredicates.nonEmpty) {
sys.error("Expected only partition pruning predicates: " +
predicates.reduceLeft(And))
}
val partColNameMap = buildLowerCasePartColNameMap(catalogTable)

val partitionSchema = catalogTable.partitionSchema
val partColNameMap = buildLowerCasePartColNameMap(getTable(db, table))

if (predicates.nonEmpty) {
val clientPrunedPartitions = client.getPartitionsByFilter(rawTable, predicates).map { part =>
val clientPrunedPartitions =
client.getPartitionsByFilter(rawTable, predicates).map { part =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if predicates.isEmpty, the previous code will run client.getPartitions. Can you double check there is no performance regression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar optimization is done in the function itself: client.getPartitionsByFilter(), while Hive Shim_v0_12 delegates to getAllPartitions anyway.

I've now made the only piece of non-trivial code along that path lazy, so I think we're good.

part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
val boundPredicate =
InterpretedPredicate.create(predicates.reduce(And).transform {
case att: AttributeReference =>
val index = partitionSchema.indexWhere(_.name == att.name)
BoundReference(index, partitionSchema(index).dataType, nullable = true)
})
clientPrunedPartitions.filter { p =>
boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId))
}
} else {
client.getPartitions(catalogTable).map { part =>
part.copy(spec = restorePartitionSpec(part.spec, partColNameMap))
}
}
prunePartitionsByFilter(catalogTable, clientPrunedPartitions, predicates, defaultTimeZoneId)
}

// --------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
*/
def convertFilters(table: Table, filters: Seq[Expression]): String = {
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
val varcharKeys = table.getPartitionKeys.asScala
lazy val varcharKeys = table.getPartitionKeys.asScala
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) ||
col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME))
.map(col => col.getName).toSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -50,13 +49,6 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite {

import utils._

test("list partitions by filter") {
val catalog = newBasicCatalog()
val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT")
assert(selectedPartitions.length == 1)
assert(selectedPartitions.head.spec == part1.spec)
}

test("SPARK-18647: do not put provider in table properties for Hive serde table") {
val catalog = newBasicCatalog()
val hiveTable = CatalogTable(
Expand Down