@@ -253,7 +253,6 @@ class Analyzer(override val catalogManager: CatalogManager)
253253 ResolveTables ::
254254 ResolvePartitionSpec ::
255255 AddMetadataColumns ::
256- DeduplicateRelations ::
257256 ResolveReferences ::
258257 ResolveCreateNamedStruct ::
259258 ResolveDeserializer ::
@@ -1374,30 +1373,124 @@ class Analyzer(override val catalogManager: CatalogManager)
13741373 * a logical plan node's children.
13751374 */
13761375 object ResolveReferences extends Rule [LogicalPlan ] {
1376+ /**
1377+ * Generate a new logical plan for the right child with different expression IDs
1378+ * for all conflicting attributes.
1379+ */
1380+ private def dedupRight (left : LogicalPlan , right : LogicalPlan ): LogicalPlan = {
1381+ val conflictingAttributes = left.outputSet.intersect(right.outputSet)
1382+ logDebug(s " Conflicting attributes ${conflictingAttributes.mkString(" ," )} " +
1383+ s " between $left and $right" )
1384+
1385+ /**
1386+ * For LogicalPlan likes MultiInstanceRelation, Project, Aggregate, etc, whose output doesn't
1387+ * inherit directly from its children, we could just stop collect on it. Because we could
1388+ * always replace all the lower conflict attributes with the new attributes from the new
1389+ * plan. Theoretically, we should do recursively collect for Generate and Window but we leave
1390+ * it to the next batch to reduce possible overhead because this should be a corner case.
1391+ */
1392+ def collectConflictPlans (plan : LogicalPlan ): Seq [(LogicalPlan , LogicalPlan )] = plan match {
1393+ // Handle base relations that might appear more than once.
1394+ case oldVersion : MultiInstanceRelation
1395+ if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1396+ val newVersion = oldVersion.newInstance()
1397+ newVersion.copyTagsFrom(oldVersion)
1398+ Seq ((oldVersion, newVersion))
1399+
1400+ case oldVersion : SerializeFromObject
1401+ if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1402+ Seq ((oldVersion, oldVersion.copy(
1403+ serializer = oldVersion.serializer.map(_.newInstance()))))
1404+
1405+ // Handle projects that create conflicting aliases.
1406+ case oldVersion @ Project (projectList, _)
1407+ if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
1408+ Seq ((oldVersion, oldVersion.copy(projectList = newAliases(projectList))))
1409+
1410+ // We don't need to search child plan recursively if the projectList of a Project
1411+ // is only composed of Alias and doesn't contain any conflicting attributes.
1412+ // Because, even if the child plan has some conflicting attributes, the attributes
1413+ // will be aliased to non-conflicting attributes by the Project at the end.
1414+ case _ @ Project (projectList, _)
1415+ if findAliases(projectList).size == projectList.size =>
1416+ Nil
1417+
1418+ case oldVersion @ Aggregate (_, aggregateExpressions, _)
1419+ if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
1420+ Seq ((oldVersion, oldVersion.copy(
1421+ aggregateExpressions = newAliases(aggregateExpressions))))
1422+
1423+ // We don't search the child plan recursively for the same reason as the above Project.
1424+ case _ @ Aggregate (_, aggregateExpressions, _)
1425+ if findAliases(aggregateExpressions).size == aggregateExpressions.size =>
1426+ Nil
1427+
1428+ case oldVersion @ FlatMapGroupsInPandas (_, _, output, _)
1429+ if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1430+ Seq ((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1431+
1432+ case oldVersion @ FlatMapCoGroupsInPandas (_, _, _, output, _, _)
1433+ if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1434+ Seq ((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1435+
1436+ case oldVersion @ MapInPandas (_, output, _)
1437+ if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1438+ Seq ((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1439+
1440+ case oldVersion : Generate
1441+ if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
1442+ val newOutput = oldVersion.generatorOutput.map(_.newInstance())
1443+ Seq ((oldVersion, oldVersion.copy(generatorOutput = newOutput)))
1444+
1445+ case oldVersion : Expand
1446+ if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
1447+ val producedAttributes = oldVersion.producedAttributes
1448+ val newOutput = oldVersion.output.map { attr =>
1449+ if (producedAttributes.contains(attr)) {
1450+ attr.newInstance()
1451+ } else {
1452+ attr
1453+ }
1454+ }
1455+ Seq ((oldVersion, oldVersion.copy(output = newOutput)))
1456+
1457+ case oldVersion @ Window (windowExpressions, _, _, child)
1458+ if AttributeSet (windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
1459+ .nonEmpty =>
1460+ Seq ((oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))))
1461+
1462+ case oldVersion @ ScriptTransformation (_, _, output, _, _)
1463+ if AttributeSet (output).intersect(conflictingAttributes).nonEmpty =>
1464+ Seq ((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1465+
1466+ case _ => plan.children.flatMap(collectConflictPlans)
1467+ }
1468+
1469+ val conflictPlans = collectConflictPlans(right)
13771470
1378- /** Return true if there're conflicting attributes among children's outputs of a plan */
1379- def hasConflictingAttrs (p : LogicalPlan ): Boolean = {
1380- p.children.length > 1 && {
1381- // Note that duplicated attributes are allowed within a single node,
1382- // e.g., df.select($"a", $"a"), so we should only check conflicting
1383- // attributes between nodes.
1384- val uniqueAttrs = mutable.HashSet [ExprId ]()
1385- p.children.head.outputSet.foreach(a => uniqueAttrs.add(a.exprId))
1386- p.children.tail.exists { child =>
1387- val uniqueSize = uniqueAttrs.size
1388- val childSize = child.outputSet.size
1389- child.outputSet.foreach(a => uniqueAttrs.add(a.exprId))
1390- uniqueSize + childSize > uniqueAttrs.size
1471+ /*
1472+ * Note that it's possible `conflictPlans` can be empty which implies that there
1473+ * is a logical plan node that produces new references that this rule cannot handle.
1474+ * When that is the case, there must be another rule that resolves these conflicts.
1475+ * Otherwise, the analysis will fail.
1476+ */
1477+ if (conflictPlans.isEmpty) {
1478+ right
1479+ } else {
1480+ val planMapping = conflictPlans.toMap
1481+ right.transformUpWithNewOutput {
1482+ case oldPlan =>
1483+ val newPlanOpt = planMapping.get(oldPlan)
1484+ newPlanOpt.map { newPlan =>
1485+ newPlan -> oldPlan.output.zip(newPlan.output)
1486+ }.getOrElse(oldPlan -> Nil )
13911487 }
13921488 }
13931489 }
13941490
13951491 def apply (plan : LogicalPlan ): LogicalPlan = plan.resolveOperatorsUp {
13961492 case p : LogicalPlan if ! p.childrenResolved => p
13971493
1398- // Wait for the rule `DeduplicateRelations` to resolve conflicting attrs first.
1399- case p : LogicalPlan if hasConflictingAttrs(p) => p
1400-
14011494 // If the projection list contains Stars, expand it.
14021495 case p : Project if containsStar(p.projectList) =>
14031496 p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
@@ -1419,12 +1512,37 @@ class Analyzer(override val catalogManager: CatalogManager)
14191512 case g : Generate if containsStar(g.generator.children) =>
14201513 throw QueryCompilationErrors .invalidStarUsageError(" explode/json_tuple/UDTF" )
14211514
1515+ // To resolve duplicate expression IDs for Join and Intersect
1516+ case j @ Join (left, right, _, _, _) if ! j.duplicateResolved =>
1517+ j.copy(right = dedupRight(left, right))
14221518 case f @ FlatMapCoGroupsInPandas (leftAttributes, rightAttributes, _, _, left, right) =>
14231519 val leftRes = leftAttributes
14241520 .map(x => resolveExpressionByPlanOutput(x, left).asInstanceOf [Attribute ])
14251521 val rightRes = rightAttributes
14261522 .map(x => resolveExpressionByPlanOutput(x, right).asInstanceOf [Attribute ])
14271523 f.copy(leftAttributes = leftRes, rightAttributes = rightRes)
1524+ // intersect/except will be rewritten to join at the beginning of optimizer. Here we need to
1525+ // deduplicate the right side plan, so that we won't produce an invalid self-join later.
1526+ case i @ Intersect (left, right, _) if ! i.duplicateResolved =>
1527+ i.copy(right = dedupRight(left, right))
1528+ case e @ Except (left, right, _) if ! e.duplicateResolved =>
1529+ e.copy(right = dedupRight(left, right))
1530+ // Only after we finish by-name resolution for Union
1531+ case u : Union if ! u.byName && ! u.duplicateResolved =>
1532+ // Use projection-based de-duplication for Union to avoid breaking the checkpoint sharing
1533+ // feature in streaming.
1534+ val newChildren = u.children.foldRight(Seq .empty[LogicalPlan ]) { (head, tail) =>
1535+ head +: tail.map {
1536+ case child if head.outputSet.intersect(child.outputSet).isEmpty =>
1537+ child
1538+ case child =>
1539+ val projectList = child.output.map { attr =>
1540+ Alias (attr, attr.name)()
1541+ }
1542+ Project (projectList, child)
1543+ }
1544+ }
1545+ u.copy(children = newChildren)
14281546
14291547 // When resolve `SortOrder`s in Sort based on child, don't report errors as
14301548 // we still have chance to resolve it based on its descendants
@@ -1478,6 +1596,9 @@ class Analyzer(override val catalogManager: CatalogManager)
14781596 // implementation and should be resolved based on the table schema.
14791597 o.copy(deleteExpr = resolveExpressionByPlanOutput(o.deleteExpr, o.table))
14801598
1599+ case m @ MergeIntoTable (targetTable, sourceTable, _, _, _) if ! m.duplicateResolved =>
1600+ m.copy(sourceTable = dedupRight(targetTable, sourceTable))
1601+
14811602 case m @ MergeIntoTable (targetTable, sourceTable, _, _, _)
14821603 if ! m.resolved && targetTable.resolved && sourceTable.resolved =>
14831604
@@ -1558,6 +1679,17 @@ class Analyzer(override val catalogManager: CatalogManager)
15581679 }
15591680 }
15601681
1682+ def newAliases (expressions : Seq [NamedExpression ]): Seq [NamedExpression ] = {
1683+ expressions.map {
1684+ case a : Alias => Alias (a.child, a.name)()
1685+ case other => other
1686+ }
1687+ }
1688+
1689+ def findAliases (projectList : Seq [NamedExpression ]): AttributeSet = {
1690+ AttributeSet (projectList.collect { case a : Alias => a.toAttribute })
1691+ }
1692+
15611693 // This method is used to trim groupByExpressions/selectedGroupByExpressions's top-level
15621694 // GetStructField Alias. Since these expression are not NamedExpression originally,
15631695 // we are safe to trim top-level GetStructField Alias.
0 commit comments