Skip to content

Commit 2340afe

Browse files
committed
review
1 parent b0b5531 commit 2340afe

File tree

3 files changed

+39
-58
lines changed

3 files changed

+39
-58
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -326,48 +326,29 @@ object TypeCoercion {
326326
*
327327
* This rule is only applied to Union/Except/Intersect
328328
*/
329-
object WidenSetOperationTypes extends Rule[LogicalPlan] {
330-
331-
def apply(plan: LogicalPlan): LogicalPlan = {
332-
val exprIdMapArray = mutable.ArrayBuffer[(ExprId, Attribute)]()
333-
val newPlan = plan resolveOperatorsUp {
334-
case s @ Except(left, right, isAll) if s.childrenResolved &&
335-
left.output.length == right.output.length && !s.resolved =>
336-
val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(left :: right :: Nil)
337-
exprIdMapArray ++= newExprIds
338-
assert(newChildren.length == 2)
339-
Except(newChildren.head, newChildren.last, isAll)
340-
341-
case s @ Intersect(left, right, isAll) if s.childrenResolved &&
342-
left.output.length == right.output.length && !s.resolved =>
343-
val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(left :: right :: Nil)
344-
exprIdMapArray ++= newExprIds
345-
assert(newChildren.length == 2)
346-
Intersect(newChildren.head, newChildren.last, isAll)
347-
348-
case s: Union if s.childrenResolved && !s.byName &&
329+
object WidenSetOperationTypes extends TypeCoercionRule {
330+
331+
override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
332+
case s @ Except(left, right, isAll) if s.childrenResolved &&
333+
left.output.length == right.output.length && !s.resolved =>
334+
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
335+
assert(newChildren.length == 2)
336+
Except(newChildren.head, newChildren.last, isAll)
337+
338+
case s @ Intersect(left, right, isAll) if s.childrenResolved &&
339+
left.output.length == right.output.length && !s.resolved =>
340+
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
341+
assert(newChildren.length == 2)
342+
Intersect(newChildren.head, newChildren.last, isAll)
343+
344+
case s: Union if s.childrenResolved && !s.byName &&
349345
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
350-
val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(s.children)
351-
exprIdMapArray ++= newExprIds
352-
s.copy(children = newChildren)
353-
}
354-
355-
// Re-maps existing references to the new ones (exprId and dataType)
356-
// for aliases added when widening columns' data types.
357-
val exprIdMap = exprIdMapArray.toMap
358-
newPlan resolveOperatorsUp {
359-
case p if p.childrenResolved && p.missingInput.nonEmpty =>
360-
p.mapExpressions { _.transform {
361-
case a: AttributeReference if p.missingInput.contains(a) &&
362-
exprIdMap.contains(a.exprId) => exprIdMap(a.exprId)
363-
}
364-
}
365-
}
346+
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
347+
s.copy(children = newChildren)
366348
}
367349

368350
/** Build new children with the widest types for each attribute among all the children */
369-
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan])
370-
: (Seq[LogicalPlan], Seq[(ExprId, Attribute)]) = {
351+
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
371352
require(children.forall(_.output.length == children.head.output.length))
372353

373354
// Get a sequence of data types, each of which is the widest type of this specific attribute
@@ -377,11 +358,10 @@ object TypeCoercion {
377358

378359
if (targetTypes.nonEmpty) {
379360
// Add an extra Project if the targetTypes are different from the original types.
380-
val (newChildren, newExprIds) = children.map(widenTypes(_, targetTypes)).unzip
381-
(newChildren, newExprIds.flatten)
361+
children.map(widenTypes(_, targetTypes))
382362
} else {
383363
// Unable to find a target type to widen, then just return the original set.
384-
(children, Nil)
364+
children
385365
}
386366
}
387367

@@ -405,16 +385,12 @@ object TypeCoercion {
405385
}
406386

407387
/** Given a plan, add an extra project on top to widen some columns' data types. */
408-
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType])
409-
: (LogicalPlan, Seq[(ExprId, Attribute)]) = {
410-
val (casted, newExprIds) = plan.output.zip(targetTypes).map {
411-
case (e, dt) if e.dataType != dt =>
412-
val alias = Alias(Cast(e, dt), e.name)()
413-
(alias, Some(e.exprId -> alias.toAttribute))
414-
case (e, _) =>
415-
(e, None)
416-
}.unzip
417-
(Project(casted, plan), newExprIds.flatten)
388+
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
389+
val casted = plan.output.zip(targetTypes).map {
390+
case (e, dt) if e.dataType != dt => Alias(Cast(e, dt), e.name)(exprId = e.exprId)
391+
case (e, _) => e
392+
}
393+
Project(casted, plan)
418394
}
419395
}
420396

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
477477
object RemoveNoopOperators extends Rule[LogicalPlan] {
478478
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
479479
// Eliminate no-op Projects
480-
case p @ Project(_, child) if child.sameOutput(p) => child
480+
case Project(projList, child) if projList.length == child.output.length &&
481+
projList.zip(child.output).forall { case (e1, e2) => e1.semanticEquals(e2) } => child
481482

482483
// Eliminate no-op Window
483484
case w: Window if w.windowExpressions.isEmpty => w.child

sql/core/src/main/scala/org/apache/spark/sql/execution/RemoveRedundantProjects.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import org.apache.spark.sql.catalyst.expressions.Attribute
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
2121
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, PartialMerge}
2222
import org.apache.spark.sql.catalyst.rules.Rule
2323
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
2424
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase
25-
import org.apache.spark.sql.execution.window.WindowExec
2625
import org.apache.spark.sql.internal.SQLConf
2726

2827
/**
@@ -85,14 +84,19 @@ case class RemoveRedundantProjects(conf: SQLConf) extends Rule[SparkPlan] {
8584
// to convert the rows to UnsafeRow. See DataSourceV2Strategy for more details.
8685
case d: DataSourceV2ScanExecBase if !d.supportsColumnar => false
8786
case _ =>
87+
def semanticEquals(exprs1: Seq[Expression], exprs2: Seq[Expression]): Boolean = {
88+
exprs1.length == exprs2.length && exprs1.zip(exprs2).forall {
89+
case (e1, e2) => e1.semanticEquals(e2)
90+
}
91+
}
8892
if (requireOrdering) {
89-
project.output.map(_.exprId.id) == child.output.map(_.exprId.id) &&
93+
semanticEquals(project.projectList, child.output) &&
9094
checkNullability(project.output, child.output)
9195
} else {
92-
val orderedProjectOutput = project.output.sortBy(_.exprId.id)
96+
val orderedProjectList = project.projectList.sortBy(_.exprId.id)
9397
val orderedChildOutput = child.output.sortBy(_.exprId.id)
94-
orderedProjectOutput.map(_.exprId.id) == orderedChildOutput.map(_.exprId.id) &&
95-
checkNullability(orderedProjectOutput, orderedChildOutput)
98+
semanticEquals(orderedProjectList, orderedChildOutput) &&
99+
checkNullability(orderedProjectList.map(_.toAttribute), orderedChildOutput)
96100
}
97101
}
98102
}

0 commit comments

Comments
 (0)