Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# 0.15.11

Bijectors for ProductNamedTupleDistribution are now implemented.

`Bijectors.output_size` is now exported. This function provides information about the size of transformed variables. There are two main invocations:

- `output_size(b, input_size::Tuple)` returns the size of the output of `b`, given an input that has size `input_size`.
- `output_size(b, dist::Distribution)` returns the size of the output of `b`, given an input sampled from distribution `dist`. For most distributions this is implemented by calling `output_size(b, size(dist))`; however, ProductNamedTupleDistribution does not implement `size`, so this method is necessary.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.15.10"
version = "0.15.11"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down Expand Up @@ -67,4 +67,4 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1 change: 1 addition & 0 deletions docs/src/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Bijectors.transformed(d::Distribution, b::Bijector)
## Utilities

```@docs
Bijectors.output_size
Bijectors.elementwise
Bijectors.isinvertible
Bijectors.isclosedform(t::Bijectors.Transform)
Expand Down
3 changes: 2 additions & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ export TransformDistribution,
RadialLayer,
Coupling,
InvertibleBatchNorm,
elementwise
elementwise,
output_size

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0")))
_debug(str) = @debug str
Expand Down
204 changes: 204 additions & 0 deletions src/bijectors/named_stacked.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
NamedStacked{names}(transforms::NamedTuple, ranges::NamedTuple)

A bijector that contains a `NamedTuple` of bijectors. This is meant primarily for transforming
`Distributions.ProductNamedTupleDistribution` and samples from them.

The arguments `transforms` and `ranges` must be `NamedTuple`s with the same field names, and
these must also match the field names of the `ProductNamedTupleDistribution` that this bijector
corresponds to.

`ranges` specifies the index or indices in the output vector that correspond to the output of
each individual bijector in `transforms`. Its elements should be either `UnitRange`s or integers.
UnitRanges are necessary when the output of a transform is not a scalar. If the output is a
scalar then an integer should be used.

## Example

```jldoctest
julia> using Bijectors, LinearAlgebra

julia> d = Distributions.ProductNamedTupleDistribution((
a = LogNormal(),
b = InverseGamma(2, 3),
c = MvNormal(zeros(2), I),
));

julia> b = bijector(d)
Bijectors.NamedStacked{(:a, :b, :c), @NamedTuple{a::Base.Fix1{typeof(broadcast), typeof(log)}, b::Base.Fix1{typeof(broadcast), typeof(log)}, c::typeof(identity)}, @NamedTuple{a::Int64, b::Int64, c::UnitRange{Int64}}}((a = Base.Fix1{typeof(broadcast), typeof(log)}(broadcast, log), b = Base.Fix1{typeof(broadcast), typeof(log)}(broadcast, log), c = identity), (a = 1, b = 2, c = 3:4))

julia> b.transforms.a == bijector(d.dists.a)
true

julia> x = (a = 1.0, b = 2.0, c = [0.5, -0.5]);

julia> y, logjac = with_logabsdet_jacobian(b, x)
([0.0, 0.6931471805599453, 0.5, -0.5], -0.6931471805599453)
```
"""
struct NamedStacked{names,Ttrf<:NamedTuple{names},Trng<:NamedTuple{names}} <: Transform
# This should be a NamedTuple of bijectors
transforms::Ttrf
# This should be a NamedTuple of UnitRanges OR integers.
ranges::Trng

function NamedStacked{names}(
transforms::Ttrf, ranges::Trng
) where {names,Ttrf<:NamedTuple{names},Trng<:NamedTuple{names}}
return new{names,Ttrf,Trng}(transforms, ranges)
end
end

# Need to overload this or else it goes into a stack overflow between Inverse(b) and
# isinvertible(b)...
isinvertible(::NamedStacked) = true

# Base.size doesn't work on ProductNamedTupleDistribution, so we need some custom machinery
# here. This enables us to nest PNTDists within each other.
# NOTE: For the outputs of this function to be correct, `trf` MUST be equal to
# bijector(dist).
function output_size(trf::NamedStacked, ::Distributions.ProductNamedTupleDistribution)
return (sum(length, trf.ranges),)
end

@generated function bijector(
d::Distributions.ProductNamedTupleDistribution{names}
) where {names}
exprs = []
push!(exprs, :(transforms = NamedTuple()))
push!(exprs, :(ranges = NamedTuple()))
push!(exprs, :(offset = 1))
for n in names
push!(exprs, :(dist = d.dists.$n))
push!(exprs, :(trf = bijector(dist)))
push!(exprs, :(output_sz_tuple = output_size(trf, dist)))
push!(
exprs,
:(
if length(output_sz_tuple) == 0
output_range = offset
offset += 1
elseif length(output_sz_tuple) == 1
output_range = offset:(offset + only(output_sz_tuple) - 1)
offset += only(output_sz_tuple)
else
errmsg = "output size for distribution $d must not be multidimensional"
throw(ArgumentError(errmsg))
end
),
)
push!(exprs, :(transforms = merge(transforms, ($n=trf,))))
push!(exprs, :(ranges = merge(ranges, ($n=output_range,))))
end
push!(exprs, :(return NamedStacked{names}(transforms, ranges)))
return Expr(:block, exprs...)
end

@generated function transform(ns::NamedStacked{names}, x::NamedTuple{names}) where {names}
exprs = []
# Note that `names` cannot be empty as `product_distribution(NamedTuple())` errors, so
# we don't need to handle that case.
for (i, n) in enumerate(names)
if i == 1
# need a vcat in case there's only one transform and it returns a scalar -- we
# always want transform to return a vector.
push!(exprs, :(output = vcat(ns.transforms.$n(x.$n))))
else
push!(exprs, :(output = vcat(output, ns.transforms.$n(x.$n))))
end
end
push!(exprs, :(return output))
return Expr(:block, exprs...)
end

@generated function with_logabsdet_jacobian(
ns::NamedStacked{names}, x::NamedTuple{names}
) where {names}
exprs = []
# Note that `names` cannot be empty as `product_distribution(NamedTuple())` errors, so
# we don't need to handle that case.
for (i, n) in enumerate(names)
if i == 1
push!(
exprs,
quote
first_out, first_logjac = with_logabsdet_jacobian(
ns.transforms.$n, x.$n
)
output = vcat(first_out)
logjac = first_logjac
end,
)
else
push!(
exprs,
quote
next_out, next_logjac = with_logabsdet_jacobian(ns.transforms.$n, x.$n)
output = vcat(output, next_out)
logjac += next_logjac
end,
)
end
end
push!(exprs, :(return output, logjac))
return Expr(:block, exprs...)
end

@generated function transform(
nsi::Inverse{<:NamedStacked{names}}, y::AbstractVector
) where {names}
exprs = []
push!(exprs, :(output = NamedTuple()))
for (i, n) in enumerate(names)
if i == 1
push!(
exprs,
:(output = ($n=inverse(nsi.orig.transforms.$n)(y[nsi.orig.ranges.$n]),)),
)
else
push!(
exprs,
:(
output = merge(
output, ($n=inverse(nsi.orig.transforms.$n)(y[nsi.orig.ranges.$n]),)
)
),
)
end
end
push!(exprs, :(return output))
return Expr(:block, exprs...)
end

@generated function with_logabsdet_jacobian(
nsi::Inverse{<:NamedStacked{names}}, y::AbstractVector
) where {names}
exprs = []
for (i, n) in enumerate(names)
if i == 1
push!(
exprs,
quote
first_out, first_logjac = with_logabsdet_jacobian(
inverse(nsi.orig.transforms.$n), y[nsi.orig.ranges.$n]
)
output = ($n=first_out,)
logjac = first_logjac
end,
)
else
push!(
exprs,
quote
next_out, next_logjac = with_logabsdet_jacobian(
inverse(nsi.orig.transforms.$n), y[nsi.orig.ranges.$n]
)
output = merge(output, ($n=next_out,))
logjac += next_logjac
end,
)
end
end
push!(exprs, :(return output, logjac))
return Expr(:block, exprs...)
end
13 changes: 13 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ Returns the output size of `f` given the input size `sz`.
"""
output_size(f, sz) = sz
output_size(f::ComposedFunction, sz) = output_size(f.outer, output_size(f.inner, sz))
"""
output_size(f, dist::Distribution)

Returns the output size of `f` given the input distribution `dist`. This is useful when
Base.size(dist) is not defined, e.g. for `ProductNamedTupleDistribution` and in particular
is used by DynamicPPL when generating new random values for transformed distributions.

By default this just calls `output_size(f, size(dist))`, but this can be overloaded for
specific distributions.
"""
output_size(f, dist::Distribution) = output_size(f, size(dist))
output_size(f::ComposedFunction, dist::Distribution) = output_size(f, size(dist))

"""
output_length(f, len::Int)
Expand Down Expand Up @@ -300,6 +312,7 @@ end
# General
include("bijectors/composed.jl")
include("bijectors/stacked.jl")
include("bijectors/named_stacked.jl")
include("bijectors/reshape.jl")

# Specific
Expand Down
3 changes: 1 addition & 2 deletions src/transformed_distribution.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
function variateform(d::Distribution, b)
sz_in = size(d)
sz_out = output_size(b, sz_in)
sz_out = output_size(b, d)
return ArrayLikeVariate{length(sz_out)}
end

Expand Down
Loading