Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/cuda_cccl/cuda/compute/_odr_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from __future__ import annotations

import enum
import itertools
import textwrap
from typing import TYPE_CHECKING

Expand All @@ -36,6 +37,10 @@
if TYPE_CHECKING:
from numba.core.typing import Signature

# Global counter to generate unique symbol names even when the same function
# is used multiple times (e.g., as both selectors in `three_way_partition`).
_wrapper_name_counter = itertools.count()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this global counter needs a lock to avoid creating race condition in free-threaded interpreter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, although I will say that cuda.compute as a whole is probably not thread-safe today. All the caching mechanisms etc., haven't taken thread safety into account thus far (#6422).


__all__ = [
"create_op_void_ptr_wrapper",
"create_advance_void_ptr_wrapper",
Expand Down Expand Up @@ -146,9 +151,9 @@ def _create_void_ptr_wrapper(
arg_str = ", ".join(arg_names)
void_sig = types.void(*(types.voidptr for _ in arg_specs))

# Create unique wrapper name
# Create unique wrapper name using global counter
sanitized_name = sanitize_identifier(name)
unique_suffix = hex(id(func))[2:]
unique_suffix = next(_wrapper_name_counter)
wrapper_name = f"wrapped_{sanitized_name}_{unique_suffix}"

# We need exec() here because Numba's @intrinsic decorator requires:
Expand Down
6 changes: 1 addition & 5 deletions python/cuda_cccl/cuda/compute/algorithms/_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ def __init__(
self.discard_second = DiscardIterator(d_out)
self.discard_unselected = DiscardIterator(d_out)

# Create a predicate that always returns False
def _cccl_always_false(x):
return False

# Use three_way_partition internally
self.partitioner = make_three_way_partition(
d_in,
Expand All @@ -59,7 +55,7 @@ def _cccl_always_false(x):
self.discard_unselected, # unselected_out - discarded
d_num_selected_out,
cond, # select_first_part_op - user's select condition
_cccl_always_false, # select_second_part_op - always false
lambda x: False, # select_second_part_op - always false
)

def __call__(
Expand Down
30 changes: 30 additions & 0 deletions python/cuda_cccl/tests/compute/test_three_way_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,36 @@ def greater_equal_op(x):
np.testing.assert_array_equal(got_unselected, h_in)


def test_three_way_partition_same_predicate():
dtype = np.int32
num_items = 100
h_in = random_array(num_items, dtype, max_value=100)

def always_true(x):
return True

d_in = cp.asarray(h_in)
d_first = cp.empty_like(d_in)
d_second = cp.empty_like(d_in)
d_unselected = cp.empty_like(d_in)
d_num_selected = cp.empty(2, dtype=np.int64)

cuda.compute.three_way_partition(
d_in,
d_first,
d_second,
d_unselected,
d_num_selected,
always_true,
always_true,
num_items,
)

num_selected = d_num_selected.get()
assert int(num_selected[0]) == num_items
assert int(num_selected[1]) == 0


def test_three_way_partition_all_selected_first():
dtype = np.int32
num_items = 20_000
Expand Down
Loading