Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}

/**
Expand All @@ -25,20 +25,14 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition
trait AliasAwareOutputExpression extends UnaryExecNode {
protected def outputExpressions: Seq[NamedExpression]

protected def hasAlias: Boolean = outputExpressions.collectFirst { case _: Alias => }.isDefined
lazy val aliasMap = AttributeMap(outputExpressions.collect {
case a @ Alias(child: AttributeReference, _) => (child, a.toAttribute)
})

protected def replaceAliases(exprs: Seq[Expression]): Seq[Expression] = {
exprs.map {
case a: AttributeReference => replaceAlias(a).getOrElse(a)
case other => other
}
}
protected def hasAlias: Boolean = aliasMap.nonEmpty

protected def replaceAlias(attr: AttributeReference): Option[Attribute] = {
outputExpressions.collectFirst {
case a @ Alias(child: AttributeReference, _) if child.semanticEquals(attr) =>
a.toAttribute
}
aliasMap.get(attr)
}
}

Expand All @@ -48,13 +42,13 @@ trait AliasAwareOutputExpression extends UnaryExecNode {
*/
trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression {
final override def outputPartitioning: Partitioning = {
if (hasAlias) {
child.outputPartitioning match {
case h: HashPartitioning => h.copy(expressions = replaceAliases(h.expressions))
case other => other
}
} else {
child.outputPartitioning
child.outputPartitioning match {
case e: Expression if hasAlias =>
val normalizedExp = e.transformDown {
case attr: AttributeReference => replaceAlias(attr).getOrElse(attr)
}
normalizedExp.asInstanceOf[Partitioning]
case other => other
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,73 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
}
}

test("No extra exchanges in case of [Inner Join -> Project with aliases -> Inner join]") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
withTempView("t1", "t2", "t3") {
spark.range(10).repartition($"id").createTempView("t1")
spark.range(20).repartition($"id").createTempView("t2")
spark.range(30).repartition($"id").createTempView("t3")
val planned = sql(
"""
|SELECT t2id, t3.id as t3id
|FROM (
| SELECT t1.id as t1id, t2.id as t2id
| FROM t1, t2
| WHERE t1.id = t2.id
|) t12, t3
|WHERE t1id = t3.id
""".stripMargin).queryExecution.executedPlan
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
assert(exchanges.size == 3)
}
}
}
}

test("No extra exchanges in case of [LeftSemi Join -> Project with aliases -> Inner join]") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("t1", "t2", "t3") {
spark.range(10).repartition($"id").createTempView("t1")
spark.range(20).repartition($"id").createTempView("t2")
spark.range(30).repartition($"id").createTempView("t3")
val planned = sql(
"""
|SELECT t1id, t3.id as t3id
|FROM (
| SELECT t1.id as t1id
| FROM t1 LEFT SEMI JOIN t2
| ON t1.id = t2.id
|) t12 INNER JOIN t3
|WHERE t1id = t3.id
""".stripMargin).queryExecution.executedPlan
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
assert(exchanges.size == 3)
}
}
}

test("No extra exchanges in case of [Inner Join -> Project with aliases -> HashAggregate]") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("t1", "t2") {
spark.range(10).repartition($"id").createTempView("t1")
spark.range(20).repartition($"id").createTempView("t2")
val planned = sql(
"""
|SELECT t1id, t2id
|FROM (
| SELECT t1.id as t1id, t2.id as t2id
| FROM t1 INNER JOIN t2
| WHERE t1.id = t2.id
|) t12
|GROUP BY t1id, t2id
""".stripMargin).queryExecution.executedPlan
val exchanges = planned.collect { case s: ShuffleExchangeExec => s }
assert(exchanges.size == 2)
}
}
}

test("aliases to expressions should not be replaced") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("df1", "df2") {
Expand Down