diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a28b6a0feb8f9..9d98739a95e96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -400,13 +400,24 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper // Push down deterministic projection through UNION ALL case p @ Project(projectList, Union(children)) => assert(children.nonEmpty) - if (projectList.forall(_.deterministic)) { - val newFirstChild = Project(projectList, children.head) + val (deterministicList, nonDeterministic) = projectList.partition(_.deterministic) + + if (deterministicList.nonEmpty) { + val newFirstChild = Project(deterministicList, children.head) val newOtherChildren = children.tail.map { child => val rewrites = buildRewrites(children.head, child) - Project(projectList.map(pushToRight(_, rewrites)), child) + Project(deterministicList.map(pushToRight(_, rewrites)), child) + } + val newUnion = Union(newFirstChild +: newOtherChildren) + if(nonDeterministic.nonEmpty) { + val newProjectList = projectList.collect { + case a: Alias if a.deterministic => a.toAttribute + case x => x + } + Project(newProjectList, newUnion) + } else { + newUnion } - Union(newFirstChild +: newOtherChildren) } else { p } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index aa8841109329c..63df91e496b21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -78,6 +78,58 @@ class SetOperationSuite extends PlanTest { comparePlans(unionOptimized, unionCorrectAnswer) } + test("SPARK-23356 union: expressions in project list are addition to each side") { + val unionQuery = testUnion.select(('a + 1).as("aa")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select(('a + 1).as("aa")) :: + testRelation2.select(('d + 1).as("aa")) :: + testRelation3.select(('g + 1).as("aa")) :: Nil).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-23356 union: expressions in project list are attribute addition to each side") { + val unionQuery = testUnion.select(('a + 'b).as("ab")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select(('a + 'b).as("ab")) :: + testRelation2.select(('d + 'e).as("ab")) :: + testRelation3.select(('g + 'h).as("ab")) :: Nil).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-23356 union: project to each side with non-deterministic expression") { + val unionQuery = testUnion.select('a, Rand(10).as("rnd")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select('a) :: + testRelation2.select('d) :: + testRelation3.select('g) :: Nil).select('a, Rand(10).as("rnd")).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-23356 union: project to each side with non-deterministic expression of alias") { + val unionQuery = testUnion.select('a.as("aa"), Rand(10).as("rnd")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select('a.as("aa")) :: + testRelation2.select('d.as("aa")) :: + testRelation3.select('g.as("aa")) :: Nil) + .select('aa, Rand(10).as("rnd")).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("SPARK-23356 union: with non-deterministic expression and addition expression") { + val unionQuery = testUnion.select(('a + 'b).as("ab"), Rand(10).as("rnd")) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Union(testRelation.select(('a + 'b).as("ab")) :: + testRelation2.select(('d + 'e).as("ab")) :: + testRelation3.select(('g + 'h).as("ab")) :: Nil) + .select('ab, Rand(10).as("rnd")).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + test("Remove unnecessary distincts in multiple unions") { val query1 = OneRowRelation() .select(Literal(1).as('a))