Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions cellpose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import numpy as np
from cellpose import io, utils, models, dynamics
from cellpose.transforms import normalize_img, random_rotate_and_resize
from cellpose.transforms import normalize_img, random_rotate_and_resize, convert_image
from pathlib import Path
import torch
from torch import nn
Expand All @@ -19,15 +19,15 @@ def _loss_fn_class(lbl, y, class_weights=None):
Args:
lbl (numpy.ndarray): True labels (cellprob, flowsY, flowsX).
y (torch.Tensor): Predicted values (flowsY, flowsX, cellprob).

Returns:
torch.Tensor: Loss value.

"""

criterion3 = nn.CrossEntropyLoss(reduction="mean", weight=class_weights)
loss3 = criterion3(y[:, :-3], lbl[:, 0].long())

return loss3

def _loss_fn_seg(lbl, y, device):
Expand Down Expand Up @@ -69,7 +69,7 @@ def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False}
for td in data:
if td.ndim == 3:
channel_axis0 = channel_axis if channel_axis is not None else np.array(td.shape).argmin()
# put channel axis first
# put channel axis first
td = np.moveaxis(td, channel_axis0, 0)
td = td[:3] # keep at most 3 channels
if td.ndim == 2 or (td.ndim == 3 and td.shape[0] == 1):
Expand All @@ -85,7 +85,7 @@ def _reshape_norm(data, channel_axis=None, normalize_params={"normalize": False}
]
return data

def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
def _get_batch(inds, data=None, labels=None, files=None, labels_files=None, channel_axis=None,
normalize_params={"normalize": False}):
"""
Get a batch of images and labels.
Expand All @@ -96,6 +96,7 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
labels (list or None): List of label data. If None, labels will be loaded from files.
files (list or None): List of file paths for images.
labels_files (list or None): List of file paths for labels.
channel_axis (int or None): Axis of channel dimension.
normalize_params (dict): Dictionary of parameters for image normalization (will be faster, if loading from files to pre-normalize).

Returns:
Expand All @@ -104,7 +105,7 @@ def _get_batch(inds, data=None, labels=None, files=None, labels_files=None,
if data is None:
lbls = None
imgs = [io.imread(files[i]) for i in inds]
imgs = _reshape_norm(imgs, normalize_params=normalize_params)
imgs = _reshape_norm(imgs, channel_axis=channel_axis, normalize_params=normalize_params)
if labels_files is not None:
lbls = [io.imread(labels_files[i])[1:] for i in inds]
else:
Expand Down Expand Up @@ -140,7 +141,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None,
train_labels_files=None, train_probs=None, test_data=None,
test_labels=None, test_files=None, test_labels_files=None,
test_probs=None, load_files=True, min_train_masks=5,
compute_flows=False, normalize_params={"normalize": False},
compute_flows=False, normalize_params={"normalize": False},
channel_axis=None, device=None):
"""
Process train and test data.
Expand Down Expand Up @@ -170,7 +171,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None,
"""
if device == None:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None

if train_data is not None and train_labels is not None:
# if data is loaded
nimg = len(train_data)
Expand Down Expand Up @@ -229,13 +230,22 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None,
test_labels = dynamics.labels_to_flows(test_labels, files=test_files,
device=device)
elif compute_flows:
for k in trange(nimg):
tl = dynamics.labels_to_flows(io.imread(train_labels_files),
files=train_files, device=device)
# Compute flows one image at a time to avoid loading all into memory
for k in trange(nimg) if nimg > 1 else range(nimg):
dynamics.labels_to_flows([io.imread(train_labels_files[k])],
files=[train_files[k]], device=device)
# Update train_labels_files to point to the newly created flow files
train_labels_files = [
os.path.splitext(str(tf))[0] + "_flows.tif" for tf in train_files
]
if test_files is not None:
for k in trange(nimg_test):
tl = dynamics.labels_to_flows(io.imread(test_labels_files),
files=test_files, device=device)
for k in trange(nimg_test) if nimg_test > 1 else range(nimg_test):
dynamics.labels_to_flows([io.imread(test_labels_files[k])],
files=[test_files[k]], device=device)
# Update test_labels_files to point to the newly created flow files
test_labels_files = [
os.path.splitext(str(tf))[0] + "_flows.tif" for tf in test_files
]

### compute diameters
nmasks = np.zeros(nimg)
Expand Down Expand Up @@ -268,6 +278,8 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None,
f"{nremove} train images with number of masks less than min_train_masks ({min_train_masks}), removing from train set"
)
ikeep = np.nonzero(nmasks >= min_train_masks)[0]

# Filter all arrays/lists first
if train_data is not None:
train_data = [train_data[i] for i in ikeep]
train_labels = [train_labels[i] for i in ikeep]
Expand All @@ -278,7 +290,14 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None,
if train_probs is not None:
train_probs = train_probs[ikeep]
diam_train = diam_train[ikeep]
nimg = len(train_data)

# Recompute nimg after filtering
if train_data is not None:
nimg = len(train_data)
elif train_files is not None:
nimg = len(train_files)
elif train_labels_files is not None:
nimg = len(train_labels_files)

### normalize probabilities
train_probs = 1. / nimg * np.ones(nimg,
Expand All @@ -294,7 +313,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None,
if normalize_params["normalize"]:
train_logger.info(f">>> normalizing {normalize_params}")
if train_data is not None:
train_data = _reshape_norm(train_data, channel_axis=channel_axis,
train_data = _reshape_norm(train_data, channel_axis=channel_axis,
normalize_params=normalize_params)
normed = True
if test_data is not None:
Expand Down Expand Up @@ -349,7 +368,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,

Returns:
tuple: A tuple containing the path to the saved model weights, training losses, and test losses.

"""
if SGD:
train_logger.warning("SGD is deprecated, using AdamW instead")
Expand All @@ -359,7 +378,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
original_net_dtype = None
if device.type == 'mps' and net.dtype == torch.bfloat16:
# NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \
original_net_dtype = torch.bfloat16
original_net_dtype = torch.bfloat16
train_logger.warning("Training with bfloat16 on MPS is not supported, using float32 network instead")
net.dtype = torch.float32
net.to(torch.float32)
Expand Down Expand Up @@ -391,7 +410,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
kwargs = {}
else:
kwargs = {"normalize_params": normalize_params, "channel_axis": channel_axis}

net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device)

if class_weights is not None and isinstance(class_weights, (list, np.ndarray, tuple)):
Expand Down Expand Up @@ -513,12 +532,12 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
if X.dtype != net.dtype:
X = X.to(net.dtype)
lbl = lbl.to(net.dtype)

y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
if y.shape[1] > 3:
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
loss += loss3
loss += loss3
test_loss = loss.item()
test_loss *= len(imgi)
lavgt += test_loss
Expand All @@ -537,7 +556,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
filename0 = filename
train_logger.info(f"saving network parameters to {filename0}")
net.save_model(filename0)

net.save_model(filename)

if original_net_dtype is not None:
Expand Down
Loading