Skip to content

Commit 0011cb2

Browse files
authored
Add opaque_reservation utility (#20885)
Using the rapidsmpf memory reservation system for opaque calls to `do_evaluate`. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) URL: #20885
1 parent 1dc82a9 commit 0011cb2

8 files changed

Lines changed: 279 additions & 123 deletions

File tree

python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/allgather.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: Apache-2.0
33
"""AllGather logic for the RapidsMPF streaming runtime."""
44

@@ -59,6 +59,7 @@ def insert(self, sequence_number: int, chunk: TableChunk) -> None:
5959
self.context.br(),
6060
),
6161
)
62+
del chunk
6263

6364
def insert_finished(self) -> None:
6465
"""Insert finished into the AllGatherManager."""
@@ -81,12 +82,9 @@ async def extract_concatenated(
8182
-------
8283
The concatenated AllGather result.
8384
"""
84-
partition_chunks = await self.allgather.extract_all(
85-
self.context, ordered=ordered
86-
)
8785
return await asyncio.to_thread(
8886
unpack_and_concat,
89-
partitions=partition_chunks,
87+
partitions=await self.allgather.extract_all(self.context, ordered=ordered),
9088
stream=stream,
9189
br=self.context.br(),
9290
)

python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/shuffle.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: Apache-2.0
33
"""Shuffle logic for the RapidsMPF streaming runtime."""
44

@@ -176,13 +176,13 @@ async def shuffle_node(
176176

177177
# Process input chunks
178178
while (msg := await ch_in.data.recv(context)) is not None:
179-
# Extract TableChunk from message
180-
chunk = TableChunk.from_message(msg).make_available_and_spill(
181-
context.br(), allow_overbooking=True
179+
# Extract TableChunk from message and insert into shuffler
180+
shuffle.insert_chunk(
181+
TableChunk.from_message(msg).make_available_and_spill(
182+
context.br(), allow_overbooking=True
183+
)
182184
)
183-
184-
# Get the table view and insert into shuffler
185-
shuffle.insert_chunk(chunk)
185+
del msg
186186

187187
# Insert finished
188188
await shuffle.insert_finished()
@@ -195,16 +195,19 @@ async def shuffle_node(
195195
num_partitions,
196196
context.comm().nranks,
197197
):
198-
# Create a new TableChunk with the result
199-
output_chunk = TableChunk.from_pylibcudf_table(
200-
table=await shuffle.extract_chunk(partition_id, stream),
201-
stream=stream,
202-
exclusive_view=True,
198+
# Extract and send the output chunk
199+
await ch_out.data.send(
200+
context,
201+
Message(
202+
partition_id,
203+
TableChunk.from_pylibcudf_table(
204+
table=await shuffle.extract_chunk(partition_id, stream),
205+
stream=stream,
206+
exclusive_view=True,
207+
),
208+
),
203209
)
204210

205-
# Send the output chunk
206-
await ch_out.data.send(context, Message(partition_id, output_chunk))
207-
208211
await ch_out.data.drain(context)
209212

210213

python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: Apache-2.0
33
"""IO logic for the RapidsMPF streaming runtime."""
44

@@ -40,6 +40,7 @@
4040
from cudf_polars.experimental.rapidsmpf.utils import (
4141
ChannelManager,
4242
Metadata,
43+
opaque_reservation,
4344
)
4445

4546
if TYPE_CHECKING:
@@ -107,7 +108,7 @@ async def drain(self) -> None:
107108

108109
# Forward any remaining buffered messages
109110
for seq in sorted(buffer.keys()):
110-
await self.ch_out.send(self.context, buffer[seq])
111+
await self.ch_out.send(self.context, buffer.pop(seq))
111112

112113
await self.ch_out.drain(self.context)
113114

@@ -142,6 +143,7 @@ async def dataframescan_node(
142143
*,
143144
num_producers: int,
144145
rows_per_partition: int,
146+
estimated_chunk_bytes: int,
145147
) -> None:
146148
"""
147149
DataFrameScan node for rapidsmpf.
@@ -160,6 +162,9 @@ async def dataframescan_node(
160162
The number of producers to use for the DataFrameScan node.
161163
rows_per_partition
162164
The number of rows per partition.
165+
estimated_chunk_bytes
166+
Estimated size of each chunk in bytes. Used for memory reservation
167+
with block spilling to avoid thrashing.
163168
"""
164169
async with shutdown_on_error(context, ch_out.metadata, ch_out.data):
165170
# Find local partition count.
@@ -206,6 +211,7 @@ async def dataframescan_node(
206211
seq_num,
207212
ch_out.data,
208213
ir_context,
214+
estimated_chunk_bytes,
209215
)
210216
await ch_out.data.drain(context)
211217
return
@@ -230,6 +236,7 @@ async def _producer(producer_id: int, ch_out: Channel) -> None:
230236
task_idx,
231237
ch_out,
232238
ir_context,
239+
estimated_chunk_bytes,
233240
)
234241
await ch_out.drain(context)
235242

@@ -250,6 +257,8 @@ def _(
250257
)
251258
rows_per_partition = config_options.executor.max_rows_per_partition
252259
num_producers = rec.state["max_io_threads"]
260+
# Use target_partition_size as the estimated chunk size
261+
estimated_chunk_bytes = config_options.executor.target_partition_size
253262

254263
context = rec.state["context"]
255264
ir_context = rec.state["ir_context"]
@@ -263,6 +272,7 @@ def _(
263272
channels[ir].reserve_input_slot(),
264273
num_producers=num_producers,
265274
rows_per_partition=rows_per_partition,
275+
estimated_chunk_bytes=estimated_chunk_bytes,
266276
)
267277
]
268278
}
@@ -307,6 +317,7 @@ async def read_chunk(
307317
seq_num: int,
308318
ch_out: Channel[TableChunk],
309319
ir_context: IRExecutionContext,
320+
estimated_chunk_bytes: int,
310321
) -> None:
311322
"""
312323
Read a chunk from disk and send it to the output channel.
@@ -323,24 +334,27 @@ async def read_chunk(
323334
The output channel.
324335
ir_context
325336
The execution context for the IR node.
337+
estimated_chunk_bytes
338+
Estimated size of the chunk in bytes. Used for memory reservation
339+
with block spilling to avoid thrashing.
326340
"""
327-
# Evaluate and send the Scan-node result
328-
df = await asyncio.to_thread(
329-
scan.do_evaluate,
330-
*scan._non_child_args,
331-
context=ir_context,
332-
)
333-
await ch_out.send(
334-
context,
335-
Message(
336-
seq_num,
337-
TableChunk.from_pylibcudf_table(
338-
df.table,
339-
df.stream,
340-
exclusive_view=True,
341+
with opaque_reservation(context, estimated_chunk_bytes):
342+
df = await asyncio.to_thread(
343+
scan.do_evaluate,
344+
*scan._non_child_args,
345+
context=ir_context,
346+
)
347+
await ch_out.send(
348+
context,
349+
Message(
350+
seq_num,
351+
TableChunk.from_pylibcudf_table(
352+
df.table,
353+
df.stream,
354+
exclusive_view=True,
355+
),
341356
),
342-
),
343-
)
357+
)
344358

345359

346360
@define_py_node()
@@ -353,6 +367,7 @@ async def scan_node(
353367
num_producers: int,
354368
plan: IOPartitionPlan,
355369
parquet_options: ParquetOptions,
370+
estimated_chunk_bytes: int,
356371
) -> None:
357372
"""
358373
Scan node for rapidsmpf.
@@ -373,6 +388,9 @@ async def scan_node(
373388
The partitioning plan.
374389
parquet_options
375390
The Parquet options.
391+
estimated_chunk_bytes
392+
Estimated size of each chunk in bytes. Used for memory reservation
393+
with block spilling to avoid thrashing.
376394
"""
377395
async with shutdown_on_error(context, ch_out.metadata, ch_out.data):
378396
# Build a list of local Scan operations
@@ -460,6 +478,7 @@ async def scan_node(
460478
seq_num,
461479
ch_out.data,
462480
ir_context,
481+
estimated_chunk_bytes,
463482
)
464483
await ch_out.data.drain(context)
465484
return
@@ -484,6 +503,7 @@ async def _producer(producer_id: int, ch_out: Channel) -> None:
484503
task_idx,
485504
ch_out,
486505
ir_context,
506+
estimated_chunk_bytes,
487507
)
488508
await ch_out.drain(context)
489509

@@ -607,7 +627,8 @@ def _(
607627
ir: Scan, rec: SubNetGenerator
608628
) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]:
609629
config_options = rec.state["config_options"]
610-
assert config_options.executor.name == "streaming", (
630+
executor = rec.state["config_options"].executor
631+
assert executor.name == "streaming", (
611632
"'in-memory' executor not supported in 'generate_ir_sub_network'"
612633
)
613634
parquet_options = config_options.parquet_options
@@ -669,6 +690,7 @@ def _(
669690
num_producers=num_producers,
670691
plan=plan,
671692
parquet_options=parquet_options,
693+
estimated_chunk_bytes=executor.target_partition_size,
672694
)
673695
]
674696
return nodes, channels

python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: Apache-2.0
33
"""Join logic for the RapidsMPF streaming runtime."""
44

@@ -27,6 +27,7 @@
2727
Metadata,
2828
chunk_to_frame,
2929
empty_table_chunk,
30+
opaque_reservation,
3031
process_children,
3132
)
3233
from cudf_polars.experimental.utils import _concat
@@ -128,6 +129,7 @@ async def broadcast_join_node(
128129
context.br(), allow_overbooking=True
129130
)
130131
)
132+
del msg
131133
small_size += small_chunks[-1].data_alloc_size(MemoryType.DEVICE)
132134

133135
# Allgather is a collective - all ranks must participate even with no local data
@@ -193,6 +195,7 @@ async def broadcast_join_node(
193195
context.br(), allow_overbooking=True
194196
)
195197
seq_num = msg.sequence_number
198+
del msg
196199

197200
large_df = DataFrame.from_table(
198201
large_chunk.table_view(),
@@ -207,10 +210,11 @@ async def broadcast_join_node(
207210
empty_small_chunk = empty_table_chunk(small_child, context, stream)
208211
small_dfs = [chunk_to_frame(empty_small_chunk, small_child)]
209212

210-
# Perform the join
211-
df = _concat(
212-
*[
213-
(
213+
large_chunk_size = large_chunk.data_alloc_size(MemoryType.DEVICE)
214+
input_bytes = large_chunk_size + small_size
215+
with opaque_reservation(context, input_bytes):
216+
df = _concat(
217+
*[
214218
await asyncio.to_thread(
215219
ir.do_evaluate,
216220
*ir._non_child_args,
@@ -221,23 +225,24 @@ async def broadcast_join_node(
221225
),
222226
context=ir_context,
223227
)
224-
)
225-
for small_df in small_dfs
226-
],
227-
context=ir_context,
228-
)
228+
for small_df in small_dfs
229+
],
230+
context=ir_context,
231+
)
229232

230-
# Send output chunk
231-
await ch_out.data.send(
232-
context,
233-
Message(
234-
seq_num,
235-
TableChunk.from_pylibcudf_table(
236-
df.table, df.stream, exclusive_view=True
233+
# Send output chunk
234+
await ch_out.data.send(
235+
context,
236+
Message(
237+
seq_num,
238+
TableChunk.from_pylibcudf_table(
239+
df.table, df.stream, exclusive_view=True
240+
),
237241
),
238-
),
239-
)
242+
)
243+
del df, large_df, large_chunk
240244

245+
del small_dfs, small_chunks
241246
await ch_out.data.drain(context)
242247

243248

0 commit comments

Comments
 (0)