Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.15.5"
version = "0.15.6"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
42 changes: 17 additions & 25 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
CorrBijector <: Bijector

A bijector implementation of Stan's parametrization method for Correlation matrix:
https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html
https://mc-stan.org/docs/reference-manual/transforms.html#correlation-matrix-transform.section

Basically, a unconstrained strictly upper triangular matrix `y` is transformed to
a correlation matrix by following readable but not that efficient form:
Expand Down Expand Up @@ -348,13 +348,12 @@ function _inv_link_chol_lkj(Y::AbstractMatrix)
T = float(eltype(W))
logJ = zero(T)

idx = 1
@inbounds for j in 1:K
log_remainder = zero(T) # log of proportion of unit vector remaining
for i in 1:(j - 1)
z = tanh(Y[i, j])
W[i, j] = z * exp(log_remainder)
log_remainder += log1p(-z^2) / 2
log_remainder -= LogExpFunctions.logcosh(Y[i, j])
logJ += log_remainder
end
logJ += log_remainder
Expand All @@ -375,15 +374,18 @@ function _inv_link_chol_lkj(y::AbstractVector)
T = float(eltype(W))
logJ = zero(T)

z_vec = tanh.(y)
lc_vec = LogExpFunctions.logcosh.(y)

idx = 1
@inbounds for j in 1:K
log_remainder = zero(T) # log of proportion of unit vector remaining
for i in 1:(j - 1)
z = tanh(y[idx])
idx += 1
z = z_vec[idx]
W[i, j] = z * exp(log_remainder)
log_remainder += log1p(-z^2) / 2
log_remainder -= lc_vec[idx]
logJ += log_remainder
idx += 1
end
logJ += log_remainder
W[j, j] = exp(log_remainder)
Expand All @@ -405,17 +407,18 @@ function _inv_link_chol_lkj_rrule(y::AbstractVector)
logJ = zero(T)

z_vec = tanh.(y)
lc_vec = LogExpFunctions.logcosh.(y)

idx = 1
W[1, 1] = 1
@inbounds for j in 2:K
log_remainder = zero(T) # log of proportion of unit vector remaining
for i in 1:(j - 1)
z = z_vec[idx]
idx += 1
W[i, j] = z * exp(log_remainder)
log_remainder += log1p(-z^2) / 2
log_remainder -= lc_vec[idx]
logJ += log_remainder
idx += 1
end
logJ += log_remainder
W[j, j] = exp(log_remainder)
Expand Down Expand Up @@ -461,13 +464,8 @@ function _logabsdetjac_inv_corr(Y::AbstractMatrix)
K = LinearAlgebra.checksquare(Y)

result = float(zero(eltype(Y)))
for j in 2:K, i in 1:(j - 1)
@inbounds abs_y_i_j = abs(Y[i, j])
result +=
(K - i + 1) * (
IrrationalConstants.logtwo -
(abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j))
)
@inbounds for j in 2:K, i in 1:(j - 1)
result -= (K - i + 1) * LogExpFunctions.logcosh(Y[i, j])
end
return result
end
Expand All @@ -477,13 +475,8 @@ function _logabsdetjac_inv_corr(y::AbstractVector)

result = float(zero(eltype(y)))
for (i, y_i) in enumerate(y)
abs_y_i = abs(y_i)
row_idx = vec_to_triu1_row_index(i)
result +=
(K - row_idx + 1) * (
IrrationalConstants.logtwo -
(abs_y_i + LogExpFunctions.log1pexp(-2 * abs_y_i))
)
result -= (K - row_idx + 1) * LogExpFunctions.logcosh(y_i)
end
return result
end
Expand All @@ -496,10 +489,9 @@ function _logabsdetjac_inv_chol(y::AbstractVector)
@inbounds for j in 2:K
tmp = zero(result)
for _ in 1:(j - 1)
z = tanh(y[idx])
logz = log(1 - z^2)
result += logz + (tmp / 2)
tmp += logz
logcoshy = LogExpFunctions.logcosh(y[idx])
tmp -= logcoshy
result += tmp - logcoshy
idx += 1
end
end
Expand Down
Loading