diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 48f5136d33998..e33cff2f14e17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -62,11 +62,13 @@ abstract class SubqueryExpression( object SubqueryExpression { /** - * Returns true when an expression contains an IN or EXISTS subquery and false otherwise. + * Returns true when an expression contains an IN or correlated EXISTS subquery + * and false otherwise. */ - def hasInOrExistsSubquery(e: Expression): Boolean = { + def hasInOrCorrelatedExistsSubquery(e: Expression): Boolean = { e.find { - case _: ListQuery | _: Exists => true + case _: ListQuery => true + case _: Exists if e.children.nonEmpty => true case _ => false }.isDefined } @@ -302,7 +304,10 @@ case class ListQuery( } /** - * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. + * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition + * or some uncorrelated condition. + * + * 1. correlated condition: * * For example (SQL): * {{{ @@ -312,6 +317,17 @@ case class ListQuery( * FROM b * WHERE b.id = a.id) * }}} + * + * 2. uncorrelated condition example: + * + * For example (SQL): + * {{{ + * SELECT * + * FROM a + * WHERE EXISTS (SELECT * + * FROM b + * WHERE b.id > 10) + * }}} */ case class Exists( plan: LogicalPlan, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9c08ca5201028..935d62015afa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -128,6 +128,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateSubqueryAliases, EliminateView, ReplaceExpressions, + RewriteNonCorrelatedExists, ComputeCurrentTime, GetCurrentDatabase(catalogManager), RewriteDistinctAggregates, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index f64b6e00373f6..c79bf3e20b776 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -52,6 +52,21 @@ object ReplaceExpressions extends Rule[LogicalPlan] { } } +/** + * Rewrite non correlated exists subquery to use ScalarSubquery + * WHERE EXISTS (SELECT A FROM TABLE B WHERE COL1 > 10) + * will be rewritten to + * WHERE (SELECT 1 FROM (SELECT A FROM TABLE B WHERE COL1 > 10) LIMIT 1) IS NOT NULL + */ +object RewriteNonCorrelatedExists extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case exists: Exists if exists.children.isEmpty => + IsNotNull( + ScalarSubquery( + plan = Limit(Literal(1), Project(Seq(Alias(Literal(1), "col")()), exists.plan)), + exprId = exists.exprId)) + } +} /** * Computes the current date and time to make sure we return the same result in a single query. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 8ac14264a9294..b6974624c6514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -96,7 +96,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Filter(condition, child) => val (withSubquery, withoutSubquery) = - splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery) + splitConjunctivePredicates(condition) + .partition(SubqueryExpression.hasInOrCorrelatedExistsSubquery) // Construct the pruned filter condition. val newFilter: LogicalPlan = withoutSubquery match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 2c97ec07bc577..85619beee0c99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join, JoinStrategyHint, SHUFFLE_HASH} import org.apache.spark.sql.catalyst.util.DateTimeConstants -import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} +import org.apache.spark.sql.execution.{ExecSubqueryExpression, RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions._ @@ -89,10 +89,19 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSparkSessi sum } + private def getNumInMemoryTablesInSubquery(plan: SparkPlan): Int = { + plan.expressions.flatMap(_.collect { + case sub: ExecSubqueryExpression => getNumInMemoryTablesRecursively(sub.plan) + }).sum + } + private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = { plan.collect { - case InMemoryTableScanExec(_, _, relation) => - getNumInMemoryTablesRecursively(relation.cachedPlan) + 1 + case inMemoryTable @ InMemoryTableScanExec(_, _, relation) => + getNumInMemoryTablesRecursively(relation.cachedPlan) + + getNumInMemoryTablesInSubquery(inMemoryTable) + 1 + case p => + getNumInMemoryTablesInSubquery(p) }.sum } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5020c1047f8dd..2f0142f3a6c2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -891,9 +891,9 @@ class SubquerySuite extends QueryTest with SharedSparkSession { val sqlText = """ - |SELECT * FROM t1 + |SELECT * FROM t1 a |WHERE - |NOT EXISTS (SELECT * FROM t1) + |NOT EXISTS (SELECT * FROM t1 b WHERE a.i = b.i) """.stripMargin val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan val join = optimizedPlan.collectFirst { case j: Join => j }.get