Skip to content

Commit a2d72ef

Browse files
committed
Address comments and merge conflicts
Signed-off-by: Karen Feng <karen.feng@databricks.com>
2 parents ce3ac0e + 1b553da commit a2d72ef

226 files changed

Lines changed: 19416 additions & 18834 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/sql-migration-guide.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ license: |
7171

7272
- In Spark 3.2, the dates subtraction expression such as `date1 - date2` returns values of `DayTimeIntervalType`. In Spark 3.1 and earlier, the returned type is `CalendarIntervalType`. To restore the behavior before Spark 3.2, you can set `spark.sql.legacy.interval.enabled` to `true`.
7373

74+
- In Spark 3.2, the timestamps subtraction expression such as `timestamp '2021-03-31 23:48:00' - timestamp '2021-01-01 00:00:00'` returns values of `DayTimeIntervalType`. In Spark 3.1 and earlier, the type of the same expression is `CalendarIntervalType`. To restore the behavior before Spark 3.2, you can set `spark.sql.legacy.interval.enabled` to `true`.
75+
7476
## Upgrading from Spark SQL 3.0 to 3.1
7577

7678
- In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 149 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ class Analyzer(override val catalogManager: CatalogManager)
253253
ResolveTables ::
254254
ResolvePartitionSpec ::
255255
AddMetadataColumns ::
256-
DeduplicateRelations ::
257256
ResolveReferences ::
258257
ResolveCreateNamedStruct ::
259258
ResolveDeserializer ::
@@ -1374,30 +1373,124 @@ class Analyzer(override val catalogManager: CatalogManager)
13741373
* a logical plan node's children.
13751374
*/
13761375
object ResolveReferences extends Rule[LogicalPlan] {
1376+
/**
1377+
* Generate a new logical plan for the right child with different expression IDs
1378+
* for all conflicting attributes.
1379+
*/
1380+
private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
1381+
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
1382+
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
1383+
s"between $left and $right")
1384+
1385+
/**
1386+
* For LogicalPlan likes MultiInstanceRelation, Project, Aggregate, etc, whose output doesn't
1387+
* inherit directly from its children, we could just stop collect on it. Because we could
1388+
* always replace all the lower conflict attributes with the new attributes from the new
1389+
* plan. Theoretically, we should do recursively collect for Generate and Window but we leave
1390+
* it to the next batch to reduce possible overhead because this should be a corner case.
1391+
*/
1392+
def collectConflictPlans(plan: LogicalPlan): Seq[(LogicalPlan, LogicalPlan)] = plan match {
1393+
// Handle base relations that might appear more than once.
1394+
case oldVersion: MultiInstanceRelation
1395+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1396+
val newVersion = oldVersion.newInstance()
1397+
newVersion.copyTagsFrom(oldVersion)
1398+
Seq((oldVersion, newVersion))
1399+
1400+
case oldVersion: SerializeFromObject
1401+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1402+
Seq((oldVersion, oldVersion.copy(
1403+
serializer = oldVersion.serializer.map(_.newInstance()))))
1404+
1405+
// Handle projects that create conflicting aliases.
1406+
case oldVersion @ Project(projectList, _)
1407+
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
1408+
Seq((oldVersion, oldVersion.copy(projectList = newAliases(projectList))))
1409+
1410+
// We don't need to search child plan recursively if the projectList of a Project
1411+
// is only composed of Alias and doesn't contain any conflicting attributes.
1412+
// Because, even if the child plan has some conflicting attributes, the attributes
1413+
// will be aliased to non-conflicting attributes by the Project at the end.
1414+
case _ @ Project(projectList, _)
1415+
if findAliases(projectList).size == projectList.size =>
1416+
Nil
1417+
1418+
case oldVersion @ Aggregate(_, aggregateExpressions, _)
1419+
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
1420+
Seq((oldVersion, oldVersion.copy(
1421+
aggregateExpressions = newAliases(aggregateExpressions))))
1422+
1423+
// We don't search the child plan recursively for the same reason as the above Project.
1424+
case _ @ Aggregate(_, aggregateExpressions, _)
1425+
if findAliases(aggregateExpressions).size == aggregateExpressions.size =>
1426+
Nil
1427+
1428+
case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
1429+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1430+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1431+
1432+
case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
1433+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1434+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1435+
1436+
case oldVersion @ MapInPandas(_, output, _)
1437+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1438+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1439+
1440+
case oldVersion: Generate
1441+
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
1442+
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
1443+
Seq((oldVersion, oldVersion.copy(generatorOutput = newOutput)))
1444+
1445+
case oldVersion: Expand
1446+
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
1447+
val producedAttributes = oldVersion.producedAttributes
1448+
val newOutput = oldVersion.output.map { attr =>
1449+
if (producedAttributes.contains(attr)) {
1450+
attr.newInstance()
1451+
} else {
1452+
attr
1453+
}
1454+
}
1455+
Seq((oldVersion, oldVersion.copy(output = newOutput)))
1456+
1457+
case oldVersion @ Window(windowExpressions, _, _, child)
1458+
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
1459+
.nonEmpty =>
1460+
Seq((oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))))
1461+
1462+
case oldVersion @ ScriptTransformation(_, _, output, _, _)
1463+
if AttributeSet(output).intersect(conflictingAttributes).nonEmpty =>
1464+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1465+
1466+
case _ => plan.children.flatMap(collectConflictPlans)
1467+
}
1468+
1469+
val conflictPlans = collectConflictPlans(right)
13771470

1378-
/** Return true if there're conflicting attributes among children's outputs of a plan */
1379-
def hasConflictingAttrs(p: LogicalPlan): Boolean = {
1380-
p.children.length > 1 && {
1381-
// Note that duplicated attributes are allowed within a single node,
1382-
// e.g., df.select($"a", $"a"), so we should only check conflicting
1383-
// attributes between nodes.
1384-
val uniqueAttrs = mutable.HashSet[ExprId]()
1385-
p.children.head.outputSet.foreach(a => uniqueAttrs.add(a.exprId))
1386-
p.children.tail.exists { child =>
1387-
val uniqueSize = uniqueAttrs.size
1388-
val childSize = child.outputSet.size
1389-
child.outputSet.foreach(a => uniqueAttrs.add(a.exprId))
1390-
uniqueSize + childSize > uniqueAttrs.size
1471+
/*
1472+
* Note that it's possible `conflictPlans` can be empty which implies that there
1473+
* is a logical plan node that produces new references that this rule cannot handle.
1474+
* When that is the case, there must be another rule that resolves these conflicts.
1475+
* Otherwise, the analysis will fail.
1476+
*/
1477+
if (conflictPlans.isEmpty) {
1478+
right
1479+
} else {
1480+
val planMapping = conflictPlans.toMap
1481+
right.transformUpWithNewOutput {
1482+
case oldPlan =>
1483+
val newPlanOpt = planMapping.get(oldPlan)
1484+
newPlanOpt.map { newPlan =>
1485+
newPlan -> oldPlan.output.zip(newPlan.output)
1486+
}.getOrElse(oldPlan -> Nil)
13911487
}
13921488
}
13931489
}
13941490

13951491
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
13961492
case p: LogicalPlan if !p.childrenResolved => p
13971493

1398-
// Wait for the rule `DeduplicateRelations` to resolve conflicting attrs first.
1399-
case p: LogicalPlan if hasConflictingAttrs(p) => p
1400-
14011494
// If the projection list contains Stars, expand it.
14021495
case p: Project if containsStar(p.projectList) =>
14031496
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
@@ -1419,12 +1512,37 @@ class Analyzer(override val catalogManager: CatalogManager)
14191512
case g: Generate if containsStar(g.generator.children) =>
14201513
throw QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF")
14211514

1515+
// To resolve duplicate expression IDs for Join and Intersect
1516+
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
1517+
j.copy(right = dedupRight(left, right))
14221518
case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) =>
14231519
val leftRes = leftAttributes
14241520
.map(x => resolveExpressionByPlanOutput(x, left).asInstanceOf[Attribute])
14251521
val rightRes = rightAttributes
14261522
.map(x => resolveExpressionByPlanOutput(x, right).asInstanceOf[Attribute])
14271523
f.copy(leftAttributes = leftRes, rightAttributes = rightRes)
1524+
// intersect/except will be rewritten to join at the beginning of optimizer. Here we need to
1525+
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
1526+
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
1527+
i.copy(right = dedupRight(left, right))
1528+
case e @ Except(left, right, _) if !e.duplicateResolved =>
1529+
e.copy(right = dedupRight(left, right))
1530+
// Only after we finish by-name resolution for Union
1531+
case u: Union if !u.byName && !u.duplicateResolved =>
1532+
// Use projection-based de-duplication for Union to avoid breaking the checkpoint sharing
1533+
// feature in streaming.
1534+
val newChildren = u.children.foldRight(Seq.empty[LogicalPlan]) { (head, tail) =>
1535+
head +: tail.map {
1536+
case child if head.outputSet.intersect(child.outputSet).isEmpty =>
1537+
child
1538+
case child =>
1539+
val projectList = child.output.map { attr =>
1540+
Alias(attr, attr.name)()
1541+
}
1542+
Project(projectList, child)
1543+
}
1544+
}
1545+
u.copy(children = newChildren)
14281546

14291547
// When resolve `SortOrder`s in Sort based on child, don't report errors as
14301548
// we still have chance to resolve it based on its descendants
@@ -1478,6 +1596,9 @@ class Analyzer(override val catalogManager: CatalogManager)
14781596
// implementation and should be resolved based on the table schema.
14791597
o.copy(deleteExpr = resolveExpressionByPlanOutput(o.deleteExpr, o.table))
14801598

1599+
case m @ MergeIntoTable(targetTable, sourceTable, _, _, _) if !m.duplicateResolved =>
1600+
m.copy(sourceTable = dedupRight(targetTable, sourceTable))
1601+
14811602
case m @ MergeIntoTable(targetTable, sourceTable, _, _, _)
14821603
if !m.resolved && targetTable.resolved && sourceTable.resolved =>
14831604

@@ -1558,6 +1679,17 @@ class Analyzer(override val catalogManager: CatalogManager)
15581679
}
15591680
}
15601681

1682+
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
1683+
expressions.map {
1684+
case a: Alias => Alias(a.child, a.name)()
1685+
case other => other
1686+
}
1687+
}
1688+
1689+
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
1690+
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
1691+
}
1692+
15611693
// This method is used to trim groupByExpressions/selectedGroupByExpressions's top-level
15621694
// GetStructField Alias. Since these expression are not NamedExpression originally,
15631695
// we are safe to trim top-level GetStructField Alias.

0 commit comments

Comments
 (0)