Skip to content

Commit 2a965dc

Browse files
committed
feat: use spawned tasks to reduce call stack depth and avoid busy waiting
1 parent 1daa5ed commit 2a965dc

File tree

5 files changed

+68
-25
lines changed

5 files changed

+68
-25
lines changed

datafusion/common/src/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,12 @@ impl From<GenericError> for DataFusionError {
350350
}
351351
}
352352

353+
impl From<JoinError> for DataFusionError {
354+
fn from(e: JoinError) -> Self {
355+
DataFusionError::ExecutionJoin(e)
356+
}
357+
}
358+
353359
impl Display for DataFusionError {
354360
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
355361
let error_prefix = self.error_prefix();

datafusion/physical-plan/src/joins/cross_join.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use datafusion_execution::TaskContext;
4747
use datafusion_physical_expr::equivalence::join_equivalence_properties;
4848

4949
use async_trait::async_trait;
50+
use datafusion_common_runtime::SpawnedTask;
5051
use futures::{ready, Stream, StreamExt, TryStreamExt};
5152

5253
/// Data of the left side that is buffered into memory
@@ -303,12 +304,13 @@ impl ExecutionPlan for CrossJoinExec {
303304

304305
let left_fut = self.left_fut.try_once(|| {
305306
let left_stream = self.left.execute(0, context)?;
306-
307-
Ok(load_left_input(
308-
left_stream,
309-
join_metrics.clone(),
310-
reservation,
311-
))
307+
let task = load_left_input(left_stream, join_metrics.clone(), reservation);
308+
Ok(async move {
309+
// Spawn a task the first time the stream is polled for the build phase.
310+
// This ensures the consumer of the join does not poll unnecessarily
311+
// while the build is ongoing
312+
SpawnedTask::spawn(task).await?
313+
})
312314
})?;
313315

314316
if enforce_batch_size_in_joins {

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ use datafusion_physical_expr::PhysicalExprRef;
8282
use datafusion_physical_expr_common::datum::compare_op_for_nested;
8383

8484
use ahash::RandomState;
85+
use datafusion_common_runtime::SpawnedTask;
8586
use datafusion_physical_expr_common::physical_expr::fmt_sql;
86-
use futures::{ready, Stream, StreamExt, TryStreamExt};
87+
use futures::{ready, FutureExt, Stream, StreamExt, TryStreamExt};
8788
use parking_lot::Mutex;
8889

8990
/// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions.
@@ -810,15 +811,22 @@ impl ExecutionPlan for HashJoinExec {
810811
let reservation =
811812
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());
812813

813-
Ok(collect_left_input(
814+
let task = collect_left_input(
814815
self.random_state.clone(),
815816
left_stream,
816817
on_left.clone(),
817818
join_metrics.clone(),
818819
reservation,
819820
need_produce_result_in_final(self.join_type),
820821
self.right().output_partitioning().partition_count(),
821-
))
822+
);
823+
824+
Ok(async move {
825+
// Spawn a task the first time the stream is polled for the build phase.
826+
// This ensures the consumer of the join does not poll unnecessarily
827+
// while the build is ongoing
828+
SpawnedTask::spawn(task).await?
829+
})
822830
})?,
823831
PartitionMode::Partitioned => {
824832
let left_stream = self.left.execute(partition, Arc::clone(&context))?;
@@ -827,15 +835,22 @@ impl ExecutionPlan for HashJoinExec {
827835
MemoryConsumer::new(format!("HashJoinInput[{partition}]"))
828836
.register(context.memory_pool());
829837

830-
OnceFut::new(collect_left_input(
838+
let task = collect_left_input(
831839
self.random_state.clone(),
832840
left_stream,
833841
on_left.clone(),
834842
join_metrics.clone(),
835843
reservation,
836844
need_produce_result_in_final(self.join_type),
837845
1,
838-
))
846+
);
847+
848+
OnceFut::new(async move {
849+
// Spawn a task the first time the stream is polled for the build phase.
850+
// This ensures the consumer of the join does not poll unnecessarily
851+
// while the build is ongoing
852+
SpawnedTask::spawn(task).map(|r| r?).await
853+
})
839854
}
840855
PartitionMode::Auto => {
841856
return plan_err!(

datafusion/physical-plan/src/joins/nested_loop_join.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ use datafusion_physical_expr::equivalence::{
6161
join_equivalence_properties, ProjectionMapping,
6262
};
6363

64+
use datafusion_common_runtime::SpawnedTask;
6465
use futures::{ready, Stream, StreamExt, TryStreamExt};
6566
use parking_lot::Mutex;
6667

@@ -499,13 +500,19 @@ impl ExecutionPlan for NestedLoopJoinExec {
499500
let inner_table = self.inner_table.try_once(|| {
500501
let stream = self.left.execute(0, Arc::clone(&context))?;
501502

502-
Ok(collect_left_input(
503+
let task = collect_left_input(
503504
stream,
504505
join_metrics.clone(),
505506
load_reservation,
506507
need_produce_result_in_final(self.join_type),
507508
self.right().output_partitioning().partition_count(),
508-
))
509+
);
510+
Ok(async move {
511+
// Spawn a task the first time the stream is polled for the build phase.
512+
// This ensures the consumer of the join does not poll unnecessarily
513+
// while the build is ongoing
514+
SpawnedTask::spawn(task).await?
515+
})
509516
})?;
510517

511518
let batch_size = context.session_config().batch_size();

datafusion/physical-plan/src/sorts/sort.rs

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ use datafusion_execution::runtime_env::RuntimeEnv;
5454
use datafusion_execution::TaskContext;
5555
use datafusion_physical_expr::LexOrdering;
5656

57+
use datafusion_common_runtime::SpawnedTask;
5758
use futures::{StreamExt, TryStreamExt};
5859
use log::{debug, trace};
5960

@@ -1126,14 +1127,20 @@ impl ExecutionPlan for SortExec {
11261127
Ok(Box::pin(RecordBatchStreamAdapter::new(
11271128
self.schema(),
11281129
futures::stream::once(async move {
1129-
while let Some(batch) = input.next().await {
1130-
let batch = batch?;
1131-
topk.insert_batch(batch)?;
1132-
if topk.finished {
1133-
break;
1130+
// Spawn a task the first time the stream is polled for the sort phase.
1131+
// This ensures the consumer of the sort does not poll unnecessarily
1132+
// while the sort is ongoing
1133+
SpawnedTask::spawn(async move {
1134+
while let Some(batch) = input.next().await {
1135+
let batch = batch?;
1136+
topk.insert_batch(batch)?;
1137+
if topk.finished {
1138+
break;
1139+
}
11341140
}
1135-
}
1136-
topk.emit()
1141+
topk.emit()
1142+
})
1143+
.await?
11371144
})
11381145
.try_flatten(),
11391146
)))
@@ -1152,11 +1159,17 @@ impl ExecutionPlan for SortExec {
11521159
Ok(Box::pin(RecordBatchStreamAdapter::new(
11531160
self.schema(),
11541161
futures::stream::once(async move {
1155-
while let Some(batch) = input.next().await {
1156-
let batch = batch?;
1157-
sorter.insert_batch(batch).await?;
1158-
}
1159-
sorter.sort().await
1162+
// Spawn a task the first time the stream is polled for the sort phase.
1163+
// This ensures the consumer of the sort does not poll unnecessarily
1164+
// while the sort is ongoing
1165+
SpawnedTask::spawn(async move {
1166+
while let Some(batch) = input.next().await {
1167+
let batch = batch?;
1168+
sorter.insert_batch(batch).await?;
1169+
}
1170+
sorter.sort().await
1171+
})
1172+
.await?
11601173
})
11611174
.try_flatten(),
11621175
)))

0 commit comments

Comments
 (0)