Skip to content

Commit 09dd34c

Browse files
committed
[PYSPARK] Updates to pyspark broadcast
1 parent a2a54a5 commit 09dd34c

10 files changed

Lines changed: 695 additions & 84 deletions

File tree

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 257 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
2424

2525
import scala.collection.JavaConverters._
2626
import scala.collection.mutable
27+
import scala.concurrent.Promise
28+
import scala.concurrent.duration.Duration
2729
import scala.language.existentials
28-
import scala.util.control.NonFatal
30+
import scala.util.Try
2931

3032
import org.apache.hadoop.conf.Configuration
3133
import org.apache.hadoop.io.compress.CompressionCodec
@@ -37,6 +39,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
3739
import org.apache.spark.broadcast.Broadcast
3840
import org.apache.spark.input.PortableDataStream
3941
import org.apache.spark.internal.Logging
42+
import org.apache.spark.network.util.JavaUtils
4043
import org.apache.spark.rdd.RDD
4144
import org.apache.spark.security.SocketAuthHelper
4245
import org.apache.spark.util._
@@ -168,27 +171,34 @@ private[spark] object PythonRDD extends Logging {
168171

169172
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
170173
JavaRDD[Array[Byte]] = {
171-
val file = new DataInputStream(new FileInputStream(filename))
174+
readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism)
175+
}
176+
177+
def readRDDFromInputStream(
178+
sc: SparkContext,
179+
in: InputStream,
180+
parallelism: Int): JavaRDD[Array[Byte]] = {
181+
val din = new DataInputStream(in)
172182
try {
173183
val objs = new mutable.ArrayBuffer[Array[Byte]]
174184
try {
175185
while (true) {
176-
val length = file.readInt()
186+
val length = din.readInt()
177187
val obj = new Array[Byte](length)
178-
file.readFully(obj)
188+
din.readFully(obj)
179189
objs += obj
180190
}
181191
} catch {
182192
case eof: EOFException => // No-op
183193
}
184-
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
194+
JavaRDD.fromRDD(sc.parallelize(objs, parallelism))
185195
} finally {
186-
file.close()
196+
din.close()
187197
}
188198
}
189199

190-
def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = {
191-
sc.broadcast(new PythonBroadcast(path))
200+
def setupBroadcast(path: String): PythonBroadcast = {
201+
new PythonBroadcast(path)
192202
}
193203

194204
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -398,34 +408,15 @@ private[spark] object PythonRDD extends Logging {
398408
* data collected from this job, and the secret for authentication.
399409
*/
400410
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
401-
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
402-
// Close the socket if no connection in 15 seconds
403-
serverSocket.setSoTimeout(15000)
404-
405-
new Thread(threadName) {
406-
setDaemon(true)
407-
override def run() {
408-
try {
409-
val sock = serverSocket.accept()
410-
authHelper.authClient(sock)
411-
412-
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
413-
Utils.tryWithSafeFinally {
414-
writeIteratorToStream(items, out)
415-
} {
416-
out.close()
417-
sock.close()
418-
}
419-
} catch {
420-
case NonFatal(e) =>
421-
logError(s"Error while sending iterator", e)
422-
} finally {
423-
serverSocket.close()
424-
}
411+
val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
412+
val out = new DataOutputStream(new BufferedOutputStream(s.getOutputStream()))
413+
Utils.tryWithSafeFinally {
414+
writeIteratorToStream(items, out)
415+
} {
416+
out.close()
425417
}
426-
}.start()
427-
428-
Array(serverSocket.getLocalPort, authHelper.secret)
418+
}
419+
Array(port, secret)
429420
}
430421

431422
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
@@ -643,13 +634,11 @@ private[spark] class PythonAccumulatorV2(
643634
}
644635
}
645636

646-
/**
647-
* A Wrapper for Python Broadcast, which is written into disk by Python. It also will
648-
* write the data into disk after deserialization, then Python can read it from disks.
649-
*/
650637
// scalastyle:off no.finalize
651638
private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
652-
with Logging {
639+
with Logging {
640+
641+
private var encryptionServer: PythonServer[Unit] = null
653642

654643
/**
655644
* Read data from disks, then copy it to `out`
@@ -692,5 +681,233 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
692681
}
693682
super.finalize()
694683
}
684+
685+
def setupEncryptionServer(): Array[Any] = {
686+
encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") {
687+
override def handleConnection(sock: Socket): Unit = {
688+
val env = SparkEnv.get
689+
val in = sock.getInputStream()
690+
val dir = new File(Utils.getLocalDir(env.conf))
691+
val file = File.createTempFile("broadcast", "", dir)
692+
path = file.getAbsolutePath
693+
val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path))
694+
DechunkedInputStream.dechunkAndCopyToOutput(in, out)
695+
}
696+
}
697+
Array(encryptionServer.port, encryptionServer.secret)
698+
}
699+
700+
def waitTillDataReceived(): Unit = encryptionServer.getResult()
695701
}
696702
// scalastyle:on no.finalize
703+
704+
/**
705+
* The inverse of pyspark's ChunkedStream for sending broadcast data.
706+
* Tested from python tests.
707+
*/
708+
private[spark] class DechunkedInputStream(wrapped: InputStream) extends InputStream with Logging {
709+
private val din = new DataInputStream(wrapped)
710+
private var remainingInChunk = din.readInt()
711+
712+
override def read(): Int = {
713+
val into = new Array[Byte](1)
714+
val n = read(into, 0, 1)
715+
if (n == -1) {
716+
-1
717+
} else {
718+
// if you just cast a byte to an int, then anything > 127 is negative, which is interpreted
719+
// as an EOF
720+
into(0) & 0xFF
721+
}
722+
}
723+
724+
override def read(dest: Array[Byte], off: Int, len: Int): Int = {
725+
if (remainingInChunk == -1) {
726+
return -1
727+
}
728+
var destSpace = len
729+
var destPos = off
730+
while (destSpace > 0 && remainingInChunk != -1) {
731+
val toCopy = math.min(remainingInChunk, destSpace)
732+
val read = din.read(dest, destPos, toCopy)
733+
destPos += read
734+
destSpace -= read
735+
remainingInChunk -= read
736+
if (remainingInChunk == 0) {
737+
remainingInChunk = din.readInt()
738+
}
739+
}
740+
assert(destSpace == 0 || remainingInChunk == -1)
741+
return destPos - off
742+
}
743+
744+
override def close(): Unit = wrapped.close()
745+
}
746+
747+
/**
748+
* The inverse of pyspark's ChunkedStream for sending data of unknown size.
749+
*
750+
* We might be serializing a really large object from python -- we don't want
751+
* python to buffer the whole thing in memory, nor can it write to a file,
752+
* so we don't know the length in advance. So python writes it in chunks, each chunk
753+
* preceeded by a length, till we get a "length" of -1 which serves as EOF.
754+
*
755+
* Tested from python tests.
756+
*/
757+
private[spark] object DechunkedInputStream {
758+
759+
/**
760+
* Dechunks the input, copies to output, and closes both input and the output safely.
761+
*/
762+
def dechunkAndCopyToOutput(chunked: InputStream, out: OutputStream): Unit = {
763+
val dechunked = new DechunkedInputStream(chunked)
764+
Utils.tryWithSafeFinally {
765+
Utils.copyStream(dechunked, out)
766+
} {
767+
JavaUtils.closeQuietly(out)
768+
JavaUtils.closeQuietly(dechunked)
769+
}
770+
}
771+
}
772+
773+
/**
774+
* Creates a server in the jvm to communicate with python for handling one batch of data, with
775+
* authentication and error handling.
776+
*/
777+
private[spark] abstract class PythonServer[T](
778+
authHelper: SocketAuthHelper,
779+
threadName: String) {
780+
781+
def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
782+
def this(threadName: String) = this(SparkEnv.get, threadName)
783+
784+
val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock =>
785+
promise.complete(Try(handleConnection(sock)))
786+
}
787+
788+
/**
789+
* Handle a connection which has already been authenticated. Any error from this function
790+
* will clean up this connection and the entire server, and get propogated to [[getResult]].
791+
*/
792+
def handleConnection(sock: Socket): T
793+
794+
val promise = Promise[T]()
795+
796+
/**
797+
* Blocks indefinitely for [[handleConnection]] to finish, and returns that result. If
798+
* handleConnection throws an exception, this will throw an exception which includes the original
799+
* exception as a cause.
800+
*/
801+
def getResult(): T = {
802+
getResult(Duration.Inf)
803+
}
804+
805+
def getResult(wait: Duration): T = {
806+
ThreadUtils.awaitResult(promise.future, wait)
807+
}
808+
809+
}
810+
811+
private[spark] object PythonServer {
812+
813+
/**
814+
* Create a socket server and run user function on the socket in a background thread.
815+
*
816+
* The socket server can only accept one connection, or close if no connection
817+
* in 15 seconds.
818+
*
819+
* The thread will terminate after the supplied user function, or if there are any exceptions.
820+
*
821+
* If you need to get a result of the supplied function, create a subclass of [[PythonServer]]
822+
*
823+
* @return The port number of a local socket and the secret for authentication.
824+
*/
825+
def setupOneConnectionServer(
826+
authHelper: SocketAuthHelper,
827+
threadName: String)
828+
(func: Socket => Unit): (Int, String) = {
829+
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
830+
// Close the socket if no connection in 15 seconds
831+
serverSocket.setSoTimeout(15000)
832+
833+
new Thread(threadName) {
834+
setDaemon(true)
835+
override def run(): Unit = {
836+
var sock: Socket = null
837+
try {
838+
sock = serverSocket.accept()
839+
authHelper.authClient(sock)
840+
func(sock)
841+
} finally {
842+
JavaUtils.closeQuietly(serverSocket)
843+
JavaUtils.closeQuietly(sock)
844+
}
845+
}
846+
}.start()
847+
(serverSocket.getLocalPort, authHelper.secret)
848+
}
849+
}
850+
851+
/**
852+
* Sends decrypted broadcast data to python worker. See [[PythonRunner]] for entire protocol.
853+
*/
854+
private[spark] class EncryptedPythonBroadcastServer(
855+
val env: SparkEnv,
856+
val idsAndFiles: Seq[(Long, String)])
857+
extends PythonServer[Unit]("broadcast-decrypt-server") with Logging {
858+
859+
override def handleConnection(socket: Socket): Unit = {
860+
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()))
861+
var socketIn: InputStream = null
862+
// send the broadcast id, then the decrypted data. We don't need to send the length, the
863+
// the python pickle module just needs a stream.
864+
Utils.tryWithSafeFinally {
865+
(idsAndFiles).foreach { case (id, path) =>
866+
out.writeLong(id)
867+
val in = env.serializerManager.wrapForEncryption(new FileInputStream(path))
868+
Utils.tryWithSafeFinally {
869+
Utils.copyStream(in, out, false)
870+
} {
871+
in.close()
872+
}
873+
}
874+
logTrace("waiting for python to accept broadcast data over socket")
875+
out.flush()
876+
socketIn = socket.getInputStream()
877+
socketIn.read()
878+
logTrace("done serving broadcast data")
879+
} {
880+
JavaUtils.closeQuietly(socketIn)
881+
JavaUtils.closeQuietly(out)
882+
}
883+
}
884+
885+
def waitTillBroadcastDataSent(): Unit = {
886+
getResult()
887+
}
888+
}
889+
890+
/**
891+
* Helper for making RDD[Array[Byte]] from some python data, by reading the data from python
892+
* over a socket. This is used in preference to writing data to a file when encryption is enabled.
893+
*/
894+
private[spark] abstract class PythonRDDServer
895+
extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
896+
897+
def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
898+
val in = sock.getInputStream()
899+
val dechunkedInput: InputStream = new DechunkedInputStream(in)
900+
streamToRDD(dechunkedInput)
901+
}
902+
903+
protected def streamToRDD(input: InputStream): RDD[Array[Byte]]
904+
905+
}
906+
907+
private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
908+
extends PythonRDDServer {
909+
910+
override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
911+
PythonRDD.readRDDFromInputStream(sc, input, parallelism)
912+
}
913+
}

0 commit comments

Comments
 (0)