julia> using Distributions, Bijectors
julia> dist = product_distribution((a = Normal(), b = Normal()))
ProductNamedTupleDistribution{(:a, :b)}(
a: Normal{Float64}(μ=0.0, σ=1.0)
b: Normal{Float64}(μ=0.0, σ=1.0)
)
julia> bijector(dist)
ERROR: MethodError: no method matching bijector(::Distributions.ProductNamedTupleDistribution{(:a, :b), Tuple{Normal{…}, Normal{…}}, Continuous, Float64})
The function `bijector` exists, but no method is defined for this combination of argument types.
I think we need to add a case here:
|
# Container distributions. |
|
bijector(d::DiscreteUnivariateDistribution) = identity |
|
bijector(d::DiscreteMultivariateDistribution) = identity |
|
bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d)) |
|
bijector(d::Product{Discrete}) = identity |
|
function bijector(d::Product{Continuous}) |
|
D = eltype(d.v) |
|
return if has_constant_bijector(D) |
|
elementwise(bijector(d.v[1])) |
|
else |
|
# FIXME: This is not great. Should use something like |
|
# `Stacked(map(bijector, d.v))` instead. |
|
# TODO: Specialize. F.ex. for FillArrays.jl we can do much better. |
|
TruncatedBijector(_minmax(d.v)...) |
|
end |
|
end |
And I think that StackedBijector already gives us most of the actual code that we need, although we might need additional code to marshal the result from/to a NamedTuple(?).
I think we need to add a case here:
Bijectors.jl/src/transformed_distribution.jl
Lines 75 to 90 in fbaf783
And I think that StackedBijector already gives us most of the actual code that we need, although we might need additional code to marshal the result from/to a NamedTuple(?).