-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Support obs_mask kwarg in sample statements #2772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
It is nice to have a simple api for users like this. My main concern is this will draw unnecessary masked latent statement in HMC. It might not be important because those extra values does not contribute to the log density, so its trajectory is just random walk (modulo some geometric quantity like u-turn condition). Regarding the api, how does this new keyword interact with event_dim? IIUC users will need to mask all invalid event dimensions I guess. |
This interface does not support masking of events. Actually I think event masking would be nonsensical: each event should be either entirely observed or drawn jointly in entirety. It seems weird to interleave pieces of a so-called event. I think of 'event' as being an 'atomic' sample that cannot be split.
Agreed, this seems similar to the issue I mentioned with # masked -> contiguous
contiguous_sample = masked_sample[mask]
# contiguous -> masked
masked_sample = zeros(...)
masked_sample[mask] = contiguous_sampleHowever this optimization might be invalid if the mask depends on upstream sample statements. I think this deserves further discussion in a separate issue. |
Should we raise an error in this case then? I assume that users might write codes like with plate("a", A):
sample("x", dist.Normal(...).to_event(1), obs=..., obs_mask=ambiguous_shape?)
Sound like a proposal for MaskTransform... ;) |
|
@fehiepsi good point. I've fixed handling of multivariate distributions, added a shape check and informative error message, and added multivariate tests for both ok and bad shapes.
Interesting, we could even add a |
| with poutine.mask(mask=obs_mask): | ||
| observed = sample(f"{name}_observed", fn, *args, **kwargs, obs=obs) | ||
| with poutine.mask(mask=~obs_mask): | ||
| unobserved = sample(f"{name}_unobserved", fn, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool so this will let downstream sample statements depend on the sampled unobserved values right (since they will appear in value together with observed)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. And it took less than three years to figure out how to do it 😄
|
What is the expected behavior of |
The |
It seems like we'd want some mechanism for optionally disregarding the observed part, though, perhaps via |
eb8680
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
fehiepsi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM too!
Resolves #1106
This implements
pyro.sample(..., obs=___, obs_mask=___)by transforming that one statement into a pair ofpoutine.maskedsamplestatements plus a finaldeterministicstatement. This solution is simple and non-invasive, with changes being limited to thepyro.sample()function.One disadvantage of this approach is that the new latent variable site is named
name + "_unobserved"rather than simplyname. However fixing this would require a much more invasive change. In particular it is unclear how handlers should treat a single site that is partially observed and partially sampled. I suspect this 3-site-name approach is actually optimal.I think to get this to work with
AutoGuides we may need to add masking logic in autoguides, otherwise I think some entirely-local guides may diverge to infinite entropy and eventually NAN out. However since that is really a shortcoming of handling ofpoutine.maskrather thanobs_mask, I have not addressed autoguides in this PR.Tested