1717
1818package org .apache .spark .broadcast
1919
20- import java .io .{ByteArrayInputStream , ObjectInputStream , ObjectOutputStream }
20+ import java .io .{ByteArrayOutputStream , ByteArrayInputStream , InputStream ,
21+ ObjectInputStream , ObjectOutputStream , OutputStream }
2122
2223import scala .reflect .ClassTag
2324import scala .util .Random
2425
2526import org .apache .spark .{Logging , SparkConf , SparkEnv , SparkException }
27+ import org .apache .spark .io .CompressionCodec
2628import org .apache .spark .storage .{BroadcastBlockId , StorageLevel }
27- import org .apache .spark .util .Utils
2829
2930/**
3031 * A [[org.apache.spark.broadcast.Broadcast ]] implementation that uses a BitTorrent-like
@@ -214,11 +215,15 @@ private[broadcast] object TorrentBroadcast extends Logging {
214215 private lazy val BLOCK_SIZE = conf.getInt(" spark.broadcast.blockSize" , 4096 ) * 1024
215216 private var initialized = false
216217 private var conf : SparkConf = null
218+ private var compress : Boolean = false
219+ private var compressionCodec : CompressionCodec = null
217220
218221 def initialize (_isDriver : Boolean , conf : SparkConf ) {
219222 TorrentBroadcast .conf = conf // TODO: we might have to fix it in tests
220223 synchronized {
221224 if (! initialized) {
225+ compress = conf.getBoolean(" spark.broadcast.compress" , true )
226+ compressionCodec = CompressionCodec .createCodec(conf)
222227 initialized = true
223228 }
224229 }
@@ -228,8 +233,13 @@ private[broadcast] object TorrentBroadcast extends Logging {
228233 initialized = false
229234 }
230235
231- def blockifyObject [T ](obj : T ): TorrentInfo = {
232- val byteArray = Utils .serialize[T ](obj)
236+ def blockifyObject [T : ClassTag ](obj : T ): TorrentInfo = {
237+ val bos = new ByteArrayOutputStream ()
238+ val out : OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
239+ val ser = SparkEnv .get.serializer.newInstance()
240+ val serOut = ser.serializeStream(out)
241+ serOut.writeObject[T ](obj).close()
242+ val byteArray = bos.toByteArray
233243 val bais = new ByteArrayInputStream (byteArray)
234244
235245 var blockNum = byteArray.length / BLOCK_SIZE
@@ -255,7 +265,7 @@ private[broadcast] object TorrentBroadcast extends Logging {
255265 info
256266 }
257267
258- def unBlockifyObject [T ](
268+ def unBlockifyObject [T : ClassTag ](
259269 arrayOfBlocks : Array [TorrentBlock ],
260270 totalBytes : Int ,
261271 totalBlocks : Int ): T = {
@@ -264,7 +274,16 @@ private[broadcast] object TorrentBroadcast extends Logging {
264274 System .arraycopy(arrayOfBlocks(i).byteArray, 0 , retByteArray,
265275 i * BLOCK_SIZE , arrayOfBlocks(i).byteArray.length)
266276 }
267- Utils .deserialize[T ](retByteArray, Thread .currentThread.getContextClassLoader)
277+
278+ val in : InputStream = {
279+ val arrIn = new ByteArrayInputStream (retByteArray)
280+ if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn
281+ }
282+ val ser = SparkEnv .get.serializer.newInstance()
283+ val serIn = ser.deserializeStream(in)
284+ val obj = serIn.readObject[T ]()
285+ serIn.close()
286+ obj
268287 }
269288
270289 /**
0 commit comments