File tree Expand file tree Collapse file tree
main/scala/org/apache/spark/sql
test/scala/org/apache/spark/sql Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -226,6 +226,21 @@ class SparkSessionExtensions {
226226 preCBORules += builder
227227 }
228228
229+ private [this ] val earlyScanPushDownRules = mutable.Buffer .empty[RuleBuilder ]
230+
231+ private [sql] def buildEarlyScanPushDownRules (session : SparkSession ): Seq [Rule [LogicalPlan ]] = {
232+ earlyScanPushDownRules.map(_.apply(session)).toSeq
233+ }
234+
235+ /**
236+ * Inject an optimizer `Rule` builder that rewrites logical plans into the [[SparkSession ]].
237+ * The injected rules will be executed once after the operator optimization batch and
238+ * after any push down optimization rules.
239+ */
240+ def injectEarlyScanPushDownRules (builder : RuleBuilder ): Unit = {
241+ earlyScanPushDownRules += builder
242+ }
243+
229244 private [this ] val plannerStrategyBuilders = mutable.Buffer .empty[StrategyBuilder ]
230245
231246 private [sql] def buildPlannerStrategies (session : SparkSession ): Seq [Strategy ] = {
Original file line number Diff line number Diff line change @@ -270,7 +270,9 @@ abstract class BaseSessionStateBuilder(
270270 *
271271 * Note that this may NOT depend on the `optimizer` function.
272272 */
273- protected def customEarlyScanPushDownRules : Seq [Rule [LogicalPlan ]] = Nil
273+ protected def customEarlyScanPushDownRules : Seq [Rule [LogicalPlan ]] = {
274+ extensions.buildEarlyScanPushDownRules(session)
275+ }
274276
275277 /**
276278 * Custom rules for rewriting plans after operator optimization and before CBO.
Original file line number Diff line number Diff line change @@ -95,6 +95,12 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
9595 }
9696 }
9797
98+ test(" SPARK-37518: inject a early scan push down rule" ) {
99+ withSession(Seq (_.injectEarlyScanPushDownRules(MyRule ))) { session =>
100+ assert(session.sessionState.optimizer.earlyScanPushDownRules.contains(MyRule (session)))
101+ }
102+ }
103+
98104 test(" inject spark planner strategy" ) {
99105 withSession(Seq (_.injectPlannerStrategy(MySparkStrategy ))) { session =>
100106 assert(session.sessionState.planner.strategies.contains(MySparkStrategy (session)))
You can’t perform that action at this time.
0 commit comments