diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index ea6ac6ca92aa..1dd9f551ff8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ -import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} +import org.apache.spark.sql.{AnalysisException, Dataset, SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -56,9 +56,19 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat session.sharedState.cacheManager.recacheByPlan(session, r) } - private def invalidateCache(r: ResolvedTable)(): Unit = { + private def invalidateCache(r: ResolvedTable, recacheTable: Boolean = false)(): Unit = { val v2Relation = DataSourceV2Relation.create(r.table, Some(r.catalog), Some(r.identifier)) + val cache = session.sharedState.cacheManager.lookupCachedData(v2Relation) session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true) + if (recacheTable && cache.isDefined) { + // save the cache name and cache level for recreation + val cacheName = cache.get.cachedRepresentation.cacheBuilder.tableName + val cacheLevel = cache.get.cachedRepresentation.cacheBuilder.storageLevel + + // recache with the same name and cache level. + val ds = Dataset.ofRows(session, v2Relation) + session.sharedState.cacheManager.cacheQuery(ds, cacheName, cacheLevel) + } } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { @@ -137,7 +147,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } case RefreshTable(r: ResolvedTable) => - RefreshTableExec(r.catalog, r.identifier, invalidateCache(r)) :: Nil + RefreshTableExec(r.catalog, r.identifier, invalidateCache(r, recacheTable = true)) :: Nil case ReplaceTable(catalog, ident, schema, parts, props, orCreate) => val propsWithOwner = CatalogV2Util.withDefaultOwnership(props) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala index 994583c1e338..e66f0a18a132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RefreshTableExec.scala @@ -29,7 +29,6 @@ case class RefreshTableExec( catalog.invalidateTable(ident) // invalidate all caches referencing the given table - // TODO(SPARK-33437): re-cache the table itself once we support caching a DSv2 table invalidateCache() Seq.empty diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 0c65e530f67d..638f06d61883 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1749,6 +1749,25 @@ class DataSourceV2SQLSuite } } + test("SPARK-33653: REFRESH TABLE should recache the target table itself") { + val tblName = "testcat.ns.t" + withTable(tblName) { + sql(s"CREATE TABLE $tblName (id bigint) USING foo") + + // if the table is not cached, refreshing it should not recache it + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(tblName)).isEmpty) + sql(s"REFRESH TABLE $tblName") + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(tblName)).isEmpty) + + sql(s"CACHE TABLE $tblName") + + // after caching & refreshing the table should be recached + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(tblName)).isDefined) + sql(s"REFRESH TABLE $tblName") + assert(spark.sharedState.cacheManager.lookupCachedData(spark.table(tblName)).isDefined) + } + } + test("REPLACE TABLE: v1 table") { val e = intercept[AnalysisException] { sql(s"CREATE OR REPLACE TABLE tbl (a int) USING ${classOf[SimpleScanSource].getName}")