Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Nov 30, 2024

This PR attempts to improve the numerical stability of the inverse transform for Cholesky matrices, specifically, the _inv_link_chol_lkj(::AbstractVector) method.

(For the forward transform, see this PR: #357)

It does this by replacing log1p(-z^2) / 2 where z = tanh(y[idx]) with -LogExpFunctions.logcosh(y[idx]), which is the same expression (see @devmotion's comment below #356 (review)).

Accuracy 1 - 'typical' samples

First, to make sure there aren't any regressions, we'll:

  1. Sample from an LKJCholesky distribution
  2. Transform it with the existing bijector
  3. Un-transform it with both the old and the new implementation
  4. Calculate the max absolute error introduced by the roundtrip transformation
  5. Plot the error with the new implementation vs the error with the old implementation
Code to generate plot
using Bijectors
using LinearAlgebra
using Distributions
using Random
using Plots
using IrrationalConstants
using LogExpFunctions

_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2

# This was copied from the main branch
function _inv_link_chol_lkj_old(y::AbstractVector)
    LinearAlgebra.require_one_based_indexing(y)
    K = _triu1_dim_from_length(length(y))
    W = similar(y, K, K)
    T = float(eltype(W))
    logJ = zero(T)
    idx = 1
    @inbounds for j in 1:K
        log_remainder = zero(T)  # log of proportion of unit vector remaining
        for i in 1:(j - 1)
            z = tanh(y[idx])
            W[i, j] = z * exp(log_remainder)
            log_remainder += log1p(-z^2) / 2
            logJ += log_remainder
            idx += 1
        end
        logJ += log_remainder
        W[j, j] = exp(log_remainder)
        for i in (j + 1):K
            W[i, j] = 0
        end
    end
    return W, logJ
end


function plot_maes(samples)
    log_mae_old = log10.([sample[1] for sample in samples])
    log_mae_new = log10.([sample[2] for sample in samples])
    scatter(log_mae_old, log_mae_new, label="")
    lim_min = floor(min(minimum(log_mae_old), minimum(log_mae_new)))
    lim_max = ceil(max(maximum(log_mae_old), maximum(log_mae_new)))
    plot!(lim_min:lim_max, lim_min:lim_max, label="y=x", color=:black)
    xlabel!("log10(maximum abs error old)")
    ylabel!("log10(maximum abs error new)")
end

function test_inverse_bijector(f_old, f_new)
    dist = LKJCholesky(5, 1.0, 'U')
    b = bijector(dist)
    Random.seed!(468)
    samples = map(1:500) do _
        x = rand(dist)
        y = b(x)
        x_true = Matrix{Float64}(x.U) # Convert to full matrix
        x_old = f_old(y)[1]
        x_new = f_new(y)[1]
        # Return the maximum absolute error between the original sample
        # and sample after roundtrip transformation
        (maximum(abs.(x_true - x_old)), maximum(abs.(x_true - x_new)))
    end
    return samples
end
plot_maes(test_inverse_bijector(_inv_link_chol_lkj_old, Bijectors._inv_link_chol_lkj))
savefig("bijector_typical.png")

bijector_typical

There isn't really much between the two implementations, sometimes the old one is better, sometimes the new one is better. In any case, the differences are very small so I think the new implementation can be said to almost break even, although I do think the old implementation is very slightly better.

Accuracy 2 - random unconstrained samples

However, when sampling in the unconstrained space, there's no guarantee that the resulting sample will resemble anything like the samples obtained via a forward transformation. This leads to issues like #279.

To test out the numerical stability of invlinking random transformed samples, we can:

  1. Generate a random transformed sample.
  2. Invlink it with the old method, but using arbitrary precision floats. This is our ground truth.
  3. Invlink it with the old method, with Float64 precision
  4. Invlink it with the new method, with Float64 precision
  5. Compare the errors as above
Code to generate plot
function test_inverse_bijector_unconstrained(f_old, f_new)
    dist = LKJCholesky(5, 1.0, 'U')
    Random.seed!(468)
    samples = map(1:500) do _
        y = rand(dist.d * (dist.d - 1) ÷ 2) * 10
        x_true = f_old(Vector{BigFloat}(y))[1]
        x_old = f_old(y)[1]
        x_new = f_new(y)[1]
        # Return the maximum absolute error between the original sample
        # and sample after roundtrip transformation
        (maximum(abs.(x_true - x_old)), maximum(abs.(x_true - x_new)))
    end
    return samples
end
plot_maes(test_inverse_bijector_unconstrained(_inv_link_chol_lkj_old, _inv_link_chol_lkj_new))
savefig("bijector_unconstrained.png")

bijector_unconstrained

As can be seen, the new method leads to much smaller errors (consistently around the magnitude of eps() ~ 1e-16) whereas the old method often has errors that are several orders of magnitude larger.


Performance

julia> using Chairmarks

julia> @be (rand(10) * 10) _inv_link_chol_lkj_old # main branch
Benchmark: 4915 samples with 60 evaluations
 min    256.250 ns (2 allocs: 272 bytes)
 median 290.267 ns (2 allocs: 272 bytes)
 mean   305.697 ns (2 allocs: 272 bytes, 0.17% gc time)
 max    32.407 μs (2 allocs: 272 bytes, 98.20% gc time)

julia> @be (rand(10) * 10) Bijectors._inv_link_chol_lkj
Benchmark: 4623 samples with 66 evaluations
 min    261.364 ns (6 allocs: 560 bytes)
 median 271.470 ns (6 allocs: 560 bytes)
 mean   297.287 ns (6 allocs: 560 bytes, 0.35% gc time)
 max    33.902 μs (6 allocs: 560 bytes, 98.25% gc time)

What next

Note that this issue doesn't actually fully solve #279. That issue arises not because of the inverse transformation, but rather because of the forward transformation (in the call to logabsdetjac). This is a result of more numerical instabilities in other functions, specifically the linking one. #357 contains a potential fix for this.

using Bijectors
using Distributions

θ_unconstrained = [
	-1.9887091960524537,
	-13.499454444466279,
	-0.39328331954134665,
	-4.426097270849902,
	13.101175413857023,
	7.66647404712346,
	9.249285786544894,
	4.714877413573335,
	6.233118490809442,
	22.28264809311481
]
n = 5
d = LKJCholesky(n, 10)
b = Bijectors.bijector(d)
b_inv = inverse(b)

θ = b_inv(θ_unconstrained)
Bijectors.logabsdetjac(b, θ)

# ERROR: DomainError with 1.0085229361957693:
# atanh(x) is only defined for |x| ≤ 1.

@penelopeysm penelopeysm marked this pull request as draft November 30, 2024 17:40
@penelopeysm penelopeysm force-pushed the py/chol-numerical branch 2 times, most recently from 164a33c to 9d11ba4 Compare November 30, 2024 18:00
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT it's even simpler since $$\log(1 - \tanh^2(x))/2 = \log(\mathrm{sech}^2(x)) / 2 = \log(\mathrm{sech}(x)) = \log(1 / \cosh(x)) = -\log(\cosh(x))$$, and LogExpFunctions.logcosh is supposed to provide a numerically stable and efficient implementation of $$\log(\cosh(\cdot))$$ (if there are problems they should be considered bugs).

@penelopeysm
Copy link
Member Author

👀 I'll give that a spin

@penelopeysm
Copy link
Member Author

Looking under the hood logcosh is implemented the same way as above, but the single function call is great 👍

for _ in 1:(j - 1)
z = tanh(y[idx])
logz = 2 * log(2 / (exp(y[idx]) + exp(-y[idx])))
logz = -2 * LogExpFunctions.logcosh(y[idx])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name logz is a bit meaningless now I guess 🙂

@penelopeysm penelopeysm changed the title Attempt to improve Cholesky numerical stability Cholesky numerical stability: inverse transform Dec 1, 2024
@penelopeysm penelopeysm marked this pull request as ready for review March 12, 2025 14:30
@penelopeysm
Copy link
Member Author

Sorry I dropped the ball on this, it probably slipped off my radar around Christmas. @devmotion Would you be willing to review again? I'll try to look into the forward transform PR (#357) separately.

Copy link

@joelkandiah joelkandiah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me but it may be worth someone checking if this same replacement could or should also be used on line 417 in the rrule in src/corr.jl

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the initial comment to (more clearly) reflect the current status and benchmarks of the PR? IMO it's a bit difficult to tell whether the comments and plots refer to the current version of the PR or some previous iteration.

@penelopeysm
Copy link
Member Author

penelopeysm commented Mar 18, 2025

@devmotion I've just updated them! I also added another commit to calculate tanh and logcosh outside of the loop, which means that it's now also a tiny bit faster than before (at the cost of more allocations) - not sure which is better.

@yebai yebai requested a review from devmotion March 20, 2025 20:59
Co-authored-by: David Widmann <[email protected]>
@yebai yebai merged commit 8a525f1 into main Mar 21, 2025
30 of 33 checks passed
@yebai yebai deleted the py/chol-numerical branch March 21, 2025 15:20
@yebai
Copy link
Member

yebai commented Mar 21, 2025

Many thanks @penelopeysm, @joelkandiah and @devmotion!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants