Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 19 additions & 34 deletions core/src/main/scala/org/apache/spark/FutureAction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,63 +108,48 @@ trait FutureAction[T] extends Future[T] {
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
extends FutureAction[T] {

// Note: `resultFunc` is a closure which may contain references to state that's updated by the
// JobWaiter's result handler function. It should only be evaluated once the job has succeeded.

@volatile private var _cancelled: Boolean = false
private[this] val jobWaiterFuture: Future[Unit] = jobWaiter.toFuture
private[this] lazy val resultFuncOutput: T = {
assert(isCompleted, "resultFunc should only be evaluated after the job has completed")
resultFunc
}

override def cancel() {
_cancelled = true
jobWaiter.cancel()
}

override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
if (!atMost.isFinite()) {
awaitResult()
} else jobWaiter.synchronized {
val finishTime = System.currentTimeMillis() + atMost.toMillis
while (!isCompleted) {
val time = System.currentTimeMillis()
if (time >= finishTime) {
throw new TimeoutException
} else {
jobWaiter.wait(finishTime - time)
}
}
}
jobWaiterFuture.ready(atMost)(permit) // Throws exception if the job failed.
this
}

@throws(classOf[Exception])
override def result(atMost: Duration)(implicit permit: CanAwait): T = {
ready(atMost)(permit)
awaitResult() match {
case scala.util.Success(res) => res
case scala.util.Failure(e) => throw e
}
jobWaiterFuture.result(atMost)(permit) // Throws exception if the job failed.
resultFuncOutput // This function is safe to evaluate because the job must have succeeded.
}

override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
executor.execute(new Runnable {
override def run() {
func(awaitResult())
}
})
override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = {
jobWaiterFuture.map { _ => resultFuncOutput }.onComplete(func)
}

override def isCompleted: Boolean = jobWaiter.jobFinished

override def isCancelled: Boolean = _cancelled

override def value: Option[Try[T]] = {
if (jobWaiter.jobFinished) {
Some(awaitResult())
} else {
if (!isCompleted) {
None
}
}

private def awaitResult(): Try[T] = {
jobWaiter.awaitResult() match {
case JobSucceeded => scala.util.Success(resultFunc)
case JobFailed(e: Exception) => scala.util.Failure(e)
} else {
jobWaiter.awaitResult() match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part seems like a bad hack to use awaitResult to get the result. Rather, there should be a JobWaiter.jobResult (make it public), that return Option[JobResult] and use that.

case JobSucceeded => Some(scala.util.Success(resultFuncOutput))
case JobFailed(e) => Some(scala.util.Failure(e))
}
}
}

Expand Down
17 changes: 17 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.scheduler

import scala.concurrent.{Future, Promise}
import scala.util.Success

/**
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
Expand All @@ -28,12 +31,18 @@ private[spark] class JobWaiter[T](
resultHandler: (Int, T) => Unit)
extends JobListener {

private val promise = Promise[Unit]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stepping back. If we are using promise anyway, why do we need a separate variable called "jobFinished"? The promise is sufficient in keeping the state of whether the job has finished or not.

val promise = Promise[JobResult]

def jobFinished(...) = promise.isCompleted()

The rest of the code needs to use jobFinished instead of _jobFinished


private var finishedTasks = 0

// Is the job as a whole finished (succeeded or failed)?
@volatile
private var _jobFinished = totalTasks == 0

if (_jobFinished) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zsxwing, this if statement fixes a subtle bug that we found in your JobWaiter future: if a job has no tasks, then it is marked as finished immediately and taskSucceeded will not be called, so we need to complete the promise here. We noticed this because a test in AsyncRDDActionsSuite was hanging.

promise.complete(Success(Unit))
}

def jobFinished: Boolean = _jobFinished

// If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
Expand All @@ -58,13 +67,15 @@ private[spark] class JobWaiter[T](
if (finishedTasks == totalTasks) {
_jobFinished = true
jobResult = JobSucceeded
promise.trySuccess()
this.notifyAll()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be removed. Right?

}
}

override def jobFailed(exception: Exception): Unit = synchronized {
_jobFinished = true
jobResult = JobFailed(exception)
promise.tryFailure(exception)
this.notifyAll()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be removed too

}

Expand All @@ -74,4 +85,10 @@ private[spark] class JobWaiter[T](
}
return jobResult
}

/**
* Return a Future to monitoring the job success or failure event. You can use this method to
* avoid blocking your thread.
*/
def toFuture: Future[Unit] = promise.future
}
19 changes: 18 additions & 1 deletion core/src/test/scala/org/apache/spark/FutureActionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.spark

import scala.concurrent.Await
import scala.concurrent.{ExecutionContext, Await}
import scala.concurrent.duration.Duration

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.util.ThreadUtils

class FutureActionSuite
extends SparkFunSuite
Expand Down Expand Up @@ -49,4 +50,20 @@ class FutureActionSuite
job.jobIds.size should be (2)
}

test("simple async action callbacks should not tie up execution context threads (SPARK-9026)") {
val rdd = sc.parallelize(1 to 10, 2).map(_ => Thread.sleep(1000 * 1000))
val pool = ThreadUtils.newDaemonCachedThreadPool("SimpleFutureActionTest")
val executionContext = ExecutionContext.fromExecutorService(pool)
val job = rdd.countAsync()
try {
for (_ <- 1 to 10) {
job.onComplete(_ => ())(executionContext)
assert(pool.getLargestPoolSize < 10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks flaky. Even they are non blocking, there is NO guarantee that one of the 10 scheduled function _ => () will finish by the end of this loop. So it may happen that in the 10th iteration, the previous 9 scheduled function are still not finished, the 10th on gets scheduled, and therefore the pool size = 10.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that my intention when writing this test was to have a test that demonstrated the eagerly-create-a-thread-per-callback problem with the old implementation of SimpleFutureAction.

I don't think that this is flaky but I also don't think that this tests adds much value since we're unlikely to ever switch back to the old inefficient implementation. I'll just drop this test, since I don't think it's adding any real value right now.

}
} finally {
job.cancel()
executionContext.shutdownNow()
}
}

}