Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# 0.15.11

Bijectors for ProductNamedTupleDistribution are now implemented.

`Bijectors.output_size` is now exported. This function provides information about the size of transformed variables. There are two main invocations:

- `output_size(b, input_size::Tuple)` returns the size of the output of `b`, given an input that has size `input_size`.
- `output_size(b, dist::Distribution)` returns the size of the output of `b`, given an input sampled from distribution `dist`. For most distributions this is implemented by calling `output_size(b, size(dist))`; however, ProductNamedTupleDistribution does not implement `size`, so this method is necessary.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.15.10"
version = "0.15.11"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down Expand Up @@ -67,4 +67,4 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1 change: 1 addition & 0 deletions docs/src/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Bijectors.transformed(d::Distribution, b::Bijector)
## Utilities

```@docs
Bijectors.output_size
Bijectors.elementwise
Bijectors.isinvertible
Bijectors.isclosedform(t::Bijectors.Transform)
Expand Down
3 changes: 2 additions & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ export TransformDistribution,
RadialLayer,
Coupling,
InvertibleBatchNorm,
elementwise
elementwise,
output_size

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0")))
_debug(str) = @debug str
Expand Down
204 changes: 204 additions & 0 deletions src/bijectors/named_stacked.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
NamedStacked{names}(transforms::NamedTuple, ranges::NamedTuple)

A bijector that contains a `NamedTuple` of bijectors. This is meant primarily for transforming
`Distributions.ProductNamedTupleDistribution` and samples from them.

The arguments `transforms` and `ranges` must be `NamedTuple`s with the same field names, and
these must also match the field names of the `ProductNamedTupleDistribution` that this bijector
corresponds to.

`ranges` specifies the index or indices in the output vector that correspond to the output of
each individual bijector in `transforms`. Its elements should be either `UnitRange`s or integers.
UnitRanges are necessary when the output of a transform is not a scalar. If the output is a
scalar then an integer should be used.

## Example

```jldoctest
julia> using Bijectors, LinearAlgebra

julia> d = Distributions.ProductNamedTupleDistribution((
a = LogNormal(),
b = InverseGamma(2, 3),
c = MvNormal(zeros(2), I),
));

julia> b = bijector(d)
Bijectors.NamedStacked{(:a, :b, :c), @NamedTuple{a::Base.Fix1{typeof(broadcast), typeof(log)}, b::Base.Fix1{typeof(broadcast), typeof(log)}, c::typeof(identity)}, @NamedTuple{a::Int64, b::Int64, c::UnitRange{Int64}}}((a = Base.Fix1{typeof(broadcast), typeof(log)}(broadcast, log), b = Base.Fix1{typeof(broadcast), typeof(log)}(broadcast, log), c = identity), (a = 1, b = 2, c = 3:4))

julia> b.transforms.a == bijector(d.dists.a)
true

julia> x = (a = 1.0, b = 2.0, c = [0.5, -0.5]);

julia> y, logjac = with_logabsdet_jacobian(b, x)
([0.0, 0.6931471805599453, 0.5, -0.5], -0.6931471805599453)
```
"""
struct NamedStacked{names,Ttrf<:NamedTuple{names},Trng<:NamedTuple{names}} <: Transform
# This should be a NamedTuple of bijectors
transforms::Ttrf
# This should be a NamedTuple of UnitRanges OR integers.
ranges::Trng

function NamedStacked{names}(
transforms::Ttrf, ranges::Trng
) where {names,Ttrf<:NamedTuple{names},Trng<:NamedTuple{names}}
return new{names,Ttrf,Trng}(transforms, ranges)
end
end

# Need to overload this or else it goes into a stack overflow between Inverse(b) and
# isinvertible(b)...
isinvertible(::NamedStacked) = true

# Base.size doesn't work on ProductNamedTupleDistribution, so we need some custom machinery
# here. This enables us to nest PNTDists within each other.
# NOTE: For the outputs of this function to be correct, `trf` MUST be equal to
# bijector(dist).
function output_size(trf::NamedStacked, ::Distributions.ProductNamedTupleDistribution)
return (sum(length, trf.ranges),)
end

@generated function bijector(
d::Distributions.ProductNamedTupleDistribution{names}
) where {names}
exprs = []
push!(exprs, :(transforms = NamedTuple()))
push!(exprs, :(ranges = NamedTuple()))
push!(exprs, :(offset = 1))
for n in names
push!(exprs, :(dist = d.dists.$n))
push!(exprs, :(trf = bijector(dist)))
push!(exprs, :(output_sz_tuple = output_size(trf, dist)))
push!(
exprs,
:(
if length(output_sz_tuple) == 0
output_range = offset
offset += 1
elseif length(output_sz_tuple) == 1
output_range = offset:(offset + only(output_sz_tuple) - 1)
offset += only(output_sz_tuple)
else
errmsg = "output size for distribution $d must not be multidimensional"
throw(ArgumentError(errmsg))
end
),
)
push!(exprs, :(transforms = merge(transforms, ($n=trf,))))
push!(exprs, :(ranges = merge(ranges, ($n=output_range,))))
end
push!(exprs, :(return NamedStacked{names}(transforms, ranges)))
return Expr(:block, exprs...)
end

@generated function transform(ns::NamedStacked{names}, x::NamedTuple{names}) where {names}
exprs = []
# Note that `names` cannot be empty as `product_distribution(NamedTuple())` errors, so
# we don't need to handle that case.
for (i, n) in enumerate(names)
if i == 1
# need a vcat in case there's only one transform and it returns a scalar -- we
# always want transform to return a vector.
push!(exprs, :(output = vcat(ns.transforms.$n(x.$n))))
else
push!(exprs, :(output = vcat(output, ns.transforms.$n(x.$n))))
end
end
push!(exprs, :(return output))
return Expr(:block, exprs...)
end

@generated function with_logabsdet_jacobian(
ns::NamedStacked{names}, x::NamedTuple{names}
) where {names}
exprs = []
# Note that `names` cannot be empty as `product_distribution(NamedTuple())` errors, so
# we don't need to handle that case.
for (i, n) in enumerate(names)
if i == 1
push!(
exprs,
quote
first_out, first_logjac = with_logabsdet_jacobian(
ns.transforms.$n, x.$n
)
output = vcat(first_out)
logjac = first_logjac
end,
)
else
push!(
exprs,
quote
next_out, next_logjac = with_logabsdet_jacobian(ns.transforms.$n, x.$n)
output = vcat(output, next_out)
logjac += next_logjac
end,
)
end
end
push!(exprs, :(return output, logjac))
return Expr(:block, exprs...)
end

@generated function transform(
nsi::Inverse{<:NamedStacked{names}}, y::AbstractVector
) where {names}
exprs = []
push!(exprs, :(output = NamedTuple()))
for (i, n) in enumerate(names)
if i == 1
push!(
exprs,
:(output = ($n=inverse(nsi.orig.transforms.$n)(y[nsi.orig.ranges.$n]),)),
)
else
push!(
exprs,
:(
output = merge(
output, ($n=inverse(nsi.orig.transforms.$n)(y[nsi.orig.ranges.$n]),)
)
),
)
end
end
push!(exprs, :(return output))
return Expr(:block, exprs...)
end

@generated function with_logabsdet_jacobian(
nsi::Inverse{<:NamedStacked{names}}, y::AbstractVector
) where {names}
exprs = []
for (i, n) in enumerate(names)
if i == 1
push!(
exprs,
quote
first_out, first_logjac = with_logabsdet_jacobian(
inverse(nsi.orig.transforms.$n), y[nsi.orig.ranges.$n]
)
output = ($n=first_out,)
logjac = first_logjac
end,
)
else
push!(
exprs,
quote
next_out, next_logjac = with_logabsdet_jacobian(
inverse(nsi.orig.transforms.$n), y[nsi.orig.ranges.$n]
)
output = merge(output, ($n=next_out,))
logjac += next_logjac
end,
)
end
end
push!(exprs, :(return output, logjac))
return Expr(:block, exprs...)
end
12 changes: 12 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ Returns the output size of `f` given the input size `sz`.
"""
output_size(f, sz) = sz
output_size(f::ComposedFunction, sz) = output_size(f.outer, output_size(f.inner, sz))
"""
output_size(f, dist::Distribution)

Returns the output size of `f` given the input distribution `dist`. This is useful when
Base.size(dist) is not defined, e.g. for `ProductNamedTupleDistribution` and in particular
is used by DynamicPPL when generating new random values for transformed distributions.

By default this just calls `output_size(f, size(dist))`, but this can be overloaded for
specific distributions.
"""
output_size(f, dist::Distribution) = output_size(f, size(dist))

"""
output_length(f, len::Int)
Expand Down Expand Up @@ -300,6 +311,7 @@ end
# General
include("bijectors/composed.jl")
include("bijectors/stacked.jl")
include("bijectors/named_stacked.jl")
include("bijectors/reshape.jl")

# Specific
Expand Down
Loading