-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-27992][PYTHON] Allow Python to join with connection thread to propagate errors #24834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
519926f
3a52960
b209e0a
c9f7fe9
2fddb43
785ce4f
ead8978
5fd8684
20eb748
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| package org.apache.spark.security | ||
|
|
||
| import java.io.{BufferedOutputStream, OutputStream} | ||
| import java.net.{InetAddress, ServerSocket, Socket} | ||
|
|
||
| import scala.concurrent.Promise | ||
|
|
@@ -25,7 +26,7 @@ import scala.util.Try | |
|
|
||
| import org.apache.spark.SparkEnv | ||
| import org.apache.spark.network.util.JavaUtils | ||
| import org.apache.spark.util.ThreadUtils | ||
| import org.apache.spark.util.{ThreadUtils, Utils} | ||
|
|
||
|
|
||
| /** | ||
|
|
@@ -41,10 +42,30 @@ private[spark] abstract class SocketAuthServer[T]( | |
|
|
||
| private val promise = Promise[T]() | ||
|
|
||
| val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { sock => | ||
| promise.complete(Try(handleConnection(sock))) | ||
| private def startServer(): (Int, String) = { | ||
| val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) | ||
| // Close the socket if no connection in 15 seconds | ||
| serverSocket.setSoTimeout(15000) | ||
|
|
||
| new Thread(threadName) { | ||
| setDaemon(true) | ||
| override def run(): Unit = { | ||
| var sock: Socket = null | ||
| try { | ||
| sock = serverSocket.accept() | ||
| authHelper.authClient(sock) | ||
| promise.complete(Try(handleConnection(sock))) | ||
| } finally { | ||
| JavaUtils.closeQuietly(serverSocket) | ||
| JavaUtils.closeQuietly(sock) | ||
| } | ||
| } | ||
| }.start() | ||
| (serverSocket.getLocalPort, authHelper.secret) | ||
| } | ||
|
|
||
| val (port, secret) = startServer() | ||
|
|
||
| /** | ||
| * Handle a connection which has already been authenticated. Any error from this function | ||
| * will clean up this connection and the entire server, and get propagated to [[getResult]]. | ||
|
|
@@ -66,42 +87,45 @@ private[spark] abstract class SocketAuthServer[T]( | |
|
|
||
| } | ||
|
|
||
| /** | ||
| * Create a socket server class and run user function on the socket in a background thread. | ||
| * This is the same as calling SocketAuthServer.setupOneConnectionServer except it creates | ||
| * a server object that can then be synced from Python. | ||
| */ | ||
| private[spark] class SocketFuncServer( | ||
| authHelper: SocketAuthHelper, | ||
| threadName: String, | ||
| func: Socket => Unit) extends SocketAuthServer[Unit](authHelper, threadName) { | ||
|
||
|
|
||
| override def handleConnection(sock: Socket): Unit = { | ||
| func(sock) | ||
| } | ||
| } | ||
|
|
||
| private[spark] object SocketAuthServer { | ||
|
|
||
| /** | ||
| * Create a socket server and run user function on the socket in a background thread. | ||
| * | ||
| * The socket server can only accept one connection, or close if no connection | ||
| * in 15 seconds. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please save this comment -- I guess move it to the class
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that is a useful comment, I didn't intend to take this out. I'll put it back in. |
||
| * Convenience function to create a socket server and run a user function in a background | ||
| * thread to write to an output stream. | ||
| * | ||
| * The thread will terminate after the supplied user function, or if there are any exceptions. | ||
| * | ||
| * If you need to get a result of the supplied function, create a subclass of [[SocketAuthServer]] | ||
| * | ||
| * @return The port number of a local socket and the secret for authentication. | ||
| * @param threadName Name for the background serving thread. | ||
| * @param authHelper SocketAuthHelper for authentication | ||
| * @param writeFunc User function to write to a given OutputStream | ||
| * @return | ||
| */ | ||
| def setupOneConnectionServer( | ||
| authHelper: SocketAuthHelper, | ||
| threadName: String) | ||
| (func: Socket => Unit): (Int, String) = { | ||
| val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) | ||
| // Close the socket if no connection in 15 seconds | ||
| serverSocket.setSoTimeout(15000) | ||
|
|
||
| new Thread(threadName) { | ||
| setDaemon(true) | ||
| override def run(): Unit = { | ||
| var sock: Socket = null | ||
| try { | ||
| sock = serverSocket.accept() | ||
| authHelper.authClient(sock) | ||
| func(sock) | ||
| } finally { | ||
| JavaUtils.closeQuietly(serverSocket) | ||
| JavaUtils.closeQuietly(sock) | ||
| } | ||
| def serveToStream( | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved this from |
||
| threadName: String, | ||
| authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = { | ||
| val handleFunc = (sock: Socket) => { | ||
| val out = new BufferedOutputStream(sock.getOutputStream()) | ||
| Utils.tryWithSafeFinally { | ||
| writeFunc(out) | ||
| } { | ||
| out.close() | ||
| } | ||
| }.start() | ||
| (serverSocket.getLocalPort, authHelper.secret) | ||
| } | ||
|
|
||
| val server = new SocketFuncServer(authHelper, threadName, handleFunc) | ||
| Array(server.port, server.secret, server) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1389,7 +1389,9 @@ private[spark] object Utils extends Logging { | |
| originalThrowable = cause | ||
| try { | ||
| logError("Aborting task", originalThrowable) | ||
| TaskContext.get().markTaskFailed(originalThrowable) | ||
| if (TaskContext.get() != null) { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using this utility here https://github.com/apache/spark/pull/24834/files#diff-0a67bc4d171abe4df8eb305b0f4123a2R184, where the task fails and completes before hitting the |
||
| TaskContext.get().markTaskFailed(originalThrowable) | ||
| } | ||
| catchBlock | ||
| } catch { | ||
| case t: Throwable => | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -140,14 +140,29 @@ def _parse_memory(s): | |
|
|
||
|
|
||
| def _create_local_socket(sock_info): | ||
| (sockfile, sock) = local_connect_and_auth(*sock_info) | ||
| """ | ||
| Create a local socket that can be used to load deserialized data from the JVM | ||
|
|
||
| :param sock_info: Tuple containing port number and authentication secret for a local socket. | ||
| :return: sockfile file descriptor of the local socket | ||
| """ | ||
| port = sock_info[0] | ||
| auth_secret = sock_info[1] | ||
| sockfile, sock = local_connect_and_auth(port, auth_secret) | ||
| # The RDD materialization time is unpredictable, if we set a timeout for socket reading | ||
| # operation, it will very possibly fail. See SPARK-18281. | ||
| sock.settimeout(None) | ||
| return sockfile | ||
|
|
||
|
|
||
| def _load_from_socket(sock_info, serializer): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @BryanCutler, what does
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uggh, yeah I'm not too happy with this. Java returns a 3-tuple with (port, auth_secret, server) and most places only use the first 2, such as port, auth_secret, _ = ...and I don't think it really made things clearer. I'll try to think of something better and maybe do a followup. |
||
| """ | ||
| Connect to a local socket described by sock_info and use the given serializer to yield data | ||
|
|
||
| :param sock_info: Tuple containing port number and authentication secret for a local socket. | ||
| :param serializer: The PySpark serializer to use | ||
| :return: result of Serializer.load_stream, usually a generator that yields deserialized data | ||
| """ | ||
| sockfile = _create_local_socket(sock_info) | ||
| # The socket will be automatically closed when garbage-collected. | ||
| return serializer.load_stream(sockfile) | ||
|
|
@@ -159,7 +174,8 @@ class PyLocalIterable(object): | |
| """ Create a synchronous local iterable over a socket """ | ||
|
|
||
| def __init__(self, _sock_info, _serializer): | ||
| self._sockfile = _create_local_socket(_sock_info) | ||
| port, auth_secret, self.jserver_obj = _sock_info | ||
|
||
| self._sockfile = _create_local_socket((port, auth_secret)) | ||
| self._serializer = _serializer | ||
| self._read_iter = iter([]) # Initialize as empty iterator | ||
| self._read_status = 1 | ||
|
|
@@ -179,11 +195,9 @@ def __iter__(self): | |
| for item in self._read_iter: | ||
| yield item | ||
|
|
||
| # An error occurred, read error message and raise it | ||
| # An error occurred, join serving thread and raise any exceptions from the JVM | ||
| elif self._read_status == -1: | ||
| error_msg = UTF8Deserializer().loads(self._sockfile) | ||
| raise RuntimeError("An error occurred while reading the next element from " | ||
| "toLocalIterator: {}".format(error_msg)) | ||
| self.jserver_obj.getResult() | ||
|
|
||
| def __del__(self): | ||
| # If local iterator is not fully consumed, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems we don't have setupOneConnectionServer anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, good catch