Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,7 @@ class Column(val expr: Expression) extends Logging {
* @since 2.0.0
*/
def name(alias: String): Column = withExpr {
Alias(normalizedExpr(), alias)()
Alias(expr, alias)()
}

/**
Expand Down
36 changes: 21 additions & 15 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,16 @@ class Dataset[T] private[sql](
private[sql] def resolve(colName: String): NamedExpression = {
val resolver = sparkSession.sessionState.analyzer.resolver
queryExecution.analyzed.resolveQuoted(colName, resolver)
.getOrElse {
val fields = schema.fieldNames
val extraMsg = if (fields.exists(resolver(_, colName))) {
s"; did you mean to quote the `$colName` column?"
} else ""
val fieldsStr = fields.mkString(", ")
val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}"""
throw new AnalysisException(errorMsg)
}
.getOrElse(throw resolveException(colName, schema.fieldNames))
}

private def resolveException(colName: String, fields: Array[String]): AnalysisException = {
val extraMsg = if (fields.exists(sparkSession.sessionState.analyzer.resolver(_, colName))) {
s"; did you mean to quote the `$colName` column?"
} else ""
val fieldsStr = fields.mkString(", ")
val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}"""
new AnalysisException(errorMsg)
}

private[sql] def numericColumns: Seq[Expression] = {
Expand Down Expand Up @@ -1083,26 +1084,31 @@ class Dataset[T] private[sql](
}

// If left/right have no output set intersection, return the plan.
val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed
val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed
val lanalyzed = this.queryExecution.analyzed
val ranalyzed = right.queryExecution.analyzed
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
return withPlan(plan)
}

// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
// By the time we get here, since we have already run analysis, all attributes should've been
// resolved and become AttributeReference.
val resolver = sparkSession.sessionState.analyzer.resolver
val cond = plan.condition.map { _.transform {
case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference)
if a.sameRef(b) =>
catalyst.expressions.EqualTo(
withPlan(plan.left).resolve(a.name),
withPlan(plan.right).resolve(b.name))
plan.left.resolveQuoted(a.name, resolver)
.getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)),
plan.right.resolveQuoted(b.name, resolver)
.getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames)))
case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference)
if a.sameRef(b) =>
catalyst.expressions.EqualNullSafe(
withPlan(plan.left).resolve(a.name),
withPlan(plan.right).resolve(b.name))
plan.left.resolveQuoted(a.name, resolver)
.getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)),
plan.right.resolveQuoted(b.name, resolver)
.getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames)))
}}

withPlan {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{count, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.TestData

class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
import testImplicits._
Expand Down Expand Up @@ -219,4 +220,20 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple))
}
}

test("SPARK-33536: Avoid changing dataset_id of LogicalPlan in join() " +
"to not break DetectAmbiguousSelfJoin") {
val emp1 = Seq[TestData](
TestData(1, "sales"),
TestData(2, "personnel"),
TestData(3, "develop"),
TestData(4, "IT")).toDS()
val emp2 = Seq[TestData](
TestData(1, "sales"),
TestData(2, "personnel"),
TestData(3, "develop")).toDS()
val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*"))
assertAmbiguousSelfJoin(emp1.join(emp3, emp1.col("key") === emp3.col("key"),
"left_outer").select(emp1.col("*"), emp3.col("key").as("e2")))
}
}