Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ case class PruneFilters(conf: SQLConf) extends Rule[LogicalPlan] with PredicateH
case f @ Filter(fc, p: LogicalPlan) =>
val (prunedPredicates, remainingPredicates) =
splitConjunctivePredicates(fc).partition { cond =>
cond.deterministic && p.getConstraints(conf.constraintPropagationEnabled).contains(cond)
cond.deterministic && p.constraints.contains(cond)
}
if (prunedPredicates.isEmpty) {
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred
}

private def buildNewJoinType(filter: Filter, join: Join): JoinType = {
val conditions = splitConjunctivePredicates(filter.condition) ++
filter.getConstraints(conf.constraintPropagationEnabled)
val conditions = splitConjunctivePredicates(filter.condition) ++ filter.constraints
val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet))
val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}

abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
Expand All @@ -27,6 +28,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]

self: PlanType =>

def conf: SQLConf = SQLConf.get

def output: Seq[Attribute]

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
* evaluate to `true` for all rows produced.
*/
lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints))

/**
* Returns [[constraints]] depending on the config of enabling constraint propagation. If the
* flag is disabled, simply returning an empty constraints.
*/
def getConstraints(constraintPropagationEnabled: Boolean): ExpressionSet =
if (constraintPropagationEnabled) {
constraints
lazy val constraints: ExpressionSet = {
if (conf.constraintPropagationEnabled) {
ExpressionSet(
validConstraints
.union(inferAdditionalConstraints(validConstraints))
.union(constructIsNotNullConstraints(validConstraints))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also bound their size as you suggested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That belongs in a separate patch.

)
} else {
ExpressionSet(Set.empty)
}
}

/**
* This method can be overridden by any child class of QueryPlan to specify a set of constraints
Expand All @@ -50,19 +52,6 @@ trait QueryPlanConstraints[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[Pl
*/
protected def validConstraints: Set[Expression] = Set.empty

/**
* Extracts the relevant constraints from a given set of constraints based on the attributes that
* appear in the [[outputSet]].
*/
protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = {
constraints
.union(inferAdditionalConstraints(constraints))
.union(constructIsNotNullConstraints(constraints))
.filter(constraint =>
constraint.references.nonEmpty && constraint.references.subsetOf(outputSet) &&
constraint.deterministic)
}

/**
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
* non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.internal

import java.util.{Locale, NoSuchElementException, Properties, TimeZone}
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._
import scala.collection.immutable
Expand Down Expand Up @@ -64,6 +65,47 @@ object SQLConf {
}
}

/**
* Default config. Only used when there is no active SparkSession for the thread.
* See [[get]] for more information.
*/
private val fallbackConf = new ThreadLocal[SQLConf] {
override def initialValue: SQLConf = new SQLConf
}

/** See [[get]] for more information. */
def getFallbackConf: SQLConf = fallbackConf.get()

/**
* Defines a getter that returns the SQLConf within scope.
* See [[get]] for more information.
*/
private val confGetter = new AtomicReference[() => SQLConf](() => fallbackConf.get())

/**
* Sets the active config object within the current scope.
* See [[get]] for more information.
*/
def setSQLConfGetter(getter: () => SQLConf): Unit = {
confGetter.set(getter)
}

/**
* Returns the active config object within the current scope. If there is an active SparkSession,
* the proper SQLConf associated with the thread's session is used.
*
* The way this works is a little bit convoluted, due to the fact that config was added initially
* only for physical plans (and as a result not in sql/catalyst module).
*
* The first time a SparkSession is instantiated, we set the [[confGetter]] to return the
* active SparkSession's config. If there is no active SparkSession, it returns using the thread
* local [[fallbackConf]]. The reason [[fallbackConf]] is a thread local (rather than just a conf)
* is to support setting different config options for different threads so we can potentially
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
* run unit tests (that does not involve SparkSession) in serial order.
*/
def get: SQLConf = confGetter.get()()

val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
.internal()
.doc("The max number of iterations the optimizer and analyzer runs.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType}

class ConstraintPropagationSuite extends SparkFunSuite {
Expand Down Expand Up @@ -402,17 +403,19 @@ class ConstraintPropagationSuite extends SparkFunSuite {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
val filterRelation = tr.where('a.attr > 10)

verifyConstraints(
filterRelation.analyze.getConstraints(constraintPropagationEnabled = true),
filterRelation.analyze.constraints)
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true)
assert(filterRelation.analyze.constraints.nonEmpty)

assert(filterRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty)
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
assert(filterRelation.analyze.constraints.isEmpty)

val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
.groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3)

verifyConstraints(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = true),
aliasedRelation.analyze.constraints)
assert(aliasedRelation.analyze.getConstraints(constraintPropagationEnabled = false).isEmpty)
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true)
assert(aliasedRelation.analyze.constraints.nonEmpty)

SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
assert(aliasedRelation.analyze.constraints.isEmpty)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unset the config in the end

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ class SparkSession private(

sparkContext.assertNotStopped()

// If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's.
SQLConf.setSQLConfGetter(() => {
SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(SQLConf.getFallbackConf)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we move this to SparkSession.Builder.getOrCreate?

})

/**
* The version of Spark on which this application is running.
*
Expand Down