diff --git a/cellpose/train.py b/cellpose/train.py index 401c0efc..3d5a4994 100644 --- a/cellpose/train.py +++ b/cellpose/train.py @@ -234,7 +234,7 @@ def _process_train_test(train_data=None, train_labels=None, train_files=None, files=train_files, device=device) if test_files is not None: for k in trange(nimg_test): - tl = dynamics.labels_to_flows(io.imread(test_labels_files), + tl = dynamics.labels_to_flows(io.imread(testLabels_files), files=test_files, device=device) ### compute diameters @@ -314,7 +314,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, n_epochs=100, weight_decay=0.1, normalize=True, compute_flows=False, save_path=None, save_every=100, save_each=False, nimg_per_epoch=None, nimg_test_per_epoch=None, rescale=False, scale_range=None, bsize=256, - min_train_masks=5, model_name=None, class_weights=None): + min_train_masks=5, model_name=None, class_weights=None, + loss_callback=None, return_loss_arrays=True, early_stopping_patience=None): """ Train the network with images for segmentation. @@ -346,10 +347,10 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, rescale (bool, optional): Boolean - whether or not to rescale images during training. Defaults to False. min_train_masks (int, optional): Integer - minimum number of masks an image must have to use in the training set. Defaults to 5. model_name (str, optional): String - name of the network. Defaults to None. - - Returns: - tuple: A tuple containing the path to the saved model weights, training losses, and test losses. - + loss_callback (callable, optional): Function called after each epoch with (epoch, train_loss, test_loss). Defaults to None. + return_loss_arrays (bool, optional): Whether to return full loss arrays or just the model path. Defaults to True. + early_stopping_patience (int, optional): Number of epochs without validation loss improvement before stopping. + If None, no early stopping. Defaults to None. """ if SGD: train_logger.warning("SGD is deprecated, using AdamW instead") @@ -432,7 +433,13 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_logger.info(f">>> saving model to {filename}") lavg, nsum = 0, 0 - train_losses, test_losses = np.zeros(n_epochs), np.zeros(n_epochs) + train_losses, test_losses = (np.zeros(n_epochs), np.zeros(n_epochs)) if return_loss_arrays else (None, None) + + # Early stopping variables + best_val_loss = float('inf') + patience_counter = 0 + best_model_path = None + for iepoch in range(n_epochs): np.random.seed(iepoch) if nimg != nimg_per_epoch: @@ -481,53 +488,85 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, lavg += train_loss nsum += len(imgi) # per epoch training loss - train_losses[iepoch] += train_loss - train_losses[iepoch] /= nimg_per_epoch - - if iepoch == 5 or iepoch % 10 == 0: - lavgt = 0. - if test_data is not None or test_files is not None: - np.random.seed(42) - if nimg_test != nimg_test_per_epoch: - rperm = np.random.choice(np.arange(0, nimg_test), - size=(nimg_test_per_epoch,), p=test_probs) - else: - rperm = np.random.permutation(np.arange(0, nimg_test)) - for ibatch in range(0, len(rperm), batch_size): - with torch.no_grad(): - net.eval() - inds = rperm[ibatch:ibatch + batch_size] - imgs, lbls = _get_batch(inds, data=test_data, - labels=test_labels, files=test_files, - labels_files=test_labels_files, - **kwargs) - diams = np.array([diam_test[i] for i in inds]) - rsc = diams / net.diam_mean.item() if rescale else np.ones( - len(diams), "float32") - imgi, lbl = random_rotate_and_resize( - imgs, Y=lbls, rescale=rsc, scale_range=scale_range, - xy=(bsize, bsize))[:2] - X = torch.from_numpy(imgi).to(device) - lbl = torch.from_numpy(lbl).to(device) - - if X.dtype != net.dtype: - X = X.to(net.dtype) - lbl = lbl.to(net.dtype) - - y = net(X)[0] - loss = _loss_fn_seg(lbl, y, device) - if y.shape[1] > 3: - loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) - loss += loss3 - test_loss = loss.item() - test_loss *= len(imgi) - lavgt += test_loss - lavgt /= len(rperm) + if return_loss_arrays: + train_losses[iepoch] += train_loss + if return_loss_arrays: + train_losses[iepoch] /= nimg_per_epoch + epoch_train_loss = (lavg / nsum) if not return_loss_arrays else (train_losses[iepoch] / nimg_per_epoch if iepoch == 0 else train_losses[iepoch]) + + # Compute validation loss every epoch for real-time tracking + current_train_loss = lavg / nsum + lavgt = 0. + + if test_data is not None or test_files is not None: + np.random.seed(42) + if nimg_test != nimg_test_per_epoch: + rperm = np.random.choice(np.arange(0, nimg_test), + size=(nimg_test_per_epoch,), p=test_probs) + else: + rperm = np.random.permutation(np.arange(0, nimg_test)) + for ibatch in range(0, len(rperm), batch_size): + with torch.no_grad(): + net.eval() + inds = rperm[ibatch:ibatch + batch_size] + imgs, lbls = _get_batch(inds, data=test_data, + labels=test_labels, files=test_files, + labels_files=test_labels_files, + **kwargs) + diams = np.array([diam_test[i] for i in inds]) + rsc = diams / net.diam_mean.item() if rescale else np.ones( + len(diams), "float32") + imgi, lbl = random_rotate_and_resize( + imgs, Y=lbls, rescale=rsc, scale_range=scale_range, + xy=(bsize, bsize))[:2] + X = torch.from_numpy(imgi).to(device) + lbl = torch.from_numpy(lbl).to(device) + + if X.dtype != net.dtype: + X = X.to(net.dtype) + lbl = lbl.to(net.dtype) + + y = net(X)[0] + loss = _loss_fn_seg(lbl, y, device) + if y.shape[1] > 3: + loss3 = _loss_fn_class(lbl, y, class_weights=class_weights) + loss += loss3 + test_loss = loss.item() + test_loss *= len(imgi) + lavgt += test_loss + lavgt /= len(rperm) + if return_loss_arrays: test_losses[iepoch] = lavgt - lavg /= nsum + + # Early stopping logic + if early_stopping_patience is not None: + if lavgt < best_val_loss: + best_val_loss = lavgt + patience_counter = 0 + # Save best model + best_model_path = str(filename) + "_best" + train_logger.info(f"New best validation loss: {lavgt:.4f}, saving model to {best_model_path}") + net.save_model(best_model_path) + else: + patience_counter += 1 + train_logger.info(f"No improvement in validation loss for {patience_counter} epoch(s)") + + if patience_counter >= early_stopping_patience: + train_logger.info(f"Early stopping triggered after {iepoch + 1} epochs") + break + + # Log every epoch (more frequent logging for real-time tracking) + if iepoch == 5 or iepoch % 10 == 0: train_logger.info( - f"{iepoch}, train_loss={lavg:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" + f"{iepoch}, train_loss={current_train_loss:.4f}, test_loss={lavgt:.4f}, LR={LR[iepoch]:.6f}, time {time.time()-t0:.2f}s" ) + + # Call the callback function every epoch with both train and test loss + if loss_callback is not None: + loss_callback(iepoch, current_train_loss, lavgt if (test_data is not None or test_files is not None) else None) + + # Reset accumulators after logging + if iepoch == 5 or iepoch % 10 == 0: lavg, nsum = 0, 0 if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): @@ -538,10 +577,20 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None, train_logger.info(f"saving network parameters to {filename0}") net.save_model(filename0) - net.save_model(filename) + # Save final model if not using early stopping or if we completed all epochs + if early_stopping_patience is None or patience_counter < early_stopping_patience: + net.save_model(filename) + + # If early stopping was used and we have a best model, use that + if best_model_path is not None: + train_logger.info(f"Training finished. Best model saved at: {best_model_path}") + filename = best_model_path if original_net_dtype is not None: net.dtype = original_net_dtype net.to(original_net_dtype) - return filename, train_losses, test_losses + if return_loss_arrays: + return filename, train_losses, test_losses + else: + return filename