Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,12 @@ class Analyzer(
}
substituted.getOrElse(u)
case other =>
// This can't be done in ResolveSubquery because that does not know the CTE.
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other transformExpressions {
case e: SubqueryExpression =>
e.withNewPlan(substituteCTE(e.query, cteRelations))
}
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,3 @@ case class Literal protected (value: Any, dataType: DataType)
case _ => value.toString
}
}

// TODO: Specialize
case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
extends LeafExpression with CodegenFallback {

def update(expression: Expression, input: InternalRow): Unit = {
value = expression.eval(input)
}

override def eval(input: InternalRow): Any = value
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* populated by the query planning infrastructure.
*/
@transient
protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null)
protected[spark] final val sqlContext = SQLContext.getActive().orNull

protected def sparkContext = sqlContext.sparkContext

Expand Down Expand Up @@ -120,44 +120,49 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}

// All the subqueries and their Future of results.
@transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]()
/**
* List of (uncorrelated scalar subquery, future holding the subquery result) for this plan node.
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
*/
@transient
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]

/**
* Collects all the subqueries and create a Future to take the first two rows of them.
* Finds scalar subquery expressions in this plan node and starts evaluating them.
* The list of subqueries are added to [[subqueryResults]].
*/
protected def prepareSubqueries(): Unit = {
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
val futureResult = Future {
// We only need the first row, try to take two rows so we can throw an exception if there
// are more than one rows returned.
// Each subquery should return only one row (and one column). We take two here and throws
// an exception later if the number of rows is greater than one.
e.executedPlan.executeTake(2)
}(SparkPlan.subqueryExecutionContext)
queryResults += e -> futureResult
subqueryResults += e -> futureResult
}
}

/**
* Waits for all the subqueries to finish and updates the results.
* Blocks the thread until all subqueries finish evaluation and update the results.
*/
protected def waitForSubqueries(): Unit = {
// fill in the result of subqueries
queryResults.foreach {
case (e, futureResult) =>
val rows = Await.result(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// There is no rows returned, the result should be null.
e.updateResult(null)
}
subqueryResults.foreach { case (e, futureResult) =>
val rows = Await.result(futureResult, Duration.Inf)
if (rows.length > 1) {
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
}
if (rows.length == 1) {
assert(rows(0).numFields == 1,
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
e.updateResult(rows(0).get(0, e.dataType))
} else {
// If there is no rows returned, the result should be null.
e.updateResult(null)
}
}
queryResults.clear()
subqueryResults.clear()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ case class ScalarSubquery(
/**
* Convert the subquery from logical plan into executed plan.
*/
private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
plan.transformAllExpressions {
case subquery: expressions.ScalarSubquery =>
Expand Down
33 changes: 18 additions & 15 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,36 +36,39 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
sql("select (select (select 1) + 1) + 1").collect()
}

// more than one columns
val error = intercept[AnalysisException] {
sql("select (select 1, 2) as b").collect()
}
assert(error.message contains "Scalar subquery must return only one column, but got 2")

// more than one rows
val error2 = intercept[RuntimeException] {
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
}
assert(error2.getMessage contains
"more than one row returned by a subquery used as an expression")

// string type
assertResult(Array(Row("s"))) {
sql("select (select 's' as s) as b").collect()
}
}

// zero rows
test("uncorrelated scalar subquery should return null if there is 0 rows") {
assertResult(Array(Row(null))) {
sql("select (select 's' as s limit 0) as b").collect()
}
}

test("analysis error when the number of columns is not 1") {
val error = intercept[AnalysisException] {
sql("select (select 1, 2) as b").collect()
}
assert(error.message.contains("Scalar subquery must return only one column, but got 2"))
}

test("runtime error when the number of rows is greater than 1") {
val error2 = intercept[RuntimeException] {
sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect()
}
assert(error2.getMessage.contains(
"more than one row returned by a subquery used as an expression"))
}

test("uncorrelated scalar subquery on testData") {
// initialize test Data
testData

assertResult(Array(Row(5))) {
sql("select (select key from testData where key > 3 limit 1) + 1").collect()
sql("select (select key from testData where key > 3 order by key limit 1) + 1").collect()
}

assertResult(Array(Row(-100))) {
Expand Down