Skip to content
91 changes: 51 additions & 40 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer}
import org.apache.spark.util._


Expand Down Expand Up @@ -137,8 +137,9 @@ private[spark] object PythonRDD extends Logging {
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a server object
* that can be used to sync the JVM serving thread in Python.
*/
def runJob(
sc: SparkContext,
Expand All @@ -156,8 +157,9 @@ private[spark] object PythonRDD extends Logging {
/**
* A helper function to collect an RDD as an iterator, then serve it via socket.
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a server object
* that can be used to sync the JVM serving thread in Python.
*/
def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
Expand All @@ -171,55 +173,59 @@ private[spark] object PythonRDD extends Logging {
* and -1 meaining an error occurred during collection. This function is used by
* pyspark.rdd._local_iterator_from_socket().
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from these jobs, and the secret for authentication.
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a server object
* that can be used to sync the JVM serving thread in Python.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
val (port, secret) = SocketAuthServer.setupOneConnectionServer(
authHelper, "serve toLocalIterator") { s =>
val out = new DataOutputStream(s.getOutputStream)
val in = new DataInputStream(s.getInputStream)
Utils.tryWithSafeFinally {

val handleFunc = (sock: Socket) => {
val out = new DataOutputStream(sock.getOutputStream)
val in = new DataInputStream(sock.getInputStream)
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
}

// Read request for data and send next partition if nonzero
// Write data until iteration is complete, client stops iteration, or error occurs
var complete = false
while (!complete && in.readInt() != 0) {
if (collectPartitionIter.hasNext) {
try {
// Attempt to collect the next partition
val partitionArray = collectPartitionIter.next()

// Send response there is a partition to read
out.writeInt(1)

// Write the next object and signal end of data for this iteration
writeIteratorToStream(partitionArray.toIterator, out)
out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
out.flush()
} catch {
case e: SparkException =>
// Send response that an error occurred followed by error message
out.writeInt(-1)
writeUTF(e.getMessage, out)
complete = true
}
while (!complete) {

// Read request for data
if (in.readInt() == 0) {

// Client requested to stop iteration
complete = true
} else if (collectPartitionIter.hasNext) {

// Client requested more data, attempt to collect the next partition
val partitionArray = collectPartitionIter.next()

// Send response there is a partition to read
out.writeInt(1)

// Write the next object and signal end of data for this iteration
writeIteratorToStream(partitionArray.toIterator, out)
out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
out.flush()
} else {

// Send response there are no more partitions to read and close
out.writeInt(0)
complete = true
}
}
} {
})(catchBlock = {
// Send response that an error occurred, original exception is re-thrown
out.writeInt(-1)
}, finallyBlock = {
out.close()
in.close()
}
})
}
Array(port, secret)

val server = new SocketFuncServer(authHelper, "serve toLocalIterator", handleFunc)
Array(server.port, server.secret, server)
}

def readRDDFromFile(
Expand Down Expand Up @@ -443,8 +449,9 @@ private[spark] object PythonRDD extends Logging {
*
* The thread will terminate after all the data are sent or any exceptions happen.
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, and the secret for authentication.
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a server object
* that can be used to sync the JVM serving thread in Python.
*/
def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
serveToStream(threadName) { out =>
Expand All @@ -464,10 +471,14 @@ private[spark] object PythonRDD extends Logging {
*
* The thread will terminate after the block of code is executed or any
* exceptions happen.
*
* @return 3-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from this job, the secret for authentication, and a server object
* that can be used to sync the JVM serving thread in Python.
*/
private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
SocketAuthHelper.serveToStream(threadName, authHelper)(writeFunc)
SocketAuthServer.serveToStream(threadName, authHelper)(writeFunc)
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/api/r/RRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
import org.apache.spark.security.SocketAuthServer

private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
parent: RDD[T],
Expand Down Expand Up @@ -166,7 +166,7 @@ private[spark] object RRDD {

private[spark] def serveToStream(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
SocketAuthHelper.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc)
SocketAuthServer.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.security

import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, OutputStream}
import java.io.{DataInputStream, DataOutputStream}
import java.net.Socket
import java.nio.charset.StandardCharsets.UTF_8

Expand Down Expand Up @@ -113,21 +113,4 @@ private[spark] class SocketAuthHelper(conf: SparkConf) {
dout.write(bytes, 0, bytes.length)
dout.flush()
}

}

private[spark] object SocketAuthHelper {
def serveToStream(
threadName: String,
authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = {
val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { s =>
val out = new BufferedOutputStream(s.getOutputStream())
Utils.tryWithSafeFinally {
writeFunc(out)
} {
out.close()
}
}
Array(port, secret)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.security

import java.io.{BufferedOutputStream, OutputStream}
import java.net.{InetAddress, ServerSocket, Socket}

import scala.concurrent.Promise
Expand All @@ -25,7 +26,7 @@ import scala.util.Try

import org.apache.spark.SparkEnv
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.{ThreadUtils, Utils}


/**
Expand All @@ -41,10 +42,30 @@ private[spark] abstract class SocketAuthServer[T](

private val promise = Promise[T]()

val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { sock =>
promise.complete(Try(handleConnection(sock)))
private def startServer(): (Int, String) = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)

new Thread(threadName) {
setDaemon(true)
override def run(): Unit = {
var sock: Socket = null
try {
sock = serverSocket.accept()
authHelper.authClient(sock)
promise.complete(Try(handleConnection(sock)))
} finally {
JavaUtils.closeQuietly(serverSocket)
JavaUtils.closeQuietly(sock)
}
}
}.start()
(serverSocket.getLocalPort, authHelper.secret)
}

val (port, secret) = startServer()

/**
* Handle a connection which has already been authenticated. Any error from this function
* will clean up this connection and the entire server, and get propagated to [[getResult]].
Expand All @@ -66,42 +87,45 @@ private[spark] abstract class SocketAuthServer[T](

}

/**
* Create a socket server class and run user function on the socket in a background thread.
* This is the same as calling SocketAuthServer.setupOneConnectionServer except it creates
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we don't have setupOneConnectionServer anymore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, good catch

* a server object that can then be synced from Python.
*/
private[spark] class SocketFuncServer(
authHelper: SocketAuthHelper,
threadName: String,
func: Socket => Unit) extends SocketAuthServer[Unit](authHelper, threadName) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need SockAuthServer.setupOneConnectionServer if we have this also, so it could be cleaned up

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed SockAuthServer.setupOneConnectionServer and replaced usage with SocketFuncServer


override def handleConnection(sock: Socket): Unit = {
func(sock)
}
}

private[spark] object SocketAuthServer {

/**
* Create a socket server and run user function on the socket in a background thread.
*
* The socket server can only accept one connection, or close if no connection
* in 15 seconds.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please save this comment -- I guess move it to the class SocketAuthServer. In particular ,its helpful to note that this only accepts one connection, its not a long-lived thing which is reused.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that is a useful comment, I didn't intend to take this out. I'll put it back in.

* Convenience function to create a socket server and run a user function in a background
* thread to write to an output stream.
*
* The thread will terminate after the supplied user function, or if there are any exceptions.
*
* If you need to get a result of the supplied function, create a subclass of [[SocketAuthServer]]
*
* @return The port number of a local socket and the secret for authentication.
* @param threadName Name for the background serving thread.
* @param authHelper SocketAuthHelper for authentication
* @param writeFunc User function to write to a given OutputStream
* @return
*/
def setupOneConnectionServer(
authHelper: SocketAuthHelper,
threadName: String)
(func: Socket => Unit): (Int, String) = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)

new Thread(threadName) {
setDaemon(true)
override def run(): Unit = {
var sock: Socket = null
try {
sock = serverSocket.accept()
authHelper.authClient(sock)
func(sock)
} finally {
JavaUtils.closeQuietly(serverSocket)
JavaUtils.closeQuietly(sock)
}
def serveToStream(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this from SocketAuthHelper because it seemed more fitting to be here

threadName: String,
authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = {
val handleFunc = (sock: Socket) => {
val out = new BufferedOutputStream(sock.getOutputStream())
Utils.tryWithSafeFinally {
writeFunc(out)
} {
out.close()
}
}.start()
(serverSocket.getLocalPort, authHelper.secret)
}

val server = new SocketFuncServer(authHelper, threadName, handleFunc)
Array(server.port, server.secret, server)
}
}
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,9 @@ private[spark] object Utils extends Logging {
originalThrowable = cause
try {
logError("Aborting task", originalThrowable)
TaskContext.get().markTaskFailed(originalThrowable)
if (TaskContext.get() != null) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using this utility here https://github.com/apache/spark/pull/24834/files#diff-0a67bc4d171abe4df8eb305b0f4123a2R184, where the task fails and completes before hitting the catchBlock, so TaskContext.get() returns a null

TaskContext.get().markTaskFailed(originalThrowable)
}
catchBlock
} catch {
case t: Throwable =>
Expand Down
26 changes: 20 additions & 6 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,29 @@ def _parse_memory(s):


def _create_local_socket(sock_info):
(sockfile, sock) = local_connect_and_auth(*sock_info)
"""
Create a local socket that can be used to load deserialized data from the JVM

:param sock_info: Tuple containing port number and authentication secret for a local socket.
:return: sockfile file descriptor of the local socket
"""
port = sock_info[0]
auth_secret = sock_info[1]
sockfile, sock = local_connect_and_auth(port, auth_secret)
# The RDD materialization time is unpredictable, if we set a timeout for socket reading
# operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None)
return sockfile


def _load_from_socket(sock_info, serializer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BryanCutler, what does sock_info expect to be? Seems it can be both 2-tuple and 3-tuple (with server).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uggh, yeah I'm not too happy with this. Java returns a 3-tuple with (port, auth_secret, server) and most places only use the first 2, such as _load_from_socket. It gets a little confusing, so I thought it might be better to expand the values returned by java for serveToStream etc., but it ended up with a lot of changes where the third value is ignored like this

port, auth_secret, _ = ...

and I don't think it really made things clearer. I'll try to think of something better and maybe do a followup.

"""
Connect to a local socket described by sock_info and use the given serializer to yield data

:param sock_info: Tuple containing port number and authentication secret for a local socket.
:param serializer: The PySpark serializer to use
:return: result of Serializer.load_stream, usually a generator that yields deserialized data
"""
sockfile = _create_local_socket(sock_info)
# The socket will be automatically closed when garbage-collected.
return serializer.load_stream(sockfile)
Expand All @@ -159,7 +174,8 @@ class PyLocalIterable(object):
""" Create a synchronous local iterable over a socket """

def __init__(self, _sock_info, _serializer):
self._sockfile = _create_local_socket(_sock_info)
port, auth_secret, self.jserver_obj = _sock_info
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a general comment, might not make the most sense to address in this particular PR -- I'd find it really helpful if the python code which is dealing w/ java objects would annotate (somehow) the java types. Its hard for me to figure out if jserver_obj is a ServerSocket or a SocketAuthServer or Py4JJavaServer etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I can rename it to something more fitting and I agree it should be clear what the variable is by the name

self._sockfile = _create_local_socket((port, auth_secret))
self._serializer = _serializer
self._read_iter = iter([]) # Initialize as empty iterator
self._read_status = 1
Expand All @@ -179,11 +195,9 @@ def __iter__(self):
for item in self._read_iter:
yield item

# An error occurred, read error message and raise it
# An error occurred, join serving thread and raise any exceptions from the JVM
elif self._read_status == -1:
error_msg = UTF8Deserializer().loads(self._sockfile)
raise RuntimeError("An error occurred while reading the next element from "
"toLocalIterator: {}".format(error_msg))
self.jserver_obj.getResult()

def __del__(self):
# If local iterator is not fully consumed,
Expand Down
Loading