Skip to content

Commit ee95cae

Browse files
authored
feat: allow to spawn/spawn_blocking on a provided runtime in RecordBatchReceiverStreamBuilder (#17239)
1 parent be3842b commit ee95cae

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

datafusion/physical-plan/src/stream.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ use futures::stream::BoxStream;
3838
use futures::{Future, Stream, StreamExt};
3939
use log::debug;
4040
use pin_project_lite::pin_project;
41+
use tokio::runtime::Handle;
4142
use tokio::sync::mpsc::{Receiver, Sender};
4243

4344
/// Creates a stream from a collection of producing tasks, routing panics to the stream.
@@ -84,6 +85,15 @@ impl<O: Send + 'static> ReceiverStreamBuilder<O> {
8485
self.join_set.spawn(task);
8586
}
8687

88+
/// Same as [`Self::spawn`] but it spawns the task on the provided runtime
89+
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
90+
where
91+
F: Future<Output = Result<()>>,
92+
F: Send + 'static,
93+
{
94+
self.join_set.spawn_on(task, handle);
95+
}
96+
8797
/// Spawn a blocking task that will be aborted if this builder (or the stream
8898
/// built from it) are dropped.
8999
///
@@ -97,6 +107,15 @@ impl<O: Send + 'static> ReceiverStreamBuilder<O> {
97107
self.join_set.spawn_blocking(f);
98108
}
99109

110+
/// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime
111+
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
112+
where
113+
F: FnOnce() -> Result<()>,
114+
F: Send + 'static,
115+
{
116+
self.join_set.spawn_blocking_on(f, handle);
117+
}
118+
100119
/// Create a stream of all data written to `tx`
101120
pub fn build(self) -> BoxStream<'static, Result<O>> {
102121
let Self {
@@ -248,6 +267,15 @@ impl RecordBatchReceiverStreamBuilder {
248267
self.inner.spawn(task)
249268
}
250269

270+
/// Same as [`Self::spawn`] but it spawns the task on the provided runtime.
271+
pub fn spawn_on<F>(&mut self, task: F, handle: &Handle)
272+
where
273+
F: Future<Output = Result<()>>,
274+
F: Send + 'static,
275+
{
276+
self.inner.spawn_on(task, handle)
277+
}
278+
251279
/// Spawn a blocking task tied to the builder and stream.
252280
///
253281
/// # Drop / Cancel Behavior
@@ -275,6 +303,15 @@ impl RecordBatchReceiverStreamBuilder {
275303
self.inner.spawn_blocking(f)
276304
}
277305

306+
/// Same as [`Self::spawn_blocking`] but it spawns the blocking task on the provided runtime.
307+
pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle)
308+
where
309+
F: FnOnce() -> Result<()>,
310+
F: Send + 'static,
311+
{
312+
self.inner.spawn_blocking_on(f, handle)
313+
}
314+
278315
/// Runs the `partition` of the `input` ExecutionPlan on the
279316
/// tokio thread pool and writes its outputs to this stream
280317
///
@@ -822,4 +859,67 @@ mod test {
822859
);
823860
}
824861
}
862+
863+
#[test]
864+
fn record_batch_receiver_stream_builder_spawn_on_runtime() {
865+
let tokio_runtime = tokio::runtime::Builder::new_multi_thread()
866+
.enable_all()
867+
.build()
868+
.unwrap();
869+
870+
let mut builder =
871+
RecordBatchReceiverStreamBuilder::new(Arc::new(Schema::empty()), 10);
872+
873+
let tx1 = builder.tx();
874+
builder.spawn_on(
875+
async move {
876+
tx1.send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
877+
.await
878+
.unwrap();
879+
880+
Ok(())
881+
},
882+
tokio_runtime.handle(),
883+
);
884+
885+
let tx2 = builder.tx();
886+
builder.spawn_blocking_on(
887+
move || {
888+
tx2.blocking_send(Ok(RecordBatch::new_empty(Arc::new(Schema::empty()))))
889+
.unwrap();
890+
891+
Ok(())
892+
},
893+
tokio_runtime.handle(),
894+
);
895+
896+
let mut stream = builder.build();
897+
898+
let mut number_of_batches = 0;
899+
900+
loop {
901+
let poll = stream.poll_next_unpin(&mut Context::from_waker(
902+
futures::task::noop_waker_ref(),
903+
));
904+
905+
match poll {
906+
Poll::Ready(None) => {
907+
break;
908+
}
909+
Poll::Ready(Some(Ok(batch))) => {
910+
number_of_batches += 1;
911+
assert_eq!(batch.num_rows(), 0);
912+
}
913+
Poll::Ready(Some(Err(e))) => panic!("Unexpected error: {e}"),
914+
Poll::Pending => {
915+
continue;
916+
}
917+
}
918+
}
919+
920+
assert_eq!(
921+
number_of_batches, 2,
922+
"Should have received exactly one empty batch"
923+
);
924+
}
825925
}

0 commit comments

Comments
 (0)