Skip to content

Commit f5ab9cb

Browse files
author
Davies Liu
committed
fix resulting columns of outer join
1 parent bc1ff9f commit f5ab9cb

2 files changed

Lines changed: 36 additions & 9 deletions

File tree

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.analysis._
3636
import org.apache.spark.sql.catalyst.expressions._
3737
import org.apache.spark.sql.catalyst.expressions.aggregate._
3838
import org.apache.spark.sql.catalyst.plans.logical._
39-
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
39+
import org.apache.spark.sql.catalyst.plans._
4040
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
4141
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution}
4242
import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
@@ -499,10 +499,8 @@ class DataFrame private[sql](
499499
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
500500
// by creating a new instance for one of the branch.
501501
val joined = sqlContext.executePlan(
502-
Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join]
502+
Join(logicalPlan, right.logicalPlan, JoinType(joinType), None)).analyzed.asInstanceOf[Join]
503503

504-
// Project only one of the join columns.
505-
val joinedCols = usingColumns.map(col => withPlan(joined.right).resolve(col))
506504
val condition = usingColumns.map { col =>
507505
catalyst.expressions.EqualTo(
508506
withPlan(joined.left).resolve(col),
@@ -511,9 +509,26 @@ class DataFrame private[sql](
511509
catalyst.expressions.And(cond, eqTo)
512510
}
513511

512+
// Project only one of the join columns.
513+
val joinedCols = JoinType(joinType) match {
514+
case Inner | LeftOuter | LeftSemi =>
515+
usingColumns.map(col => withPlan(joined.left).resolve(col))
516+
case RightOuter =>
517+
usingColumns.map(col => withPlan(joined.right).resolve(col))
518+
case FullOuter =>
519+
usingColumns.map { col =>
520+
val leftCol = withPlan(joined.left).resolve(col)
521+
val rightCol = withPlan(joined.right).resolve(col)
522+
Alias(Coalesce(Seq(leftCol, rightCol)), col)()
523+
}
524+
}
525+
// The nullability of output of joined could be different than original column,
526+
// so we can only compare them by exprId
527+
val joinRefs = condition.map(_.references.toSeq.map(_.exprId)).getOrElse(Nil)
528+
val resultCols = joinedCols ++ joined.output.filterNot(e => joinRefs.contains(e.exprId))
514529
withPlan {
515530
Project(
516-
joined.output.filterNot(joinedCols.contains(_)),
531+
resultCols,
517532
Join(
518533
joined.left,
519534
joined.right,

sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,28 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
4343
}
4444

4545
test("join - join using multiple columns and specifying join type") {
46-
val df = Seq(1, 2, 3).map(i => (i, i + 1, i.toString)).toDF("int", "int2", "str")
47-
val df2 = Seq(1, 2, 3).map(i => (i, i + 1, (i + 1).toString)).toDF("int", "int2", "str")
46+
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str")
47+
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str")
48+
49+
checkAnswer(
50+
df.join(df2, Seq("int", "str"), "inner"),
51+
Row(1, "1", 2, 3) :: Nil)
4852

4953
checkAnswer(
5054
df.join(df2, Seq("int", "str"), "left"),
51-
Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil)
55+
Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Nil)
5256

5357
checkAnswer(
5458
df.join(df2, Seq("int", "str"), "right"),
55-
Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil)
59+
Row(1, "1", 2, 3) :: Row(5, "5", null, 6) :: Nil)
60+
61+
checkAnswer(
62+
df.join(df2, Seq("int", "str"), "outer"),
63+
Row(1, "1", 2, 3) :: Row(3, "3", 4, null) :: Row(5, "5", null, 6) :: Nil)
64+
65+
checkAnswer(
66+
df.join(df2, Seq("int", "str"), "left_semi"),
67+
Row(1, "1", 2) :: Nil)
5668
}
5769

5870
test("join - join using self join") {

0 commit comments

Comments
 (0)