Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,4 @@ cov.syspath.txt
#include docs images
!docs/source/logo/*
!docs/source/images/*
napari_cellseg3d/dev_scripts/wandb
38 changes: 38 additions & 0 deletions docker/Dockerfile.cellseg3d
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# original file by Steffen Schneider https://github.com/stes/docker/tree/main
FROM nvidia/cuda:11.7.0-runtime-ubuntu20.04
ENV DEBIAN_FRONTEND=noninteractive

RUN apt-get update -yy
RUN apt-get install -yy --no-install-recommends \
git curl wget build-essential libhdf5-dev \
libgl1-mesa-glx libglib2.0-0 software-properties-common

ENV PYTHON_VERSION 3.8
RUN add-apt-repository ppa:deadsnakes/ppa
RUN apt-cache policy python$${PYTHON_VERSION}
RUN apt-get install -yy --no-install-recommends \
python${PYTHON_VERSION} \
python3-pip \
python${PYTHON_VERSION}-dev

RUN apt-get clean
RUN rm -rf /var/lib/apt/lists/*

RUN pip install --no-cache-dir torch==2.0.0 --extra-index-url https://download.pytorch.org/whl/cu117

RUN apt-get update -yy \
&& apt-get install -yy git \
&& apt-get install -yy vim

RUN git clone git+https://github.com/AdaptiveMotorControlLab/CellSeg3d@cy/jupyter-books-docs \
&& cd CellSeg3d \
&& pip3 install -e .[wandb]

# create user session
RUN useradd -ms /bin/bash cyril
USER cyril
WORKDIR /home/cellseg3d


# docker build -f Dockerfile.cellseg3d -t cyril/cellseg3d .
# docker run -it --rm --gpus device=3 --shm-size=4gb -v "$(pwd)":/workspace/cellseg3d_results --name CellSeg3D-GPU3
2 changes: 1 addition & 1 deletion napari_cellseg3d/dev_scripts/colab_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
)

self.dice_metric = DiceMetric(
include_background=True, reduction="mean", get_not_nans=False
include_background=False, reduction="mean", get_not_nans=False
)
self.normalize_function = utils.remap_image
self.start_time = time.time()
Expand Down
138 changes: 138 additions & 0 deletions napari_cellseg3d/dev_scripts/remote_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Showcases how to train a model without napari."""

from pathlib import Path

from napari_cellseg3d import config as cfg
from napari_cellseg3d.code_models.worker_training import (
SupervisedTrainingWorker,
)
from napari_cellseg3d.utils import LOGGER as logger

TRAINING_SPLIT = 0.2 # 0.4, 0.8
MODEL_NAME = "SegResNet" # "SwinUNetR"
BATCH_SIZE = 10 if MODEL_NAME == "SegResNet" else 5
# BATCH_SIZE = 1

SPLIT_FOLDER = "1_c15" # "2_c1_c4_visual" "3_c1245_visual"
RESULTS_PATH = (
Path("/data/cyril")
/ "CELLSEG_BENCHMARK/cellseg3d_train"
/ f"{MODEL_NAME}_{SPLIT_FOLDER}_{int(TRAINING_SPLIT*100)}"
)

IMAGES = (
Path("/data/cyril")
/ f"CELLSEG_BENCHMARK/TPH2_mesospim/SPLITS/{SPLIT_FOLDER}"
)
LABELS = (
Path("/data/cyril")
/ f"CELLSEG_BENCHMARK/TPH2_mesospim/SPLITS/{SPLIT_FOLDER}/labels/semantic"
)


class LogFixture:
"""Fixture for napari-less logging, replaces napari_cellseg3d.interface.Log in model_workers.

This allows to redirect the output of the workers to stdout instead of a specialized widget.
"""

def __init__(self):
"""Creates a LogFixture object."""
super(LogFixture, self).__init__()

def print_and_log(self, text, printing=None):
"""Prints and logs text."""
print(text)

def warn(self, warning):
"""Logs warning."""
logger.warning(warning)

def error(self, e):
"""Logs error."""
raise (e)


def prepare_data(images_path, labels_path):
"""Prepares data for training."""
assert images_path.exists(), f"Images path does not exist: {images_path}"
assert labels_path.exists(), f"Labels path does not exist: {labels_path}"
if not RESULTS_PATH.exists():
RESULTS_PATH.mkdir(parents=True, exist_ok=True)

images = sorted(Path.glob(images_path, "*.tif"))
labels = sorted(Path.glob(labels_path, "*.tif"))

print(f"Images paths: {images}")
print(f"Labels paths: {labels}")

logger.info("Images :\n")
for file in images:
logger.info(Path(file).name)
logger.info("*" * 10)
logger.info("Labels :\n")
for file in images:
logger.info(Path(file).name)

assert len(images) == len(
labels
), "Number of images and labels must be the same"

return [
{"image": str(image_path), "label": str(label_path)}
for image_path, label_path in zip(images, labels)
]


def remote_training():
"""Function to train a model without napari."""
# print(f"Results path: {RESULTS_PATH.resolve()}")

wandb_config = cfg.WandBConfig(
mode="online",
save_model_artifact=True,
)

deterministic_config = cfg.DeterministicConfig(
seed=34936339,
)

worker_config = cfg.SupervisedTrainingWorkerConfig(
device="cuda:0",
max_epochs=50,
learning_rate=0.001, # 1e-3
validation_interval=2,
batch_size=BATCH_SIZE, # 10 for SegResNet
deterministic_config=deterministic_config,
scheduler_factor=0.5,
scheduler_patience=10, # use default scheduler
weights_info=cfg.WeightsInfo(), # no pretrained weights
results_path_folder=str(RESULTS_PATH),
sampling=False,
do_augmentation=True,
train_data_dict=prepare_data(IMAGES, LABELS),
# supervised specific
model_info=cfg.ModelInfo(
name=MODEL_NAME,
model_input_size=(64, 64, 64),
),
loss_function="Generalized Dice",
training_percent=TRAINING_SPLIT,
)

worker = SupervisedTrainingWorker(worker_config)
worker.wandb_config = wandb_config
######### SET LOG
log = LogFixture()
worker.log_signal.connect(log.print_and_log)
worker.warn_signal.connect(log.warn)
worker.error_signal.connect(log.error)

results = []
for result in worker.train():
results.append(result)
print("Training finished")


if __name__ == "__main__":
results = remote_training()