diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8726ee268a477..feb5d77210741 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1564,7 +1564,8 @@ class Dataset[T] private[sql]( /** @inheritdoc */ def storageLevel: StorageLevel = { sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData => - cachedData.cachedRepresentation.cacheBuilder.storageLevel + cachedData.cachedRepresentation.fold(CacheManager.inMemoryRelationExtractor, identity). + cacheBuilder.storageLevel }.getOrElse(StorageLevel.NONE) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a3382c83e1f20..03585241d1d23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -17,15 +17,17 @@ package org.apache.spark.sql.execution +import scala.collection.mutable + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.{LogEntry, Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.HiveTableRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, AttributeSet, Expression, NamedExpression, SubqueryExpression} import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint -import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPlan, ResolvedHint, View} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, IgnoreCachedData, LeafNode, LogicalPlan, Project, ResolvedHint, View} import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -37,19 +39,22 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK + /** Holds a cached logical plan and its data */ + case class CachedData( - // A normalized resolved plan (See QueryExecution#normalized). plan: LogicalPlan, - cachedRepresentation: InMemoryRelation) { + cachedRepresentation: Either[LogicalPlan, InMemoryRelation]) { + override def toString: String = s""" |CachedData( |logicalPlan=$plan - |InMemoryRelation=$cachedRepresentation) + |InMemoryRelation=${cachedRepresentation.merge}) |""".stripMargin } + /** * Provides support in a SQLContext for caching query results and automatically using these cached * results when subsequent queries are executed. Data is cached using byte buffers stored in an @@ -72,7 +77,8 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { /** Clears all cached tables. */ def clearCache(): Unit = this.synchronized { - cachedData.foreach(_.cachedRepresentation.cacheBuilder.clearCache()) + cachedData.foreach(_.cachedRepresentation.fold(CacheManager.inMemoryRelationExtractor, identity) + .cacheBuilder.clearCache()) cachedData = IndexedSeq[CachedData]() CacheManager.logCacheOperation(log"Cleared all Dataframe cache entries") } @@ -125,7 +131,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { storageLevel: StorageLevel): Unit = { if (storageLevel == StorageLevel.NONE) { // Do nothing for StorageLevel.NONE since it will not actually cache any data. - } else if (lookupCachedDataInternal(normalizedPlan).nonEmpty) { + } else if (lookupCachedDataInternal(normalizedPlan).exists(_.cachedRepresentation.isRight)) { logWarning("Asked to cache already cached data.") } else { val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark) @@ -139,11 +145,11 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } this.synchronized { - if (lookupCachedDataInternal(normalizedPlan).nonEmpty) { + if (lookupCachedDataInternal(normalizedPlan).exists(_.cachedRepresentation.isRight)) { logWarning("Data has already been cached.") } else { // the cache key is the normalized plan - val cd = CachedData(normalizedPlan, inMemoryRelation) + val cd = CachedData(normalizedPlan, Right(inMemoryRelation)) cachedData = cd +: cachedData CacheManager.logCacheOperation(log"Added Dataframe cache entry:" + log"${MDC(DATAFRAME_CACHE_ENTRY, cd)}") @@ -200,21 +206,46 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { uncacheQuery(spark, plan, cascade, blocking = false) } - // The `plan` should have been normalized. - private def uncacheQueryInternal( + /** + * Un-cache the given plan or all the cache entries that refer to the given plan. + * + * @param spark The Spark session. + * @param plan The plan to be un-cached. + * @param cascade If true, un-cache all the cache entries that refer to the given + * plan; otherwise un-cache the given plan only. + * @param blocking Whether to block until all blocks are deleted. + */ + def uncacheQueryInternal( spark: SparkSession, plan: LogicalPlan, cascade: Boolean, - blocking: Boolean): Unit = { - uncacheByCondition(spark, _.sameResult(plan), cascade, blocking) + blocking: Boolean = false): Unit = { + val dummyCd = CachedData(plan, Left(plan)) + uncacheByCondition(spark, + (planToCheck: LogicalPlan, partialMatchOk: Boolean) => { + dummyCd.plan.sameResult(planToCheck) || (partialMatchOk && + (planToCheck match { + case p: Project => lookUpPartiallyMatchedCachedPlan(p, IndexedSeq(dummyCd)).isDefined + case _ => false + })) + }, cascade, blocking) } + def uncacheTableOrView(spark: SparkSession, name: Seq[String], cascade: Boolean): Unit = { uncacheByCondition( - spark, isMatchedTableOrView(_, name, spark.sessionState.conf), cascade, blocking = false) + spark, + isMatchedTableOrView(_, _, name, spark.sessionState.conf), + cascade, + blocking = false) } - private def isMatchedTableOrView(plan: LogicalPlan, name: Seq[String], conf: SQLConf): Boolean = { + + private def isMatchedTableOrView( + plan: LogicalPlan, + partialMatch: Boolean, + name: Seq[String], + conf: SQLConf): Boolean = { def isSameName(nameInCache: Seq[String]): Boolean = { nameInCache.length == name.length && nameInCache.zip(name).forall(conf.resolver.tupled) } @@ -239,20 +270,22 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { private def uncacheByCondition( spark: SparkSession, - isMatchedPlan: LogicalPlan => Boolean, + isMatchedPlan: (LogicalPlan, Boolean) => Boolean, cascade: Boolean, blocking: Boolean): Unit = { - val shouldRemove: LogicalPlan => Boolean = - if (cascade) { - _.exists(isMatchedPlan) + + val shouldRemove: LogicalPlan => Boolean = if (cascade) { + _.exists(isMatchedPlan(_, false)) } else { - isMatchedPlan + isMatchedPlan(_, false) } val plansToUncache = cachedData.filter(cd => shouldRemove(cd.plan)) this.synchronized { cachedData = cachedData.filterNot(cd => plansToUncache.exists(_ eq cd)) } - plansToUncache.foreach { _.cachedRepresentation.cacheBuilder.clearCache(blocking) } + plansToUncache.foreach { _.cachedRepresentation. + fold(CacheManager.inMemoryRelationExtractor, identity).cacheBuilder.clearCache(blocking) } + CacheManager.logCacheOperation(log"Removed ${MDC(SIZE, plansToUncache.size)} Dataframe " + log"cache entries, with logical plans being " + log"\n[${MDC(QUERY_PLAN, plansToUncache.map(_.plan).mkString(",\n"))}]") @@ -272,8 +305,10 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { // 2) The buffer has been cleared, but `isCachedColumnBuffersLoaded` returns true, then we // will keep it as it is. It means the physical plan has been re-compiled already in the // other thread. - val cacheAlreadyLoaded = cd.cachedRepresentation.cacheBuilder.isCachedColumnBuffersLoaded - cd.plan.exists(isMatchedPlan) && !cacheAlreadyLoaded + val cacheAlreadyLoaded = cd.cachedRepresentation. + fold(CacheManager.inMemoryRelationExtractor, identity).cacheBuilder. + isCachedColumnBuffersLoaded + !cacheAlreadyLoaded && cd.plan.exists(isMatchedPlan(_, true)) }) } } @@ -285,8 +320,9 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { column: Seq[Attribute]): Unit = { val relation = cachedData.cachedRepresentation val (rowCount, newColStats) = - CommandUtils.computeColumnStats(sparkSession, relation, column) - relation.updateStats(rowCount, newColStats) + CommandUtils.computeColumnStats(sparkSession, relation.merge, column) + relation.fold(CacheManager.inMemoryRelationExtractor, identity). + updateStats(rowCount, newColStats) } /** @@ -310,15 +346,17 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { cachedData = cachedData.filterNot(cd => needToRecache.exists(_ eq cd)) } needToRecache.foreach { cd => - cd.cachedRepresentation.cacheBuilder.clearCache() + cd.cachedRepresentation.fold(CacheManager.inMemoryRelationExtractor, identity). + cacheBuilder.clearCache() val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark) val newCache = sessionWithConfigsOff.withActive { val qe = sessionWithConfigsOff.sessionState.executePlan(cd.plan) - InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe) + InMemoryRelation(cd.cachedRepresentation. + fold(CacheManager.inMemoryRelationExtractor, identity).cacheBuilder, qe) } - val recomputedPlan = cd.copy(cachedRepresentation = newCache) + val recomputedPlan = cd.copy(cachedRepresentation = Right(newCache)) this.synchronized { - if (lookupCachedDataInternal(recomputedPlan.plan).nonEmpty) { + if (lookupCachedDataInternal(recomputedPlan.plan).exists(_.cachedRepresentation.isRight)) { logWarning("While recaching, data was already added to cache.") } else { cachedData = recomputedPlan +: cachedData @@ -336,6 +374,34 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { lookupCachedDataInternal(query.queryExecution.normalized) } + /* + Partial match cases: + InComingPlan (case of add cols) cached plan InComing Plan ( case of rename) + Project P2 Project P1 Project P2 + attr1 attr1 attr1 + attr2 attr2 Alias2'(x, attr2) + Alias3 Alias3 Alias3'(y, Alias3-childExpr) + Alias4 Alias4 Alias4'(z, Alias4-childExpr) + Alias5 (k, f(attr1, attr2, al3, al4) + Alias6 (p, f(attr1, attr2, al3, al4) + */ + + /** Optionally returns cached data for the given [[LogicalPlan]]. */ + def lookupCachedDataInternal(plan: LogicalPlan): Option[CachedData] = { + val fullMatch = cachedData.find(cd => plan.sameResult(cd.plan)) + val result = fullMatch.map(Option(_)).getOrElse( + plan match { + case p: Project => lookUpPartiallyMatchedCachedPlan(p, cachedData) + case _ => None + }) + if (result.isDefined) { + CacheManager.logCacheOperation(log"Dataframe cache hit for input plan:" + + log"\n${MDC(QUERY_PLAN, plan)} matched with cache entry:" + + log"${MDC(DATAFRAME_CACHE_ENTRY, result.get)}") + } + result + } + /** * Optionally returns cached data for the given [[LogicalPlan]]. The given plan will be normalized * before being used further. @@ -345,29 +411,248 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { lookupCachedDataInternal(normalized) } - private def lookupCachedDataInternal(plan: LogicalPlan): Option[CachedData] = { - val result = cachedData.find(cd => plan.sameResult(cd.plan)) - if (result.isDefined) { - CacheManager.logCacheOperation(log"Dataframe cache hit for input plan:" + - log"\n${MDC(QUERY_PLAN, plan)} matched with cache entry:" + - log"${MDC(DATAFRAME_CACHE_ENTRY, result.get)}") + private def lookUpPartiallyMatchedCachedPlan( + incomingProject: Project, + cachedPlansToUse: IndexedSeq[CachedData]): Option[CachedData] = { + var foundMatch = false + var partialMatch: Option[CachedData] = None + val (incmngchild, incomingFilterChain) = + CompatibilityChecker.extractChildIgnoringFiltersFromIncomingProject(incomingProject) + for (cd <- cachedPlansToUse if !foundMatch) { + (incmngchild, incomingFilterChain, cd.plan) match { + case CompatibilityChecker(residualIncomingFilterChain, cdPlanProject) => + // since the child of both incoming and cached plan are same + // that is why we are here. for mapping and comparison purposes lets + // canonicalize the cachedPlan's project list in terms of the incoming plan's child + // so that we can map correctly. + val cdPlanToIncomngPlanChildOutputMapping = + cdPlanProject.child.output.zip(incmngchild.output).toMap + + val canonicalizedCdProjList = cdPlanProject.projectList.map(_.transformUp { + case attr: Attribute => cdPlanToIncomngPlanChildOutputMapping(attr) + }.asInstanceOf[NamedExpression]) + + // matchIndexInCdPlanProj remains -1 in the end, it indicates it is + // new cols created out of existing output attribs + val (directlyMappedincomingToCachedPlanIndx, inComingProjNoDirectMapping) = + getDirectAndIndirectMappingOfIncomingToCachedProjectAttribs( + incomingProject, canonicalizedCdProjList) + + // Now there is a possible case where a literal is present in IMR as attribute + // and the incoming project also has that literal somewhere in the alias. Though + // we do not need to read it but looks like the deserializer fails if we skip that + // literal in the projection enforced on IMR. so in effect even if we do not + // require an attribute it still needs to be present in the projection forced + // also its possible that some attribute from IMR can be used in subexpression + // of the incoming projection. so we have to handle that + val unusedAttribsOfCDPlanToGenIncomingAttr = + cdPlanProject.projectList.indices.filterNot(i => + directlyMappedincomingToCachedPlanIndx.exists(_._2 == i)).map(i => { + val cdAttrib = cdPlanProject.projectList(i) + i -> AttributeReference(cdAttrib.name, cdAttrib.dataType, + cdAttrib.nullable, cdAttrib.metadata)(qualifier = cdAttrib.qualifier) + }) + + // Because in case of rename multiple incmong named exprs ( attribute or aliases) + // will point to a common cdplan attrib, we need to ensure they do not create + // separate attribute in the the modifiedProject for incoming plan.. + // that is a single attribute ref is present in all mixes of rename and pass thru + // attributes. + // so we will use the first attribute ref in the incoming directly mapped project + // or if no attrib exists ( only case of rename) we will pick the child expr which + // is bound to be an attribute as the common ref. + val cdAttribToCommonAttribForIncmngNe = directlyMappedincomingToCachedPlanIndx.map { + case (inAttribIndex, cdAttribIndex) => + cdPlanProject.projectList(cdAttribIndex).toAttribute -> + incomingProject.projectList(inAttribIndex) + }.groupBy(_._1).map { + case (cdAttr, incomngSeq) => + val incmngCommonAttrib = incomngSeq.map(_._2).flatMap { + case attr: Attribute => Seq(attr) + case Alias(attr: Attribute, _) => Seq(attr) + case _ => Seq.empty + }.headOption.getOrElse( + AttributeReference(cdAttr.name, cdAttr.dataType, cdAttr.nullable)()) + cdAttr -> incmngCommonAttrib + } + + // If expressions of inComingProjNoDirectMapping can be expressed in terms of the + // incoming attribute refs or incoming alias exprs, which can be mapped directly + // to the CachedPlan's output, we are good. so lets transform such indirectly + // mappable named expressions in terms of mappable attributes of the incoming plan + val transformedIndirectlyMappableExpr = + transformIndirectlyMappedExpressionsToUseCachedPlanAttributes( + inComingProjNoDirectMapping, incomingProject, cdPlanProject, + directlyMappedincomingToCachedPlanIndx, cdAttribToCommonAttribForIncmngNe, + unusedAttribsOfCDPlanToGenIncomingAttr, canonicalizedCdProjList) + + val projectionToForceOnCdPlan = cdPlanProject.output.zipWithIndex.map { + case (cdAttr, i) => + cdAttribToCommonAttribForIncmngNe.getOrElse(cdAttr, + unusedAttribsOfCDPlanToGenIncomingAttr.find(_._1 == i).map(_._2).get) + } + val forcedAttribset = AttributeSet(projectionToForceOnCdPlan) + if (transformedIndirectlyMappableExpr.forall( + _._2.references.subsetOf(forcedAttribset))) { + val transformedIntermediateFilters = transformFilters(residualIncomingFilterChain, + projectionToForceOnCdPlan, canonicalizedCdProjList) + if (transformedIntermediateFilters.forall(_.references.subsetOf(forcedAttribset))) { + val modifiedInProj = replacementProjectListForIncomingProject(incomingProject, + directlyMappedincomingToCachedPlanIndx, cdPlanProject, + cdAttribToCommonAttribForIncmngNe, transformedIndirectlyMappableExpr) + // If InMemoryRelation (right is defined) it is the case of lookup or cache query + // Else it is a case of dummy CachedData partial lookup for finding out if the + // plan being checked uses the uncached plan + val newPartialPlan = if (cd.cachedRepresentation.isRight) { + val root = cd.cachedRepresentation.toOption.get.withOutput( + projectionToForceOnCdPlan) + if (transformedIntermediateFilters.isEmpty) { + Project(modifiedInProj, root) + } else { + val chainedFilter = CompatibilityChecker.combineFilterChainUsingRoot( + transformedIntermediateFilters, root) + Project(modifiedInProj, chainedFilter) + } + } else { + cd.cachedRepresentation.left.toOption.get + } + partialMatch = Option(cd.copy(cachedRepresentation = Left(newPartialPlan))) + foundMatch = true + } + } + case _ => + } + } + partialMatch + } + + private def transformFilters(skippedFilters: Seq[Filter], + projectionToForceOnCdPlan: Seq[Attribute], + canonicalizedCdProjList: Seq[NamedExpression]): Seq[Filter] = { + val canonicalizedCdProjAsExpr = canonicalizedCdProjList.map { + case Alias(child, _) => child + case x => x } - result + skippedFilters.map(f => { + val transformedCondn = f.condition.transformDown { + case expr => val matchedIndex = canonicalizedCdProjAsExpr.indexWhere(_ == expr) + if (matchedIndex != -1) { + projectionToForceOnCdPlan(matchedIndex) + } else { + expr + } + } + f.copy(condition = transformedCondn) + }) } - /** - * Replaces segments of the given logical plan with cached versions where possible. The input - * plan must be normalized. - */ - private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = { + private def replacementProjectListForIncomingProject( + incomingProject: Project, + directlyMappedincomingToCachedPlanIndx: Seq[(Int, Int)], + cdPlanProject: Project, + cdAttribToCommonAttribForIncmngNe: Map[Attribute, Attribute], + transformedIndirectlyMappableExpr: Map[Int, NamedExpression]): Seq[NamedExpression] = + { + incomingProject.projectList.zipWithIndex.map { + case (ne, indx) => + directlyMappedincomingToCachedPlanIndx.find(_._1 == indx).map { + case (_, cdIndex) => + ne match { + case attr: Attribute => attr + case al: Alias => + val cdAttr = cdPlanProject.projectList(cdIndex).toAttribute + al.copy(child = cdAttribToCommonAttribForIncmngNe(cdAttr))( + exprId = al.exprId, qualifier = al.qualifier, + explicitMetadata = al.explicitMetadata, + nonInheritableMetadataKeys = al.nonInheritableMetadataKeys + ) + } + }.getOrElse({ + transformedIndirectlyMappableExpr(indx) + }) + } + } + + private def transformIndirectlyMappedExpressionsToUseCachedPlanAttributes( + inComingProjNoDirectMapping: Seq[(Int, Int)], + incomingProject: Project, + cdPlanProject: Project, + directlyMappedincomingToCachedPlanIndx: Seq[(Int, Int)], + cdAttribToCommonAttribForIncmngNe: Map[Attribute, Attribute], + unusedAttribsOfCDPlanToGenIncomingAttr: Seq[(Int, AttributeReference)], + canonicalizedCdProjList: Seq[NamedExpression]): Map[Int, NamedExpression] = + { + inComingProjNoDirectMapping.map { + case (incomngIndex, _) => + val indirectIncmnNe = incomingProject.projectList(incomngIndex) + val modifiedNe = indirectIncmnNe.transformDown { + case expr => directlyMappedincomingToCachedPlanIndx.find { + case (incomingIndex, _) => + val directMappedNe = incomingProject.projectList(incomingIndex) + directMappedNe.toAttribute == expr || + directMappedNe.children.headOption.contains(expr) + }.map { + case (_, cdIndex) => + val cdAttrib = cdPlanProject.projectList(cdIndex).toAttribute + cdAttribToCommonAttribForIncmngNe(cdAttrib) + }.orElse( + unusedAttribsOfCDPlanToGenIncomingAttr.find { + case (i, _) => val cdNe = canonicalizedCdProjList(i) + cdNe.children.headOption.contains(expr) + }.map(_._2)). + map(ne => ne.toAttribute).getOrElse(expr) + }.asInstanceOf[NamedExpression] + + incomngIndex -> modifiedNe + }.toMap + } + + private def getDirectAndIndirectMappingOfIncomingToCachedProjectAttribs( + incomingProject: Project, + canonicalizedCdProjList: Seq[NamedExpression]): (Seq[(Int, Int)], Seq[(Int, Int)]) = + { + incomingProject.projectList.zipWithIndex.map { + case (inComingNE, index) => + // first check for equivalent named expressions..if index is != -1, that means + // it is pass thru Alias or pass thru - Attribute + var matchIndexInCdPlanProj = canonicalizedCdProjList.indexWhere(_ == inComingNE) + if (matchIndexInCdPlanProj == -1) { + // if match index is -1, that means it could be two possibilities: + // 1) it is a case of rename which means the incoming expr is an alias and + // its child is an attrib ref, which may have a direct attribref in the + // cdPlanProj, or it may actually have an alias whose ref matches the ref + // of incoming attribRef + // 2) the positions in the incoming project alias and the cdPlanProject are + // different. as a result the canonicalized alias of each would have + // relatively different exprIDs ( as their relative positions differ), but + // even in such cases as their child logical plans are same, so the child + // expression of each alias will have same canonicalized data + val incomingExprToCheck = inComingNE match { + case x: AttributeReference => x + case Alias(expr, _) => expr + } + matchIndexInCdPlanProj = canonicalizedCdProjList.indexWhere { + case Alias(expr, _) => expr == incomingExprToCheck + case x => x == incomingExprToCheck + } + } + index -> matchIndexInCdPlanProj + }.partition(_._2 != -1) + } + + /** Replaces segments of the given logical plan with cached versions where possible. */ + def useCachedData(plan: LogicalPlan): LogicalPlan = { val newPlan = plan transformDown { case command: IgnoreCachedData => command - case currentFragment => + case currentFragment if !currentFragment.isInstanceOf[InMemoryRelation] => lookupCachedDataInternal(currentFragment).map { cached => + // After cache lookup, we should still keep the hints from the input plan. val hints = EliminateResolvedHint.extractHintsFromPlan(currentFragment)._2 - val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) + val cachedPlan = cached.cachedRepresentation.map(_.withOutput(currentFragment.output)). + merge + // The returned hint list is in top-down order, we should create the hint nodes from // right to left. hints.foldRight[LogicalPlan](cachedPlan) { case (hint, p) => @@ -485,6 +770,19 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } object CacheManager extends Logging { + + val expressionRemapper: (Expression, AttributeMap[(NamedExpression, Expression)]) => Expression = + (expr, mappings) => { + expr transformUp { + case attr: AttributeReference => mappings.get(attr).map { + case (_, expr) => expr + }.getOrElse(attr) + } + } + + val inMemoryRelationExtractor: LogicalPlan => InMemoryRelation = + plan => plan.collectLeaves().head.asInstanceOf[InMemoryRelation] + def logCacheOperation(f: => LogEntry): Unit = { SQLConf.get.dataframeCacheLogLevel match { case "TRACE" => logTrace(f) @@ -496,3 +794,87 @@ object CacheManager extends Logging { } } } + +object CompatibilityChecker { + def unapply(data: (LogicalPlan, Seq[Filter], LogicalPlan)): Option[(Seq[Filter], Project)] = { + val(incomingChild, incomingFilterChain, cachedPlan) = data + cachedPlan match { + case p: Project if incomingChild.sameResult(p.child) => Option(incomingFilterChain -> p) + + case f: Filter => + val collectedFilters = mutable.ListBuffer[Filter](f) + var projectFound: Option[Project] = None + var child: LogicalPlan = f.child + var keepChecking = true + while (keepChecking) { + child match { + case x: Filter => child = x.child + collectedFilters += x + case p: Project => projectFound = Option(p) + keepChecking = false + case _ => keepChecking = false + } + } + if (collectedFilters.size <= incomingFilterChain.size && + projectFound.exists(_.child.sameResult(incomingChild))) { + val (residualIncomingFilterChain, otherFilterChain) = incomingFilterChain.splitAt( + incomingFilterChain.size - collectedFilters.size) + val isCompatible = if (otherFilterChain.isEmpty) { + true + } else { + // the other filter chain must be equal to the collected filter chain + // But we need to transform the collected Filter chain such that it is below + // the project of the cached plan, we have found, as the incoming filters are also below + // the incoming project. + val mappingFilterExpr = AttributeMap(projectFound.get.projectList.flatMap { + case _: Attribute => Seq.empty[(Attribute, (NamedExpression, Expression))] + case al: Alias => Seq(al.toAttribute -> (al, al.child)) + }) + + val modifiedCdFilters = collectedFilters.map(f => + f.copy(condition = CacheManager.expressionRemapper( + f.condition, mappingFilterExpr))).toSeq + val chainedFilter1 = combineFilterChainUsingRoot(otherFilterChain, + EmptyRelation(incomingChild.output)) + val chainedFilter2 = combineFilterChainUsingRoot(modifiedCdFilters, + EmptyRelation(projectFound.map(_.child).get.output)) + chainedFilter1.sameResult(chainedFilter2) + } + if (isCompatible) { + Option(residualIncomingFilterChain -> projectFound.get) + } else { + None + } + } else { + None + } + + case _ => None + } + } + + def combineFilterChainUsingRoot(filters: Seq[Filter], root: LogicalPlan): Filter = { + val lastFilterNode = filters.last + val lastFilterMod = lastFilterNode.copy(child = root) + filters.dropRight(1).foldRight(lastFilterMod)((f, c) => f.copy(child = c)) + } + + def extractChildIgnoringFiltersFromIncomingProject(incomingProject: Project): + (LogicalPlan, Seq[Filter]) = { + val collectedFilters = mutable.ListBuffer[Filter]() + var child: LogicalPlan = incomingProject.child + var keepChecking = true + while (keepChecking) { + child match { + case f: Filter => child = f.child + collectedFilters += f + case _ => keepChecking = false + } + } + (child, collectedFilters.toSeq) + } + + case class EmptyRelation(output: Seq[Attribute]) extends LeafNode { + override def maxRows: Option[Long] = Some(0) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 23555c98135f6..6a69168251052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -62,12 +62,16 @@ case class AnalyzeColumnCommand( private def analyzeColumnInCachedData(plan: LogicalPlan, sparkSession: SparkSession): Boolean = { val cacheManager = sparkSession.sharedState.cacheManager val df = Dataset.ofRows(sparkSession, plan) - cacheManager.lookupCachedData(df).map { cachedData => - val columnsToAnalyze = getColumnsToAnalyze( - tableIdent, cachedData.cachedRepresentation, columnNames, allColumns) - cacheManager.analyzeColumnCacheQuery(sparkSession, cachedData, columnsToAnalyze) - cachedData - }.isDefined + cacheManager.lookupCachedData(df).exists { cachedData => + if (cachedData.cachedRepresentation.isRight) { + val columnsToAnalyze = getColumnsToAnalyze( + tableIdent, cachedData.cachedRepresentation.merge, columnNames, allColumns) + cacheManager.analyzeColumnCacheQuery(sparkSession, cachedData, columnsToAnalyze) + true + } else { + false + } + } } private def analyzeColumnInTempView(plan: LogicalPlan, sparkSession: SparkSession): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index 48d98c14c3889..f49a53b9e6d76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -241,7 +241,7 @@ object CommandUtils extends Logging { // Analyzes a catalog view if the view is cached val table = sparkSession.table(tableIdent.quotedString) val cacheManager = sparkSession.sharedState.cacheManager - if (cacheManager.lookupCachedData(table).isDefined) { + if (cacheManager.lookupCachedData(table).exists(_.cachedRepresentation.isRight)) { if (!noScan) { // To collect table stats, materializes an underlying columnar RDD table.count() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 84b73a74f3ab2..b1f27472340bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeed import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.execution.CacheManager import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -201,7 +202,8 @@ case class AlterTableRenameCommand( // If `optStorageLevel` is defined, the old table was cached. val optCachedData = sparkSession.sharedState.cacheManager.lookupCachedData( sparkSession.table(oldName.unquotedString)) - val optStorageLevel = optCachedData.map(_.cachedRepresentation.cacheBuilder.storageLevel) + val optStorageLevel = optCachedData.map(_.cachedRepresentation. + fold(CacheManager.inMemoryRelationExtractor, identity).cacheBuilder.storageLevel) if (optStorageLevel.isDefined) { CommandUtils.uncacheTableOrView(sparkSession, oldName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 499721fbae4e8..1340bad24766e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.{CacheManager, FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelationWithTable, PushableColumnAndNestedColumn} import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} @@ -74,8 +74,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val v2Relation = DataSourceV2Relation.create(r.table, Some(r.catalog), Some(r.identifier)) val cache = session.sharedState.cacheManager.lookupCachedData(session, v2Relation) session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true) - if (cache.isDefined) { - val cacheLevel = cache.get.cachedRepresentation.cacheBuilder.storageLevel + if (cache.exists(_.cachedRepresentation.isRight)) { + val cacheLevel = cache.get.cachedRepresentation. + fold(CacheManager.inMemoryRelationExtractor, identity).cacheBuilder.storageLevel Some(cacheLevel) } else { None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index bda8c7f26082f..81205d7fe2bde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql import org.scalatest.concurrent.TimeLimits import org.scalatest.time.SpanSugar._ +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.IntegerType import org.apache.spark.storage.StorageLevel import org.apache.spark.tags.SlowSQLTest @@ -244,13 +246,12 @@ class DatasetCacheSuite extends QueryTest case i: InMemoryRelation => i.cacheBuilder.cachedPlan } assert(df1LimitInnerPlan.isDefined && df1LimitInnerPlan.get == df1InnerPlan) - - // Verify that df2's cache has been re-cached, with a new physical plan rid of dependency - // on df, since df2's cache had not been loaded before df.unpersist(). val df2Limit = df2.limit(2) val df2LimitInnerPlan = df2Limit.queryExecution.withCachedData.collectFirst { case i: InMemoryRelation => i.cacheBuilder.cachedPlan } + // Verify that df2's cache has been re-cached, with a new physical plan rid of dependency + // on df, since df2's cache had not been loaded before df.unpersist(). assert(df2LimitInnerPlan.isDefined && !df2LimitInnerPlan.get.exists(_.isInstanceOf[InMemoryTableScanExec])) } @@ -275,6 +276,77 @@ class DatasetCacheSuite extends QueryTest } } + test("SPARK-47609. Partial match IMR with more columns than base") { + val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b", + "id % 3 AS c") + val testDfCreator = () => spark.range(0, 50).select($"id".as("a"), ($"id" % 2).as("b"), + ($"id" + 1).as("d"), ($"id" % 3).as("c") ) + checkIMRUseAndInvalidation(baseDfCreator, testDfCreator) + } + + test("SPARK-47609. Partial match IMR with less columns than base") { + val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b", + "id % 3 AS c") + val testDfCreator = () => spark.range(0, 50).select($"id".as("a"), ($"id" % 2).as("b")) + checkIMRUseAndInvalidation(baseDfCreator, testDfCreator) + } + + test("SPARK-47609. Partial match IMR with columns remapped to different value") { + val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b", + "id % 3 AS c") + val testDfCreator = () => spark.range(0, 50).select(($"id" + ($"id" % 3)).as("a"), + ($"id" % 2).as("b"), ($"id" % 3).as("c")) + checkIMRUseAndInvalidation(baseDfCreator, testDfCreator) + } + + test("SPARK-47609. Partial match IMR with extra columns as literal") { + val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b", + "id % 3 AS c") + val testDfCreator = () => spark.range(0, 50).select(($"id" + ($"id" % 3)).as("a"), + ($"id" % 2).as("b"), expr("100").cast(IntegerType).as("t")) + checkIMRUseAndInvalidation(baseDfCreator, testDfCreator) + } + + test("SPARK-47609. Partial match IMR with multiple columns remap") { + val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b", + "id % 3 AS c") + val testDfCreator = () => spark.range(0, 50).select((($"id" % 3) * ($"id" % 2)).as("c"), + ($"id" - ($"id" % 3)).as("a"), $"id".as("b")) + checkIMRUseAndInvalidation(baseDfCreator, testDfCreator) + } + + test("SPARK-47609. Partial match IMR with multiple columns remap and one remap as constant") { + val baseDfCreator = () => spark.range(0, 50).selectExpr("id as a ", "id % 2 AS b", + "id % 3 AS c") + val testDfCreator = () => spark.range(0, 50).select((($"id" % 3) * ($"id" % 2)).as("c"), + ($"id" - ($"id" % 3)).as("a"), expr("100").cast(IntegerType).as("b"), $"b".as("d")) + checkIMRUseAndInvalidation(baseDfCreator, testDfCreator) + } + + test("SPARK-47609. Partial match IMR with partial filters match") { + val baseDfCreator = () => spark.range(0, 100). + selectExpr("id as a ", "id % 2 AS b", "id % 3 AS c").filter($"a" > 7) + val testDfCreator = () => spark.range(0, 100).filter($"id"> 7).filter(($"id" % 3) > 8).select( + (($"id" % 3) * ($"id" % 2)).as("c"), ($"id" - ($"id" % 3)).as("a"), + expr("100").cast(IntegerType).as("b"), $"b".as("d")) + checkIMRUseAndInvalidation(baseDfCreator, testDfCreator) + } + + test("SPARK-47609. Because of filter mismatch partial match should not happen") { + val baseDfCreator = () => spark.range(0, 100). + selectExpr("id as a ", "id % 2 AS b", "id % 3 AS c").filter($"a" > 7) + val testDfCreator = () => spark.range(0, 100).filter(($"id" % 3) > 8).select( + (($"id" % 3) * ($"id" % 2)).as("c"), ($"id" - ($"id" % 3)).as("a"), + expr("100").cast(IntegerType).as("b"), $"b".as("d")) + val baseDf = baseDfCreator() + baseDf.cache() + val testDf = testDfCreator() + verifyCacheDependency(baseDfCreator(), 1) + // cache should not be used + verifyCacheDependency(testDf, 0) + baseDf.unpersist(true) + } + test("SPARK-44653: non-trivial DataFrame unions should not break caching") { val df1 = Seq(1 -> 1).toDF("i", "j") val df2 = Seq(2 -> 2).toDF("i", "j") @@ -312,4 +384,40 @@ class DatasetCacheSuite extends QueryTest } } } + + protected def checkIMRUseAndInvalidation( + baseDfCreator: () => DataFrame, + testExec: () => DataFrame): Unit = { + // now check if the results of optimized dataframe and completely unoptimized dataframe are + // same + val baseDf = baseDfCreator() + val testDf = testExec() + val testDfRows = testDf.collect() + baseDf.cache() + verifyCacheDependency(baseDfCreator(), 1) + verifyCacheDependency(testExec(), 1) + baseDfCreator().unpersist(true) + verifyCacheDependency(baseDfCreator(), 0) + verifyCacheDependency(testExec(), 0) + baseDfCreator().cache() + val newTestDf = testExec() + // re-verify cache dependency + verifyCacheDependency(newTestDf, 1) + checkAnswer(newTestDf, testDfRows) + baseDfCreator().unpersist(true) + } + + def verifyCacheDependency(df: DataFrame, numOfCachesExpected: Int): Unit = { + def recurse(sparkPlan: SparkPlan): Int = { + val imrs = sparkPlan.collect { + case i: InMemoryTableScanExec => i + } + imrs.size + imrs.map(ime => recurse(ime.relation.cacheBuilder.cachedPlan)).sum + } + val cachedPlans = df.queryExecution.withCachedData.collect { + case i: InMemoryRelation => i.cacheBuilder.cachedPlan + } + val totalIMRs = cachedPlans.size + cachedPlans.map(ime => recurse(ime)).sum + assert(totalIMRs == numOfCachesExpected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala index 506b44741ab4b..d872fdac5cfbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableRenameSuiteBase.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.execution.CacheManager import org.apache.spark.storage.StorageLevel /** @@ -73,7 +74,8 @@ trait AlterTableRenameSuiteBase extends QueryTest with DDLCommandTestUtils { def getStorageLevel(tableName: String): StorageLevel = { val table = spark.table(tableName) val cachedData = spark.sharedState.cacheManager.lookupCachedData(table).get - cachedData.cachedRepresentation.cacheBuilder.storageLevel + cachedData.cachedRepresentation.fold(CacheManager.inMemoryRelationExtractor, identity). + cacheBuilder.storageLevel } sql(s"CREATE TABLE $src (c0 INT) $defaultUsing") sql(s"INSERT INTO $src SELECT 0")