Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ on:
push:
branches:
- main
- cy/utils
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would you add a temporary branch here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this PR targets another branch that is not main, the "on PR" condition is not fulfilled in Actions. So I'm overriding manually, this will be removed when I start merging all the PRs (which depend on each other)

tags:
- "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
pull_request:
Expand Down Expand Up @@ -51,7 +52,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install setuptools tox tox-gh-actions
# pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf
# pip install git+https://github.com/kodalli/pydensecrf.git@master#egg=pydensecrf

# this runs the platform-specific tests declared in tox.ini
- name: Test with tox
Expand Down
17 changes: 13 additions & 4 deletions napari_cellseg3d/_tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,27 @@ def test_inference_on_folder():
config.images_filepaths = [
str(Path(__file__).resolve().parent / "res/test.tif")
]
config.sliding_window_config.window_size = 64

def mock_work(x):
return x
class mock_work:
def __call__(self, x):
return x

def eval(self):
return None

worker = InferenceWorker(worker_config=config)
worker.aniso_transform = mock_work
worker.aniso_transform = mock_work()

image = torch.Tensor(rand_gen.random((1, 1, 64, 64, 64)))
res = worker.inference_on_folder(
{"image": image}, 0, model=mock_work, post_process_transforms=mock_work
{"image": image},
0,
model=mock_work(),
post_process_transforms=mock_work(),
)
assert isinstance(res, InferenceResult)
assert res.result is not None


def test_post_processing():
Expand Down
2 changes: 1 addition & 1 deletion napari_cellseg3d/_tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_pretrained_weights_compatibility():
for model_name in MODEL_LIST:
file_name = MODEL_LIST[model_name].weights_file
WeightsDownloader().download_weights(model_name, file_name)
model = MODEL_LIST[model_name](input_img_size=[128, 128, 128])
model = MODEL_LIST[model_name](input_img_size=[64, 64, 64])
try:
model.load_state_dict(
torch.load(
Expand Down
2 changes: 1 addition & 1 deletion napari_cellseg3d/code_models/crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Philipp Krähenbühl and Vladlen Koltun
NIPS 2011

Implemented using the pydense libary available at https://github.com/lucasb-eyer/pydensecrf.
Implemented using the pydense libary available at https://github.com/kodalli/pydensecrf.
"""
from warnings import warn

Expand Down
2 changes: 1 addition & 1 deletion napari_cellseg3d/code_models/models/model_SegResNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class SegResNet_(SegResNetVAE):
use_default_training = True
weights_file = "SegResNet.pth"
weights_file = "SegResNet_latest.pth"

def __init__(
self, input_img_size, out_channels=1, dropout_prob=0.3, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion napari_cellseg3d/code_models/models/model_SwinUNetR.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class SwinUNETR_(SwinUNETR):
use_default_training = True
weights_file = "Swin64_best_metric.pth"
weights_file = "SwinUNetR_latest.pth"

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class TRAILMAP_MS_(UNet3D):
use_default_training = True
weights_file = "TRAILMAP_MS_best_metric_epoch_26.pth"
weights_file = "TRAILMAP_MS_best_metric.pth"

# original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly TPH2 as of July 2022)

Expand Down
2 changes: 1 addition & 1 deletion napari_cellseg3d/code_models/models/model_VNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class VNet_(VNet):
use_default_training = True
weights_file = "VNet_40e.pth"
weights_file = "VNet_latest.pth"

def __init__(self, in_channels=1, out_channels=1, **kwargs):
try:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"TRAILMAP_MS": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/TRAILMAP_MS.tar.gz",
"SegResNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SegResNet.tar.gz",
"VNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/VNet.tar.gz",
"SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/Swin64.tar.gz",
"TRAILMAP_MS": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/TRAILMAP_latest.tar.gz",
"SegResNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SegResNet_latest.tar.gz",
"VNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/VNet_latest.tar.gz",
"SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SwinUNetR_latest.tar.gz",
"WNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet.tar.gz",
"WNet_ONNX": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet_onnx.tar.gz",
"test": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/test.tar.gz"
Expand Down
2 changes: 2 additions & 0 deletions napari_cellseg3d/code_models/worker_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ def save_image(
+ f"_{time}"
+ filetype
)
if not Path(self.config.results_path).exists():
Path(self.config.results_path).mkdir(parents=True, exist_ok=True)
try:
imwrite(file_path, image)
except ValueError as e:
Expand Down
6 changes: 5 additions & 1 deletion napari_cellseg3d/code_models/workers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(self, file_location):
except ImportError as e:
logger.error("ONNX is not installed but ONNX model was loaded")
logger.error(e)
msg = "PLEASE INSTALL ONNX CPU OR GPU USING pip install napari-cellseg3d[onnx-cpu] OR napari-cellseg3d[onnx-gpu]"
msg = "PLEASE INSTALL ONNX CPU OR GPU USING: pip install napari-cellseg3d[onnx-cpu] OR pip install napari-cellseg3d[onnx-gpu]"
logger.error(msg)
raise ImportError(msg) from e

Expand All @@ -177,6 +177,8 @@ def to(self, device):


class QuantileNormalizationd(MapTransform):
"""MONAI-style dict transform to normalize each image in a batch individually by quantile normalization."""

def __init__(self, keys, allow_missing_keys: bool = False):
super().__init__(keys, allow_missing_keys)

Expand All @@ -199,6 +201,8 @@ def normalizer(self, image: torch.Tensor):


class QuantileNormalization(Transform):
"""MONAI-style transform to normalize each image in a batch individually by quantile normalization."""

def __call__(self, img):
return utils.quantile_normalization(img)

Expand Down
37 changes: 36 additions & 1 deletion napari_cellseg3d/code_plugins/plugin_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,10 @@ def _update_default_paths(self, path=None):
self.extract_dataset_paths(self.labels_filepaths),
self.results_path,
]
return
return utils.parse_default_path(self._default_path)
if Path(path).is_dir():
self._default_path.append(path)
return utils.parse_default_path(self._default_path)

@staticmethod
def extract_dataset_paths(paths):
Expand All @@ -458,3 +459,37 @@ def extract_dataset_paths(paths):
if paths[0] is None:
return None
return str(Path(paths[0]).parent)


class BasePluginUtils(BasePluginFolder):
"""Small subclass used to have centralized widgets layer and result path selection in utilities"""

save_path = None
utils_default_paths = [Path.home() / "cellseg3d"]

def __init__(
self,
viewer: napari.viewer.Viewer,
parent=None,
loads_images=True,
loads_labels=True,
):
super().__init__(
viewer=viewer,
loads_images=loads_images,
loads_labels=loads_labels,
parent=parent,
)
if parent is not None:
self.setParent(parent)
self.parent = parent

self.layer = None
"""Should contain the layer associated with the results of the utility widget"""

def _update_default_paths(self, path=None):
"""Override to also update utilities' pool of default paths"""
default_path = super()._update_default_paths(path)
logger.debug(f"Trying to update default with {default_path}")
if default_path is not None:
self.utils_default_paths.append(default_path)
Loading