Skip to content
136 changes: 122 additions & 14 deletions igneous/tasks/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
cross_sectional_area_smoothing_window:int = 1,
cross_sectional_area_shape_delta:int = 150,
cross_sectional_area_repair_sec_per_label:int = 0, # default disabled
cross_sectional_area_low_memory_threshold:int = int(8e9),
dry_run:bool = False,
strip_integer_attributes:bool = True,
fix_autapses:bool = False,
Expand All @@ -98,8 +99,11 @@ def __init__(
dust_threshold, progress, parallel,
fill_missing, bool(sharded), frag_path, bool(spatial_index),
spatial_grid_shape, synapses, bool(dust_global),
bool(cross_sectional_area), int(cross_sectional_area_smoothing_window),
int(cross_sectional_area_shape_delta), int(cross_sectional_area_repair_sec_per_label),
bool(cross_sectional_area),
int(cross_sectional_area_smoothing_window),
int(cross_sectional_area_shape_delta),
int(cross_sectional_area_repair_sec_per_label),
int(cross_sectional_area_low_memory_threshold),
bool(dry_run), bool(strip_integer_attributes),
bool(fix_autapses), timestamp,
root_ids_cloudpath,
Expand All @@ -124,7 +128,7 @@ def execute(self):
lru_bytes = 0
lru_encoding = 'same'

if self.cross_sectional_area:
if self.cross_sectional_area and (self.cross_sectional_area_repair_sec_per_label > 0 or self.parallel == 1):
lru_bytes = self.bounds.size() + 2 * self.cross_sectional_area_shape_delta
lru_bytes = int(lru_bytes[0]) * int(lru_bytes[1]) * int(lru_bytes[2]) * 8 // 50
lru_encoding = 'crackle'
Expand Down Expand Up @@ -220,7 +224,13 @@ def decompress_all_labels():
skel.id = sid

if self.cross_sectional_area: # This is expensive!
skeletons = self.compute_cross_sectional_area(vol, bbox, skeletons)
if self.should_use_low_memory(bbox):
skeletons = self.compute_cross_sectional_area_low_mem(vol, bbox, skeletons)
else:
skeletons = self.compute_cross_sectional_area(vol, bbox, skeletons)

if self.cross_sectional_area_repair_sec_per_label != 0:
skeletons = self.repair_cross_sectional_area_contacts(vol, bbox, skeletons)

# voxel centered (+0.5) and uses more accurate bounding box from mip 0
corrected_offset = (bbox.minpt.astype(np.float32) - vol.meta.voxel_offset(self.mip) + 0.5) * vol.meta.resolution(self.mip)
Expand Down Expand Up @@ -255,6 +265,13 @@ def decompress_all_labels():
if self.spatial_index:
self.upload_spatial_index(vol, path, index_bbox, skeletons)

def should_use_low_memory(self, bbox:Bbox) -> bool:
bigger_bbx = bbox.clone()
bigger_bbx.grow(self.cross_sectional_area_shape_delta)
# Factor of 6 is based on observed behavior on 2025-10-06
# with delta +250
return bigger_bbx.volume() * 6 > self.cross_sectional_area_low_memory_threshold

def _compute_fill_holes(self, all_labels):
filled_labels, hole_labels_set = fastmorph.fill_holes(
all_labels,
Expand Down Expand Up @@ -392,16 +409,15 @@ def compute_cross_sectional_area(self, vol, bbox, skeletons):
big_bbox = bbox.clone()
big_bbox.grow(delta)
big_bbox = Bbox.clamp(big_bbox, vol.bounds)

big_bbox.minpt -= self.hole_filling_padding
big_bbox.maxpt += self.hole_filling_padding
big_bbox.grow(self.hole_filling_padding)

true_delta = bbox.minpt - big_bbox.minpt

# place the skeletons in exactly the same position
# in the enlarged image
for skel in skeletons.values():
skel.vertices += true_delta * vol.resolution
for label in skeletons.keys():
skeletons[label] = skeletons[label].voxel_space()
skeletons[label].vertices += true_delta

mapping = {}

Expand Down Expand Up @@ -452,14 +468,106 @@ def do_cross_section(labels):
skel.id = sid

# move the vertices back to their old smaller image location
for skel in skeletons.values():
skel.vertices -= true_delta * vol.resolution
for label in skeletons.keys():
skel = skeletons[label]
skel.vertices -= true_delta # move the vertices back to their old smaller image location
skeletons[label] = skel.physical_space()

if self.cross_sectional_area_repair_sec_per_label != 0:
return self.repair_cross_sectional_area_contacts(vol, bbox, skeletons)
else:
return skeletons

def compute_cross_sectional_area_low_mem(self, vol, bbox, skeletons):
if len(skeletons) == 0:
return skeletons

# Why redownload a bigger image? In order to avoid clipping the
# cross sectional areas on the edges.
delta = int(self.cross_sectional_area_shape_delta)

big_bbox = bbox.clone()
big_bbox.grow(delta)
big_bbox = Bbox.clamp(big_bbox, vol.bounds)
big_bbox.grow(self.hole_filling_padding)

true_delta = bbox.minpt - big_bbox.minpt

# place the skeletons in exactly the same position
# in the enlarged image
for label in skeletons.keys():
skeletons[label] = skeletons[label].voxel_space()
skeletons[label].vertices += true_delta

all_labels = vol.download(big_bbox, crackle=True)
all_labels.parallel = self.parallel

bbxes = all_labels.bounding_boxes(no_slice_conversion=True)
skel_labels = [ int(x) for x in skeletons.keys() ]

# For parallel=1, this gives better performance
# than decoding the crackle binary because
# it exploits the chunked nature of the precomputed
# representation to avoid decoding many chunks
# but for parallel > 1, there's sufficient firepower
# to go faster decoding the crackle volume natively.
class BinaryImageIterator:
def __len__(self):
return len(skel_labels)
def __iter__(self):
for label in skel_labels:
bbx = Bbox.from_list(bbxes[label])
bbx.maxpt += 1
bbx = bbx.clone()
bbx += big_bbox.minpt
yield label, vol.download(bbx, label=label)[...,0]

if self.parallel == 1:
iterator = BinaryImageIterator()
del all_labels
else:
iterator = all_labels.each(
crop=True,
labels=skel_labels,
)

with tqdm(
iterator,
disable=(not self.progress),
desc="Cross Sectional Area Analysis",
) as pbar:
for label, binimg in pbar:
pbar.set_postfix(label=str(label))

if self.fill_holes > 0:
binimg = fastmorph.fill_holes(
binimg,
remove_enclosed=True,
fix_borders=(self.fill_holes >= 2),
morphological_closing=(self.fill_holes >= 3),
progress=False,
)

if self.fill_holes >= 3:
hp = self.hole_filling_padding
binimg = np.asfortranarray(binimg[hp:-hp,hp:-hp,hp:-hp])

bbx = Bbox.from_list(bbxes[label])
bbx.maxpt += 1

skeletons[label] = kimimaro.cross_sectional_area_single(
binimg, skeletons[label],
anisotropy=vol.resolution,
smoothing_window=self.cross_sectional_area_smoothing_window,
progress=False,
in_place=True,
roi=bbx,
)

for label in skeletons.keys():
skel = skeletons[label]
skel.vertices -= true_delta # move the vertices back to their old smaller image location
skeletons[label] = skel.physical_space()

return skeletons

def repair_cross_sectional_area_contacts(self, vol, bbox, skeletons):
from dbscan import DBSCAN

Expand Down