Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3257,12 +3257,11 @@ class Dataset[T] private[sql](

private[sql] def collectToPython(): Array[Any] = {
EvaluatePython.registerPicklers()
withAction("collectToPython", queryExecution) { plan =>
val iter = withAction("collectToPython", queryExecution) { plan =>
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
plan.executeCollect().iterator.map(toJava))
PythonRDD.serveIterator(iter, "serve-DataFrame")
new SerDeUtil.AutoBatchedPickler(plan.executeCollect().iterator.map(toJava))
}
PythonRDD.serveIterator(iter, "serve-DataFrame")
}

private[sql] def getRowsToPython(
Expand Down
33 changes: 31 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,24 @@
package org.apache.spark.sql

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.net.{InetAddress, Socket}
import java.sql.{Date, Timestamp}

import org.apache.spark.SparkException
import scala.io.Source

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec}
import org.apache.spark.sql.execution.{LogicalRDD, QueryExecution, RDDScanExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.QueryExecutionListener

case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
case class TestDataPoint2(x: Int, s: String)
Expand Down Expand Up @@ -1586,6 +1591,30 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-34726: Fix collectToPython timeouts") {
val listener = new QueryExecutionListener {
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}

override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
// Longer than 15s in `PythonServer.setupOneConnectionServer`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's extract the 15000 into a private[spark] var as member:

And add a comment above // visible for testing.
Then we can speed up this single test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, added in 6b18cc7 and 8f6b811

Thread.sleep(20 * 1000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my own understanding does this mean the test waits 15 seconds to pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to wait a bit longer than the timeout. Without the fix in Dataset.collectToPython this UT fails.

}
}
spark.listenerManager.register(listener)

val Array(port: Int, secretToPython: String) = spark.range(5).toDF().collectToPython()

// Mimic Python side
val socket = new Socket(InetAddress.getByAddress(Array(127, 0, 0, 1)), port)
val authHelper = new SocketAuthHelper(new SparkConf()) {
override val secret: String = secretToPython
}
authHelper.authToServer(socket)
Source.fromInputStream(socket.getInputStream)

spark.listenerManager.unregister(listener)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put this into finally to be on the safe side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right, added in 8f6b811

}
}

case class TestDataUnion(x: Int, y: Int, z: Int)
Expand Down