diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index d7abe537b0..6de36a87b3 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -157,12 +157,16 @@ def __init__( self.init_loc_fn = init_loc_fn self._init_scale = init_scale self._computing_median = False + self._computing_quantiles = False + self._quantile_values = None def get_posterior( self, name: str, prior: Distribution ) -> Union[Distribution, torch.Tensor]: if self._computing_median: return self._get_posterior_median(name, prior) + if self._computing_quantiles: + return self._get_posterior_quantiles(name, prior) with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) @@ -205,11 +209,30 @@ def median(self, *args, **kwargs): finally: self._computing_median = False + @torch.no_grad() def _get_posterior_median(self, name, prior): transform = biject_to(prior.support) loc, scale = self._get_params(name, prior) return transform(loc) + def quantiles(self, quantiles, *args, **kwargs): + self._computing_quantiles = True + self._quantile_values = quantiles + try: + return self(*args, **kwargs) + finally: + self._computing_quantiles = False + + @torch.no_grad() + def _get_posterior_quantiles(self, name, prior): + transform = biject_to(prior.support) + loc, scale = self._get_params(name, prior) + site_quantiles = torch.tensor( + self._quantile_values, dtype=loc.dtype, device=loc.device + ) + site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles) + return transform(site_quantiles_values) + class AutoHierarchicalNormalMessenger(AutoNormalMessenger): """ @@ -263,12 +286,16 @@ def __init__( self._init_weight = init_weight self._hierarchical_sites = hierarchical_sites self._computing_median = False + self._computing_quantiles = False + self._quantile_values = None def get_posterior( self, name: str, prior: Distribution ) -> Union[Distribution, torch.Tensor]: if self._computing_median: return self._get_posterior_median(name, prior) + if self._computing_quantiles: + return self._get_posterior_quantiles(name, prior) with helpful_support_errors({"name": name, "fn": prior}): transform = biject_to(prior.support) @@ -351,6 +378,7 @@ def median(self, *args, **kwargs): finally: self._computing_median = False + @torch.no_grad() def _get_posterior_median(self, name, prior): transform = biject_to(prior.support) if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): @@ -360,6 +388,29 @@ def _get_posterior_median(self, name, prior): loc, scale = self._get_params(name, prior) return transform(loc) + def quantiles(self, quantiles, *args, **kwargs): + self._computing_quantiles = True + self._quantile_values = quantiles + try: + return self(*args, **kwargs) + finally: + self._computing_quantiles = False + + @torch.no_grad() + def _get_posterior_quantiles(self, name, prior): + transform = biject_to(prior.support) + if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): + loc, scale, weight = self._get_params(name, prior) + loc = loc + transform.inv(prior.mean) * weight + else: + loc, scale = self._get_params(name, prior) + + site_quantiles = torch.tensor( + self._quantile_values, dtype=loc.dtype, device=loc.device + ) + site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles) + return transform(site_quantiles_values) + class AutoRegressiveMessenger(AutoMessenger): """ diff --git a/tests/infer/test_autoguide.py b/tests/infer/test_autoguide.py index f316ee049d..d3d7a7d7e9 100644 --- a/tests/infer/test_autoguide.py +++ b/tests/infer/test_autoguide.py @@ -522,10 +522,13 @@ def AutoGuideList_x(model): AutoLowRankMultivariateNormal, AutoLaplaceApproximation, AutoGuideList_x, + AutoNormalMessenger, + AutoHierarchicalNormalMessenger, ], ) @pytest.mark.parametrize("Elbo", [Trace_ELBO, TraceGraph_ELBO, TraceEnum_ELBO]) def test_quantiles(auto_class, Elbo): + xfail_messenger(auto_class, Elbo) def model(): pyro.sample("y", dist.LogNormal(0.0, 1.0)) pyro.sample("z", dist.Beta(2.0, 2.0).expand([2]).to_event(1))