Skip to content

Commit 57987d1

Browse files
committed
Broadcast on executors.
1 parent 8f0c35a commit 57987d1

15 files changed

Lines changed: 318 additions & 54 deletions

File tree

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad
4444
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
4545

4646
import org.apache.spark.annotation.DeveloperApi
47-
import org.apache.spark.broadcast.Broadcast
47+
import org.apache.spark.broadcast.{Broadcast, BroadcastMode}
4848
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
4949
import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat,
5050
WholeTextFileInputFormat}
@@ -1401,6 +1401,21 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
14011401
bc
14021402
}
14031403

1404+
/**
1405+
* Broadcast a read-only variable to the cluster, returning a
1406+
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
1407+
* The variable will be sent to each cluster only once.
1408+
*/
1409+
def broadcastRDDOnExecutor[T: ClassTag, U: ClassTag](
1410+
rdd: RDD[T], mode: BroadcastMode[T]): Broadcast[U] = {
1411+
assertNotStopped()
1412+
val bc = env.broadcastManager.newBroadcastOnExecutor[T, U](rdd, mode, isLocal)
1413+
val callSite = getCallSite
1414+
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
1415+
cleaner.foreach(_.registerBroadcastForCleanup(bc))
1416+
bc
1417+
}
1418+
14041419
/**
14051420
* Add a file to be downloaded with this Spark job on every node.
14061421
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported

core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ package org.apache.spark.broadcast
1919

2020
import scala.reflect.ClassTag
2121

22-
import org.apache.spark.SecurityManager
23-
import org.apache.spark.SparkConf
22+
import org.apache.spark.{SecurityManager, SparkConf}
23+
import org.apache.spark.rdd.RDD
2424

2525
/**
2626
* An interface for all the broadcast implementations in Spark (to allow
@@ -41,6 +41,21 @@ private[spark] trait BroadcastFactory {
4141
*/
4242
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
4343

44+
/**
45+
* Creates a new broadcast variable which is broadcasted on executors without collecting first
46+
* to the driver.
47+
*
48+
* @param rdd the RDD to be broadcasted among executors
49+
* @param mode the broadcast mode used to transform the result of RDD to broadcasted object
50+
* @param isLocal whether we are in local mode (single JVM process)
51+
* @param id unique id representing this broadcast variable
52+
*/
53+
def newBroadcastOnExecutor[T: ClassTag, U: ClassTag](
54+
rdd: RDD[T],
55+
mode: BroadcastMode[T],
56+
isLocal: Boolean,
57+
id: Long): Broadcast[U]
58+
4459
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
4560

4661
def stop(): Unit

core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
2323

2424
import org.apache.spark.{SecurityManager, SparkConf}
2525
import org.apache.spark.internal.Logging
26+
import org.apache.spark.rdd.RDD
2627

2728
private[spark] class BroadcastManager(
2829
val isDriver: Boolean,
@@ -56,6 +57,14 @@ private[spark] class BroadcastManager(
5657
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
5758
}
5859

60+
def newBroadcastOnExecutor[T: ClassTag, U: ClassTag](
61+
rdd_ : RDD[T],
62+
mode: BroadcastMode[T],
63+
isLocal: Boolean): Broadcast[U] = {
64+
broadcastFactory.newBroadcastOnExecutor[T, U](rdd_, mode, isLocal,
65+
nextBroadcastId.getAndIncrement())
66+
}
67+
5968
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
6069
broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
6170
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.broadcast
19+
20+
/**
21+
* Marker trait to identify the shape in which tuples are broadcasted. Typical examples of this are
22+
* identity (tuples remain unchanged) or hashed (tuples are converted into some hash index).
23+
*/
24+
trait BroadcastMode[T] {
25+
def transform(rows: Array[T]): Any
26+
27+
/**
28+
* Returns true iff this [[BroadcastMode]] generates the same result as `other`.
29+
*/
30+
def compatibleWith(other: BroadcastMode[T]): Boolean
31+
}

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,12 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
177177
val blockManager = SparkEnv.get.blockManager
178178
blockManager.getLocalValues(broadcastId).map(_.data.next()) match {
179179
case Some(x) =>
180+
// Found broadcasted value in local [[BlockManager]]. Use it directly.
180181
releaseLock(broadcastId)
181182
x.asInstanceOf[T]
182183

183184
case None =>
185+
// Not found. Going to fetch the chunks of the broadcasted value from driver/executors.
184186
logInfo("Started reading broadcast variable " + id)
185187
val startTimeMs = System.currentTimeMillis()
186188
val blocks = readBlocks().flatMap(_.getChunks())

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ package org.apache.spark.broadcast
2020
import scala.reflect.ClassTag
2121

2222
import org.apache.spark.{SecurityManager, SparkConf}
23+
import org.apache.spark.rdd.RDD
2324

2425
/**
25-
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
26+
* A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a BitTorrent-like
2627
* protocol to do a distributed transfer of the broadcasted data to the executors. Refer to
2728
* [[org.apache.spark.broadcast.TorrentBroadcast]] for more details.
2829
*/
@@ -34,6 +35,13 @@ private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
3435
new TorrentBroadcast[T](value_, id)
3536
}
3637

38+
override def newBroadcastOnExecutor[T: ClassTag, U: ClassTag](
39+
rdd: RDD[T],
40+
mode: BroadcastMode[T],
41+
isLocal: Boolean, id: Long): Broadcast[U] = {
42+
new TorrentExecutorBroadcast[T, U](rdd.getNumPartitions, rdd.id, mode, id)
43+
}
44+
3745
override def stop() { }
3846

3947
/**
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.broadcast
19+
20+
import java.io.ObjectOutputStream
21+
22+
import scala.collection.mutable.ArrayBuffer
23+
import scala.reflect.ClassTag
24+
import scala.util.Random
25+
26+
import org.apache.spark._
27+
import org.apache.spark.internal.Logging
28+
import org.apache.spark.storage.{BlockId, BlockResult, BroadcastBlockId, RDDBlockId, StorageLevel}
29+
import org.apache.spark.util.Utils
30+
31+
/**
32+
* A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
33+
*
34+
* Different to [[TorrentBroadcast]], this implementation doesn't divide the object to broadcast.
35+
* In contrast, this implementation performs broadcast on executor side for a RDD. So the results
36+
* of the RDD does not need to collect first back to the driver before broadcasting.
37+
*
38+
* The mechanism is as follows:
39+
*
40+
* On each executor, the executor first attempts to fetch the object from its BlockManager. If
41+
* it doesn not exist, it then uses remote fetches to fetch the blocks of the RDD from other
42+
* executors if available. Once it gets the blocks, it puts the blocks in its own BlockManager,
43+
* ready for other executors to fetch from.
44+
*
45+
* @tparam T The type of the element of RDD to be broadcasted.
46+
* @tparam U The type of object transformed from the collection of elements of the RDD.
47+
*
48+
* @param numBlocks Total number of blocks this broadcast variable contains.
49+
* @param rddId The id of the RDD to be broadcasted on executors.
50+
* @param mode The [[BroadcastMode]] object used to transform the result of RDD to the object which
51+
* will be stored in the [[BlockManager]].
52+
* @param id A unique identifier for the broadcast variable.
53+
*/
54+
private[spark] class TorrentExecutorBroadcast[T: ClassTag, U: ClassTag](
55+
numBlocks: Int,
56+
rddId: Int,
57+
mode: BroadcastMode[T],
58+
id: Long) extends Broadcast[U](id) with Logging with Serializable {
59+
60+
/**
61+
* Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
62+
* which builds this value by reading blocks from other executors.
63+
*/
64+
@transient private lazy val _value: U = readBroadcastBlock()
65+
66+
private val broadcastId = BroadcastBlockId(id)
67+
68+
override protected def getValue() = {
69+
_value
70+
}
71+
72+
/** Fetch torrent blocks from other executors. */
73+
private def readBlocks(): Array[T] = {
74+
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
75+
// to the driver, so other executors can pull these chunks from this executor as well.
76+
val blocks = new Array[Array[T]](numBlocks)
77+
val bm = SparkEnv.get.blockManager
78+
79+
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
80+
val pieceId = RDDBlockId(rddId, pid)
81+
// First try getLocalValues because there is a chance that previous attempts to fetch the
82+
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
83+
// would be available locally (on this executor).
84+
bm.getLocalValues(pieceId) match {
85+
case Some(block: BlockResult) =>
86+
blocks(pid) = block.data.asInstanceOf[Iterator[T]].toArray
87+
case None =>
88+
bm.get[T](pieceId) match {
89+
case Some(b) =>
90+
val data = b.data.asInstanceOf[Iterator[T]].toArray
91+
// We found the block from remote executors' BlockManager, so put the block
92+
// in this executor's BlockManager.
93+
if (!bm.putIterator(pieceId, data.toIterator,
94+
StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
95+
throw new SparkException(
96+
s"Failed to store $pieceId of $broadcastId in local BlockManager")
97+
}
98+
blocks(pid) = data
99+
case None =>
100+
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
101+
}
102+
}
103+
}
104+
blocks.flatMap(x => x)
105+
}
106+
107+
/**
108+
* Remove all persisted state associated with this Torrent broadcast on the executors.
109+
*/
110+
override protected def doUnpersist(blocking: Boolean) {
111+
TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
112+
}
113+
114+
/**
115+
* Remove all persisted state associated with this Torrent broadcast on the executors
116+
* and driver.
117+
*/
118+
override protected def doDestroy(blocking: Boolean) {
119+
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
120+
}
121+
122+
/** Used by the JVM when serializing this object. */
123+
private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
124+
assertValid()
125+
out.defaultWriteObject()
126+
}
127+
128+
private def readBroadcastBlock(): U = Utils.tryOrIOException {
129+
TorrentBroadcast.synchronized {
130+
val blockManager = SparkEnv.get.blockManager
131+
blockManager.getLocalValues(broadcastId).map(_.data.next()) match {
132+
case Some(x) =>
133+
// Found broadcasted value in local [[BlockManager]]. Use it directly.
134+
releaseLock(broadcastId)
135+
x.asInstanceOf[U]
136+
137+
case None =>
138+
// Not found. Going to fetch the chunks of the broadcasted value from executors.
139+
logInfo("Started reading broadcast variable " + id)
140+
val startTimeMs = System.currentTimeMillis()
141+
val rawInput = readBlocks()
142+
logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
143+
144+
val obj = mode.transform(rawInput.toArray).asInstanceOf[U]
145+
// Store the merged copy in BlockManager so other tasks on this executor don't
146+
// need to re-fetch it.
147+
val storageLevel = StorageLevel.MEMORY_AND_DISK
148+
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
149+
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
150+
}
151+
obj
152+
}
153+
}
154+
}
155+
156+
/**
157+
* If running in a task, register the given block's locks for release upon task completion.
158+
* Otherwise, if not running in a task then immediately release the lock.
159+
*/
160+
private def releaseLock(blockId: BlockId): Unit = {
161+
val blockManager = SparkEnv.get.blockManager
162+
Option(TaskContext.get()) match {
163+
case Some(taskContext) =>
164+
taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId))
165+
case None =>
166+
// This should only happen on the driver, where broadcast variables may be accessed
167+
// outside of running tasks (e.g. when computing rdd.partitions()). In order to allow
168+
// broadcast variables to be garbage collected we need to free the reference here
169+
// which is slightly unsafe but is technically okay because broadcast variables aren't
170+
// stored off-heap.
171+
blockManager.releaseLock(blockId)
172+
}
173+
}
174+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,17 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.physical
1919

20+
import org.apache.spark.broadcast.BroadcastMode
2021
import org.apache.spark.sql.catalyst.InternalRow
2122

22-
/**
23-
* Marker trait to identify the shape in which tuples are broadcasted. Typical examples of this are
24-
* identity (tuples remain unchanged) or hashed (tuples are converted into some hash index).
25-
*/
26-
trait BroadcastMode {
27-
def transform(rows: Array[InternalRow]): Any
28-
29-
/**
30-
* Returns true iff this [[BroadcastMode]] generates the same result as `other`.
31-
*/
32-
def compatibleWith(other: BroadcastMode): Boolean
33-
}
34-
3523
/**
3624
* IdentityBroadcastMode requires that rows are broadcasted in their original form.
3725
*/
38-
case object IdentityBroadcastMode extends BroadcastMode {
26+
case object IdentityBroadcastMode extends BroadcastMode[InternalRow] {
3927
// TODO: pack the UnsafeRows into single bytes array.
4028
override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows
4129

42-
override def compatibleWith(other: BroadcastMode): Boolean = {
30+
override def compatibleWith(other: BroadcastMode[InternalRow]): Boolean = {
4331
this eq other
4432
}
4533
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.physical
1919

20+
import org.apache.spark.broadcast.BroadcastMode
2021
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.InternalRow
2123
import org.apache.spark.sql.types.{DataType, IntegerType}
2224

2325
/**
@@ -79,7 +81,7 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
7981
* Represents data where tuples are broadcasted to every node. It is quite common that the
8082
* entire set of tuples is transformed into different data structure.
8183
*/
82-
case class BroadcastDistribution(mode: BroadcastMode) extends Distribution
84+
case class BroadcastDistribution(mode: BroadcastMode[InternalRow]) extends Distribution
8385

8486
/**
8587
* Describes how an operator's output is split across partitions. The `compatibleWith`,
@@ -365,7 +367,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
365367
* Represents a partitioning where rows are collected, transformed and broadcasted to each
366368
* node in the cluster.
367369
*/
368-
case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
370+
case class BroadcastPartitioning(mode: BroadcastMode[InternalRow]) extends Partitioning {
369371
override val numPartitions: Int = 1
370372

371373
override def satisfies(required: Distribution): Boolean = required match {

0 commit comments

Comments
 (0)