Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, size: int, stream: Optional[StreamLike] = None):
self, _finalize_buffer, self._ptr, self._stream_handle
)

@property
def __cuda_array_interface__(self):
return {
"data": (self._ptr, False),
Expand Down
2 changes: 1 addition & 1 deletion python/cuda_cccl/cuda/compute/iterators/_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
DiscardIterator as _DiscardIterator,
)
from ._iterators import (
make_permutation_iterator,
make_reverse_iterator,
make_transform_iterator,
)
from ._permutation_iterator import make_permutation_iterator
from ._zip_iterator import make_zip_iterator


Expand Down
199 changes: 1 addition & 198 deletions python/cuda_cccl/cuda/compute/iterators/_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from llvmlite import ir
from numba import cuda, types
from numba.core.extending import as_numba_type, intrinsic, overload
from numba.core.extending import intrinsic, overload
from numba.core.typing.ctypes_utils import to_ctypes
from numba.cuda.dispatcher import CUDADispatcher

Expand Down Expand Up @@ -657,200 +657,3 @@ def _get_last_element_ptr(device_array) -> int:

ptr = get_data_pointer(device_array)
return ptr + offset_to_last_element


class PermutationIteratorKind(IteratorKind):
pass


def make_permutation_iterator(values, indices):
"""
Create a PermutationIterator that accesses values through an index mapping.

The permutation iterator accesses elements from `values` using indices from `indices`,
effectively computing values[indices[i]] at position i.

Args:
values: The values array or iterator to permute
indices: The indices array or iterator specifying the permutation

Returns:
PermutationIterator: Iterator that yields permuted values
"""
from ..struct import make_struct_type

# Convert arrays to iterators if needed
if hasattr(values, "__cuda_array_interface__"):
values = pointer(values, numba.from_dtype(get_dtype(values)))
elif not isinstance(values, IteratorBase):
raise TypeError("values must be a device array or iterator")

if hasattr(indices, "__cuda_array_interface__"):
indices = pointer(indices, numba.from_dtype(get_dtype(indices)))
elif not isinstance(indices, IteratorBase):
raise TypeError("indices must be an iterator or device array")

# JIT compile value advance/dereference methods
value_dtype = values.value_type
values_state_type = values.state_type
index_type = indices.value_type
value_advance = cuda.jit(values.advance, device=True)
value_input_dereference = cuda.jit(values.input_dereference, device=True)

try:
output_deref = values.output_dereference
if output_deref is not None:
value_output_dereference = cuda.jit(output_deref, device=True)
values_is_output_iterator = True
else:
values_is_output_iterator = False
except AttributeError:
values_is_output_iterator = False

# JIT compile index advance/dereference methods
index_advance = cuda.jit(indices.advance, device=True)
index_input_dereference = cuda.jit(indices.input_dereference, device=True)

# The cvalue and state for PermutationIterator are
# structs composed of the cvalues and states of the
# value and index iterators.
class PermutationCValueStruct(ctypes.Structure):
_fields_ = [
("value_state", values.cvalue.__class__),
("index_state", indices.cvalue.__class__),
]

PermutationState = make_struct_type(
"PermutationState",
field_names=("value_state", "index_state"),
field_types=(values_state_type, indices.state_type),
)

cvalue = PermutationCValueStruct(values.cvalue, indices.cvalue)
state_type = as_numba_type(PermutationState)
value_type = value_dtype

# Define intrinsics for accessing struct fields
@intrinsic
def get_value_state_field_ptr(context, struct_ptr_type):
def codegen(context, builder, sig, args):
struct_ptr = args[0]
# Use GEP to get pointer to field at index 0 (value_state)
field_ptr = builder.gep(
struct_ptr,
[ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 0)],
)
return field_ptr

from numba.core.datamodel.registry import default_manager

struct_model = default_manager.lookup(struct_ptr_type.dtype)
field_type = struct_model._members[0]
return types.CPointer(field_type)(struct_ptr_type), codegen

@intrinsic
def get_index_state_field_ptr(context, struct_ptr_type):
def codegen(context, builder, sig, args):
struct_ptr = args[0]
# Use GEP to get pointer to field at index 1 (index_state)
field_ptr = builder.gep(
struct_ptr,
[ir.Constant(ir.IntType(32), 0), ir.Constant(ir.IntType(32), 1)],
)
return field_ptr

from numba.core.datamodel.registry import default_manager

struct_model = default_manager.lookup(struct_ptr_type.dtype)
field_type = struct_model._members[1]
return types.CPointer(field_type)(struct_ptr_type), codegen

# Create intrinsic for allocating temporary storage for index
@intrinsic
def alloca_temp_for_index_type(context):
def codegen(context, builder, sig, args):
temp_value_type = context.get_value_type(index_type)
temp_ptr = builder.alloca(temp_value_type)
return temp_ptr

return types.CPointer(index_type)(), codegen

# Create intrinsic for allocating temporary storage for value state
@intrinsic
def alloca_temp_for_value_state(context):
def codegen(context, builder, sig, args):
temp_state_type = context.get_value_type(values_state_type)
temp_ptr = builder.alloca(temp_state_type)
return temp_ptr

return types.CPointer(values_state_type)(), codegen

class PermutationIterator(IteratorBase):
iterator_kind_type = PermutationIteratorKind

def __init__(self, values_it, indices_it):
self._values = values_it
self._indices = indices_it
super().__init__(
cvalue=cvalue,
state_type=state_type,
value_type=value_type,
)
self._kind = self.__class__.iterator_kind_type(
(value_type, values_it.kind, indices_it.kind), state_type
)

@property
def advance(self):
return PermutationIterator._advance

@property
def input_dereference(self):
return PermutationIterator._input_dereference

@property
def output_dereference(self):
if not values_is_output_iterator:
raise AttributeError(
"PermutationIterator cannot be used as output iterator "
"when values iterator does not support output"
)
return PermutationIterator._output_dereference

@staticmethod
def _advance(state, distance):
# advance the index iterator
index_state_ptr = get_index_state_field_ptr(state)
index_advance(index_state_ptr, distance)

@staticmethod
def _input_dereference(state, result):
# dereference index to get the index value
index_state_ptr = get_index_state_field_ptr(state)
temp_index = alloca_temp_for_index_type()
index_input_dereference(index_state_ptr, temp_index)

# copy the value state (which always points to position 0)
# and advance it by the index value
value_state_ptr = get_value_state_field_ptr(state)
temp_value_state = alloca_temp_for_value_state()
temp_value_state[0] = value_state_ptr[0]
value_advance(temp_value_state, temp_index[0])
value_input_dereference(temp_value_state, result)

@staticmethod
def _output_dereference(state, x):
# dereference index to get the index value
index_state_ptr = get_index_state_field_ptr(state)
temp_index = alloca_temp_for_index_type()
index_input_dereference(index_state_ptr, temp_index)

# copy the value state (which always points to position 0)
# and advance it by the index value
value_state_ptr = get_value_state_field_ptr(state)
temp_value_state = alloca_temp_for_value_state()
temp_value_state[0] = value_state_ptr[0]
value_advance(temp_value_state, temp_index[0])
value_output_dereference(temp_value_state, x)

return PermutationIterator(values, indices)
Loading
Loading