Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,14 @@ def test_filtered_frame(self):
self.assertEqual(pdf.columns[0], "i")
self.assertTrue(pdf.empty)

def test_no_partition_frame(self):
schema = StructType([StructField("field1", StringType(), True)])
df = self.spark.createDataFrame(self.sc.emptyRDD(), schema)
pdf = df.toPandas()
self.assertEqual(len(pdf.columns), 1)
self.assertEqual(pdf.columns[0], "field1")
self.assertTrue(pdf.empty)

def _createDataFrame_toggle(self, pdf, schema=None):
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
Expand Down
26 changes: 17 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3310,23 +3310,31 @@ class Dataset[T] private[sql](

// After last batch, end the stream and write batch order indices
if (partitionCount == numPartitions) {
batchWriter.end()
out.writeInt(batchOrder.length)
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
out.writeInt(overallBatchIndex)
}
out.flush()
doAfterLastPartition()
}
}

def doAfterLastPartition(): Unit = {
batchWriter.end()
out.writeInt(batchOrder.length)
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
out.writeInt(overallBatchIndex)
}
out.flush()
}

sparkSession.sparkContext.runJob(
arrowBatchRdd,
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
0 until numPartitions,
handlePartitionBatches)

if (numPartitions == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is well-commented. Can you add another comment that we should end stream when partitions are empty?

Also, I would do:

partitions = 0 until numPartitions
sparkSession.sparkContext.runJob(
  arrowBatchRdd,
  (ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
  partitions,
  handlePartitionBatches)

if (partitions.isEmpty) {
  // Currently result handler is not called when given partitions are empty.
  // Therefore, we should end stream here.
  doAfterLastPartition()
}

doAfterLastPartition()
}
}
}
}
Expand Down