1818package org .apache .spark .sql .connect .service
1919
2020import scala .collection .JavaConverters ._
21- import scala .util .control .NonFatal
2221
2322import com .google .protobuf .ByteString
2423import io .grpc .stub .StreamObserver
2524
26- import org .apache .spark .SparkException
2725import org .apache .spark .connect .proto
2826import org .apache .spark .connect .proto .{ExecutePlanRequest , ExecutePlanResponse }
2927import org .apache .spark .internal .Logging
@@ -34,7 +32,6 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
3432import org .apache .spark .sql .execution .adaptive .{AdaptiveSparkPlanExec , AdaptiveSparkPlanHelper , QueryStageExec }
3533import org .apache .spark .sql .execution .arrow .ArrowConverters
3634import org .apache .spark .sql .types .StructType
37- import org .apache .spark .util .ThreadUtils
3835
3936class SparkConnectStreamHandler (responseObserver : StreamObserver [ExecutePlanResponse ])
4037 extends Logging {
@@ -57,75 +54,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
5754 // Extract the plan from the request and convert it to a logical plan
5855 val planner = new SparkConnectPlanner (session)
5956 val dataframe = Dataset .ofRows(session, planner.transformRelation(request.getPlan.getRoot))
60- try {
61- processAsArrowBatches(request.getClientId, dataframe)
62- } catch {
63- case e : Exception =>
64- logWarning(e.getMessage)
65- processAsJsonBatches(request.getClientId, dataframe)
66- }
67- }
68-
69- def processAsJsonBatches (clientId : String , dataframe : DataFrame ): Unit = {
70- // Only process up to 10MB of data.
71- val sb = new StringBuilder
72- var rowCount = 0
73- dataframe.toJSON
74- .collect()
75- .foreach(row => {
76-
77- // There are a few cases to cover here.
78- // 1. The aggregated buffer size is larger than the MAX_BATCH_SIZE
79- // -> send the current batch and reset.
80- // 2. The aggregated buffer size is smaller than the MAX_BATCH_SIZE
81- // -> append the row to the buffer.
82- // 3. The row in question is larger than the MAX_BATCH_SIZE
83- // -> fail the query.
84-
85- // Case 3. - Fail
86- if (row.size > MAX_BATCH_SIZE ) {
87- throw SparkException .internalError(
88- s " Serialized row is larger than MAX_BATCH_SIZE: ${row.size} > ${MAX_BATCH_SIZE }" )
89- }
90-
91- // Case 1 - FLush and send.
92- if (sb.size + row.size > MAX_BATCH_SIZE ) {
93- val response = proto.ExecutePlanResponse .newBuilder().setClientId(clientId)
94- val batch = proto.ExecutePlanResponse .JSONBatch
95- .newBuilder()
96- .setData(ByteString .copyFromUtf8(sb.toString()))
97- .setRowCount(rowCount)
98- .build()
99- response.setJsonBatch(batch)
100- responseObserver.onNext(response.build())
101- sb.clear()
102- sb.append(row)
103- rowCount = 1
104- } else {
105- // Case 2 - Append.
106- // Make sure to put the newline delimiters only between items and not at the end.
107- if (rowCount > 0 ) {
108- sb.append(" \n " )
109- }
110- sb.append(row)
111- rowCount += 1
112- }
113- })
114-
115- // If the last batch is not empty, send out the data to the client.
116- if (sb.size > 0 ) {
117- val response = proto.ExecutePlanResponse .newBuilder().setClientId(clientId)
118- val batch = proto.ExecutePlanResponse .JSONBatch
119- .newBuilder()
120- .setData(ByteString .copyFromUtf8(sb.toString()))
121- .setRowCount(rowCount)
122- .build()
123- response.setJsonBatch(batch)
124- responseObserver.onNext(response.build())
125- }
126-
127- responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
128- responseObserver.onCompleted()
57+ processAsArrowBatches(request.getClientId, dataframe)
12958 }
13059
13160 def processAsArrowBatches (clientId : String , dataframe : DataFrame ): Unit = {
@@ -142,83 +71,20 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
14271 var numSent = 0
14372
14473 if (numPartitions > 0 ) {
145- type Batch = (Array [Byte ], Long )
146-
14774 val batches = rows.mapPartitionsInternal(
14875 SparkConnectStreamHandler
14976 .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))
15077
151- val signal = new Object
152- val partitions = collection.mutable.Map .empty[Int , Array [Batch ]]
153- var error : Throwable = null
154-
155- val processPartition = (iter : Iterator [Batch ]) => iter.toArray
156-
157- // This callback is executed by the DAGScheduler thread.
158- // After fetching a partition, it inserts the partition into the Map, and then
159- // wakes up the main thread.
160- val resultHandler = (partitionId : Int , partition : Array [Batch ]) => {
161- signal.synchronized {
162- partitions(partitionId) = partition
163- signal.notify()
164- }
165- ()
166- }
167-
168- val future = spark.sparkContext.submitJob(
169- rdd = batches,
170- processPartition = processPartition,
171- partitions = Seq .range(0 , numPartitions),
172- resultHandler = resultHandler,
173- resultFunc = () => ())
174-
175- // Collect errors and propagate them to the main thread.
176- future.onComplete { result =>
177- result.failed.foreach { throwable =>
178- signal.synchronized {
179- error = throwable
180- signal.notify()
181- }
182- }
183- }(ThreadUtils .sameThread)
184-
185- // The main thread will wait until 0-th partition is available,
186- // then send it to client and wait for the next partition.
187- // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
188- // the arrow batches in main thread to avoid DAGScheduler thread been blocked for
189- // tasks not related to scheduling. This is particularly important if there are
190- // multiple users or clients running code at the same time.
191- var currentPartitionId = 0
192- while (currentPartitionId < numPartitions) {
193- val partition = signal.synchronized {
194- var result = partitions.remove(currentPartitionId)
195- while (result.isEmpty && error == null ) {
196- signal.wait()
197- result = partitions.remove(currentPartitionId)
198- }
199- error match {
200- case NonFatal (e) =>
201- responseObserver.onError(error)
202- logError(" Error while processing query." , e)
203- return
204- case fatal : Throwable => throw fatal
205- case null => result.get
206- }
207- }
208-
209- partition.foreach { case (bytes, count) =>
210- val response = proto.ExecutePlanResponse .newBuilder().setClientId(clientId)
211- val batch = proto.ExecutePlanResponse .ArrowBatch
212- .newBuilder()
213- .setRowCount(count)
214- .setData(ByteString .copyFrom(bytes))
215- .build()
216- response.setArrowBatch(batch)
217- responseObserver.onNext(response.build())
218- numSent += 1
219- }
220-
221- currentPartitionId += 1
78+ batches.collect().foreach { case (bytes, count) =>
79+ val response = proto.ExecutePlanResponse .newBuilder().setClientId(clientId)
80+ val batch = proto.ExecutePlanResponse .ArrowBatch
81+ .newBuilder()
82+ .setRowCount(count)
83+ .setData(ByteString .copyFrom(bytes))
84+ .build()
85+ response.setArrowBatch(batch)
86+ responseObserver.onNext(response.build())
87+ numSent += 1
22288 }
22389 }
22490
0 commit comments