diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b775e4036760..a280e2225e81 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -878,6 +878,9 @@ private[spark] abstract class PythonServer[T]( private[spark] object PythonServer { + // visible for testing + private[spark] var timeout = 15000 + /** * Create a socket server and run user function on the socket in a background thread. * @@ -896,7 +899,7 @@ private[spark] object PythonServer { (func: Socket => Unit): (Int, String) = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Close the socket if no connection in 15 seconds - serverSocket.setSoTimeout(15000) + serverSocket.setSoTimeout(timeout) new Thread(threadName) { setDaemon(true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6e4577591dab..bfa49a8d81a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3257,12 +3257,11 @@ class Dataset[T] private[sql]( private[sql] def collectToPython(): Array[Any] = { EvaluatePython.registerPicklers() - withAction("collectToPython", queryExecution) { plan => + val iter = withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) - val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( - plan.executeCollect().iterator.map(toJava)) - PythonRDD.serveIterator(iter, "serve-DataFrame") + new SerDeUtil.AutoBatchedPickler(plan.executeCollect().iterator.map(toJava)) } + PythonRDD.serveIterator(iter, "serve-DataFrame") } private[sql] def getRowsToPython( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 08ebf8b10fef..aa33f79e2191 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -18,19 +18,25 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.net.{InetAddress, Socket} import java.sql.{Date, Timestamp} -import org.apache.spark.SparkException +import scala.io.Source + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.api.python.PythonServer +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec} +import org.apache.spark.sql.execution.{LogicalRDD, QueryExecution, RDDScanExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.QueryExecutionListener case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) case class TestDataPoint2(x: Int, s: String) @@ -1586,6 +1592,37 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-34726: Fix collectToPython timeouts") { + // Lower `PythonServer.setupOneConnectionServer` timeout for this test + val oldTimeout = PythonServer.timeout + PythonServer.timeout = 1000 + + val listener = new QueryExecutionListener { + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + // Wait longer than `PythonServer.setupOneConnectionServer` timeout + Thread.sleep(PythonServer.timeout + 1000) + } + } + try { + spark.listenerManager.register(listener) + + val Array(port: Int, secretToPython: String) = spark.range(5).toDF().collectToPython() + + // Mimic Python side + val socket = new Socket(InetAddress.getByAddress(Array(127, 0, 0, 1)), port) + val authHelper = new SocketAuthHelper(new SparkConf()) { + override val secret: String = secretToPython + } + authHelper.authToServer(socket) + Source.fromInputStream(socket.getInputStream) + } finally { + spark.listenerManager.unregister(listener) + PythonServer.timeout = oldTimeout + } + } } case class TestDataUnion(x: Int, y: Int, z: Int)