Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
/docs/Manifest.toml
/test/coverage/Manifest.toml
/test/Manifest.toml
.vscode/settings.json
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"

[extensions]
AccessorsExt = "Accessors"
ChangesOfVariablesExt = "ChangesOfVariables"
InverseFunctionsExt = "InverseFunctions"
ReactantExt = "Reactant"

[compat]
Accessors = "0.1.42"
Expand All @@ -36,6 +38,7 @@ InverseFunctions = "0.1"
LinearAlgebra = "1.6"
LogExpFunctions = "0.3"
Random = "1.6"
Reactant = "0.2"
StaticArrays = "1"
julia = "1.10"

Expand Down
49 changes: 49 additions & 0 deletions ext/ReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module ReactantExt
using TransformVariables: TransformVariables, ArrayTransformation, LogJacFlag,
logjac_zero, transform_with, _ensure_float, dimension
using Reactant
using Reactant: TracedRNumber, AnyTracedRArray

RInt = Union{Int, TracedRNumber{Int}}
Base.@propagate_inbounds function TransformVariables.tv_getindex(a::AnyTracedRArray, i::RInt)
@allowscalar a[i]
end

TransformVariables._ensure_float(x::Type{T}) where {T<:TracedRNumber} = T

@noinline function TransformVariables.transform_with(
flag::TransformVariables.LogJacFlag,
transformation::TransformVariables.ArrayTransformation,
x::AnyTracedRArray,
index::T
) where {T}
(; inner_transformation, dims) = transformation
# NOTE not using index increments as that somehow breaks type inference
d = dimension(inner_transformation) # length of an element transformation
len = prod(dims) # number of elements
𝐼 = reshape(range(index; length = len, step = d), dims)
@info 𝐼
ℓa = logjac_zero(flag, _ensure_float(eltype(x)))
tmp,_,_ = transform_with(flag, inner_transformation, x, first(𝐼))
if typeof(tmp) <: Number
yℓ = similar(x, typeof(tmp), length(𝐼))
elseif typeof(tmp) <: AbstractArray
yℓ = [similar(tmp) for _ in 1:length(𝐼)]
else
throw(ArgumentError("Number and AbstractArray transformations are only supported in Reactant compilation mode"))
end
@trace for i in eachindex(𝐼)
idx = 𝐼[i]
y, ℓ, _ = transform_with(flag, inner_transformation, x, idx)
# if !isempty(y)
ℓa += ℓ
# end
@allowscalar yℓ[i] = y
i += 1
end
index′ = index + d * len
yℓ, ℓa, index′
end


end
2 changes: 1 addition & 1 deletion src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ $(SIGNATURES)

Initial value for log Jacobian calculations.
"""
logjac_zero(::LogJac, ::Type{T}) where {T<:Real} = log(one(T))
logjac_zero(::LogJac, ::Type{T}) where {T<:Number} = log(one(T))

logjac_zero(::NoLogJac, _) = NOLOGJAC

Expand Down
38 changes: 19 additions & 19 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ abstract type ScalarTransform <: AbstractTransform end

dimension(::ScalarTransform) = 1

function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index::Int)
transform(t, @inbounds x[index]), flag, index + 1
function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit surprised that you need this, since transform_with is called internally and index is an integer.

Can you explain what the actual type is that you need here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes in Reactant you can get a TracedNumber{Integer} if I loop through this and the array x is a AbstractTraced array.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically for the @trace macro it will raise the index to a TracedRNumber{Int} so I needed to handle this.

transform(t, @inbounds tv_getindex(x, index)), flag, index + 1
end

function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index::Int)
transform_and_logjac(t, @inbounds x[index])..., index + 1
function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

transform_and_logjac(t, @inbounds tv_getindex(x, index))..., index + 1
end

function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y)
Expand All @@ -43,15 +43,15 @@ Identity ``x ↦ x``.
"""
struct Identity <: ScalarTransform end

transform(::Identity, x::Real) = x
transform(::Identity, x::Number) = x
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle I am OK with widening Real here and in other places, the intent is to exclude Complex. It unfortunate that Base does not have an intermediate type for this purpose, but we can use Number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I agree it is very sad.


transform_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x))
transform_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x))

inverse_eltype(t::Identity, ::Type{T}) where T = T

inverse(::Identity, x::Number) = x

inverse_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x))
inverse_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x))

####
#### elementary scalar transforms
Expand All @@ -64,9 +64,9 @@ Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive real
"""
struct TVExp <: ScalarTransform end

transform(::TVExp, x::Real) = exp(x)
transform(::TVExp, x::Number) = exp(x)

transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x
transform_and_logjac(t::TVExp, x::Number) = transform(t, x), x

inverse_eltype(t::TVExp, ::Type{T}) where T = _ensure_float(T)

Expand All @@ -83,9 +83,9 @@ Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1).
"""
struct TVLogistic <: ScalarTransform end

transform(::TVLogistic, x::Real) = logistic(x)
transform(::TVLogistic, x::Number) = logistic(x)

transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x)
transform_and_logjac(t::TVLogistic, x::Number) = transform(t, x), logistic_logjac(x)

inverse_eltype(t::TVLogistic, ::Type{T}) where T = _ensure_float(T)

Expand All @@ -100,13 +100,13 @@ $(TYPEDEF)

Shift transformation `x ↦ x + shift`.
"""
struct TVShift{T <: Real} <: ScalarTransform
struct TVShift{T} <: ScalarTransform
shift::T
end

transform(t::TVShift, x::Real) = x + t.shift
transform(t::TVShift, x::Number) = x + t.shift

transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))
transform_and_logjac(t::TVShift, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse_eltype(t::TVShift{S}, ::Type{T}) where {S,T} = typeof(zero(_ensure_float(T)) - zero(S))

Expand All @@ -129,15 +129,15 @@ end

TVScale(scale::T) where {T} = TVScale{T}(scale)

transform(t::TVScale, x::Real) = t.scale * x
transform(t::TVScale, x::Number) = t.scale * x

transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale)
transform_and_logjac(t::TVScale{<:Real}, x::Number) = transform(t, x), log(t.scale)

inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(oneunit(T) / oneunit(S))

inverse(t::TVScale, x::Number) = x / t.scale

inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale)
inverse_and_logjac(t::TVScale, x::Number) = inverse(t, x), -log(t.scale)

"""
$(TYPEDEF)
Expand All @@ -147,8 +147,8 @@ Negative transformation `x ↦ -x`.
struct TVNeg <: ScalarTransform
end

transform(::TVNeg, x::Real) = -x
transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))
transform(::TVNeg, x::Number) = -x
transform_and_logjac(t::TVNeg, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-oneunit(T))
inverse(::TVNeg, x::Number) = -x
Expand Down
6 changes: 3 additions & 3 deletions src/special_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, inde
(; n) = t
T = _ensure_float(eltype(x))
log_r = zero(T)
y = Vector{T}(undef, n)
y = similar(x, T, n)
ℓ = logjac_zero(flag, T)
@inbounds for i in 1:(n - 1)
xi = x[index]
Expand Down Expand Up @@ -171,7 +171,7 @@ function transform_with(flag::LogJacFlag, t::UnitVectorNorm, x::AbstractVector,
(; n, chi_prior) = t
T = _ensure_float(eltype(x))
log_r = zero(T)
y = Vector{T}(undef, n)
y = similar(x, T, n)
copyto!(y, 1, x, index, n)
r = norm(y, 2)
__normalize!(y, r)
Expand Down Expand Up @@ -230,7 +230,7 @@ function transform_with(flag::LogJacFlag, t::UnitSimplex, x::AbstractVector, ind
T = _ensure_float(eltype(x))
ℓ = logjac_zero(flag, T)
stick = one(T)
y = Vector{T}(undef, n)
y = similar(x, T, n)
@inbounds for i in 1:n-1
xi = x[index]
index += 1
Expand Down
18 changes: 17 additions & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,23 @@
### logistic and logit
###

function logistic_logjac(x::Real)

"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it actually necessary to add a proper doctoring to an internal method? Should it just be a comment?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a tendency to add docstrings to everything, so I think it is fine.

$(SIGNATURES)

When `a <: Reactant.AnyTracedRArray` this adds a `@allowscalar` annotation so that the function can
be compiled with Reactant. When `a` is not a `Reactant.AnyTracedRArray` it simply returns `a[i]`.

!!! warn
This is necessary because by default Reactant does not allow scalar indexing of arrays unless you
opt in. Note that often `Reactant` is able to raise the scalar indexing to the level of the whole
array so this operation is not necessarily slow, but there are no guarantees.
"""
Base.@propagate_inbounds function tv_getindex(a, i)
return a[i]
end

function logistic_logjac(x::Number)
mx = -abs(x)
mx - 2*log1pexp(mx)
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
Expand Down
33 changes: 33 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using TransformVariables:
import ChangesOfVariables, InverseFunctions
using Enzyme: autodiff, ReverseWithPrimal, Active, Const
using Unitful: @u_str, ustrip, uconvert
using Reactant

"are we in a CI environment (fewer iterations)"
const CIENV = get(ENV, "CI", "") == "true"
Expand Down Expand Up @@ -652,6 +653,7 @@ end
end
end


# if VERSION ≥ v"1.1"
# if CIENV
# @info "installing Zygote"
Expand Down Expand Up @@ -996,6 +998,37 @@ end
@test_throws InexactError inverse(t, fill(Complex(0, 1), 3))
end


###
### Reactant compat tests
###
@testset "Reactant.jl" begin

@testset "Scalar transforms" begin
a = 3.1
ar = ConcreteRNumber(a)
@test @jit(transform(asℝ, ar)) ≈ transform(asℝ, a)
@test @jit(transform(asℝ₊, ar)) ≈ transform(asℝ₊, a)
@test @jit(transform(asℝ₋, ar)) ≈ transform(asℝ₋, a)
@test @jit(transform(as𝕀, ar)) ≈ transform(as𝕀, a)
end

@testset "Array transforms begin" begin
tr = as(Array, as(Array, asℝ₊, 3), 3)
a = randn(dimension(tr))
ar = Reactant.to_rarray(a)

outr = @jit(transform_and_logjac(tr, ar))
out = transform_and_logjac(tr, a)
@test outr[1] ≈ out[1]
@test outr[2] ≈ out[2]

end


end


####
#### static analysis with JET
####
Expand Down
Loading