Skip to content

Commit 8dc5387

Browse files
authored
Merge pull request #35 from 36000/minmaxlen
[ENH] add min/max len
2 parents f7bf5f7 + 682ed06 commit 8dc5387

2 files changed

Lines changed: 26 additions & 5 deletions

File tree

cuslines/cuda_python/cu_propagate_seeds.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525

2626

2727
class SeedBatchPropagator:
28-
def __init__(self, gpu_tracker):
28+
def __init__(self, gpu_tracker, minlen=0, maxlen=np.inf):
2929
self.gpu_tracker = gpu_tracker
3030
self.ngpus = gpu_tracker.ngpus
31+
self.minlen = minlen
32+
self.maxlen = maxlen
3133

3234
self.nSlines_old = np.zeros(self.ngpus, dtype=np.int32)
3335
self.nSlines = np.zeros(self.ngpus, dtype=np.int32)
@@ -240,6 +242,8 @@ def get_buffer_size(self):
240242
for ii in range(self.ngpus):
241243
lens = self.sline_lens[ii]
242244
for jj in range(self.nSlines[ii]):
245+
if lens[jj] < self.minlen or lens[jj] > self.maxlen:
246+
continue
243247
buffer_size += lens[jj] * 3 * REAL_SIZE
244248
return math.ceil(buffer_size / MEGABYTE)
245249

@@ -252,9 +256,14 @@ def _yield_slines():
252256
for jj in range(self.nSlines[ii]):
253257
npts = this_len[jj]
254258

259+
if npts < self.minlen or npts > self.maxlen:
260+
continue
261+
255262
yield np.asarray(this_sls[jj], dtype=REAL_DTYPE)[:npts]
256263

257264
return _yield_slines()
258265

259266
def as_array_sequence(self):
260-
return ArraySequence(self.as_generator(), self.get_buffer_size())
267+
return ArraySequence(
268+
self.as_generator(),
269+
self.get_buffer_size())

cuslines/cuda_python/cu_tractography.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def __init__(
4444
sphere_edges: np.ndarray,
4545
max_angle: float = radians(60),
4646
step_size: float = 0.5,
47+
min_pts=0,
48+
max_pts=np.inf,
4749
relative_peak_thresh: float = 0.25,
4850
min_separation_angle: float = radians(45),
4951
ngpus: int = 1,
@@ -74,8 +76,14 @@ def __init__(
7476
Maximum angle (in radians) between steps
7577
default: radians(60)
7678
step_size : float, optional
77-
Step size for tracking
79+
Step size for tracking, in voxels
7880
default: 0.5
81+
min_pts : int, optional
82+
Minimum number of points in a streamline to be kept
83+
default: 0
84+
max_pts : int, optional
85+
Maximum number of points in a streamline to be kept
86+
default: np.inf
7987
relative_peak_thresh : float, optional
8088
Relative peak threshold for direction selection
8189
default: 0.25
@@ -136,7 +144,11 @@ def __init__(
136144
self.streams = []
137145
self.managed_data = []
138146

139-
self.seed_propagator = SeedBatchPropagator(gpu_tracker=self)
147+
self.seed_propagator = SeedBatchPropagator(
148+
gpu_tracker=self,
149+
minlen=min_pts,
150+
maxlen=max_pts
151+
)
140152
self._allocated = False
141153

142154
def __enter__(self):
@@ -260,7 +272,7 @@ def generate_trx(self, seeds, ref_img):
260272

261273
# Will resize by a factor of 2 if these are exceeded
262274
sl_len_guess = 100
263-
sl_per_seed_guess = 4
275+
sl_per_seed_guess = 2
264276
n_sls_guess = sl_per_seed_guess * seeds.shape[0]
265277

266278
# trx files use memory mapping

0 commit comments

Comments
 (0)