Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._

import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy}
import org.apache.spark.sql.{AnalysisException, Dataset, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable}
import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
Expand Down Expand Up @@ -56,9 +56,19 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
session.sharedState.cacheManager.recacheByPlan(session, r)
}

private def invalidateCache(r: ResolvedTable)(): Unit = {
private def invalidateCache(r: ResolvedTable, recacheTable: Boolean = false)(): Unit = {
val v2Relation = DataSourceV2Relation.create(r.table, Some(r.catalog), Some(r.identifier))
val cache = session.sharedState.cacheManager.lookupCachedData(v2Relation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cache is only necessary in if block?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad - it should be checked in the if condition

session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true)
if (recacheTable && cache.isDefined) {
// save the cache name and cache level for recreation
val cacheName = cache.get.cachedRepresentation.cacheBuilder.tableName
val cacheLevel = cache.get.cachedRepresentation.cacheBuilder.storageLevel

// recache with the same name and cache level.
val ds = Dataset.ofRows(session, v2Relation)
session.sharedState.cacheManager.cacheQuery(ds, cacheName, cacheLevel)
}
}

override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
Expand Down Expand Up @@ -137,7 +147,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
}

case RefreshTable(r: ResolvedTable) =>
RefreshTableExec(r.catalog, r.identifier, invalidateCache(r)) :: Nil
RefreshTableExec(r.catalog, r.identifier, invalidateCache(r, recacheTable = true)) :: Nil

case ReplaceTable(catalog, ident, schema, parts, props, orCreate) =>
val propsWithOwner = CatalogV2Util.withDefaultOwnership(props)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ case class RefreshTableExec(
catalog.invalidateTable(ident)

// invalidate all caches referencing the given table
// TODO(SPARK-33437): re-cache the table itself once we support caching a DSv2 table
invalidateCache()

Seq.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,25 @@ class DataSourceV2SQLSuite
}
}

test("SPARK-33653: REFRESH TABLE should recache the target table itself") {
val tblName = "testcat.ns.t"
withTable(tblName) {
sql(s"CREATE TABLE $tblName (id bigint) USING foo")

// if the table is not cached, refreshing it should not recache it
assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(tblName)).isEmpty)
sql(s"REFRESH TABLE $tblName")
assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(tblName)).isEmpty)

sql(s"CACHE TABLE $tblName")

// after caching & refreshing the table should be recached
assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(tblName)).isDefined)
sql(s"REFRESH TABLE $tblName")
assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(tblName)).isDefined)
}
}

test("REPLACE TABLE: v1 table") {
val e = intercept[AnalysisException] {
sql(s"CREATE OR REPLACE TABLE tbl (a int) USING ${classOf[SimpleScanSource].getName}")
Expand Down