Skip to content

Commit 1ad1abd

Browse files
committed
reworked ComplexFutureAction and AsyncRDDActions.takeAsync to be non-blocking
1 parent 4210aa6 commit 1ad1abd

File tree

2 files changed

+43
-68
lines changed

2 files changed

+43
-68
lines changed

core/src/main/scala/org/apache/spark/FutureAction.scala

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ import java.util.concurrent.TimeUnit
2222

2323
import org.apache.spark.api.java.JavaFutureAction
2424
import org.apache.spark.rdd.RDD
25-
import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter}
25+
import org.apache.spark.scheduler.JobWaiter
2626

2727
import scala.concurrent._
2828
import scala.concurrent.duration.Duration
29-
import scala.util.{Failure, Try}
29+
import scala.util.Try
3030

3131
/**
3232
* A future for the result of an action to support cancellation. This is an extension of the
@@ -148,44 +148,25 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
148148
*/
149149
class ComplexFutureAction[T] extends FutureAction[T] {
150150

151-
// Pointer to the thread that is executing the action. It is set when the action is run.
152-
@volatile private var thread: Thread = _
151+
@volatile private var _cancelled = false
153152

154-
// A flag indicating whether the future has been cancelled. This is used in case the future
155-
// is cancelled before the action was even run (and thus we have no thread to interrupt).
156-
@volatile private var _cancelled: Boolean = false
157-
158-
@volatile private var jobs: Seq[Int] = Nil
153+
@volatile private var subActions: List[FutureAction[_]] = Nil
159154

160155
// A promise used to signal the future.
161-
private val p = promise[T]()
156+
private val p = Promise[T]()
162157

163-
override def cancel(): Unit = this.synchronized {
158+
override def cancel(): Unit = synchronized {
164159
_cancelled = true
165-
if (thread != null) {
166-
thread.interrupt()
167-
}
160+
p.tryFailure(new SparkException("Action has been cancelled"))
161+
subActions foreach {_.cancel()}
168162
}
169163

170164
/**
171165
* Executes some action enclosed in the closure. To properly enable cancellation, the closure
172166
* should use runJob implementation in this promise. See takeAsync for example.
173167
*/
174-
def run(func: => T)(implicit executor: ExecutionContext): this.type = {
175-
scala.concurrent.future {
176-
thread = Thread.currentThread
177-
try {
178-
p.success(func)
179-
} catch {
180-
case e: Exception => p.failure(e)
181-
} finally {
182-
// This lock guarantees when calling `thread.interrupt()` in `cancel`,
183-
// thread won't be set to null.
184-
ComplexFutureAction.this.synchronized {
185-
thread = null
186-
}
187-
}
188-
}
168+
def run(func: => Future[T])(implicit executor: ExecutionContext): this.type = {
169+
p tryCompleteWith func
189170
this
190171
}
191172

@@ -198,28 +179,15 @@ class ComplexFutureAction[T] extends FutureAction[T] {
198179
processPartition: Iterator[T] => U,
199180
partitions: Seq[Int],
200181
resultHandler: (Int, U) => Unit,
201-
resultFunc: => R) {
182+
resultFunc: => R)(implicit executor: ExecutionContext) = synchronized {
202183
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
203184
// command need to be in an atomic block.
204-
val job = this.synchronized {
205-
if (!isCancelled) {
206-
rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
207-
} else {
208-
throw new SparkException("Action has been cancelled")
209-
}
210-
}
211-
212-
this.jobs = jobs ++ job.jobIds
213-
214-
// Wait for the job to complete. If the action is cancelled (with an interrupt),
215-
// cancel the job and stop the execution. This is not in a synchronized block because
216-
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
217-
try {
218-
Await.ready(job, Duration.Inf)
219-
} catch {
220-
case e: InterruptedException =>
221-
job.cancel()
222-
throw new SparkException("Action has been cancelled")
185+
if (!isCancelled) {
186+
val job = rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
187+
subActions = job::subActions
188+
job
189+
} else {
190+
throw new SparkException("Action has been cancelled")
223191
}
224192
}
225193

@@ -245,7 +213,7 @@ class ComplexFutureAction[T] extends FutureAction[T] {
245213

246214
override def value: Option[Try[T]] = p.future.value
247215

248-
def jobIds: Seq[Int] = jobs
216+
def jobIds: Seq[Int] = subActions flatMap {_.jobIds}
249217

250218
}
251219

@@ -272,7 +240,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S
272240
Await.ready(futureAction, timeout)
273241
futureAction.value.get match {
274242
case scala.util.Success(value) => converter(value)
275-
case Failure(exception) =>
243+
case scala.util.Failure(exception) =>
276244
if (isCancelled) {
277245
throw new CancellationException("Job cancelled").initCause(exception)
278246
} else {

core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ import java.util.concurrent.atomic.AtomicLong
2222
import org.apache.spark.util.ThreadUtils
2323

2424
import scala.collection.mutable.ArrayBuffer
25-
import scala.concurrent.ExecutionContext
25+
import scala.concurrent.{Future, ExecutionContext}
2626
import scala.reflect.ClassTag
2727

28-
import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
28+
import org.apache.spark.{SimpleFutureAction, ComplexFutureAction, FutureAction, Logging}
2929

3030
/**
3131
* A set of asynchronous RDD actions available through an implicit conversion.
@@ -66,14 +66,22 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
6666
*/
6767
def takeAsync(num: Int): FutureAction[Seq[T]] = self.withScope {
6868
val f = new ComplexFutureAction[Seq[T]]
69-
70-
f.run {
71-
// This is a blocking action so we should use "AsyncRDDActions.futureExecutionContext" which
72-
// is a cached thread pool.
73-
val results = new ArrayBuffer[T](num)
74-
val totalParts = self.partitions.length
75-
var partsScanned = 0
76-
while (results.size < num && partsScanned < totalParts) {
69+
// Cached thread pool to handle aggregation of subtasks.
70+
implicit val executionContext = AsyncRDDActions.futureExecutionContext
71+
val results = new ArrayBuffer[T](num)
72+
val totalParts = self.partitions.length
73+
74+
/*
75+
Recursively triggers jobs to scan partitions until either the requested
76+
number of elements are retrieved, or the partitions to scan are exhausted.
77+
This implementation is non-blocking, asynchronously handling the
78+
results of each job and triggering the next job using callbacks on futures.
79+
*/
80+
def continue(partsScanned : Int) : Future[Seq[T]] =
81+
if (results.size >= num || partsScanned >= totalParts) {
82+
Future.successful(results.toSeq)
83+
}
84+
else {
7785
// The number of partitions to try in this iteration. It is ok for this number to be
7886
// greater than totalParts because we actually cap it at totalParts in runJob.
7987
var numPartsToTry = 1
@@ -95,19 +103,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
95103
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
96104

97105
val buf = new Array[Array[T]](p.size)
98-
f.runJob(self,
106+
val job = f.runJob(self,
99107
(it: Iterator[T]) => it.take(left).toArray,
100108
p,
101109
(index: Int, data: Array[T]) => buf(index) = data,
102110
Unit)
103-
104-
buf.foreach(results ++= _.take(num - results.size))
105-
partsScanned += numPartsToTry
111+
job flatMap {case _ =>
112+
buf.foreach(results ++= _.take(num - results.size))
113+
continue(partsScanned + numPartsToTry)
114+
}
106115
}
107-
results.toSeq
108-
}(AsyncRDDActions.futureExecutionContext)
109116

110-
f
117+
f.run {continue(0)}
111118
}
112119

113120
/**

0 commit comments

Comments
 (0)