diff --git a/cellpose/train.py b/cellpose/train.py index 401c0efc..c78c814f 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -2,7 +2,7 @@ import os import numpy as np from cellpose import io, utils, models, dynamics -from cellpose.transforms import normalize_img, random_rotate_and_resize +from cellpose.transforms import normalize_img, random_rotate_and_resize, convert_image from pathlib import Path import torch from torch import nn @@ -19,7 +19,7 @@ def _loss_fn_class(lbl, y, class_weights=None): Args: lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX). y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob). - + Returns: torch.Tensor: Loss value. @@ -27,7 +27,7 @@ def _loss_fn_class(lbl, y, class_weights=None): criterion3 = nn.CrossEntropyLoss(reduction="mean", weight=class_weights) loss3 = criterion3(y[:, :-3], lbl[:, 0].long()) - + return loss3 def _loss_fn_seg(lbl, y, device): @@ -69,7 +69,7 @@ def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False} for td in data: if td.ndim == 3: channel_axis0 = channel_axis if channel_axis is not None else np.array(td.shape).argmin() - # put channel axis first + # put channel axis first td = np.moveaxis(td, channel_axis0, 0) td = td[:3] # keep at most 3 channels if td.ndim == 2 or (td.ndim == 3 and td.shape[0] == 1): @@ -85,7 +85,7 @@ def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False} ] return data -def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, +def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, channel_axis=None, normalize_params={"normalize": False}): """ Get a batch of images and labels. @@ -96,6 +96,7 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, labels (list or None): List of label data. If None, labels will be loaded from files. files (list or None): List of file paths for images. labels_files (list or None): List of file paths for labels. + channel_axis (int or None): Axis of channel dimension. normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize). Returns: @@ -104,7 +105,7 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, if data is None: lbls = None imgs = [io.imread(files[i]) for i in inds] - imgs = _reshape_norm(imgs, normalize_params=normalize_params) + imgs = _reshape_norm(imgs, channel_axis=channel_axis, normalize_params=normalize_params) if labels_files is not None: lbls = [io.imread(labels_files[i])[1:] for i in inds] else: @@ -140,7 +141,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, train_labels_files=None, train_probs=None, test_data=None, test_labels=None, test_files=None, test_labels_files=None, test_probs=None, load_files=True, min_train_masks=5, - compute_flows=False, normalize_params={"normalize": False}, + compute_flows=False, normalize_params={"normalize": False}, channel_axis=None, device=None): """ Process train and test data. @@ -170,7 +171,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, """ if device == None: device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None - + if train_data is not None and train_labels is not None: # if data is loaded nimg = len(train_data) @@ -229,13 +230,22 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, test_labels = dynamics.labels_to_flows(test_labels, files=test_files, device=device) elif compute_flows: - for k in trange(nimg): - tl = dynamics.labels_to_flows(io.imread(train_labels_files), - files=train_files, device=device) + # Compute flows one image at a time to avoid loading all into memory + for k in trange(nimg) if nimg > 1 else range(nimg): + dynamics.labels_to_flows([io.imread(train_labels_files[k])], + files=[train_files[k]], device=device) + # Update train_labels_files to point to the newly created flow files + train_labels_files = [ + os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files + ] if test_files is not None: - for k in trange(nimg_test): - tl = dynamics.labels_to_flows(io.imread(test_labels_files), - files=test_files, device=device) + for k in trange(nimg_test) if nimg_test > 1 else range(nimg_test): + dynamics.labels_to_flows([io.imread(test_labels_files[k])], + files=[test_files[k]], device=device) + # Update test_labels_files to point to the newly created flow files + test_labels_files = [ + os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files + ] ### compute diameters nmasks = np.zeros(nimg) @@ -268,6 +278,8 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, f"{nremove} train images with number of masks less than min_train_masks ({min_train_masks}), removing from train set" ) ikeep = np.nonzero(nmasks >= min_train_masks)[0] + + # Filter all arrays/lists first if train_data is not None: train_data = [train_data[i] for i in ikeep] train_labels = [train_labels[i] for i in ikeep] @@ -278,7 +290,14 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, if train_probs is not None: train_probs = train_probs[ikeep] diam_train = diam_train[ikeep] - nimg = len(train_data) + + # Recompute nimg after filtering + if train_data is not None: + nimg = len(train_data) + elif train_files is not None: + nimg = len(train_files) + elif train_labels_files is not None: + nimg = len(train_labels_files) ### normalize probabilities train_probs = 1. / nimg * np.ones(nimg, @@ -294,7 +313,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, if normalize_params["normalize"]: train_logger.info(f">>> normalizing {normalize_params}") if train_data is not None: - train_data = _reshape_norm(train_data, channel_axis=channel_axis, + train_data = _reshape_norm(train_data, channel_axis=channel_axis, normalize_params=normalize_params) normed = True if test_data is not None: @@ -349,7 +368,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, Returns: tuple: A tuple containing the path to the saved model weights, training losses, and test losses. - + """ if SGD: train_logger.warning("SGD is deprecated, using AdamW instead") @@ -359,7 +378,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, original_net_dtype = None if device.type == 'mps' and net.dtype == torch.bfloat16: # NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \ - original_net_dtype = torch.bfloat16 + original_net_dtype = torch.bfloat16 train_logger.warning("Training with bfloat16 on MPS is not supported, using float32 network instead") net.dtype = torch.float32 net.to(torch.float32) @@ -391,7 +410,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, kwargs = {} else: kwargs = {"normalize_params": normalize_params, "channel_axis": channel_axis} - + net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device) if class_weights is not None and isinstance(class_weights, (list, np.ndarray, tuple)): @@ -513,12 +532,12 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, if X.dtype != net.dtype: X = X.to(net.dtype) lbl = lbl.to(net.dtype) - + y = net(X)[0] loss = _loss_fn_seg(lbl, y, device) if y.shape[1] > 3: loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) - loss += loss3 + loss += loss3 test_loss = loss.item() test_loss *= len(imgi) lavgt += test_loss @@ -537,7 +556,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, filename0 = filename train_logger.info(f"saving network parameters to {filename0}") net.save_model(filename0) - + net.save_model(filename) if original_net_dtype is not None: