Skip to content

Commit 88c6d0b

Browse files
committed
SPARK-42500: ConstantPropagation support more case
1 parent 8cfd5bf commit 88c6d0b

2 files changed

Lines changed: 42 additions & 7 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,20 @@ object ConstantPropagation extends Rule[LogicalPlan] {
200200

201201
private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
202202
: Expression = {
203-
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
204-
val predicates = equalityPredicates.map(_._2).toSet
205-
def replaceConstants0(expression: Expression) = expression transform {
203+
val allConstantsMap = AttributeMap(equalityPredicates.map(_._1))
204+
val allPredicates = equalityPredicates.map(_._2).toSet
205+
def replaceConstants0(
206+
expression: Expression, constantsMap: AttributeMap[Literal]) = expression transform {
206207
case a: AttributeReference => constantsMap.getOrElse(a, a)
207208
}
208209
condition transform {
209-
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
210-
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
210+
case b: BinaryComparison =>
211+
if (!allPredicates.contains(b)) {
212+
replaceConstants0(b, allConstantsMap)
213+
} else {
214+
val excludedEqualityPredicates = equalityPredicates.filterNot(_._2.semanticEquals(b))
215+
replaceConstants0(b, AttributeMap(excludedEqualityPredicates.map(_._1)))
216+
}
211217
}
212218
}
213219
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
2323
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
2425
import org.apache.spark.sql.catalyst.plans.PlanTest
2526
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
2627
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -159,8 +160,9 @@ class ConstantPropagationSuite extends PlanTest {
159160
columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3)))
160161

161162
val correctAnswer = testRelation
162-
.select(columnA)
163-
.where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze
163+
.select(columnA, columnB)
164+
.where(FalseLiteral)
165+
.select(columnA).analyze
164166

165167
comparePlans(Optimize.execute(query.analyze), correctAnswer)
166168
}
@@ -186,4 +188,31 @@ class ConstantPropagationSuite extends PlanTest {
186188
.analyze
187189
comparePlans(Optimize.execute(query2), correctAnswer2)
188190
}
191+
192+
test("SPARK-42500: ConstantPropagation supports more cases") {
193+
comparePlans(
194+
Optimize.execute(testRelation.where(columnA === 1 && columnB > columnA + 2).analyze),
195+
testRelation.where(columnA === 1 && columnB > 3).analyze)
196+
197+
comparePlans(
198+
Optimize.execute(testRelation.where(columnA === 1 && columnA === 2).analyze),
199+
testRelation.where(FalseLiteral).analyze)
200+
201+
comparePlans(
202+
Optimize.execute(testRelation.where(columnA === 1 && columnA === columnA + 2).analyze),
203+
testRelation.where(FalseLiteral).analyze)
204+
205+
comparePlans(
206+
Optimize.execute(
207+
testRelation.where((columnA === 1 || columnB === 2) && columnB === 1).analyze),
208+
testRelation.where(columnA === 1 && columnB === 1).analyze)
209+
210+
comparePlans(
211+
Optimize.execute(testRelation.where(columnA === 1 && columnA === 1).analyze),
212+
testRelation.where(columnA === 1).analyze)
213+
214+
comparePlans(
215+
Optimize.execute(testRelation.where(Not(columnA === 1 && columnA === columnA + 2)).analyze),
216+
testRelation.where(Not(columnA === 1) || Not(columnA === columnA + 2)).analyze)
217+
}
189218
}

0 commit comments

Comments
 (0)