Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1564,7 +1564,8 @@ class Dataset[T] private[sql](
/** @inheritdoc */
def storageLevel: StorageLevel = {
sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData =>
cachedData.cachedRepresentation.cacheBuilder.storageLevel
cachedData.cachedRepresentation.fold(CacheManager.inMemoryRelationExtractor, identity).
cacheBuilder.storageLevel
}.getOrElse(StorageLevel.NONE)
}

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,16 @@ case class AnalyzeColumnCommand(
private def analyzeColumnInCachedData(plan: LogicalPlan, sparkSession: SparkSession): Boolean = {
val cacheManager = sparkSession.sharedState.cacheManager
val df = Dataset.ofRows(sparkSession, plan)
cacheManager.lookupCachedData(df).map { cachedData =>
val columnsToAnalyze = getColumnsToAnalyze(
tableIdent, cachedData.cachedRepresentation, columnNames, allColumns)
cacheManager.analyzeColumnCacheQuery(sparkSession, cachedData, columnsToAnalyze)
cachedData
}.isDefined
cacheManager.lookupCachedData(df).exists { cachedData =>
if (cachedData.cachedRepresentation.isRight) {
val columnsToAnalyze = getColumnsToAnalyze(
tableIdent, cachedData.cachedRepresentation.merge, columnNames, allColumns)
cacheManager.analyzeColumnCacheQuery(sparkSession, cachedData, columnsToAnalyze)
true
} else {
false
}
}
}

private def analyzeColumnInTempView(plan: LogicalPlan, sparkSession: SparkSession): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ object CommandUtils extends Logging {
// Analyzes a catalog view if the view is cached
val table = sparkSession.table(tableIdent.quotedString)
val cacheManager = sparkSession.sharedState.cacheManager
if (cacheManager.lookupCachedData(table).isDefined) {
if (cacheManager.lookupCachedData(table).exists(_.cachedRepresentation.isRight)) {
if (!noScan) {
// To collect table stats, materializes an underlying columnar RDD
table.count()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeed
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.CacheManager
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
Expand Down Expand Up @@ -201,7 +202,8 @@ case class AlterTableRenameCommand(
// If `optStorageLevel` is defined, the old table was cached.
val optCachedData = sparkSession.sharedState.cacheManager.lookupCachedData(
sparkSession.table(oldName.unquotedString))
val optStorageLevel = optCachedData.map(_.cachedRepresentation.cacheBuilder.storageLevel)
val optStorageLevel = optCachedData.map(_.cachedRepresentation.
fold(CacheManager.inMemoryRelationExtractor, identity).cacheBuilder.storageLevel)
if (optStorageLevel.isDefined) {
CommandUtils.uncacheTableOrView(sparkSession, oldName)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.read.LocalScan
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.connector.write.V1Write
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.{CacheManager, FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelationWithTable, PushableColumnAndNestedColumn}
import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
Expand Down Expand Up @@ -74,8 +74,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
val v2Relation = DataSourceV2Relation.create(r.table, Some(r.catalog), Some(r.identifier))
val cache = session.sharedState.cacheManager.lookupCachedData(session, v2Relation)
session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true)
if (cache.isDefined) {
val cacheLevel = cache.get.cachedRepresentation.cacheBuilder.storageLevel
if (cache.exists(_.cachedRepresentation.isRight)) {
val cacheLevel = cache.get.cachedRepresentation.
fold(CacheManager.inMemoryRelationExtractor, identity).cacheBuilder.storageLevel
Some(cacheLevel)
} else {
None
Expand Down
114 changes: 111 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ package org.apache.spark.sql
import org.scalatest.concurrent.TimeLimits
import org.scalatest.time.SpanSugar._

import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.storage.StorageLevel
import org.apache.spark.tags.SlowSQLTest

Expand Down Expand Up @@ -244,13 +246,12 @@ class DatasetCacheSuite extends QueryTest
case i: InMemoryRelation => i.cacheBuilder.cachedPlan
}
assert(df1LimitInnerPlan.isDefined && df1LimitInnerPlan.get == df1InnerPlan)

// Verify that df2's cache has been re-cached, with a new physical plan rid of dependency
// on df, since df2's cache had not been loaded before df.unpersist().
val df2Limit = df2.limit(2)
val df2LimitInnerPlan = df2Limit.queryExecution.withCachedData.collectFirst {
case i: InMemoryRelation => i.cacheBuilder.cachedPlan
}
// Verify that df2's cache has been re-cached, with a new physical plan rid of dependency
// on df, since df2's cache had not been loaded before df.unpersist().
assert(df2LimitInnerPlan.isDefined &&
!df2LimitInnerPlan.get.exists(_.isInstanceOf[InMemoryTableScanExec]))
}
Expand All @@ -275,6 +276,77 @@ class DatasetCacheSuite extends QueryTest
}
}

test("SPARK-47609. Partial match IMR with more columns than base") {
val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b",
"id % 3 AS c")
val testDfCreator = () => spark.range(0, 50).select($"id".as("a"), ($"id" % 2).as("b"),
($"id" + 1).as("d"), ($"id" % 3).as("c") )
checkIMRUseAndInvalidation(baseDfCreator, testDfCreator)
}

test("SPARK-47609. Partial match IMR with less columns than base") {
val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b",
"id % 3 AS c")
val testDfCreator = () => spark.range(0, 50).select($"id".as("a"), ($"id" % 2).as("b"))
checkIMRUseAndInvalidation(baseDfCreator, testDfCreator)
}

test("SPARK-47609. Partial match IMR with columns remapped to different value") {
val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b",
"id % 3 AS c")
val testDfCreator = () => spark.range(0, 50).select(($"id" + ($"id" % 3)).as("a"),
($"id" % 2).as("b"), ($"id" % 3).as("c"))
checkIMRUseAndInvalidation(baseDfCreator, testDfCreator)
}

test("SPARK-47609. Partial match IMR with extra columns as literal") {
val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b",
"id % 3 AS c")
val testDfCreator = () => spark.range(0, 50).select(($"id" + ($"id" % 3)).as("a"),
($"id" % 2).as("b"), expr("100").cast(IntegerType).as("t"))
checkIMRUseAndInvalidation(baseDfCreator, testDfCreator)
}

test("SPARK-47609. Partial match IMR with multiple columns remap") {
val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b",
"id % 3 AS c")
val testDfCreator = () => spark.range(0, 50).select((($"id" % 3) * ($"id" % 2)).as("c"),
($"id" - ($"id" % 3)).as("a"), $"id".as("b"))
checkIMRUseAndInvalidation(baseDfCreator, testDfCreator)
}

test("SPARK-47609. Partial match IMR with multiple columns remap and one remap as constant") {
val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b",
"id % 3 AS c")
val testDfCreator = () => spark.range(0, 50).select((($"id" % 3) * ($"id" % 2)).as("c"),
($"id" - ($"id" % 3)).as("a"), expr("100").cast(IntegerType).as("b"), $"b".as("d"))
checkIMRUseAndInvalidation(baseDfCreator, testDfCreator)
}

test("SPARK-47609. Partial match IMR with partial filters match") {
val baseDfCreator = () => spark.range(0, 100).
selectExpr("id as a ", "id % 2 AS b", "id % 3 AS c").filter($"a" > 7)
val testDfCreator = () => spark.range(0, 100).filter($"id"> 7).filter(($"id" % 3) > 8).select(
(($"id" % 3) * ($"id" % 2)).as("c"), ($"id" - ($"id" % 3)).as("a"),
expr("100").cast(IntegerType).as("b"), $"b".as("d"))
checkIMRUseAndInvalidation(baseDfCreator, testDfCreator)
}

test("SPARK-47609. Because of filter mismatch partial match should not happen") {
val baseDfCreator = () => spark.range(0, 100).
selectExpr("id as a ", "id % 2 AS b", "id % 3 AS c").filter($"a" > 7)
val testDfCreator = () => spark.range(0, 100).filter(($"id" % 3) > 8).select(
(($"id" % 3) * ($"id" % 2)).as("c"), ($"id" - ($"id" % 3)).as("a"),
expr("100").cast(IntegerType).as("b"), $"b".as("d"))
val baseDf = baseDfCreator()
baseDf.cache()
val testDf = testDfCreator()
verifyCacheDependency(baseDfCreator(), 1)
// cache should not be used
verifyCacheDependency(testDf, 0)
baseDf.unpersist(true)
}

test("SPARK-44653: non-trivial DataFrame unions should not break caching") {
val df1 = Seq(1 -> 1).toDF("i", "j")
val df2 = Seq(2 -> 2).toDF("i", "j")
Expand Down Expand Up @@ -312,4 +384,40 @@ class DatasetCacheSuite extends QueryTest
}
}
}

protected def checkIMRUseAndInvalidation(
baseDfCreator: () => DataFrame,
testExec: () => DataFrame): Unit = {
// now check if the results of optimized dataframe and completely unoptimized dataframe are
// same
val baseDf = baseDfCreator()
val testDf = testExec()
val testDfRows = testDf.collect()
baseDf.cache()
verifyCacheDependency(baseDfCreator(), 1)
verifyCacheDependency(testExec(), 1)
baseDfCreator().unpersist(true)
verifyCacheDependency(baseDfCreator(), 0)
verifyCacheDependency(testExec(), 0)
baseDfCreator().cache()
val newTestDf = testExec()
// re-verify cache dependency
verifyCacheDependency(newTestDf, 1)
checkAnswer(newTestDf, testDfRows)
baseDfCreator().unpersist(true)
}

def verifyCacheDependency(df: DataFrame, numOfCachesExpected: Int): Unit = {
def recurse(sparkPlan: SparkPlan): Int = {
val imrs = sparkPlan.collect {
case i: InMemoryTableScanExec => i
}
imrs.size + imrs.map(ime => recurse(ime.relation.cacheBuilder.cachedPlan)).sum
}
val cachedPlans = df.queryExecution.withCachedData.collect {
case i: InMemoryRelation => i.cacheBuilder.cachedPlan
}
val totalIMRs = cachedPlans.size + cachedPlans.map(ime => recurse(ime)).sum
assert(totalIMRs == numOfCachesExpected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.command

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.execution.CacheManager
import org.apache.spark.storage.StorageLevel

/**
Expand Down Expand Up @@ -73,7 +74,8 @@ trait AlterTableRenameSuiteBase extends QueryTest with DDLCommandTestUtils {
def getStorageLevel(tableName: String): StorageLevel = {
val table = spark.table(tableName)
val cachedData = spark.sharedState.cacheManager.lookupCachedData(table).get
cachedData.cachedRepresentation.cacheBuilder.storageLevel
cachedData.cachedRepresentation.fold(CacheManager.inMemoryRelationExtractor, identity).
cacheBuilder.storageLevel
}
sql(s"CREATE TABLE $src (c0 INT) $defaultUsing")
sql(s"INSERT INTO $src SELECT 0")
Expand Down
Loading