Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions napari_cellseg3d/code_models/worker_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,11 @@ def log_parameters(self):
self.log("*" * 20)

def train(
self, provided_model=None, provided_optimizer=None, provided_loss=None
self,
provided_model=None,
provided_optimizer=None,
provided_loss=None,
wandb_name_override=None,
):
"""Main training function.

Expand All @@ -412,6 +416,7 @@ def train(
provided_model (WNet, optional): A model to use for training. Defaults to None.
provided_optimizer (torch.optim.Optimizer, optional): An optimizer to use for training. Defaults to None.
provided_loss (torch.nn.Module, optional): A loss function to use for training. Defaults to None.
wandb_name_override (str, optional): A name to override the wandb run name. Defaults to None.
"""
try:
if self.config is None:
Expand All @@ -431,7 +436,9 @@ def train(
wandb.init(
config=config_dict,
project="CellSeg3D - WNet",
name=f"WNet_training - {utils.get_date_time()}",
name=f"WNet_training - {utils.get_date_time()}"
if wandb_name_override is None
else wandb_name_override,
mode=self.wandb_config.mode,
tags=["WNet", "training"],
)
Expand Down Expand Up @@ -1079,6 +1086,7 @@ def train(
provided_optimizer=None,
provided_loss=None,
provided_scheduler=None,
wandb_name_override=None,
):
"""Trains the PyTorch model for the given number of epochs.

Expand Down Expand Up @@ -1142,7 +1150,9 @@ def train(
wandb.init(
config=config_dict,
project="CellSeg3D",
name=f"{model_config.name}_supervised_training - {utils.get_date_time()}",
name=f"{model_config.name}_supervised_training - {utils.get_date_time()}"
if wandb_name_override is None
else wandb_name_override,
tags=[f"{model_config.name}", "supervised"],
mode=self.wandb_config.mode,
)
Expand Down Expand Up @@ -1203,7 +1213,10 @@ def train(
epoch_loss_values = []
val_metric_values = []

if len(self.config.train_data_dict) > 1:
if (
len(self.config.train_data_dict) > 1
and self.config.eval_data_dict is None
):
self.train_files, self.val_files = (
self.config.train_data_dict[
0 : int(
Expand All @@ -1218,6 +1231,11 @@ def train(
) :
],
)
elif self.config.eval_data_dict is not None:
# train files are used as is, validation files are from eval_data_dict
# not used in the plugin yet, only for training via the API
self.train_files = self.config.train_data_dict
self.val_files = self.config.eval_data_dict
else:
self.train_files = self.val_files = self.config.train_data_dict
msg = f"Only one image file was provided : {self.config.train_data_dict[0]['image']}.\n"
Expand Down Expand Up @@ -1591,6 +1609,10 @@ def get_patch_loader_func(num_samples):
val_data["image"].to(device),
val_data["label"].to(device),
)
if self.labels_not_semantic:
val_labels = torch.where(
val_labels > 1, 1, val_labels
)

try:
with torch.no_grad():
Expand Down Expand Up @@ -1624,7 +1646,11 @@ def get_patch_loader_func(num_samples):
EnsureType(),
]
) #
post_label = EnsureType()
post_label = Compose(
[
EnsureType(),
]
)

output_raw = [
RemapTensor(new_max=1, new_min=0)(t)
Expand Down
2 changes: 2 additions & 0 deletions napari_cellseg3d/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,13 @@ class SupervisedTrainingWorkerConfig(TrainingWorkerConfig):
"""Class to record config for Trainer plugin.

Args:
eval_data_dict (dict): dict of eval data as {"image": np.array, "labels": np.array}. Optional.
model_info (ModelInfo): model info
loss_function (callable): loss function
validation_percent (float): validation percent
"""

eval_data_dict: dict = None
model_info: ModelInfo = None
loss_function: callable = None
training_percent: float = 0.8
Expand Down