@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2222import org .apache .spark .sql .catalyst .expressions ._
2323import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , ExprCode , ExpressionCanonicalizer }
2424import org .apache .spark .sql .catalyst .plans .physical ._
25- import org .apache .spark .sql .execution .metric .{ LongSQLMetricValue , SQLMetrics }
25+ import org .apache .spark .sql .execution .metric .SQLMetrics
2626import org .apache .spark .sql .types .LongType
2727import org .apache .spark .util .random .PoissonSampler
2828
@@ -79,16 +79,20 @@ case class Filter(condition: Expression, child: SparkPlan)
7979
8080 // Split out all the IsNotNulls from condition.
8181 private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
82- case IsNotNull (a) if child.output.contains(a ) => true
82+ case IsNotNull (a) if child.output.exists(_.semanticEquals(a) ) => true
8383 case _ => false
8484 }
8585
8686 // The columns that will filtered out by `IsNotNull` could be considered as not nullable.
8787 private val notNullAttributes = notNullPreds.flatMap(_.references)
8888
89+ // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
90+ // all the variables at the beginning to take advantage of short circuiting.
91+ override def usedInputs : AttributeSet = AttributeSet .empty
92+
8993 override def output : Seq [Attribute ] = {
9094 child.output.map { a =>
91- if (a.nullable && notNullAttributes.contains(a )) {
95+ if (a.nullable && notNullAttributes.exists(_.semanticEquals(a) )) {
9296 a.withNullability(false )
9397 } else {
9498 a
@@ -110,39 +114,81 @@ case class Filter(condition: Expression, child: SparkPlan)
110114 override def doConsume (ctx : CodegenContext , input : Seq [ExprCode ], row : String ): String = {
111115 val numOutput = metricTerm(ctx, " numOutputRows" )
112116
113- // filter out the nulls
114- val filterOutNull = notNullAttributes.map { a =>
115- val idx = child.output.indexOf(a)
116- s " if ( ${input(idx).isNull}) continue; "
117- }.mkString(" \n " )
117+ /**
118+ * Generates code for `c`, using `in` for input attributes and `attrs` for nullability.
119+ */
120+ def genPredicate (c : Expression , in : Seq [ExprCode ], attrs : Seq [Attribute ]): String = {
121+ val bound = BindReferences .bindReference(c, attrs)
122+ val evaluated = evaluateRequiredVariables(child.output, in, c.references)
118123
119- ctx.currentVars = input
120- val predicates = otherPreds.map { e =>
121- val bound = ExpressionCanonicalizer .execute(
122- BindReferences .bindReference(e, output))
123- val ev = bound.gen(ctx)
124+ // Generate the code for the predicate.
125+ val ev = ExpressionCanonicalizer .execute(bound).gen(ctx)
124126 val nullCheck = if (bound.nullable) {
125127 s " ${ev.isNull} || "
126128 } else {
127129 s " "
128130 }
131+
129132 s """
133+ | $evaluated
130134 | ${ev.code}
131135 |if ( ${nullCheck}! ${ev.value}) continue;
132136 """ .stripMargin
137+ }
138+
139+ ctx.currentVars = input
140+
141+ // To generate the predicates we will follow this algorithm.
142+ // For each predicate that is not IsNotNull, we will generate them one by one loading attributes
143+ // as necessary. For each of both attributes, if there is a IsNotNull predicate we will generate
144+ // that check *before* the predicate. After all of these predicates, we will generate the
145+ // remaining IsNotNull checks that were not part of other predicates.
146+ // This has the property of not doing redundant IsNotNull checks and taking better advantage of
147+ // short-circuiting, not loading attributes until they are needed.
148+ // This is very perf sensitive.
149+ // TODO: revisit this. We can consider reodering predicates as well.
150+ val generatedIsNotNullChecks = new Array [Boolean ](notNullPreds.length)
151+ val generated = otherPreds.map { c =>
152+ val nullChecks = c.references.map { r =>
153+ val idx = notNullPreds.indexWhere { n => n.asInstanceOf [IsNotNull ].child.semanticEquals(r)}
154+ if (idx != - 1 && ! generatedIsNotNullChecks(idx)) {
155+ // Use the child's output. The nullability is what the child produced.
156+ val code = genPredicate(notNullPreds(idx), input, child.output)
157+ generatedIsNotNullChecks(idx) = true
158+ code
159+ } else {
160+ " "
161+ }
162+ }.mkString(" \n " ).trim
163+
164+ // Here we use *this* operator's output with this output's nullability since we already
165+ // enforced them with the IsNotNull checks above.
166+ s """
167+ | $nullChecks
168+ | ${genPredicate(c, input, output)}
169+ """ .stripMargin.trim
170+ }.mkString(" \n " )
171+
172+ val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) =>
173+ if (! generatedIsNotNullChecks(idx)) {
174+ genPredicate(c, input, child.output)
175+ } else {
176+ " "
177+ }
133178 }.mkString(" \n " )
134179
135180 // Reset the isNull to false for the not-null columns, then the followed operators could
136- // generate better code (remove dead branches).
181+ // generate better code (remove dead branches). O
137182 val resultVars = input.zipWithIndex.map { case (ev, i) =>
138- if (notNullAttributes.contains( child.output(i))) {
183+ if (notNullAttributes.exists(_.semanticEquals( child.output(i) ))) {
139184 ev.isNull = " false"
140185 }
141186 ev
142187 }
188+
143189 s """
144- | $filterOutNull
145- | $predicates
190+ | $generated
191+ | $nullChecks
146192 | $numOutput.add(1);
147193 | ${consume(ctx, resultVars)}
148194 """ .stripMargin
0 commit comments