diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 37fcc93c62fa8..b30bd74812b36 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io.File import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets -import java.util.concurrent.TimeUnit +import java.util.concurrent.{Semaphore, TimeUnit} import scala.concurrent.duration._ @@ -499,6 +499,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu test("Cancelling stages/jobs with custom reasons.") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) val REASON = "You shall not pass" + val slices = 10 val listener = new SparkListener { override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { @@ -508,6 +509,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } sc.cancelStage(taskStart.stageId, REASON) SparkContextSuite.cancelStage = false + SparkContextSuite.semaphore.release(slices) } } @@ -518,21 +520,25 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } sc.cancelJob(jobStart.jobId, REASON) SparkContextSuite.cancelJob = false + SparkContextSuite.semaphore.release(slices) } } } sc.addSparkListener(listener) for (cancelWhat <- Seq("stage", "job")) { + SparkContextSuite.semaphore.drainPermits() SparkContextSuite.isTaskStarted = false SparkContextSuite.cancelStage = (cancelWhat == "stage") SparkContextSuite.cancelJob = (cancelWhat == "job") val ex = intercept[SparkException] { - sc.range(0, 10000L).mapPartitions { x => - org.apache.spark.SparkContextSuite.isTaskStarted = true + sc.range(0, 10000L, numSlices = slices).mapPartitions { x => + SparkContextSuite.isTaskStarted = true + // Block waiting for the listener to cancel the stage or job. + SparkContextSuite.semaphore.acquire() x - }.cartesian(sc.range(0, 10L))count() + }.count() } ex.getCause() match { @@ -636,4 +642,5 @@ object SparkContextSuite { @volatile var isTaskStarted = false @volatile var taskKilled = false @volatile var taskSucceeded = false + val semaphore = new Semaphore(0) }