-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-32945][SQL] Avoid collapsing projects if reaching max allowed common exprs #29950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f418714
98843dd
1b567e7
76509b3
43eb50d
4bf4dc2
9bfafc7
c2c01e4
4990375
e8f18f8
58e71d8
bbaae3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -732,10 +732,12 @@ object ColumnPruning extends Rule[LogicalPlan] { | |
| * `GlobalLimit(LocalLimit)` pattern is also considered. | ||
| */ | ||
| object CollapseProject extends Rule[LogicalPlan] { | ||
|
|
||
| def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { | ||
| case p1 @ Project(_, p2: Project) => | ||
| if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { | ||
| val maxCommonExprs = SQLConf.get.maxCommonExprsInCollapseProject | ||
|
|
||
| if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList) || | ||
| getLargestNumOfCommonOutput(p1.projectList, p2.projectList) >= maxCommonExprs) { | ||
|
||
| p1 | ||
| } else { | ||
| p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) | ||
|
|
@@ -766,6 +768,23 @@ object CollapseProject extends Rule[LogicalPlan] { | |
| }) | ||
| } | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could extend to other cases like |
||
|
|
||
| // Counts for the largest times common outputs from lower operator are used in upper operators. | ||
| private def getLargestNumOfCommonOutput( | ||
|
||
| upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Int = { | ||
| val aliases = collectAliases(lower) | ||
| val exprMap = mutable.HashMap.empty[Attribute, Int] | ||
|
|
||
| upper.foreach(_.collect { | ||
| case a: Attribute if aliases.contains(a) => exprMap.update(a, exprMap.getOrElse(a, 0) + 1) | ||
| }) | ||
|
|
||
| if (exprMap.size > 0) { | ||
viirya marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| exprMap.maxBy(_._2)._2 | ||
| } else { | ||
| 0 | ||
| } | ||
| } | ||
|
|
||
| private def haveCommonNonDeterministicOutput( | ||
| upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { | ||
| // Create a map of Aliases to their values from the lower projection. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1926,6 +1926,19 @@ object SQLConf { | |
| .booleanConf | ||
| .createWithDefault(true) | ||
|
|
||
| val MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT = | ||
| buildConf("spark.sql.optimizer.maxCommonExprsInCollapseProject") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we set this value to 1, all the existing tests can pass?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess not. We might have at lease few common expressions in collapsed projection. If set to 1, any duplicated expression is not allowed. |
||
| .doc("An integer number indicates the maximum allowed number of a common expression " + | ||
|
||
| "can be collapsed into upper Project from lower Project by optimizer rule " + | ||
| "`CollapseProject`. Normally `CollapseProject` will collapse adjacent Project " + | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Just a comment) Even if we set
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but currently if we exclude
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm I see. Yea, updating the doc sounds nice to me. |
||
| "and merge expressions. But in some edge cases, expensive expressions might be " + | ||
| "duplicated many times in merged Project by this optimization. This config sets " + | ||
| "a maximum number. Once an expression is duplicated equal to or more than this number " + | ||
| "if merging two Project, Spark SQL will skip the merging.") | ||
| .version("3.1.0") | ||
| .intConf | ||
dongjoon-hyun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| .createWithDefault(20) | ||
|
||
|
|
||
| val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = | ||
| buildConf("spark.sql.decimalOperations.allowPrecisionLoss") | ||
| .internal() | ||
|
|
@@ -3289,6 +3302,8 @@ class SQLConf extends Serializable with Logging { | |
|
|
||
| def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) | ||
|
|
||
| def maxCommonExprsInCollapseProject: Int = getConf(MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT) | ||
|
|
||
| def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) | ||
|
|
||
| def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer | |
| import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases | ||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions.{Alias, Rand} | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.plans.PlanTest | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
| import org.apache.spark.sql.types.MetadataBuilder | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.types.{MetadataBuilder, StructType} | ||
|
|
||
| class CollapseProjectSuite extends PlanTest { | ||
| object Optimize extends RuleExecutor[LogicalPlan] { | ||
|
|
@@ -170,4 +171,34 @@ class CollapseProjectSuite extends PlanTest { | |
| val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze | ||
| comparePlans(optimized, expected) | ||
| } | ||
|
|
||
| test("SPARK-32945: avoid collapsing projects if reaching max allowed common exprs") { | ||
dongjoon-hyun marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| val options = Map.empty[String, String] | ||
| val schema = StructType.fromDDL("a int, b int, c string, d long") | ||
|
|
||
| Seq("1", "2", "3", "4").foreach { maxCommonExprs => | ||
| withSQLConf(SQLConf.MAX_COMMON_EXPRS_IN_COLLAPSE_PROJECT.key -> maxCommonExprs) { | ||
| // If we collapse two Projects, `JsonToStructs` will be repeated three times. | ||
| val relation = LocalRelation('json.string) | ||
| val query = relation.select( | ||
| JsonToStructs(schema, options, 'json).as("struct")) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indentation? Maybe, the following is better? - val query1 = relation.select(
- JsonToStructs(schema, options, 'json).as("struct"))
- .select(
+ val query1 = relation.select(JsonToStructs(schema, options, 'json).as("struct"))
+ .select( |
||
| .select( | ||
| GetStructField('struct, 0).as("a"), | ||
| GetStructField('struct, 1).as("b"), | ||
| GetStructField('struct, 2).as("c")).analyze | ||
|
||
| val optimized = Optimize.execute(query) | ||
|
|
||
| if (maxCommonExprs.toInt <= 3) { | ||
| val expected = query | ||
| comparePlans(optimized, expected) | ||
| } else { | ||
| val expected = relation.select( | ||
| GetStructField(JsonToStructs(schema, options, 'json), 0).as("a"), | ||
| GetStructField(JsonToStructs(schema, options, 'json), 1).as("b"), | ||
| GetStructField(JsonToStructs(schema, options, 'json), 2).as("c")).analyze | ||
| comparePlans(optimized, expected) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.