Skip to content

Commit cc152fc

Browse files
committed
Don't cache the RDD broadcast variable.
1 parent d256b45 commit cc152fc

6 files changed

Lines changed: 13 additions & 28 deletions

File tree

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ abstract class RDD[T: ClassTag](
12251225
* might modify state of objects referenced in their closures. This is necessary in Hadoop
12261226
* where the JobConf/Configuration object is not thread-safe.
12271227
*/
1228-
@transient private[spark] lazy val broadcasted: Broadcast[Array[Byte]] = {
1228+
@transient private[spark] def createBroadcastBinary(): Broadcast[Array[Byte]] = synchronized {
12291229
val ser = SparkEnv.get.closureSerializer.newInstance()
12301230
val bytes = ser.serialize(this).array()
12311231
val size = Utils.bytesToString(bytes.length)

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -694,18 +694,21 @@ class DAGScheduler(
694694
// Get our pending tasks and remember them in our pendingTasks entry
695695
stage.pendingTasks.clear()
696696
var tasks = ArrayBuffer[Task[_]]()
697+
val broadcastRddBinary = stage.rdd.createBroadcastBinary()
697698
if (stage.isShuffleMap) {
698699
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
699700
val locs = getPreferredLocs(stage.rdd, p)
700-
tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
701+
val part = stage.rdd.partitions(p)
702+
tasks += new ShuffleMapTask(stage.id, broadcastRddBinary, stage.shuffleDep.get, part, locs)
701703
}
702704
} else {
703705
// This is a final stage; figure out its job's missing partitions
704706
val job = stage.resultOfJob.get
705707
for (id <- 0 until job.numPartitions if !job.finished(id)) {
706-
val partition = job.partitions(id)
707-
val locs = getPreferredLocs(stage.rdd, partition)
708-
tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
708+
val p: Int = job.partitions(id)
709+
val part = stage.rdd.partitions(p)
710+
val locs = getPreferredLocs(stage.rdd, p)
711+
tasks += new ResultTask(stage.id, broadcastRddBinary, job.func, part, locs, id)
709712
}
710713
}
711714

core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,6 @@ private[spark] class ResultTask[T, U](
5050
// TODO: Should we also broadcast func? For that we would need a place to
5151
// keep a reference to it (perhaps in DAGScheduler's job object).
5252

53-
def this(
54-
stageId: Int,
55-
rdd: RDD[T],
56-
func: (TaskContext, Iterator[T]) => U,
57-
partitionId: Int,
58-
locs: Seq[TaskLocation],
59-
outputId: Int) = {
60-
this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId)
61-
}
62-
6353
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
6454
if (locs == null) Nil else locs.toSet.toSeq
6555
}

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,6 @@ private[spark] class ShuffleMapTask(
4747
// TODO: Should we also broadcast the ShuffleDependency? For that we would need a place to
4848
// keep a reference to it (perhaps in Stage).
4949

50-
def this(
51-
stageId: Int,
52-
rdd: RDD[_],
53-
dep: ShuffleDependency[_, _, _],
54-
partitionId: Int,
55-
locs: Seq[TaskLocation]) = {
56-
this(stageId, rdd.broadcasted, dep, rdd.partitions(partitionId), locs)
57-
}
58-
5950
/** A constructor used only in test suites. This does not require passing in an RDD. */
6051
def this(partitionId: Int) {
6152
this(0, null, null, new Partition { override def index = 0 }, null)

core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
146146

147147
// Test that GC causes broadcast task data cleanup after dereferencing the RDD.
148148
val postGCTester = new CleanerTester(sc,
149-
broadcastIds = Seq(rdd.broadcasted.id, rdd.firstParent.broadcasted.id))
149+
broadcastIds = Seq(rdd.createBroadcastBinary.id, rdd.firstParent.createBroadcastBinary.id))
150150
rdd = null
151151
runGC()
152152
postGCTester.assertCleanup()

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
3838
sys.error("failed")
3939
}
4040
}
41-
val func = (c: TaskContext, i: Iterator[String]) => i.next
42-
val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0)
41+
val func = (c: TaskContext, i: Iterator[String]) => i.next()
42+
val task = new ResultTask[String, String](
43+
0, rdd.createBroadcastBinary(), func, rdd.partitions(0), Seq(), 0)
4344
intercept[RuntimeException] {
4445
task.run(0)
4546
}
4647
assert(completed === true)
4748
}
4849

49-
case class StubPartition(val index: Int) extends Partition
50+
case class StubPartition(index: Int) extends Partition
5051
}

0 commit comments

Comments
 (0)