diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fe25c3aac81b..5b80e149b38a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -38,7 +38,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer} +import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer, SocketFuncServer} import org.apache.spark.util._ @@ -137,8 +137,9 @@ private[spark] object PythonRDD extends Logging { * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. * - * @return 2-tuple (as a Java array) with the port number of a local socket which serves the - * data collected from this job, and the secret for authentication. + * @return 3-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, the secret for authentication, and a socket auth + * server object that can be used to join the JVM serving thread in Python. */ def runJob( sc: SparkContext, @@ -156,8 +157,9 @@ private[spark] object PythonRDD extends Logging { /** * A helper function to collect an RDD as an iterator, then serve it via socket. * - * @return 2-tuple (as a Java array) with the port number of a local socket which serves the - * data collected from this job, and the secret for authentication. + * @return 3-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, the secret for authentication, and a socket auth + * server object that can be used to join the JVM serving thread in Python. */ def collectAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") @@ -168,58 +170,59 @@ private[spark] object PythonRDD extends Logging { * are collected as separate jobs, by order of index. Partition data is first requested by a * non-zero integer to start a collection job. The response is prefaced by an integer with 1 * meaning partition data will be served, 0 meaning the local iterator has been consumed, - * and -1 meaining an error occurred during collection. This function is used by + * and -1 meaning an error occurred during collection. This function is used by * pyspark.rdd._local_iterator_from_socket(). * - * @return 2-tuple (as a Java array) with the port number of a local socket which serves the - * data collected from these jobs, and the secret for authentication. + * @return 3-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, the secret for authentication, and a socket auth + * server object that can be used to join the JVM serving thread in Python. */ def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { - val (port, secret) = SocketAuthServer.setupOneConnectionServer( - authHelper, "serve toLocalIterator") { s => - val out = new DataOutputStream(s.getOutputStream) - val in = new DataInputStream(s.getInputStream) - Utils.tryWithSafeFinally { - + val handleFunc = (sock: Socket) => { + val out = new DataOutputStream(sock.getOutputStream) + val in = new DataInputStream(sock.getInputStream) + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Collects a partition on each iteration val collectPartitionIter = rdd.partitions.indices.iterator.map { i => rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head } - // Read request for data and send next partition if nonzero + // Write data until iteration is complete, client stops iteration, or error occurs var complete = false - while (!complete && in.readInt() != 0) { - if (collectPartitionIter.hasNext) { - try { - // Attempt to collect the next partition - val partitionArray = collectPartitionIter.next() - - // Send response there is a partition to read - out.writeInt(1) - - // Write the next object and signal end of data for this iteration - writeIteratorToStream(partitionArray.toIterator, out) - out.writeInt(SpecialLengths.END_OF_DATA_SECTION) - out.flush() - } catch { - case e: SparkException => - // Send response that an error occurred followed by error message - out.writeInt(-1) - writeUTF(e.getMessage, out) - complete = true - } + while (!complete) { + + // Read request for data, value of zero will stop iteration or non-zero to continue + if (in.readInt() == 0) { + complete = true + } else if (collectPartitionIter.hasNext) { + + // Client requested more data, attempt to collect the next partition + val partitionArray = collectPartitionIter.next() + + // Send response there is a partition to read + out.writeInt(1) + + // Write the next object and signal end of data for this iteration + writeIteratorToStream(partitionArray.toIterator, out) + out.writeInt(SpecialLengths.END_OF_DATA_SECTION) + out.flush() } else { // Send response there are no more partitions to read and close out.writeInt(0) complete = true } } - } { + })(catchBlock = { + // Send response that an error occurred, original exception is re-thrown + out.writeInt(-1) + }, finallyBlock = { out.close() in.close() - } + }) } - Array(port, secret) + + val server = new SocketFuncServer(authHelper, "serve toLocalIterator", handleFunc) + Array(server.port, server.secret, server) } def readRDDFromFile( @@ -443,8 +446,9 @@ private[spark] object PythonRDD extends Logging { * * The thread will terminate after all the data are sent or any exceptions happen. * - * @return 2-tuple (as a Java array) with the port number of a local socket which serves the - * data collected from this job, and the secret for authentication. + * @return 3-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, the secret for authentication, and a socket auth + * server object that can be used to join the JVM serving thread in Python. */ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { serveToStream(threadName) { out => @@ -464,10 +468,14 @@ private[spark] object PythonRDD extends Logging { * * The thread will terminate after the block of code is executed or any * exceptions happen. + * + * @return 3-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, the secret for authentication, and a socket auth + * server object that can be used to join the JVM serving thread in Python. */ private[spark] def serveToStream( threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { - SocketAuthHelper.serveToStream(threadName, authHelper)(writeFunc) + SocketAuthServer.serveToStream(threadName, authHelper)(writeFunc) } private def getMergedConf(confAsMap: java.util.HashMap[String, String], diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 07f84057abd5..892e69bfce5c 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -29,7 +29,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer} +import org.apache.spark.security.SocketAuthServer private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( parent: RDD[T], @@ -166,7 +166,7 @@ private[spark] object RRDD { private[spark] def serveToStream( threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = { - SocketAuthHelper.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc) + SocketAuthServer.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc) } } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala index 3a107c076492..dbcb37690533 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -17,7 +17,7 @@ package org.apache.spark.security -import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, OutputStream} +import java.io.{DataInputStream, DataOutputStream} import java.net.Socket import java.nio.charset.StandardCharsets.UTF_8 @@ -113,21 +113,4 @@ private[spark] class SocketAuthHelper(conf: SparkConf) { dout.write(bytes, 0, bytes.length) dout.flush() } - -} - -private[spark] object SocketAuthHelper { - def serveToStream( - threadName: String, - authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = { - val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { s => - val out = new BufferedOutputStream(s.getOutputStream()) - Utils.tryWithSafeFinally { - writeFunc(out) - } { - out.close() - } - } - Array(port, secret) - } } diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index e616d239ce8d..548fd1b07ddc 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -17,6 +17,7 @@ package org.apache.spark.security +import java.io.{BufferedOutputStream, OutputStream} import java.net.{InetAddress, ServerSocket, Socket} import scala.concurrent.Promise @@ -25,12 +26,15 @@ import scala.util.Try import org.apache.spark.SparkEnv import org.apache.spark.network.util.JavaUtils -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Creates a server in the JVM to communicate with external processes (e.g., Python and R) for * handling one batch of data, with authentication and error handling. + * + * The socket server can only accept one connection, or close if no connection + * in 15 seconds. */ private[spark] abstract class SocketAuthServer[T]( authHelper: SocketAuthHelper, @@ -41,10 +45,30 @@ private[spark] abstract class SocketAuthServer[T]( private val promise = Promise[T]() - val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { sock => - promise.complete(Try(handleConnection(sock))) + private def startServer(): (Int, String) = { + val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) + // Close the socket if no connection in 15 seconds + serverSocket.setSoTimeout(15000) + + new Thread(threadName) { + setDaemon(true) + override def run(): Unit = { + var sock: Socket = null + try { + sock = serverSocket.accept() + authHelper.authClient(sock) + promise.complete(Try(handleConnection(sock))) + } finally { + JavaUtils.closeQuietly(serverSocket) + JavaUtils.closeQuietly(sock) + } + } + }.start() + (serverSocket.getLocalPort, authHelper.secret) } + val (port, secret) = startServer() + /** * Handle a connection which has already been authenticated. Any error from this function * will clean up this connection and the entire server, and get propagated to [[getResult]]. @@ -66,42 +90,50 @@ private[spark] abstract class SocketAuthServer[T]( } +/** + * Create a socket server class and run user function on the socket in a background thread + * that can read and write to the socket input/output streams. The function is passed in a + * socket that has been connected and authenticated. + */ +private[spark] class SocketFuncServer( + authHelper: SocketAuthHelper, + threadName: String, + func: Socket => Unit) extends SocketAuthServer[Unit](authHelper, threadName) { + + override def handleConnection(sock: Socket): Unit = { + func(sock) + } +} + private[spark] object SocketAuthServer { /** - * Create a socket server and run user function on the socket in a background thread. + * Convenience function to create a socket server and run a user function in a background + * thread to write to an output stream. * * The socket server can only accept one connection, or close if no connection * in 15 seconds. * - * The thread will terminate after the supplied user function, or if there are any exceptions. - * - * If you need to get a result of the supplied function, create a subclass of [[SocketAuthServer]] - * - * @return The port number of a local socket and the secret for authentication. + * @param threadName Name for the background serving thread. + * @param authHelper SocketAuthHelper for authentication + * @param writeFunc User function to write to a given OutputStream + * @return 3-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, the secret for authentication, and a socket auth + * server object that can be used to join the JVM serving thread in Python. */ - def setupOneConnectionServer( - authHelper: SocketAuthHelper, - threadName: String) - (func: Socket => Unit): (Int, String) = { - val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) - // Close the socket if no connection in 15 seconds - serverSocket.setSoTimeout(15000) - - new Thread(threadName) { - setDaemon(true) - override def run(): Unit = { - var sock: Socket = null - try { - sock = serverSocket.accept() - authHelper.authClient(sock) - func(sock) - } finally { - JavaUtils.closeQuietly(serverSocket) - JavaUtils.closeQuietly(sock) - } + def serveToStream( + threadName: String, + authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = { + val handleFunc = (sock: Socket) => { + val out = new BufferedOutputStream(sock.getOutputStream()) + Utils.tryWithSafeFinally { + writeFunc(out) + } { + out.close() } - }.start() - (serverSocket.getLocalPort, authHelper.secret) + } + + val server = new SocketFuncServer(authHelper, threadName, handleFunc) + Array(server.port, server.secret, server) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 00135c3259c8..80d70a1d4850 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1389,7 +1389,9 @@ private[spark] object Utils extends Logging { originalThrowable = cause try { logError("Aborting task", originalThrowable) - TaskContext.get().markTaskFailed(originalThrowable) + if (TaskContext.get() != null) { + TaskContext.get().markTaskFailed(originalThrowable) + } catchBlock } catch { case t: Throwable => diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 395abc841827..fa4609dc5ba1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -140,7 +140,15 @@ def _parse_memory(s): def _create_local_socket(sock_info): - (sockfile, sock) = local_connect_and_auth(*sock_info) + """ + Create a local socket that can be used to load deserialized data from the JVM + + :param sock_info: Tuple containing port number and authentication secret for a local socket. + :return: sockfile file descriptor of the local socket + """ + port = sock_info[0] + auth_secret = sock_info[1] + sockfile, sock = local_connect_and_auth(port, auth_secret) # The RDD materialization time is unpredictable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) @@ -148,6 +156,13 @@ def _create_local_socket(sock_info): def _load_from_socket(sock_info, serializer): + """ + Connect to a local socket described by sock_info and use the given serializer to yield data + + :param sock_info: Tuple containing port number and authentication secret for a local socket. + :param serializer: The PySpark serializer to use + :return: result of Serializer.load_stream, usually a generator that yields deserialized data + """ sockfile = _create_local_socket(sock_info) # The socket will be automatically closed when garbage-collected. return serializer.load_stream(sockfile) @@ -159,7 +174,8 @@ class PyLocalIterable(object): """ Create a synchronous local iterable over a socket """ def __init__(self, _sock_info, _serializer): - self._sockfile = _create_local_socket(_sock_info) + port, auth_secret, self.jsocket_auth_server = _sock_info + self._sockfile = _create_local_socket((port, auth_secret)) self._serializer = _serializer self._read_iter = iter([]) # Initialize as empty iterator self._read_status = 1 @@ -179,11 +195,9 @@ def __iter__(self): for item in self._read_iter: yield item - # An error occurred, read error message and raise it + # An error occurred, join serving thread and raise any exceptions from the JVM elif self._read_status == -1: - error_msg = UTF8Deserializer().loads(self._sockfile) - raise RuntimeError("An error occurred while reading the next element from " - "toLocalIterator: {}".format(error_msg)) + self.jsocket_auth_server.getResult() def __del__(self): # If local iterator is not fully consumed, diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6ba740ddc7a5..8b0e06d9fcab 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2200,10 +2200,16 @@ def _collectAsArrow(self): .. note:: Experimental. """ with SCCallSiteSync(self._sc) as css: - sock_info = self._jdf.collectAsArrowToPython() + port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython() # Collect list of un-ordered batches where last element is a list of correct order indices - results = list(_load_from_socket(sock_info, ArrowCollectSerializer())) + try: + results = list(_load_from_socket((port, auth_secret), ArrowCollectSerializer())) + finally: + # Join serving thread and raise any exceptions from collectAsArrowToPython + jsocket_auth_server.getResult() + + # Separate RecordBatches from batch order indices in results batches = results[:-1] batch_order = results[-1] diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 067113722adb..870f5a8f910d 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -214,7 +214,7 @@ def raise_exception(): exception_udf = udf(raise_exception, IntegerType()) df = df.withColumn("error", exception_udf()) with QuietTest(self.sc): - with self.assertRaisesRegexp(RuntimeError, 'My error'): + with self.assertRaisesRegexp(Exception, 'My error'): df.toPandas() def _createDataFrame_toggle(self, pdf, schema=None): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a80aadebe353..45ec7dcb07a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.TaskContext import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ @@ -3321,34 +3321,24 @@ class Dataset[T] private[sql]( } } - var sparkException: Option[SparkException] = None - try { + Utils.tryWithSafeFinally { val arrowBatchRdd = toArrowBatchRdd(plan) sparkSession.sparkContext.runJob( arrowBatchRdd, (it: Iterator[Array[Byte]]) => it.toArray, handlePartitionBatches) - } catch { - case e: SparkException => - sparkException = Some(e) - } - - // After processing all partitions, end the batch stream - batchWriter.end() - sparkException match { - case Some(exception) => - // Signal failure and write error message - out.writeInt(-1) - PythonRDD.writeUTF(exception.getMessage, out) - case None => - // Write batch order indices - out.writeInt(batchOrder.length) - // Sort by (index of partition, batch index in that partition) tuple to get the - // overall_batch_index from 0 to N-1 batches, which can be used to put the - // transferred batches in the correct order - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => - out.writeInt(overallBatchIndex) - } + } { + // After processing all partitions, end the batch stream + batchWriter.end() + + // Write batch order indices + out.writeInt(batchOrder.length) + // Sort by (index of partition, batch index in that partition) tuple to get the + // overall_batch_index from 0 to N-1 batches, which can be used to put the + // transferred batches in the correct order + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => + out.writeInt(overallBatchIndex) + } } } }