@@ -24,8 +24,10 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
2424
2525import scala .collection .JavaConverters ._
2626import scala .collection .mutable
27+ import scala .concurrent .Promise
28+ import scala .concurrent .duration .Duration
2729import scala .language .existentials
28- import scala .util .control . NonFatal
30+ import scala .util .Try
2931
3032import org .apache .hadoop .conf .Configuration
3133import org .apache .hadoop .io .compress .CompressionCodec
@@ -37,6 +39,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
3739import org .apache .spark .broadcast .Broadcast
3840import org .apache .spark .input .PortableDataStream
3941import org .apache .spark .internal .Logging
42+ import org .apache .spark .network .util .JavaUtils
4043import org .apache .spark .rdd .RDD
4144import org .apache .spark .security .SocketAuthHelper
4245import org .apache .spark .util ._
@@ -168,27 +171,34 @@ private[spark] object PythonRDD extends Logging {
168171
169172 def readRDDFromFile (sc : JavaSparkContext , filename : String , parallelism : Int ):
170173 JavaRDD [Array [Byte ]] = {
171- val file = new DataInputStream (new FileInputStream (filename))
174+ readRDDFromInputStream(sc.sc, new FileInputStream (filename), parallelism)
175+ }
176+
177+ def readRDDFromInputStream (
178+ sc : SparkContext ,
179+ in : InputStream ,
180+ parallelism : Int ): JavaRDD [Array [Byte ]] = {
181+ val din = new DataInputStream (in)
172182 try {
173183 val objs = new mutable.ArrayBuffer [Array [Byte ]]
174184 try {
175185 while (true ) {
176- val length = file .readInt()
186+ val length = din .readInt()
177187 val obj = new Array [Byte ](length)
178- file .readFully(obj)
188+ din .readFully(obj)
179189 objs += obj
180190 }
181191 } catch {
182192 case eof : EOFException => // No-op
183193 }
184- JavaRDD .fromRDD(sc.sc. parallelize(objs, parallelism))
194+ JavaRDD .fromRDD(sc.parallelize(objs, parallelism))
185195 } finally {
186- file .close()
196+ din .close()
187197 }
188198 }
189199
190- def readBroadcastFromFile ( sc : JavaSparkContext , path : String ): Broadcast [ PythonBroadcast ] = {
191- sc.broadcast( new PythonBroadcast (path) )
200+ def setupBroadcast ( path : String ): PythonBroadcast = {
201+ new PythonBroadcast (path)
192202 }
193203
194204 def writeIteratorToStream [T ](iter : Iterator [T ], dataOut : DataOutputStream ) {
@@ -398,34 +408,15 @@ private[spark] object PythonRDD extends Logging {
398408 * data collected from this job, and the secret for authentication.
399409 */
400410 def serveIterator (items : Iterator [_], threadName : String ): Array [Any ] = {
401- val serverSocket = new ServerSocket (0 , 1 , InetAddress .getByName(" localhost" ))
402- // Close the socket if no connection in 15 seconds
403- serverSocket.setSoTimeout(15000 )
404-
405- new Thread (threadName) {
406- setDaemon(true )
407- override def run () {
408- try {
409- val sock = serverSocket.accept()
410- authHelper.authClient(sock)
411-
412- val out = new DataOutputStream (new BufferedOutputStream (sock.getOutputStream))
413- Utils .tryWithSafeFinally {
414- writeIteratorToStream(items, out)
415- } {
416- out.close()
417- sock.close()
418- }
419- } catch {
420- case NonFatal (e) =>
421- logError(s " Error while sending iterator " , e)
422- } finally {
423- serverSocket.close()
424- }
411+ val (port, secret) = PythonServer .setupOneConnectionServer(authHelper, threadName) { s =>
412+ val out = new DataOutputStream (new BufferedOutputStream (s.getOutputStream()))
413+ Utils .tryWithSafeFinally {
414+ writeIteratorToStream(items, out)
415+ } {
416+ out.close()
425417 }
426- }.start()
427-
428- Array (serverSocket.getLocalPort, authHelper.secret)
418+ }
419+ Array (port, secret)
429420 }
430421
431422 private def getMergedConf (confAsMap : java.util.HashMap [String , String ],
@@ -643,13 +634,11 @@ private[spark] class PythonAccumulatorV2(
643634 }
644635}
645636
646- /**
647- * A Wrapper for Python Broadcast, which is written into disk by Python. It also will
648- * write the data into disk after deserialization, then Python can read it from disks.
649- */
650637// scalastyle:off no.finalize
651638private [spark] class PythonBroadcast (@ transient var path : String ) extends Serializable
652- with Logging {
639+ with Logging {
640+
641+ private var encryptionServer : PythonServer [Unit ] = null
653642
654643 /**
655644 * Read data from disks, then copy it to `out`
@@ -692,5 +681,233 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
692681 }
693682 super .finalize()
694683 }
684+
685+ def setupEncryptionServer (): Array [Any ] = {
686+ encryptionServer = new PythonServer [Unit ](" broadcast-encrypt-server" ) {
687+ override def handleConnection (sock : Socket ): Unit = {
688+ val env = SparkEnv .get
689+ val in = sock.getInputStream()
690+ val dir = new File (Utils .getLocalDir(env.conf))
691+ val file = File .createTempFile(" broadcast" , " " , dir)
692+ path = file.getAbsolutePath
693+ val out = env.serializerManager.wrapForEncryption(new FileOutputStream (path))
694+ DechunkedInputStream .dechunkAndCopyToOutput(in, out)
695+ }
696+ }
697+ Array (encryptionServer.port, encryptionServer.secret)
698+ }
699+
700+ def waitTillDataReceived (): Unit = encryptionServer.getResult()
695701}
696702// scalastyle:on no.finalize
703+
704+ /**
705+ * The inverse of pyspark's ChunkedStream for sending broadcast data.
706+ * Tested from python tests.
707+ */
708+ private [spark] class DechunkedInputStream (wrapped : InputStream ) extends InputStream with Logging {
709+ private val din = new DataInputStream (wrapped)
710+ private var remainingInChunk = din.readInt()
711+
712+ override def read (): Int = {
713+ val into = new Array [Byte ](1 )
714+ val n = read(into, 0 , 1 )
715+ if (n == - 1 ) {
716+ - 1
717+ } else {
718+ // if you just cast a byte to an int, then anything > 127 is negative, which is interpreted
719+ // as an EOF
720+ into(0 ) & 0xFF
721+ }
722+ }
723+
724+ override def read (dest : Array [Byte ], off : Int , len : Int ): Int = {
725+ if (remainingInChunk == - 1 ) {
726+ return - 1
727+ }
728+ var destSpace = len
729+ var destPos = off
730+ while (destSpace > 0 && remainingInChunk != - 1 ) {
731+ val toCopy = math.min(remainingInChunk, destSpace)
732+ val read = din.read(dest, destPos, toCopy)
733+ destPos += read
734+ destSpace -= read
735+ remainingInChunk -= read
736+ if (remainingInChunk == 0 ) {
737+ remainingInChunk = din.readInt()
738+ }
739+ }
740+ assert(destSpace == 0 || remainingInChunk == - 1 )
741+ return destPos - off
742+ }
743+
744+ override def close (): Unit = wrapped.close()
745+ }
746+
747+ /**
748+ * The inverse of pyspark's ChunkedStream for sending data of unknown size.
749+ *
750+ * We might be serializing a really large object from python -- we don't want
751+ * python to buffer the whole thing in memory, nor can it write to a file,
752+ * so we don't know the length in advance. So python writes it in chunks, each chunk
753+ * preceeded by a length, till we get a "length" of -1 which serves as EOF.
754+ *
755+ * Tested from python tests.
756+ */
757+ private [spark] object DechunkedInputStream {
758+
759+ /**
760+ * Dechunks the input, copies to output, and closes both input and the output safely.
761+ */
762+ def dechunkAndCopyToOutput (chunked : InputStream , out : OutputStream ): Unit = {
763+ val dechunked = new DechunkedInputStream (chunked)
764+ Utils .tryWithSafeFinally {
765+ Utils .copyStream(dechunked, out)
766+ } {
767+ JavaUtils .closeQuietly(out)
768+ JavaUtils .closeQuietly(dechunked)
769+ }
770+ }
771+ }
772+
773+ /**
774+ * Creates a server in the jvm to communicate with python for handling one batch of data, with
775+ * authentication and error handling.
776+ */
777+ private [spark] abstract class PythonServer [T ](
778+ authHelper : SocketAuthHelper ,
779+ threadName : String ) {
780+
781+ def this (env : SparkEnv , threadName : String ) = this (new SocketAuthHelper (env.conf), threadName)
782+ def this (threadName : String ) = this (SparkEnv .get, threadName)
783+
784+ val (port, secret) = PythonServer .setupOneConnectionServer(authHelper, threadName) { sock =>
785+ promise.complete(Try (handleConnection(sock)))
786+ }
787+
788+ /**
789+ * Handle a connection which has already been authenticated. Any error from this function
790+ * will clean up this connection and the entire server, and get propogated to [[getResult ]].
791+ */
792+ def handleConnection (sock : Socket ): T
793+
794+ val promise = Promise [T ]()
795+
796+ /**
797+ * Blocks indefinitely for [[handleConnection ]] to finish, and returns that result. If
798+ * handleConnection throws an exception, this will throw an exception which includes the original
799+ * exception as a cause.
800+ */
801+ def getResult (): T = {
802+ getResult(Duration .Inf )
803+ }
804+
805+ def getResult (wait : Duration ): T = {
806+ ThreadUtils .awaitResult(promise.future, wait)
807+ }
808+
809+ }
810+
811+ private [spark] object PythonServer {
812+
813+ /**
814+ * Create a socket server and run user function on the socket in a background thread.
815+ *
816+ * The socket server can only accept one connection, or close if no connection
817+ * in 15 seconds.
818+ *
819+ * The thread will terminate after the supplied user function, or if there are any exceptions.
820+ *
821+ * If you need to get a result of the supplied function, create a subclass of [[PythonServer ]]
822+ *
823+ * @return The port number of a local socket and the secret for authentication.
824+ */
825+ def setupOneConnectionServer (
826+ authHelper : SocketAuthHelper ,
827+ threadName : String )
828+ (func : Socket => Unit ): (Int , String ) = {
829+ val serverSocket = new ServerSocket (0 , 1 , InetAddress .getByAddress(Array (127 , 0 , 0 , 1 )))
830+ // Close the socket if no connection in 15 seconds
831+ serverSocket.setSoTimeout(15000 )
832+
833+ new Thread (threadName) {
834+ setDaemon(true )
835+ override def run (): Unit = {
836+ var sock : Socket = null
837+ try {
838+ sock = serverSocket.accept()
839+ authHelper.authClient(sock)
840+ func(sock)
841+ } finally {
842+ JavaUtils .closeQuietly(serverSocket)
843+ JavaUtils .closeQuietly(sock)
844+ }
845+ }
846+ }.start()
847+ (serverSocket.getLocalPort, authHelper.secret)
848+ }
849+ }
850+
851+ /**
852+ * Sends decrypted broadcast data to python worker. See [[PythonRunner ]] for entire protocol.
853+ */
854+ private [spark] class EncryptedPythonBroadcastServer (
855+ val env : SparkEnv ,
856+ val idsAndFiles : Seq [(Long , String )])
857+ extends PythonServer [Unit ](" broadcast-decrypt-server" ) with Logging {
858+
859+ override def handleConnection (socket : Socket ): Unit = {
860+ val out = new DataOutputStream (new BufferedOutputStream (socket.getOutputStream()))
861+ var socketIn : InputStream = null
862+ // send the broadcast id, then the decrypted data. We don't need to send the length, the
863+ // the python pickle module just needs a stream.
864+ Utils .tryWithSafeFinally {
865+ (idsAndFiles).foreach { case (id, path) =>
866+ out.writeLong(id)
867+ val in = env.serializerManager.wrapForEncryption(new FileInputStream (path))
868+ Utils .tryWithSafeFinally {
869+ Utils .copyStream(in, out, false )
870+ } {
871+ in.close()
872+ }
873+ }
874+ logTrace(" waiting for python to accept broadcast data over socket" )
875+ out.flush()
876+ socketIn = socket.getInputStream()
877+ socketIn.read()
878+ logTrace(" done serving broadcast data" )
879+ } {
880+ JavaUtils .closeQuietly(socketIn)
881+ JavaUtils .closeQuietly(out)
882+ }
883+ }
884+
885+ def waitTillBroadcastDataSent (): Unit = {
886+ getResult()
887+ }
888+ }
889+
890+ /**
891+ * Helper for making RDD[Array[Byte]] from some python data, by reading the data from python
892+ * over a socket. This is used in preference to writing data to a file when encryption is enabled.
893+ */
894+ private [spark] abstract class PythonRDDServer
895+ extends PythonServer [JavaRDD [Array [Byte ]]](" pyspark-parallelize-server" ) {
896+
897+ def handleConnection (sock : Socket ): JavaRDD [Array [Byte ]] = {
898+ val in = sock.getInputStream()
899+ val dechunkedInput : InputStream = new DechunkedInputStream (in)
900+ streamToRDD(dechunkedInput)
901+ }
902+
903+ protected def streamToRDD (input : InputStream ): RDD [Array [Byte ]]
904+
905+ }
906+
907+ private [spark] class PythonParallelizeServer (sc : SparkContext , parallelism : Int )
908+ extends PythonRDDServer {
909+
910+ override protected def streamToRDD (input : InputStream ): RDD [Array [Byte ]] = {
911+ PythonRDD .readRDDFromInputStream(sc, input, parallelism)
912+ }
913+ }
0 commit comments