@@ -326,48 +326,29 @@ object TypeCoercion {
326326 *
327327 * This rule is only applied to Union/Except/Intersect
328328 */
329- object WidenSetOperationTypes extends Rule [LogicalPlan ] {
330-
331- def apply (plan : LogicalPlan ): LogicalPlan = {
332- val exprIdMapArray = mutable.ArrayBuffer [(ExprId , Attribute )]()
333- val newPlan = plan resolveOperatorsUp {
334- case s @ Except (left, right, isAll) if s.childrenResolved &&
335- left.output.length == right.output.length && ! s.resolved =>
336- val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(left :: right :: Nil )
337- exprIdMapArray ++= newExprIds
338- assert(newChildren.length == 2 )
339- Except (newChildren.head, newChildren.last, isAll)
340-
341- case s @ Intersect (left, right, isAll) if s.childrenResolved &&
342- left.output.length == right.output.length && ! s.resolved =>
343- val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(left :: right :: Nil )
344- exprIdMapArray ++= newExprIds
345- assert(newChildren.length == 2 )
346- Intersect (newChildren.head, newChildren.last, isAll)
347-
348- case s : Union if s.childrenResolved && ! s.byName &&
329+ object WidenSetOperationTypes extends TypeCoercionRule {
330+
331+ override def coerceTypes (plan : LogicalPlan ): LogicalPlan = plan resolveOperatorsUp {
332+ case s @ Except (left, right, isAll) if s.childrenResolved &&
333+ left.output.length == right.output.length && ! s.resolved =>
334+ val newChildren : Seq [LogicalPlan ] = buildNewChildrenWithWiderTypes(left :: right :: Nil )
335+ assert(newChildren.length == 2 )
336+ Except (newChildren.head, newChildren.last, isAll)
337+
338+ case s @ Intersect (left, right, isAll) if s.childrenResolved &&
339+ left.output.length == right.output.length && ! s.resolved =>
340+ val newChildren : Seq [LogicalPlan ] = buildNewChildrenWithWiderTypes(left :: right :: Nil )
341+ assert(newChildren.length == 2 )
342+ Intersect (newChildren.head, newChildren.last, isAll)
343+
344+ case s : Union if s.childrenResolved && ! s.byName &&
349345 s.children.forall(_.output.length == s.children.head.output.length) && ! s.resolved =>
350- val (newChildren, newExprIds) = buildNewChildrenWithWiderTypes(s.children)
351- exprIdMapArray ++= newExprIds
352- s.copy(children = newChildren)
353- }
354-
355- // Re-maps existing references to the new ones (exprId and dataType)
356- // for aliases added when widening columns' data types.
357- val exprIdMap = exprIdMapArray.toMap
358- newPlan resolveOperatorsUp {
359- case p if p.childrenResolved && p.missingInput.nonEmpty =>
360- p.mapExpressions { _.transform {
361- case a : AttributeReference if p.missingInput.contains(a) &&
362- exprIdMap.contains(a.exprId) => exprIdMap(a.exprId)
363- }
364- }
365- }
346+ val newChildren : Seq [LogicalPlan ] = buildNewChildrenWithWiderTypes(s.children)
347+ s.copy(children = newChildren)
366348 }
367349
368350 /** Build new children with the widest types for each attribute among all the children */
369- private def buildNewChildrenWithWiderTypes (children : Seq [LogicalPlan ])
370- : (Seq [LogicalPlan ], Seq [(ExprId , Attribute )]) = {
351+ private def buildNewChildrenWithWiderTypes (children : Seq [LogicalPlan ]): Seq [LogicalPlan ] = {
371352 require(children.forall(_.output.length == children.head.output.length))
372353
373354 // Get a sequence of data types, each of which is the widest type of this specific attribute
@@ -377,11 +358,10 @@ object TypeCoercion {
377358
378359 if (targetTypes.nonEmpty) {
379360 // Add an extra Project if the targetTypes are different from the original types.
380- val (newChildren, newExprIds) = children.map(widenTypes(_, targetTypes)).unzip
381- (newChildren, newExprIds.flatten)
361+ children.map(widenTypes(_, targetTypes))
382362 } else {
383363 // Unable to find a target type to widen, then just return the original set.
384- ( children, Nil )
364+ children
385365 }
386366 }
387367
@@ -405,16 +385,12 @@ object TypeCoercion {
405385 }
406386
407387 /** Given a plan, add an extra project on top to widen some columns' data types. */
408- private def widenTypes (plan : LogicalPlan , targetTypes : Seq [DataType ])
409- : (LogicalPlan , Seq [(ExprId , Attribute )]) = {
410- val (casted, newExprIds) = plan.output.zip(targetTypes).map {
411- case (e, dt) if e.dataType != dt =>
412- val alias = Alias (Cast (e, dt), e.name)()
413- (alias, Some (e.exprId -> alias.toAttribute))
414- case (e, _) =>
415- (e, None )
416- }.unzip
417- (Project (casted, plan), newExprIds.flatten)
388+ private def widenTypes (plan : LogicalPlan , targetTypes : Seq [DataType ]): LogicalPlan = {
389+ val casted = plan.output.zip(targetTypes).map {
390+ case (e, dt) if e.dataType != dt => Alias (Cast (e, dt), e.name)(exprId = e.exprId)
391+ case (e, _) => e
392+ }
393+ Project (casted, plan)
418394 }
419395 }
420396
0 commit comments