Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions cuslines/cuda_python/cu_propagate_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def get_buffer_size(self):
buffer_size += lens[jj] * 3 * REAL_SIZE
return math.ceil(buffer_size / MEGABYTE)

def as_generator(self):
def as_generator(self, minlen=0, maxlen=np.inf):
def _yield_slines():
for ii in range(self.ngpus):
this_sls = self.slines[ii]
Expand All @@ -252,9 +252,14 @@ def _yield_slines():
for jj in range(self.nSlines[ii]):
npts = this_len[jj]

if npts < minlen or npts > maxlen:
continue

yield np.asarray(this_sls[jj], dtype=REAL_DTYPE)[:npts]

return _yield_slines()

def as_array_sequence(self):
return ArraySequence(self.as_generator(), self.get_buffer_size())
def as_array_sequence(self, minlen=0, maxlen=np.inf):
return ArraySequence(
self.as_generator(minlen=minlen, maxlen=maxlen),
self.get_buffer_size())
Comment thread
36000 marked this conversation as resolved.
Outdated
14 changes: 8 additions & 6 deletions cuslines/cuda_python/cu_tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
Maximum angle (in radians) between steps
default: radians(60)
step_size : float, optional
Step size for tracking
Step size for tracking, in voxels
default: 0.5
relative_peak_thresh : float, optional
Relative peak threshold for direction selection
Expand Down Expand Up @@ -235,7 +235,7 @@ def _divide_chunks(self, seeds):
nchunks = (seeds.shape[0] + global_chunk_sz - 1) // global_chunk_sz
return global_chunk_sz, nchunks

def generate_sft(self, seeds, ref_img):
def generate_sft(self, seeds, ref_img, minlen=0, maxlen=np.inf):
global_chunk_sz, nchunks = self._divide_chunks(seeds)
buffer_size = 0
Comment thread
36000 marked this conversation as resolved.
Outdated
generators = []
Expand All @@ -246,7 +246,8 @@ def generate_sft(self, seeds, ref_img):
seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz]
)
buffer_size += self.seed_propagator.get_buffer_size()
generators.append(self.seed_propagator.as_generator())
generators.append(self.seed_propagator.as_generator(
minlen=minlen, maxlen=maxlen))
pbar.update(
seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz].shape[0]
)
Expand All @@ -255,12 +256,12 @@ def generate_sft(self, seeds, ref_img):
)
Comment thread
36000 marked this conversation as resolved.
return StatefulTractogram(array_sequence, ref_img, Space.VOX)

def generate_trx(self, seeds, ref_img):
def generate_trx(self, seeds, ref_img, minlen=0, maxlen=np.inf):
global_chunk_sz, nchunks = self._divide_chunks(seeds)

# Will resize by a factor of 2 if these are exceeded
sl_len_guess = 100
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change from sl_per_seed_guess = 4 to sl_per_seed_guess = 2 appears unrelated to adding min/max length filtering. This change affects memory allocation for TRX file generation and could impact performance if the actual number of streamlines per seed exceeds this guess (causing more frequent resizes). If this change is intentional and related to the filtering reducing expected streamlines per seed, it should be documented in the PR description or as a comment in the code.

Suggested change
sl_len_guess = 100
sl_len_guess = 100
# Heuristic: initial guess of how many streamlines we get per seed.
# With the current min/max length filtering in the GPU propagator we
# typically obtain fewer valid streamlines per seed than before, so
# we lower the guess from 4 to 2 to avoid over-allocating memory.
# If the filtering or seeding strategy changes (e.g. more accepted
# streamlines per seed), this value should be re-evaluated to balance
# memory usage against the cost of TRX internal resizes.

Copilot uses AI. Check for mistakes.
sl_per_seed_guess = 4
sl_per_seed_guess = 2
n_sls_guess = sl_per_seed_guess * seeds.shape[0]

# trx files use memory mapping
Expand All @@ -284,7 +285,8 @@ def generate_trx(self, seeds, ref_img):
seeds[idx * global_chunk_sz : (idx + 1) * global_chunk_sz]
)
tractogram = Tractogram(
self.seed_propagator.as_array_sequence(),
self.seed_propagator.as_array_sequence(
minlen=minlen, maxlen=maxlen),
affine_to_rasmm=ref_img.affine,
)
tractogram.to_world()
Expand Down
Loading