Skip to content

Support ProductNamedTupleDistribution #388

@penelopeysm

Description

@penelopeysm
julia> using Distributions, Bijectors

julia> dist = product_distribution((a = Normal(), b = Normal()))
ProductNamedTupleDistribution{(:a, :b)}(
a: Normal{Float64}=0.0, σ=1.0)
b: Normal{Float64}=0.0, σ=1.0)
)

julia> bijector(dist)
ERROR: MethodError: no method matching bijector(::Distributions.ProductNamedTupleDistribution{(:a, :b), Tuple{Normal{…}, Normal{…}}, Continuous, Float64})
The function `bijector` exists, but no method is defined for this combination of argument types.

I think we need to add a case here:

# Container distributions.
bijector(d::DiscreteUnivariateDistribution) = identity
bijector(d::DiscreteMultivariateDistribution) = identity
bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d))
bijector(d::Product{Discrete}) = identity
function bijector(d::Product{Continuous})
D = eltype(d.v)
return if has_constant_bijector(D)
elementwise(bijector(d.v[1]))
else
# FIXME: This is not great. Should use something like
# `Stacked(map(bijector, d.v))` instead.
# TODO: Specialize. F.ex. for FillArrays.jl we can do much better.
TruncatedBijector(_minmax(d.v)...)
end
end

And I think that StackedBijector already gives us most of the actual code that we need, although we might need additional code to marshal the result from/to a NamedTuple(?).

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions