Skip to content
5 changes: 3 additions & 2 deletions datafusion/physical-plan/src/sorts/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use std::cmp::Ordering;
use std::sync::Arc;

use arrow::array::{
types::ByteArrayType, Array, ArrowPrimitiveType, GenericByteArray,
Expand Down Expand Up @@ -151,7 +152,7 @@ impl<T: CursorValues> Ord for Cursor<T> {
/// Used for sorting when there are multiple columns in the sort key
#[derive(Debug)]
pub struct RowValues {
rows: Rows,
rows: Arc<Rows>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense, thank you @Dandandan !


/// Tracks for the memory used by in the `Rows` of this
/// cursor. Freed on drop
Expand All @@ -164,7 +165,7 @@ impl RowValues {
///
/// Panics if the reservation is not for exactly `rows.size()`
/// bytes or if `rows` is empty.
pub fn new(rows: Rows, reservation: MemoryReservation) -> Self {
pub fn new(rows: Arc<Rows>, reservation: MemoryReservation) -> Self {
assert_eq!(
rows.size(),
reservation.size(),
Expand Down
48 changes: 42 additions & 6 deletions datafusion/physical-plan/src/sorts/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use crate::{PhysicalExpr, PhysicalSortExpr};
use arrow::array::Array;
use arrow::datatypes::Schema;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, SortField};
use datafusion_common::Result;
use arrow::row::{RowConverter, Rows, SortField};
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::MemoryReservation;
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use futures::stream::{Fuse, StreamExt};
Expand Down Expand Up @@ -78,6 +78,8 @@ impl FusedStreams {

/// A [`PartitionedStream`] that wraps a set of [`SendableRecordBatchStream`]
/// and computes [`RowValues`] based on the provided [`PhysicalSortExpr`]
/// Note: the stream returns an error if the consumer buffers more than one RowValues (i.e. holds on to two RowValues
/// from the same partition at the same time).
#[derive(Debug)]
pub struct RowCursorStream {
/// Converter to convert output of physical expressions
Expand All @@ -88,6 +90,9 @@ pub struct RowCursorStream {
streams: FusedStreams,
/// Tracks the memory used by `converter`
reservation: MemoryReservation,
/// Allocated rows for each partition, we keep two to allow for buffering one
/// in the consumer of the stream
rows: Vec<[Option<Arc<Rows>>; 2]>,
}

impl RowCursorStream {
Expand All @@ -105,26 +110,57 @@ impl RowCursorStream {
})
.collect::<Result<Vec<_>>>()?;

let streams = streams.into_iter().map(|s| s.fuse()).collect();
let streams: Vec<_> = streams.into_iter().map(|s| s.fuse()).collect();
let converter = RowConverter::new(sort_fields)?;
let mut rows = Vec::with_capacity(streams.len());
for _ in &streams {
// Initialize each stream with an empty Rows
rows.push([
Some(Arc::new(converter.empty_rows(0, 0))),
Some(Arc::new(converter.empty_rows(0, 0))),
]);
}
Ok(Self {
converter,
reservation,
column_expressions: expressions.iter().map(|x| Arc::clone(&x.expr)).collect(),
streams: FusedStreams(streams),
rows,
})
}

fn convert_batch(&mut self, batch: &RecordBatch) -> Result<RowValues> {
fn convert_batch(
&mut self,
batch: &RecordBatch,
stream_idx: usize,
) -> Result<RowValues> {
let cols = self
.column_expressions
.iter()
.map(|expr| expr.evaluate(batch)?.into_array(batch.num_rows()))
.collect::<Result<Vec<_>>>()?;

let rows = self.converter.convert_columns(&cols)?;
// At this point, ownership should of this Rows should be unique
let mut rows = Arc::try_unwrap(self.rows[stream_idx][1].take().unwrap())
.map_err(|_| {
DataFusionError::Internal(
"Rows from RowCursorStream is still in use by consumer".to_string(),
)
})?;

rows.clear();

self.converter.append(&mut rows, &cols)?;
self.reservation.try_resize(self.converter.size())?;

let rows = Arc::new(rows);

self.rows[stream_idx][1] = Some(Arc::clone(&rows));

// swap the curent with the previous one, so that the next poll can reuse the Rows from the previous poll
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Dandandan , this implementation is really clever:

Double‐buffer swap for smooth handoff
After appending the new data into the “cur” slot, swapping the two slots with std::mem::swap transparently rotates which buffer will be reused next. This means you always have one slot holding the “previous” data for downstream consumers and an idle slot ready for your next try_unwrap.

let [a, b] = &mut self.rows[stream_idx];
std::mem::swap(a, b);

// track the memory in the newly created Rows.
let mut rows_reservation = self.reservation.new_empty();
rows_reservation.try_grow(rows.size())?;
Expand All @@ -146,7 +182,7 @@ impl PartitionedStream for RowCursorStream {
) -> Poll<Option<Self::Output>> {
Poll::Ready(ready!(self.streams.poll_next(cx, stream_idx)).map(|r| {
r.and_then(|batch| {
let cursor = self.convert_batch(&batch)?;
let cursor = self.convert_batch(&batch, stream_idx)?;
Ok((cursor, batch))
})
}))
Expand Down