Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ case class FilterExec(condition: Expression, child: SparkPlan)

// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
case IsNotNull(a: NullIntolerant) if a.references.subsetOf(child.outputSet) => true
case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
case _ => false
}

// One expression is null intolerant iff it and its children are null intolerant
private def isNullIntolerant(expr: Expression): Boolean = expr match {
case e: NullIntolerant =>
if (e.isInstanceOf[LeafExpression]) true else e.children.forall(isNullIntolerant)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

This change is too conservative. Actually we only need to consider a non NullIntolerant expression when it contains the attributes in the output. I think we can do more aggressive way. E.g.,

// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
  case IsNotNull(a: NullIntolerant) =>
    isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
  case _ => false
}

private def isNullIntolerant(expr: Expression): Boolean = {
  expr.find { e =>
    !e.isInstanceOf[NullIntolerant] && e.references.subsetOf(child.outputSet)
  }.isEmpty
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized the original code was from your PR. Then, in your above code, why you still need to keep a.references.subsetOf(child.outputSet)? It looks confusing to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even a passed the check of isNullIntolerant, i.e., it has not non NullIntolerant which wraps output attributes. If it doesn't refer to any output attributes, we don't need it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show me an example?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsNotNull(Rand() > 0.5)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, I see.

First, we definitely need test cases to cover each positive and negative scenario. Previously, we did not have any test case to check the validity of nullability changes. Second, the code needs more comments when the variable/function names are not able to explain the codes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about something like this for better readability:

  private def isNullIntolerant(expr: Expression): Boolean = expr match {
    case e: NullIntolerant with LeafExpression => true
    case e: NullIntolerant => e.children.forall(isNullIntolerant)
    case _ => false
  }

Copy link
Member Author

@gatorsmile gatorsmile Nov 3, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forall will return true if the children is empty. Thus, we can remove the first case. Now it becomes simpler. : )

  private def isNullIntolerant(expr: Expression): Boolean = expr match {
    case e: NullIntolerant => e.children.forall(isNullIntolerant)
    case _ => false
  }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

case _ => false
}

Expand Down
70 changes: 68 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import org.scalatest.Matchers._

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, Union}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union}
import org.apache.spark.sql.execution.{FilterExec, QueryExecution}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -1617,6 +1617,72 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}

private def verifyNullabilityInFilterExec(
df: DataFrame,
expr: String,
expectedNonNullableColumns: Seq[String]): Unit = {
val dfWithFilter = df.where(s"isnotnull($expr)").selectExpr(expr)
// In the logical plan, all the output columns of input dataframe are nullable
dfWithFilter.queryExecution.optimizedPlan.collect {
case e: Filter => assert(e.output.forall(_.nullable))
}

dfWithFilter.queryExecution.executedPlan.collect {
// When the child expression in isnotnull is null-intolerant (i.e. any null input will
// result in null output), the involved columns are converted to not nullable;
// otherwise, no change should be made.
case e: FilterExec =>
assert(e.output.forall { o =>
if (expectedNonNullableColumns.contains(o.name)) !o.nullable else o.nullable
})
}
}

test("SPARK-17957: no change on nullability in FilterExec output") {
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()

verifyNullabilityInFilterExec(df,
expr = "Rand()", expectedNonNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "coalesce(_1, _2)", expectedNonNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "coalesce(_1, 0) + Rand()", expectedNonNullableColumns = Seq.empty[String])
verifyNullabilityInFilterExec(df,
expr = "cast(coalesce(cast(coalesce(_1, _2) as double), 0.0) as int)",
expectedNonNullableColumns = Seq.empty[String])
}

test("SPARK-17957: set nullability to false in FilterExec output") {
val df = sparkContext.parallelize(Seq(
null.asInstanceOf[java.lang.Integer] -> new java.lang.Integer(3),
new java.lang.Integer(1) -> null.asInstanceOf[java.lang.Integer],
new java.lang.Integer(2) -> new java.lang.Integer(4))).toDF()

verifyNullabilityInFilterExec(df,
expr = "_1 + _2 * 3", expectedNonNullableColumns = Seq("_1", "_2"))
verifyNullabilityInFilterExec(df,
expr = "_1 + _2", expectedNonNullableColumns = Seq("_1", "_2"))
verifyNullabilityInFilterExec(df,
expr = "_1", expectedNonNullableColumns = Seq("_1"))
verifyNullabilityInFilterExec(df,
expr = "_2 + Rand()", expectedNonNullableColumns = Seq("_2"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not work in current approach. It works now because we infer redundant IsNotNull constraints. E.g., if Filter has a constraint IsNotNull(_2 + Rand()), we will infer another IsNotNull(_2) from it. Your approach is working on IsNotNull(_2) to decide _2 is non-nullable column, not IsNotNull(_2 + Rand()).

I submitted another PR #15653 for redundant IsNotNull constraints. But I am not sure if we want to fix it since it doesn't affect correctness. I left that to @cloud-fan or @hvanhovell to decide it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already explained why my current solution works in my previous statement. Personally, I like simple code, which is easy to understand and maintain, especially when it can cover all the cases. Result correctness and code maintainability are alwasy more important.

If constructIsNotNullConstraints is changed by somebody else (i.e., it does not provide the expected IsNotNull constraints), the test cases added by this PR will fail. Then, we can modify the codes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, so I said I will left that to @cloud-fan or others to decide...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the simplicity argument but can you please add a comment here explaining why this particular case is working due to the null inference rule?

verifyNullabilityInFilterExec(df,
expr = "_2 * 3 + coalesce(_1, 0)", expectedNonNullableColumns = Seq("_2"))
verifyNullabilityInFilterExec(df,
expr = "cast((_1 + _2) as boolean)", expectedNonNullableColumns = Seq("_1", "_2"))
}

test("SPARK-17957: outer join + na.fill") {
val df1 = Seq((1, 2), (2, 3)).toDF("a", "b")
val df2 = Seq((2, 5), (3, 4)).toDF("a", "c")
val joinedDf = df1.join(df2, Seq("a"), "outer").na.fill(0)
val df3 = Seq((3, 1)).toDF("a", "d")
checkAnswer(joinedDf.join(df3, "a"), Row(3, 0, 4, 1))
}

test("SPARK-17123: Performing set operations that combine non-scala native types") {
val dates = Seq(
(new Date(0), BigDecimal.valueOf(1), new Timestamp(2)),
Expand Down