Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.Map
import scala.collection.immutable.Set
import scala.collection.mutable.{HashMap, HashSet, Stack}
import scala.concurrent.duration._
import scala.language.existentials
Expand Down Expand Up @@ -283,7 +284,9 @@ class DAGScheduler(
case None =>
// We are going to register ancestor shuffle dependencies
getAncestorShuffleDependencies(shuffleDep.rdd).foreach { dep =>
shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId)
if (!shuffleToMapStage.contains(dep.shuffleId)) {
shuffleToMapStage(dep.shuffleId) = newOrUsedShuffleStage(dep, firstJobId)
}
}
// Then register current shuffleDep
val stage = newOrUsedShuffleStage(shuffleDep, firstJobId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1901,6 +1901,26 @@ class DAGSchedulerSuite
assertDataStructuresEmpty()
}

test("Eliminate creating duplicate stage") {
val rdd2 = new MyRDD(sc, 2, Nil)
val rdd1 = new MyRDD(sc, 2, Nil)
val dep1 = new ShuffleDependency(rdd1, new HashPartitioner(1))
val dep2 = new ShuffleDependency(rdd2, new HashPartitioner(1))
val rdd3 = new MyRDD(sc, 1, List(dep1, dep2), tracker = mapOutputTracker)
val dep3 = new ShuffleDependency(rdd3, new HashPartitioner(2))
val dep4 = new ShuffleDependency(rdd3, new HashPartitioner(2))
val rdd4 = new MyRDD(sc, 2, List(dep3), tracker = mapOutputTracker)
val rdd5 = new MyRDD(sc, 2, List(dep4), tracker = mapOutputTracker)
val dep5 = new ShuffleDependency(rdd4, new HashPartitioner(1))
val dep6 = new ShuffleDependency(rdd5, new HashPartitioner(1))
val rdd6 = new MyRDD(sc, 1, List(dep5, dep6), tracker = mapOutputTracker)
val dep7 = new ShuffleDependency(rdd6, new HashPartitioner(2))
val rdd7 = new MyRDD(sc, 2, List(dep7), tracker = mapOutputTracker)
submit(rdd7, Array(0, 1))
assert(scheduler.stageIdToStage.size == 8)
assert(scheduler.shuffleToMapStage.size == 7)
}

/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
Expand Down