Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ abstract type ScalarTransform <: AbstractTransform end

dimension(::ScalarTransform) = 1

function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index::Int)
function transform_with(flag::NoLogJac, t::ScalarTransform, x::AbstractVector, index)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit surprised that you need this, since transform_with is called internally and index is an integer.

Can you explain what the actual type is that you need here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes in Reactant you can get a TracedNumber{Integer} if I loop through this and the array x is a AbstractTraced array.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically for the @trace macro it will raise the index to a TracedRNumber{Int} so I needed to handle this.

transform(t, @inbounds x[index]), flag, index + 1
end

function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index::Int)
function transform_with(::LogJac, t::ScalarTransform, x::AbstractVector, index)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

transform_and_logjac(t, @inbounds x[index])..., index + 1
end

Expand All @@ -43,15 +43,15 @@ Identity ``x ↦ x``.
"""
struct Identity <: ScalarTransform end

transform(::Identity, x::Real) = x
transform(::Identity, x::Number) = x
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle I am OK with widening Real here and in other places, the intent is to exclude Complex. It unfortunate that Base does not have an intermediate type for this purpose, but we can use Number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya I agree it is very sad.


transform_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x))
transform_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x))

inverse_eltype(t::Identity, ::Type{T}) where T = T

inverse(::Identity, x::Number) = x

inverse_and_logjac(::Identity, x::Real) = x, logjac_zero(LogJac(), typeof(x))
inverse_and_logjac(::Identity, x::Number) = x, logjac_zero(LogJac(), typeof(x))

####
#### elementary scalar transforms
Expand All @@ -64,9 +64,9 @@ Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive real
"""
struct TVExp <: ScalarTransform end

transform(::TVExp, x::Real) = exp(x)
transform(::TVExp, x::Number) = exp(x)

transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x
transform_and_logjac(t::TVExp, x::Number) = transform(t, x), x

inverse_eltype(t::TVExp, ::Type{T}) where T = _ensure_float(T)

Expand All @@ -83,9 +83,9 @@ Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1).
"""
struct TVLogistic <: ScalarTransform end

transform(::TVLogistic, x::Real) = logistic(x)
transform(::TVLogistic, x::Number) = logistic(x)

transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x)
transform_and_logjac(t::TVLogistic, x::Number) = transform(t, x), logistic_logjac(x)

inverse_eltype(t::TVLogistic, ::Type{T}) where T = _ensure_float(T)

Expand All @@ -100,13 +100,13 @@ $(TYPEDEF)

Shift transformation `x ↦ x + shift`.
"""
struct TVShift{T <: Real} <: ScalarTransform
struct TVShift{T} <: ScalarTransform
shift::T
end

transform(t::TVShift, x::Real) = x + t.shift
transform(t::TVShift, x::Number) = x + t.shift

transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))
transform_and_logjac(t::TVShift, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse_eltype(t::TVShift{S}, ::Type{T}) where {S,T} = typeof(zero(_ensure_float(T)) - zero(S))

Expand All @@ -129,15 +129,15 @@ end

TVScale(scale::T) where {T} = TVScale{T}(scale)

transform(t::TVScale, x::Real) = t.scale * x
transform(t::TVScale, x::Number) = t.scale * x

transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale)
transform_and_logjac(t::TVScale{<:Real}, x::Number) = transform(t, x), log(t.scale)

inverse_eltype(t::TVScale{S}, ::Type{T}) where {S,T} = typeof(oneunit(T) / oneunit(S))

inverse(t::TVScale, x::Number) = x / t.scale

inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale)
inverse_and_logjac(t::TVScale, x::Number) = inverse(t, x), -log(t.scale)

"""
$(TYPEDEF)
Expand All @@ -147,8 +147,8 @@ Negative transformation `x ↦ -x`.
struct TVNeg <: ScalarTransform
end

transform(::TVNeg, x::Real) = -x
transform_and_logjac(t::TVNeg, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))
transform(::TVNeg, x::Number) = -x
transform_and_logjac(t::TVNeg, x::Number) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse_eltype(::TVNeg, ::Type{T}) where T = typeof(-oneunit(T))
inverse(::TVNeg, x::Number) = -x
Expand Down
2 changes: 1 addition & 1 deletion src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
### logistic and logit
###

function logistic_logjac(x::Real)
function logistic_logjac(x::Number)
mx = -abs(x)
mx - 2*log1pexp(mx)
end
Expand Down
Loading