-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Stein's Unbiased Risk Estimator (SURE) loss and Conjugate Gradient #7308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 28 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
73c1295
add sure loss, its test functions and its documents
cxlcl 74205dc
modified docs
cxlcl ad597e5
add conjugate gradient: class, unit test and doc
cxlcl daa5889
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6cc6a7a
change the doc conjugate_gradient
cxlcl a30b4c2
Merge branch 'Project-MONAI:dev' into sure_loss-cg
cxlcl 2546377
fix CI error
cxlcl 2b5b895
fix CI error
cxlcl 29dbaa0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 17f5ea0
fix CI error
cxlcl 7679837
Merge branch 'sure_loss-cg' of github.com:cxlcl/MONAI into sure_loss-cg
cxlcl 1f75eb1
fix CI error
cxlcl e8b34e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7b0e982
fix CI error
cxlcl ea81474
fix CI: after running ./runtest --autofix
cxlcl ca45f5f
fix CI: after running ./runtest --autofix
cxlcl 5dea18a
Merge branch 'sure_loss-cg' of github.com:cxlcl/MONAI into sure_loss-cg
cxlcl 58ee712
fix CI: after running ./runtest --autofix
cxlcl 9755aa5
fix CI: after running ./runtest --autofix
cxlcl dd72c52
Merge branch 'dev' into sure_loss-cg
cxlcl 11302c6
Merge branch 'dev' into sure_loss-cg
KumoLiu e2aad74
Update monai/losses/sure_loss.py
cxlcl f9f8e6f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b995cb4
modifications based on revision
cxlcl 7a5c712
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 007e779
Modifications based on revision
cxlcl 659d0c0
Merge branch 'sure_loss-cg' of github.com:cxlcl/MONAI into sure_loss-cg
cxlcl 701342e
Merge branch 'Project-MONAI:dev' into sure_loss-cg
cxlcl b7198fb
Update docs/source/losses.rst
cxlcl 5590abc
Update monai/losses/sure_loss.py
cxlcl 281b780
Update monai/losses/sure_loss.py
cxlcl 36ca0d4
Update monai/losses/sure_loss.py
cxlcl 7b8c39d
Update tests/test_conjugate_gradient.py
cxlcl 20e81a7
Update tests/test_sure_loss.py
cxlcl 008920b
Update monai/networks/layers/conjugate_gradient.py
cxlcl e5b9d13
Update tests/test_sure_loss.py
cxlcl 31fb456
Update tests/test_sure_loss.py
cxlcl 3971581
Update tests/test_conjugate_gradient.py
cxlcl 5e309c7
Merge branch 'dev' into sure_loss-cg
KumoLiu 81b9434
fix flake8
KumoLiu 2c78f4c
fix ci
KumoLiu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,194 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Callable, Optional | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from torch.nn.modules.loss import _Loss | ||
|
|
||
|
|
||
| def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| First compute the difference in the complex domain, | ||
| then get the absolute value and take the mse | ||
| Args: | ||
| x, y - B, 2, H, W real valued tensors representing complex numbers | ||
| or B,1,H,W complex valued tensors | ||
| Returns: | ||
| l2_loss - scalar | ||
| """ | ||
| if not x.is_complex(): | ||
| x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous()) | ||
| if not y.is_complex(): | ||
| y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous()) | ||
|
|
||
| diff = torch.abs(x - y) | ||
| return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction="mean") | ||
|
|
||
|
|
||
| def sure_loss_function( | ||
| operator: Callable, | ||
| x: torch.Tensor, | ||
| y_pseudo_gt: torch.Tensor, | ||
| y_ref: Optional[torch.Tensor] = None, | ||
| eps: Optional[float] = -1.0, | ||
| perturb_noise: Optional[torch.Tensor] = None, | ||
| complex_input: Optional[bool] = False, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| operator (function): The operator function that takes in an input | ||
| tensor x and returns an output tensor y. We will use this to compute | ||
| the divergence. More specifically, we will perturb the input x by a | ||
| small amount and compute the divergence between the perturbed output | ||
| and the reference output | ||
| x (torch.Tensor): The input tensor of shape (B, C, H, W) to the | ||
| operator. For complex input, the shape is (B, 2, H, W) aka C=2 real. | ||
| For real input, the shape is (B, 1, H, W) real. | ||
| y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape | ||
| (B, C, H, W) used to compute the L2 loss. For complex input, the shape is | ||
| (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) | ||
| real. | ||
| y_ref (torch.Tensor, optional): The reference output tensor of shape | ||
| (B, C, H, W) used to compute the divergence. Defaults to None. For | ||
| complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, | ||
| the shape is (B, 1, H, W) real. | ||
| eps (float, optional): The perturbation scalar. Set to -1 to set it | ||
| automatically estimated based on y_pseudo_gtk | ||
| perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W). | ||
| Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. | ||
| For real input, the shape is (B, 1, H, W) real. | ||
| complex_input(bool, optional): Whether the input is complex or not. | ||
| Defaults to False. | ||
| Returns: | ||
| sure_loss (torch.Tensor): The SURE loss scalar. | ||
| """ | ||
| # perturb input | ||
| if perturb_noise is None: | ||
| perturb_noise = torch.randn_like(x) | ||
| if eps == -1.0: | ||
| eps = float(torch.abs(y_pseudo_gt.max())) / 1000 | ||
| # get y_ref if not provided | ||
| if y_ref is None: | ||
| y_ref = operator(x) | ||
|
|
||
| # get perturbed output | ||
| x_perturbed = x + eps * perturb_noise | ||
| y_perturbed = operator(x_perturbed) | ||
| # divergence | ||
| divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore | ||
| # l2 loss between y_ref, y_pseudo_gt | ||
| if complex_input: | ||
| l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt) | ||
| else: | ||
| # real input | ||
| l2_loss = nn.functional.mse_loss(y_ref, y_pseudo_gt, reduction="mean") | ||
|
|
||
| # sure loss | ||
| sure_loss = l2_loss * divergence / (x.shape[0] * x.shape[2] * x.shape[3]) | ||
| return sure_loss | ||
|
|
||
|
|
||
| class SURELoss(_Loss): | ||
| """ | ||
| Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator. | ||
| This is a differentiable loss function that can be used to train/giude an | ||
cxlcl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| operator (e.g. neural network), where the pseudo ground truth is available | ||
| but the reference ground truth is not. For example, in the MRI | ||
| reconstruction, the pseudo ground truth is the zero-filled reconstruction | ||
| and the reference ground truth is the fully sampled reconstruction. Often, | ||
| the reference ground truth is not available due to the lack of fully sampled | ||
| data. | ||
| The original SURE loss is proposed in [1]. The SURE loss used for guiding | ||
| the diffusion model based MRI reconstruction is proposed in [2]. | ||
| Reference | ||
| [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics | ||
| [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. | ||
| (https://arxiv.org/pdf/2310.01799.pdf) | ||
| """ | ||
|
|
||
| def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None: | ||
| """ | ||
| Args: | ||
| perturb_noise (torch.Tensor, optional): The noise vector of shape | ||
| (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. | ||
| For real input, the shape is (B, 1, H, W) real. | ||
| eps (float, optional): The perturbation scalar. Defaults to None. | ||
| """ | ||
| super().__init__() | ||
| self.perturb_noise = perturb_noise | ||
| self.eps = eps | ||
|
|
||
| def forward( | ||
| self, | ||
| operator: Callable, | ||
| x: torch.Tensor, | ||
| y_pseudo_gt: torch.Tensor, | ||
| y_ref: Optional[torch.Tensor] = None, | ||
| complex_input: Optional[bool] = False, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Args: | ||
| operator (function): The operator function that takes in an input | ||
| tensor x and returns an output tensor y. We will use this to compute | ||
| the divergence. More specifically, we will perturb the input x by a | ||
| small amount and compute the divergence between the perturbed output | ||
| and the reference output | ||
| x (torch.Tensor): The input tensor of shape (B, C, H, W) to the | ||
| operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka | ||
| C=2 real. For real input, the shape is (B, 1, H, W) real. | ||
| y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape | ||
| (B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex | ||
| input, the shape is (B, 2, H, W) aka C=2 real. For real input, the | ||
| shape is (B, 1, H, W) real. | ||
| y_ref (torch.Tensor, optional): The reference output tensor of the | ||
| same shape as y_pseudo_gt | ||
| Returns: | ||
| sure_loss (torch.Tensor): The SURE loss scalar. | ||
| """ | ||
|
|
||
| # check inputs shapes | ||
| if x.dim() != 4: | ||
| raise ValueError("Input tensor x should be 4D.") | ||
cxlcl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if y_pseudo_gt.dim() != 4: | ||
| raise ValueError("Input tensor y_pseudo_gt should be 4D.") | ||
| if y_ref is not None and y_ref.dim() != 4: | ||
| raise ValueError("Input tensor y_ref should be 4D.") | ||
| if x.shape != y_pseudo_gt.shape: | ||
| raise ValueError("Input tensor x and y_pseudo_gt should have the same shape.") | ||
| if y_ref is not None and y_pseudo_gt.shape != y_ref.shape: | ||
| raise ValueError("Input tensor y_pseudo_gt and y_ref should have the same shape.") | ||
cxlcl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # compute loss | ||
| loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input) | ||
|
|
||
| return loss | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Callable | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
|
|
||
| def _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Complex dot product between tensors x1 and x2: sum(x1.*x2) | ||
| """ | ||
| if torch.is_complex(x1): | ||
| assert torch.is_complex(x2), "x1 and x2 must both be complex" | ||
| return torch.sum(x1.conj() * x2) | ||
| else: | ||
| return torch.sum(x1 * x2) | ||
|
|
||
|
|
||
| def _zdot_single(x: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Complex dot product between tensor x and itself | ||
| """ | ||
| res = _zdot(x, x) | ||
| if torch.is_complex(res): | ||
| return res.real | ||
| else: | ||
| return res | ||
|
|
||
|
|
||
| class ConjugateGradient(nn.Module): | ||
| """ | ||
| Congugate Gradient (CG) solver for linear systems Ax = y. | ||
| For linear_op that is positive definite and self-adjoint, CG is | ||
| guaranteed to converge CG is often used to solve linear systems of the form | ||
| Ax = y, where A is too large to store explicitly, but can be computed via a | ||
| linear operator. | ||
| As a result, here we won't set A explicitly as a matrix, but rather as a | ||
| linear operator. For example, A could be a FFT/IFFT operation | ||
| """ | ||
|
|
||
| def __init__(self, linear_op: Callable, num_iter: int): | ||
| """ | ||
| Args: | ||
| linear_op: Linear operator | ||
| num_iter: Number of iterations to run CG | ||
| """ | ||
| super().__init__() | ||
|
|
||
| self.linear_op = linear_op | ||
| self.num_iter = num_iter | ||
|
|
||
| def update( | ||
| self, x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| """ | ||
| perform one iteration of the CG method. It takes the current solution x, | ||
| the current search direction p, the current residual r, and the old | ||
| residual norm rsold as inputs. Then it computes the new solution, search | ||
| direction, residual, and residual norm, and returns them. | ||
| """ | ||
|
|
||
| dy = self.linear_op(p) | ||
| p_dot_dy = _zdot(p, dy) | ||
| alpha = rsold / p_dot_dy | ||
| x = x + alpha * p | ||
| r = r - alpha * dy | ||
| rsnew = _zdot_single(r) | ||
| beta = rsnew / rsold | ||
| rsold = rsnew | ||
| p = beta * p + r | ||
| return x, p, r, rsold | ||
|
|
||
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| run conjugate gradient for num_iter iterations to solve Ax = y | ||
| Args: | ||
| x: tensor (real or complex); Initial guess for linear system Ax = y. | ||
| The size of x should be applicable to the linear operator. For | ||
| example, if the linear operator is FFT, then x is HCHW; if the | ||
| linear operator is a matrix multiplication, then x is a vector | ||
| y: tensor (real or complex); Measurement. Same size as x | ||
| Returns: | ||
| x: Solution to Ax = y | ||
| """ | ||
| # Compute residual | ||
| r = y - self.linear_op(x) | ||
| rsold = _zdot_single(r) | ||
| p = r | ||
|
|
||
| # Update | ||
| for i in range(self.num_iter): | ||
cxlcl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| x, p, r, rsold = self.update(x, p, r, rsold) | ||
| if rsold < 1e-10: | ||
| break | ||
| return x | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.