11package edu .berkeley .cs .amplab .sparkr
22
33import java .io ._
4+ import java .net .{ServerSocket }
45import java .util .{Map => JMap }
56
67import scala .collection .JavaConversions ._
78import scala .io .Source
89import scala .reflect .ClassTag
10+ import scala .util .Try
911
1012import org .apache .spark .{SparkEnv , Partition , SparkException , TaskContext , SparkConf }
1113import org .apache .spark .api .java .{JavaSparkContext , JavaRDD , JavaPairRDD }
1214import org .apache .spark .broadcast .Broadcast
1315import org .apache .spark .rdd .RDD
1416
17+
1518private abstract class BaseRRDD [T : ClassTag , U : ClassTag ](
1619 parent : RDD [T ],
1720 numPartitions : Int ,
@@ -27,21 +30,35 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
2730
2831 override def compute (split : Partition , context : TaskContext ): Iterator [U ] = {
2932
33+ // The parent may be also an RRDD, so we should launch it first.
3034 val parentIterator = firstParent[T ].iterator(split, context)
3135
32- val pb = rWorkerProcessBuilder()
36+ // we expect two connections
37+ val serverSocket = new ServerSocket (0 , 2 )
38+ val listenPort = serverSocket.getLocalPort()
39+
40+ val pb = rWorkerProcessBuilder(listenPort)
41+ pb.redirectErrorStream() // redirect stderr into stdout
3342 val proc = pb.start()
43+ val errThread = startStdoutThread(proc)
44+
45+ // We use two sockets to separate input and output, then it's easy to manage
46+ // the lifecycle of them to avoid deadlock.
47+ // TODO: optimize it to use one socket
3448
35- val errThread = startStderrThread(proc)
49+ // the socket used to send out the input of task
50+ serverSocket.setSoTimeout(10000 )
51+ val inSocket = serverSocket.accept()
52+ startStdinThread(inSocket.getOutputStream(), parentIterator, split.index)
3653
37- val tempFile = startStdinThread(proc, parentIterator, split.index)
54+ // the socket used to receive the output of task
55+ val outSocket = serverSocket.accept()
56+ val inputStream = new BufferedInputStream (outSocket.getInputStream)
57+ val dataStream = openDataStream(inputStream)
3858
39- // Return an iterator that read lines from the process's stdout
40- val inputStream = new BufferedReader (new InputStreamReader (proc.getInputStream))
59+ serverSocket.close()
4160
4261 try {
43- val stdOutFileName = inputStream.readLine().trim()
44- val dataStream = openDataStream(stdOutFileName)
4562
4663 return new Iterator [U ] {
4764 def next (): U = {
@@ -57,9 +74,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
5774 def hasNext (): Boolean = {
5875 val hasMore = (_nextObj != null )
5976 if (! hasMore) {
60- // Delete the temporary file we created as we are done reading it
6177 dataStream.close()
62- tempFile.delete()
6378 }
6479 hasMore
6580 }
@@ -73,7 +88,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
7388 /**
7489 * ProcessBuilder used to launch worker R processes.
7590 */
76- private def rWorkerProcessBuilder () = {
91+ private def rWorkerProcessBuilder (port : Int ) = {
7792 val rCommand = " Rscript"
7893 val rOptions = " --vanilla"
7994 val rExecScript = rLibDir + " /SparkR/worker/worker.R"
@@ -82,47 +97,42 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
8297 // This is set by R CMD check as startup.Rs
8398 // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
8499 // and confuses worker script which tries to load a non-existent file
85- pb.environment().put(" R_TESTS" , " " );
100+ pb.environment().put(" R_TESTS" , " " )
101+ pb.environment().put(" SPARKR_WORKER_PORT" , port.toString)
86102 pb
87103 }
88104
89105 /**
90106 * Start a thread to print the process's stderr to ours
91107 */
92- private def startStderrThread (proc : Process ): BufferedStreamThread = {
93- val ERR_BUFFER_SIZE = 100
94- val errThread = new BufferedStreamThread (proc.getErrorStream , " stderr reader for R" ,
95- ERR_BUFFER_SIZE )
96- errThread .start()
97- errThread
108+ private def startStdoutThread (proc : Process ): BufferedStreamThread = {
109+ val BUFFER_SIZE = 100
110+ val thread = new BufferedStreamThread (proc.getInputStream , " stdout reader for R" , BUFFER_SIZE )
111+ thread.setDaemon( true )
112+ thread .start()
113+ thread
98114 }
99115
100116 /**
101117 * Start a thread to write RDD data to the R process.
102118 */
103119 private def startStdinThread [T ](
104- proc : Process ,
120+ output : OutputStream ,
105121 iter : Iterator [T ],
106- splitIndex : Int ) : File = {
122+ splitIndex : Int ) = {
107123
108124 val env = SparkEnv .get
109- val conf = env.conf
110- val tempDir = RRDD .getLocalDir(conf)
111- val tempFile = File .createTempFile(" rSpark" , " out" , new File (tempDir))
112- val tempFileIn = File .createTempFile(" rSpark" , " in" , new File (tempDir))
113-
114- val tempFileName = tempFile.getAbsolutePath()
115125 val bufferSize = System .getProperty(" spark.buffer.size" , " 65536" ).toInt
126+ val stream = new BufferedOutputStream (output, bufferSize)
116127
117- // Start a thread to feed the process input from our parent's iterator
118- new Thread (" stdin writer for R" ) {
128+ new Thread (" writer for R" ) {
119129 override def run () {
120130 try {
121131 SparkEnv .set(env)
122- val stream = new BufferedOutputStream (new FileOutputStream (tempFileIn), bufferSize)
123132 val printOut = new PrintStream (stream)
124- val dataOut = new DataOutputStream (stream )
133+ printOut.println(rLibDir )
125134
135+ val dataOut = new DataOutputStream (stream)
126136 dataOut.writeInt(splitIndex)
127137
128138 dataOut.writeInt(func.length)
@@ -166,35 +176,21 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
166176 printOut.println(elem)
167177 }
168178 }
169-
170- printOut.flush()
171- dataOut.flush()
172179 stream.flush()
173- stream.close()
174-
175- // NOTE: We need to write out the temp file before writing out the
176- // file name to stdin. Otherwise the R process could read partial state
177- val streamStd = new BufferedOutputStream (proc.getOutputStream, bufferSize)
178- val printOutStd = new PrintStream (streamStd)
179- printOutStd.println(tempFileName)
180- printOutStd.println(rLibDir)
181- printOutStd.println(tempFileIn.getAbsolutePath())
182- printOutStd.flush()
183-
184- streamStd.close()
185180 } catch {
186181 // TODO: We should propogate this error to the task thread
187182 case e : Exception =>
188183 System .err.println(" R Writer thread got an exception " + e)
189184 e.printStackTrace()
185+ } finally {
186+ Try (output.close())
190187 }
191188 }
192189 }.start()
193-
194- tempFile
195190 }
196191
197- protected def openDataStream (stdOutFileName : String ): Closeable
192+ protected def openDataStream (input : InputStream ): Closeable
193+
198194 protected def read (): U
199195}
200196
@@ -217,8 +213,8 @@ private class PairwiseRRDD[T: ClassTag](
217213
218214 private var dataStream : DataInputStream = _
219215
220- override protected def openDataStream (stdOutFileName : String ) = {
221- dataStream = new DataInputStream (new FileInputStream (stdOutFileName) )
216+ override protected def openDataStream (input : InputStream ) = {
217+ dataStream = new DataInputStream (input )
222218 dataStream
223219 }
224220
@@ -261,9 +257,9 @@ private class RRDD[T: ClassTag](
261257 broadcastVars.map(x => x.asInstanceOf [Broadcast [Object ]])) {
262258
263259 private var dataStream : DataInputStream = _
264-
265- override protected def openDataStream (stdOutFileName : String ) = {
266- dataStream = new DataInputStream (new FileInputStream (stdOutFileName) )
260+
261+ override protected def openDataStream (input : InputStream ) = {
262+ dataStream = new DataInputStream (input )
267263 dataStream
268264 }
269265
@@ -305,9 +301,8 @@ private class StringRRDD[T: ClassTag](
305301
306302 private var dataStream : BufferedReader = _
307303
308- override protected def openDataStream (stdOutFileName : String ) = {
309- dataStream = new BufferedReader (
310- new InputStreamReader (new FileInputStream (stdOutFileName)))
304+ override protected def openDataStream (input : InputStream ) = {
305+ dataStream = new BufferedReader (new InputStreamReader (input))
311306 dataStream
312307 }
313308
@@ -334,6 +329,7 @@ private class BufferedStreamThread(
334329 for (line <- Source .fromInputStream(in).getLines) {
335330 lines(lineIdx) = line
336331 lineIdx = (lineIdx + 1 ) % errBufferSize
332+ // TODO: user logger
337333 System .err.println(line)
338334 }
339335 }
0 commit comments