diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9eee7c2b914a..b7c8f775b857 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -185,6 +185,9 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveLiteralFromGroupExpressions, RemoveRepetitionFromGroupExpressions) :: Nil ++ operatorOptimizationBatch) :+ + // This batch rewrites data source plans and should be run after the operator + // optimization batch and before any batches that depend on stats. + Batch("Data Source Rewrite Rules", Once, dataSourceRewriteRules: _*) :+ // This batch pushes filters and projections into scan nodes. Before this batch, the logical // plan may contain nodes that do not report stats. Anything that uses stats must run after // this batch. @@ -289,6 +292,12 @@ abstract class Optimizer(catalogManager: CatalogManager) */ def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = Nil + /** + * Override to provide additional rules for rewriting data source plans. Such rules will be + * applied after operator optimization rules and before any rules that depend on stats. + */ + def dataSourceRewriteRules: Seq[Rule[LogicalPlan]] = Nil + /** * Returns (defaultBatches - (excludedRules - nonExcludableRules)), the rule batches that * eventually run in the Optimizer. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 538a5408723b..e159d88bd822 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -240,6 +240,9 @@ abstract class BaseSessionStateBuilder( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = super.earlyScanPushDownRules ++ customEarlyScanPushDownRules + override def dataSourceRewriteRules: Seq[Rule[LogicalPlan]] = + super.dataSourceRewriteRules ++ customDataSourceRewriteRules + override def extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = super.extendedOperatorOptimizationRules ++ customOperatorOptimizationRules } @@ -263,6 +266,14 @@ abstract class BaseSessionStateBuilder( */ protected def customEarlyScanPushDownRules: Seq[Rule[LogicalPlan]] = Nil + /** + * Custom rules for rewriting data source plans to add to the Optimizer. Prefer overriding + * this instead of creating your own Optimizer. + * + * Note that this may NOT depend on the `optimizer` function. + */ + protected def customDataSourceRewriteRules: Seq[Rule[LogicalPlan]] = Nil + /** * Planner that converts optimized logical plans to physical plans. *