diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e01a9609b9a0..3f5f2d87c95b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map +import scala.collection.immutable.Set import scala.collection.mutable.{HashMap, HashSet, Stack} import scala.concurrent.duration._ import scala.language.existentials @@ -283,7 +284,9 @@ class DAGScheduler( case None => // We are going to register ancestor shuffle dependencies getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep => - shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId) + if (!shuffleToMapStage.contains(dep.shuffleId)) { + shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId) + } } // Then register current shuffleDep val stage = newOrUsedShuffleStage(shuffleDep, firstJobId) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 653d41fc053c..3a9bf28b372d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1901,6 +1901,26 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + test("Eliminate creating duplicate stage") { + val rdd2 = new MyRDD(sc, 2, Nil) + val rdd1 = new MyRDD(sc, 2, Nil) + val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(1)) + val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(1)) + val rdd3 = new MyRDD(sc, 1, List(dep1, dep2), tracker = mapOutputTracker) + val dep3 = new ShuffleDependency(rdd3, new HashPartitioner(2)) + val dep4 = new ShuffleDependency(rdd3, new HashPartitioner(2)) + val rdd4 = new MyRDD(sc, 2, List(dep3), tracker = mapOutputTracker) + val rdd5 = new MyRDD(sc, 2, List(dep4), tracker = mapOutputTracker) + val dep5 = new ShuffleDependency(rdd4, new HashPartitioner(1)) + val dep6 = new ShuffleDependency(rdd5, new HashPartitioner(1)) + val rdd6 = new MyRDD(sc, 1, List(dep5, dep6), tracker = mapOutputTracker) + val dep7 = new ShuffleDependency(rdd6, new HashPartitioner(2)) + val rdd7 = new MyRDD(sc, 2, List(dep7), tracker = mapOutputTracker) + submit(rdd7, Array(0, 1)) + assert(scheduler.stageIdToStage.size == 8) + assert(scheduler.shuffleToMapStage.size == 7) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID.