Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 102 additions & 53 deletions cellpose/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None,
files=train_files, device=device)
if test_files is not None:
for k in trange(nimg_test):
tl = dynamics.labels_to_flows(io.imread(test_labels_files),
tl = dynamics.labels_to_flows(io.imread(testLabels_files),
files=test_files, device=device)

### compute diameters
Expand Down Expand Up @@ -314,7 +314,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
n_epochs=100, weight_decay=0.1, normalize=True, compute_flows=False,
save_path=None, save_every=100, save_each=False, nimg_per_epoch=None,
nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=256,
min_train_masks=5, model_name=None, class_weights=None):
min_train_masks=5, model_name=None, class_weights=None,
loss_callback=None, return_loss_arrays=True, early_stopping_patience=None):
"""
Train the network with images for segmentation.

Expand Down Expand Up @@ -346,10 +347,10 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to False.
min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5.
model_name (str, optional): String - name of the network. Defaults to None.

Returns:
tuple: A tuple containing the path to the saved model weights, training losses, and test losses.

loss_callback (callable, optional): Function called after each epoch with (epoch, train_loss, test_loss). Defaults to None.
return_loss_arrays (bool, optional): Whether to return full loss arrays or just the model path. Defaults to True.
early_stopping_patience (int, optional): Number of epochs without validation loss improvement before stopping.
If None, no early stopping. Defaults to None.
"""
if SGD:
train_logger.warning("SGD is deprecated, using AdamW instead")
Expand Down Expand Up @@ -432,7 +433,13 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
train_logger.info(f">>> saving model to {filename}")

lavg, nsum = 0, 0
train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs)
train_losses, test_losses = (np.zeros(n_epochs), np.zeros(n_epochs)) if return_loss_arrays else (None, None)

# Early stopping variables
best_val_loss = float('inf')
patience_counter = 0
best_model_path = None

for iepoch in range(n_epochs):
np.random.seed(iepoch)
if nimg != nimg_per_epoch:
Expand Down Expand Up @@ -481,53 +488,85 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
lavg += train_loss
nsum += len(imgi)
# per epoch training loss
train_losses[iepoch] += train_loss
train_losses[iepoch] /= nimg_per_epoch

if iepoch == 5 or iepoch % 10 == 0:
lavgt = 0.
if test_data is not None or test_files is not None:
np.random.seed(42)
if nimg_test != nimg_test_per_epoch:
rperm = np.random.choice(np.arange(0, nimg_test),
size=(nimg_test_per_epoch,), p=test_probs)
else:
rperm = np.random.permutation(np.arange(0, nimg_test))
for ibatch in range(0, len(rperm), batch_size):
with torch.no_grad():
net.eval()
inds = rperm[ibatch:ibatch + batch_size]
imgs, lbls = _get_batch(inds, data=test_data,
labels=test_labels, files=test_files,
labels_files=test_labels_files,
**kwargs)
diams = np.array([diam_test[i] for i in inds])
rsc = diams / net.diam_mean.item() if rescale else np.ones(
len(diams), "float32")
imgi, lbl = random_rotate_and_resize(
imgs, Y=lbls, rescale=rsc, scale_range=scale_range,
xy=(bsize, bsize))[:2]
X = torch.from_numpy(imgi).to(device)
lbl = torch.from_numpy(lbl).to(device)

if X.dtype != net.dtype:
X = X.to(net.dtype)
lbl = lbl.to(net.dtype)

y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
if y.shape[1] > 3:
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
loss += loss3
test_loss = loss.item()
test_loss *= len(imgi)
lavgt += test_loss
lavgt /= len(rperm)
if return_loss_arrays:
train_losses[iepoch] += train_loss
if return_loss_arrays:
train_losses[iepoch] /= nimg_per_epoch
epoch_train_loss = (lavg / nsum) if not return_loss_arrays else (train_losses[iepoch] / nimg_per_epoch if iepoch == 0 else train_losses[iepoch])

# Compute validation loss every epoch for real-time tracking
current_train_loss = lavg / nsum
lavgt = 0.

if test_data is not None or test_files is not None:
np.random.seed(42)
if nimg_test != nimg_test_per_epoch:
rperm = np.random.choice(np.arange(0, nimg_test),
size=(nimg_test_per_epoch,), p=test_probs)
else:
rperm = np.random.permutation(np.arange(0, nimg_test))
for ibatch in range(0, len(rperm), batch_size):
with torch.no_grad():
net.eval()
inds = rperm[ibatch:ibatch + batch_size]
imgs, lbls = _get_batch(inds, data=test_data,
labels=test_labels, files=test_files,
labels_files=test_labels_files,
**kwargs)
diams = np.array([diam_test[i] for i in inds])
rsc = diams / net.diam_mean.item() if rescale else np.ones(
len(diams), "float32")
imgi, lbl = random_rotate_and_resize(
imgs, Y=lbls, rescale=rsc, scale_range=scale_range,
xy=(bsize, bsize))[:2]
X = torch.from_numpy(imgi).to(device)
lbl = torch.from_numpy(lbl).to(device)

if X.dtype != net.dtype:
X = X.to(net.dtype)
lbl = lbl.to(net.dtype)

y = net(X)[0]
loss = _loss_fn_seg(lbl, y, device)
if y.shape[1] > 3:
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
loss += loss3
test_loss = loss.item()
test_loss *= len(imgi)
lavgt += test_loss
lavgt /= len(rperm)
if return_loss_arrays:
test_losses[iepoch] = lavgt
lavg /= nsum

# Early stopping logic
if early_stopping_patience is not None:
if lavgt < best_val_loss:
best_val_loss = lavgt
patience_counter = 0
# Save best model
best_model_path = str(filename) + "_best"
train_logger.info(f"New best validation loss: {lavgt:.4f}, saving model to {best_model_path}")
net.save_model(best_model_path)
else:
patience_counter += 1
train_logger.info(f"No improvement in validation loss for {patience_counter} epoch(s)")

if patience_counter >= early_stopping_patience:
train_logger.info(f"Early stopping triggered after {iepoch + 1} epochs")
break

# Log every epoch (more frequent logging for real-time tracking)
if iepoch == 5 or iepoch % 10 == 0:
train_logger.info(
f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s"
f"{iepoch}, train_loss={current_train_loss:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s"
)

# Call the callback function every epoch with both train and test loss
if loss_callback is not None:
loss_callback(iepoch, current_train_loss, lavgt if (test_data is not None or test_files is not None) else None)

# Reset accumulators after logging
if iepoch == 5 or iepoch % 10 == 0:
lavg, nsum = 0, 0

if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
Expand All @@ -538,10 +577,20 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
train_logger.info(f"saving network parameters to {filename0}")
net.save_model(filename0)

net.save_model(filename)
# Save final model if not using early stopping or if we completed all epochs
if early_stopping_patience is None or patience_counter < early_stopping_patience:
net.save_model(filename)

# If early stopping was used and we have a best model, use that
if best_model_path is not None:
train_logger.info(f"Training finished. Best model saved at: {best_model_path}")
filename = best_model_path

if original_net_dtype is not None:
net.dtype = original_net_dtype
net.to(original_net_dtype)

return filename, train_losses, test_losses
if return_loss_arrays:
return filename, train_losses, test_losses
else:
return filename