Skip to content

Commit 4b1e311

Browse files
committed
Fix
1 parent ab981f1 commit 4b1e311

13 files changed

Lines changed: 65 additions & 70 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,14 @@ object MutableProjection
8989
}
9090

9191
/**
92-
* Returns an MutableProjection for given sequence of bound Expressions.
92+
* Returns a MutableProjection for given sequence of bound Expressions.
9393
*/
9494
def create(exprs: Seq[Expression]): MutableProjection = {
9595
createObject(exprs)
9696
}
9797

9898
/**
99-
* Returns an MutableProjection for given sequence of Expressions, which will be bound to
99+
* Returns a MutableProjection for given sequence of Expressions, which will be bound to
100100
* `inputSchema`.
101101
*/
102102
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,17 @@ package org.apache.spark.sql.catalyst.expressions.codegen
2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions._
2222

23-
/**
24-
* Interface for generated predicate
25-
*/
26-
abstract class Predicate {
27-
def eval(r: InternalRow): Boolean
28-
29-
/**
30-
* Initializes internal states given the current partition index.
31-
* This is used by nondeterministic expressions to set initial states.
32-
* The default implementation does nothing.
33-
*/
34-
def initialize(partitionIndex: Int): Unit = {}
35-
}
36-
3723
/**
3824
* Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]].
3925
*/
40-
object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
26+
object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] {
4127

4228
protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)
4329

4430
protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
4531
BindReferences.bindReference(in, inputSchema)
4632

47-
protected def create(predicate: Expression): Predicate = {
33+
protected def create(predicate: Expression): BasePredicate = {
4834
val ctx = newCodeGenContext()
4935
val eval = predicate.genCode(ctx)
5036

@@ -53,7 +39,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
5339
return new SpecificPredicate(references);
5440
}
5541

56-
class SpecificPredicate extends ${classOf[Predicate].getName} {
42+
class SpecificPredicate extends ${classOf[BasePredicate].getName} {
5743
private final Object[] references;
5844
${ctx.declareMutableStates()}
5945

@@ -79,6 +65,6 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
7965
logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")
8066

8167
val (clazz, _) = CodeGenerator.compile(code)
82-
clazz.generate(ctx.references.toArray).asInstanceOf[Predicate]
68+
clazz.generate(ctx.references.toArray).asInstanceOf[BasePredicate]
8369
}
8470
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,30 @@ import scala.collection.immutable.TreeSet
2121

2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
24+
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference
2425
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
25-
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
26+
import org.apache.spark.sql.catalyst.expressions.codegen._
2627
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2728
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
2829
import org.apache.spark.sql.catalyst.util.TypeUtils
2930
import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.types._
3132

3233

34+
/**
35+
* Interface for generated/interpreted predicate
36+
*/
37+
abstract class BasePredicate {
38+
def eval(r: InternalRow): Boolean
39+
40+
/**
41+
* Initializes internal states given the current partition index.
42+
* This is used by nondeterministic expressions to set initial states.
43+
* The default implementation does nothing.
44+
*/
45+
def initialize(partitionIndex: Int): Unit = {}
46+
}
47+
3348
object InterpretedPredicate {
3449
def create(expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate =
3550
create(BindReferences.bindReference(expression, inputSchema))
@@ -56,6 +71,26 @@ trait Predicate extends Expression {
5671
override def dataType: DataType = BooleanType
5772
}
5873

74+
/**
75+
* The factory object for `BasePredicate`.
76+
*/
77+
object Predicate extends CodeGeneratorWithInterpretedFallback[Expression, BasePredicate] {
78+
79+
override protected def createCodeGeneratedObject(in: Expression): BasePredicate = {
80+
GeneratePredicate.generate(in)
81+
}
82+
83+
override protected def createInterpretedObject(in: Expression): BasePredicate = {
84+
InterpretedPredicate.create(in)
85+
}
86+
87+
/**
88+
* Returns a BasePredicate for an Expression, which will be bound to `inputSchema`.
89+
*/
90+
def create(exprs: Expression, inputSchema: Seq[Attribute]): BasePredicate = {
91+
createObject(bindReference(exprs, inputSchema))
92+
}
93+
}
5994

6095
trait PredicateHelper {
6196
protected def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ case class FileSourceScanExec(
230230
// call the file index for the files matching all filters except dynamic partition filters
231231
val predicate = dynamicPartitionFilters.reduce(And)
232232
val partitionColumns = relation.partitionSchema
233-
val boundPredicate = newPredicate(predicate.transform {
233+
val boundPredicate = Predicate.create(predicate.transform {
234234
case a: AttributeReference =>
235235
val index = partitionColumns.indexWhere(a.name == _.name)
236236
BoundReference(index, partitionColumns(index).dataType, nullable = true)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
2121
import java.util.concurrent.atomic.AtomicInteger
2222

2323
import scala.collection.mutable.ArrayBuffer
24-
import scala.concurrent.ExecutionContext
2524

2625
import org.codehaus.commons.compiler.CompileException
2726
import org.codehaus.janino.InternalCompilerException
@@ -33,7 +32,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope}
3332
import org.apache.spark.sql.{Row, SparkSession}
3433
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
3534
import org.apache.spark.sql.catalyst.expressions._
36-
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
35+
import org.apache.spark.sql.catalyst.expressions.codegen._
3736
import org.apache.spark.sql.catalyst.plans.QueryPlan
3837
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3938
import org.apache.spark.sql.catalyst.plans.physical._
@@ -471,28 +470,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
471470
MutableProjection.create(expressions, inputSchema)
472471
}
473472

474-
private def genInterpretedPredicate(
475-
expression: Expression, inputSchema: Seq[Attribute]): InterpretedPredicate = {
476-
val str = expression.toString
477-
val logMessage = if (str.length > 256) {
478-
str.substring(0, 256 - 3) + "..."
479-
} else {
480-
str
481-
}
482-
logWarning(s"Codegen disabled for this expression:\n $logMessage")
483-
InterpretedPredicate.create(expression, inputSchema)
484-
}
485-
486-
protected def newPredicate(
487-
expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = {
488-
try {
489-
GeneratePredicate.generate(expression, inputSchema)
490-
} catch {
491-
case _ @ (_: InternalCompilerException | _: CompileException) if codeGenFallBack =>
492-
genInterpretedPredicate(expression, inputSchema)
493-
}
494-
}
495-
496473
protected def newOrdering(
497474
order: Seq[SortOrder], inputSchema: Seq[Attribute]): Ordering[InternalRow] = {
498475
GenerateOrdering.generate(order, inputSchema)

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.catalyst.expressions._
3030
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
3131
import org.apache.spark.sql.catalyst.expressions.codegen._
32-
import org.apache.spark.sql.catalyst.plans.QueryPlan
3332
import org.apache.spark.sql.catalyst.plans.physical._
3433
import org.apache.spark.sql.execution.metric.SQLMetrics
3534
import org.apache.spark.sql.types.{LongType, StructType}
@@ -227,7 +226,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
227226
protected override def doExecute(): RDD[InternalRow] = {
228227
val numOutputRows = longMetric("numOutputRows")
229228
child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
230-
val predicate = newPredicate(condition, child.output)
229+
val predicate = Predicate.create(condition, child.output)
231230
predicate.initialize(0)
232231
iter.filter { row =>
233232
val r = predicate.eval(row)

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ case class InMemoryTableScanExec(
310310
val buffers = relation.cacheBuilder.cachedColumnBuffers
311311

312312
buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
313-
val partitionFilter = newPredicate(
313+
val partitionFilter = Predicate.create(
314314
partitionFilters.reduceOption(And).getOrElse(Literal(true)),
315315
schema)
316316
partitionFilter.initialize(index)

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@ package org.apache.spark.sql.execution.joins
1919

2020
import org.apache.spark.broadcast.Broadcast
2121
import org.apache.spark.rdd.RDD
22-
import org.apache.spark.sql.AnalysisException
2322
import org.apache.spark.sql.catalyst.InternalRow
2423
import org.apache.spark.sql.catalyst.expressions._
2524
import org.apache.spark.sql.catalyst.plans._
2625
import org.apache.spark.sql.catalyst.plans.physical._
2726
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
2827
import org.apache.spark.sql.execution.metric.SQLMetrics
29-
import org.apache.spark.sql.internal.SQLConf
3028
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
3129

3230
case class BroadcastNestedLoopJoinExec(
@@ -84,7 +82,7 @@ case class BroadcastNestedLoopJoinExec(
8482

8583
@transient private lazy val boundCondition = {
8684
if (condition.isDefined) {
87-
newPredicate(condition.get, streamed.output ++ broadcast.output).eval _
85+
Predicate.create(condition.get, streamed.output ++ broadcast.output).eval _
8886
} else {
8987
(r: InternalRow) => true
9088
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ package org.apache.spark.sql.execution.joins
2020
import org.apache.spark._
2121
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
2222
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
23+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, Predicate, UnsafeRow}
2424
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
25-
import org.apache.spark.sql.catalyst.plans.QueryPlan
2625
import org.apache.spark.sql.execution.{BinaryExecNode, ExplainUtils, ExternalAppendOnlyUnsafeRowArray, SparkPlan}
2726
import org.apache.spark.sql.execution.metric.SQLMetrics
2827
import org.apache.spark.util.CompletionIterator
@@ -93,7 +92,7 @@ case class CartesianProductExec(
9392
pair.mapPartitionsWithIndexInternal { (index, iter) =>
9493
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
9594
val filtered = if (condition.isDefined) {
96-
val boundCondition = newPredicate(condition.get, left.output ++ right.output)
95+
val boundCondition = Predicate.create(condition.get, left.output ++ right.output)
9796
boundCondition.initialize(index)
9897
val joined = new JoinedRow
9998

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ trait HashJoin {
9999
UnsafeProjection.create(streamedKeys)
100100

101101
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
102-
newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _
102+
Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _
103103
} else {
104104
(r: InternalRow) => true
105105
}

0 commit comments

Comments
 (0)