1717
1818package org .apache .spark .api .r
1919
20- import java .io .{DataOutputStream , File , FileOutputStream , IOException }
21- import java .net .{InetAddress , InetSocketAddress , ServerSocket }
20+ import java .io .{DataInputStream , DataOutputStream , File , FileOutputStream , IOException }
21+ import java .net .{InetAddress , InetSocketAddress , ServerSocket , Socket }
2222import java .util .concurrent .TimeUnit
2323
2424import io .netty .bootstrap .ServerBootstrap
@@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler
3232
3333import org .apache .spark .SparkConf
3434import org .apache .spark .internal .Logging
35+ import org .apache .spark .network .util .JavaUtils
36+ import org .apache .spark .util .Utils
3537
3638/**
3739 * Netty-based backend server that is used to communicate between R and Java.
@@ -45,14 +47,15 @@ private[spark] class RBackend {
4547 /** Tracks JVM objects returned to R for this RBackend instance. */
4648 private [r] val jvmObjectTracker = new JVMObjectTracker
4749
48- def init (): Int = {
50+ def init (): ( Int , RAuthHelper ) = {
4951 val conf = new SparkConf ()
5052 val backendConnectionTimeout = conf.getInt(
5153 " spark.r.backendConnectionTimeout" , SparkRDefaults .DEFAULT_CONNECTION_TIMEOUT )
5254 bossGroup = new NioEventLoopGroup (
5355 conf.getInt(" spark.r.numRBackendThreads" , SparkRDefaults .DEFAULT_NUM_RBACKEND_THREADS ))
5456 val workerGroup = bossGroup
5557 val handler = new RBackendHandler (this )
58+ val authHelper = new RAuthHelper (conf)
5659
5760 bootstrap = new ServerBootstrap ()
5861 .group(bossGroup, workerGroup)
@@ -71,13 +74,16 @@ private[spark] class RBackend {
7174 new LengthFieldBasedFrameDecoder (Integer .MAX_VALUE , 0 , 4 , 0 , 4 ))
7275 .addLast(" decoder" , new ByteArrayDecoder ())
7376 .addLast(" readTimeoutHandler" , new ReadTimeoutHandler (backendConnectionTimeout))
77+ .addLast(new RBackendAuthHandler (authHelper.secret))
7478 .addLast(" handler" , handler)
7579 }
7680 })
7781
7882 channelFuture = bootstrap.bind(new InetSocketAddress (" localhost" , 0 ))
7983 channelFuture.syncUninterruptibly()
80- channelFuture.channel().localAddress().asInstanceOf [InetSocketAddress ].getPort()
84+
85+ val port = channelFuture.channel().localAddress().asInstanceOf [InetSocketAddress ].getPort()
86+ (port, authHelper)
8187 }
8288
8389 def run (): Unit = {
@@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging {
116122 val sparkRBackend = new RBackend ()
117123 try {
118124 // bind to random port
119- val boundPort = sparkRBackend.init()
125+ val ( boundPort, authHelper) = sparkRBackend.init()
120126 val serverSocket = new ServerSocket (0 , 1 , InetAddress .getByName(" localhost" ))
121127 val listenPort = serverSocket.getLocalPort()
122128 // Connection timeout is set by socket client. To make it configurable we will pass the
@@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging {
133139 dos.writeInt(listenPort)
134140 SerDe .writeString(dos, RUtils .rPackages.getOrElse(" " ))
135141 dos.writeInt(backendConnectionTimeout)
142+ SerDe .writeString(dos, authHelper.secret)
136143 dos.close()
137144 f.renameTo(new File (path))
138145
@@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging {
144151 val buf = new Array [Byte ](1024 )
145152 // shutdown JVM if R does not connect back in 10 seconds
146153 serverSocket.setSoTimeout(10000 )
154+
155+ // Wait for the R process to connect back, ignoring any failed auth attempts. Allow
156+ // a max number of connection attempts to avoid looping forever.
147157 try {
148- val inSocket = serverSocket.accept()
158+ var remainingAttempts = 10
159+ var inSocket : Socket = null
160+ while (inSocket == null ) {
161+ inSocket = serverSocket.accept()
162+ try {
163+ authHelper.authClient(inSocket)
164+ } catch {
165+ case e : Exception =>
166+ remainingAttempts -= 1
167+ if (remainingAttempts == 0 ) {
168+ val msg = " Too many failed authentication attempts."
169+ logError(msg)
170+ throw new IllegalStateException (msg)
171+ }
172+ logInfo(" Client connection failed authentication." )
173+ inSocket = null
174+ }
175+ }
176+
149177 serverSocket.close()
178+
150179 // wait for the end of socket, closed if R process die
151180 inSocket.getInputStream().read(buf)
152181 } finally {
182+ serverSocket.close()
153183 sparkRBackend.close()
154184 System .exit(0 )
155185 }
@@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging {
165195 }
166196 System .exit(0 )
167197 }
198+
168199}
0 commit comments