-
Notifications
You must be signed in to change notification settings - Fork 41
Cholesky numerical stability: inverse transform #356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
164a33c to
9d11ba4
Compare
9d11ba4 to
21ec0e0
Compare
devmotion
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAICT it's even simpler since LogExpFunctions.logcosh is supposed to provide a numerically stable and efficient implementation of
|
👀 I'll give that a spin |
|
Looking under the hood logcosh is implemented the same way as above, but the single function call is great 👍 |
Co-authored-by: David Widmann <[email protected]>
src/bijectors/corr.jl
Outdated
| for _ in 1:(j - 1) | ||
| z = tanh(y[idx]) | ||
| logz = 2 * log(2 / (exp(y[idx]) + exp(-y[idx]))) | ||
| logz = -2 * LogExpFunctions.logcosh(y[idx]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name logz is a bit meaningless now I guess 🙂
|
Sorry I dropped the ball on this, it probably slipped off my radar around Christmas. @devmotion Would you be willing to review again? I'll try to look into the forward transform PR (#357) separately. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me but it may be worth someone checking if this same replacement could or should also be used on line 417 in the rrule in src/corr.jl
devmotion
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you update the initial comment to (more clearly) reflect the current status and benchmarks of the PR? IMO it's a bit difficult to tell whether the comments and plots refer to the current version of the PR or some previous iteration.
|
@devmotion I've just updated them! I also added another commit to calculate tanh and logcosh outside of the loop, which means that it's now also a tiny bit faster than before (at the cost of more allocations) - not sure which is better. |
Co-authored-by: David Widmann <[email protected]>
|
Many thanks @penelopeysm, @joelkandiah and @devmotion! |
This PR attempts to improve the numerical stability of the inverse transform for Cholesky matrices, specifically, the
_inv_link_chol_lkj(::AbstractVector)method.(For the forward transform, see this PR: #357)
It does this by replacing
log1p(-z^2) / 2wherez = tanh(y[idx])with-LogExpFunctions.logcosh(y[idx]), which is the same expression (see @devmotion's comment below #356 (review)).Accuracy 1 - 'typical' samples
First, to make sure there aren't any regressions, we'll:
Code to generate plot
There isn't really much between the two implementations, sometimes the old one is better, sometimes the new one is better. In any case, the differences are very small so I think the new implementation can be said to almost break even, although I do think the old implementation is very slightly better.
Accuracy 2 - random unconstrained samples
However, when sampling in the unconstrained space, there's no guarantee that the resulting sample will resemble anything like the samples obtained via a forward transformation. This leads to issues like #279.
To test out the numerical stability of invlinking random transformed samples, we can:
Code to generate plot
As can be seen, the new method leads to much smaller errors (consistently around the magnitude of
eps() ~ 1e-16) whereas the old method often has errors that are several orders of magnitude larger.Performance
What next
Note that this issue doesn't actually fully solve #279. That issue arises not because of the inverse transformation, but rather because of the forward transformation (in the call to
logabsdetjac). This is a result of more numerical instabilities in other functions, specifically the linking one. #357 contains a potential fix for this.