Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
* <li>Analyzer Rules.</li>
* <li>Check Analysis Rules.</li>
* <li>Optimizer Rules.</li>
* <li>Data Source Rewrite Rules.</li>
* <li>Planning Strategies.</li>
* <li>Customized Parser.</li>
* <li>(External) Catalog listeners.</li>
Expand Down Expand Up @@ -199,6 +200,21 @@ class SparkSessionExtensions {
optimizerRules += builder
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the description above, too?

* This current provides the following extension points:
*
* <ul>
* <li>Analyzer Rules.</li>
* <li>Check Analysis Rules.</li>
* <li>Optimizer Rules.</li>
* <li>Planning Strategies.</li>
* <li>Customized Parser.</li>
* <li>(External) Catalog listeners.</li>
* <li>Columnar Rules.</li>
* <li>Adaptive Query Stage Preparation Rules.</li>
* </ul>

Copy link
Contributor Author

@aokolnychyi aokolnychyi Dec 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for catching this!

private[this] val dataSourceRewriteRules = mutable.Buffer.empty[RuleBuilder]

private[sql] def buildDataSourceRewriteRules(session: SparkSession): Seq[Rule[LogicalPlan]] = {
dataSourceRewriteRules.map(_.apply(session)).toSeq
}

/**
* Inject an optimizer `Rule` builder that rewrites data source plans into the [[SparkSession]].
* The injected rules will be executed after the operator optimization batch and before rules
* that depend on stats.
*/
def injectDataSourceRewriteRule(builder: RuleBuilder): Unit = {
dataSourceRewriteRules += builder
}

private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder]

private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ abstract class BaseSessionStateBuilder(
*
* Note that this may NOT depend on the `optimizer` function.
*/
protected def customDataSourceRewriteRules: Seq[Rule[LogicalPlan]] = Nil
protected def customDataSourceRewriteRules: Seq[Rule[LogicalPlan]] = {
extensions.buildDataSourceRewriteRules(session)
}

/**
* Planner that converts optimized logical plans to physical plans.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}

test("SPARK-33621: inject data source rewrite rule") {
withSession(Seq(_.injectDataSourceRewriteRule(MyRule))) { session =>
assert(session.sessionState.optimizer.dataSourceRewriteRules.contains(MyRule(session)))
}
}

test("inject spark planner strategy") {
withSession(Seq(_.injectPlannerStrategy(MySparkStrategy))) { session =>
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
Expand Down