Skip to content

Conversation

@fritzo
Copy link
Member

@fritzo fritzo commented Feb 24, 2021

Resolves #1106

This implements pyro.sample(..., obs=___, obs_mask=___) by transforming that one statement into a pair of poutine.masked sample statements plus a final deterministic statement. This solution is simple and non-invasive, with changes being limited to the pyro.sample() function.

One disadvantage of this approach is that the new latent variable site is named name + "_unobserved" rather than simply name. However fixing this would require a much more invasive change. In particular it is unclear how handlers should treat a single site that is partially observed and partially sampled. I suspect this 3-site-name approach is actually optimal.

I think to get this to work with AutoGuides we may need to add masking logic in autoguides, otherwise I think some entirely-local guides may diverge to infinite entropy and eventually NAN out. However since that is really a shortcoming of handling of poutine.mask rather than obs_mask, I have not addressed autoguides in this PR.

Tested

  • ELBO smoke tests in tests/infer/test_valid_models.py
  • MCMC smoke tests in tests/infer/mcmc/test_valid_models.py

@fehiepsi
Copy link
Member

It is nice to have a simple api for users like this. My main concern is this will draw unnecessary masked latent statement in HMC. It might not be important because those extra values does not contribute to the log density, so its trajectory is just random walk (modulo some geometric quantity like u-turn condition).

Regarding the api, how does this new keyword interact with event_dim? IIUC users will need to mask all invalid event dimensions I guess.

@fritzo
Copy link
Member Author

fritzo commented Feb 24, 2021

how does this new keyword interact with event_dim?

This interface does not support masking of events. Actually I think event masking would be nonsensical: each event should be either entirely observed or drawn jointly in entirety. It seems weird to interleave pieces of a so-called event. I think of 'event' as being an 'atomic' sample that cannot be split.

My main concern is this will draw unnecessary masked latent statement in HMC

Agreed, this seems similar to the issue I mentioned with AutoGuide. And similarly I believe this is a more basic issue with poutine.mask, and not specific to obs_mask. I think an optimal strategy could be to transform masked sample statements to contiguous unmasked statements like

# masked -> contiguous
contiguous_sample = masked_sample[mask]

# contiguous -> masked
masked_sample = zeros(...)
masked_sample[mask] = contiguous_sample

However this optimization might be invalid if the mask depends on upstream sample statements. I think this deserves further discussion in a separate issue.

@fritzo fritzo added this to the 1.6 release milestone Feb 24, 2021
@fehiepsi
Copy link
Member

This interface does not support masking of events

Should we raise an error in this case then? I assume that users might write codes like

with plate("a", A):
    sample("x", dist.Normal(...).to_event(1), obs=..., obs_mask=ambiguous_shape?)

an optimal strategy could be to transform masked sample statements to contiguous unmasked statements

Sound like a proposal for MaskTransform... ;)

@fritzo
Copy link
Member Author

fritzo commented Feb 25, 2021

@fehiepsi good point. I've fixed handling of multivariate distributions, added a shape check and informative error message, and added multivariate tests for both ok and bad shapes.

MaskTransform

Interesting, we could even add a MaskReparam 🤔

with poutine.mask(mask=obs_mask):
observed = sample(f"{name}_observed", fn, *args, **kwargs, obs=obs)
with poutine.mask(mask=~obs_mask):
unobserved = sample(f"{name}_unobserved", fn, *args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool so this will let downstream sample statements depend on the sampled unobserved values right (since they will appear in value together with observed)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. And it took less than three years to figure out how to do it 😄

@eb8680
Copy link
Member

eb8680 commented Feb 25, 2021

What is the expected behavior of obs_masks under Predictive?

@fritzo
Copy link
Member Author

fritzo commented Feb 25, 2021

What is the expected behavior of obs_masks under Predictive?

The observed part should just be provided by the user, the unobserved part should be sampled from the posterior, and finally these two parts are interleaved and returned.

@eb8680
Copy link
Member

eb8680 commented Feb 25, 2021

The observed part should just be provided by the user, the unobserved part should be sampled from the posterior, and finally these two parts are interleaved and returned.

It seems like we'd want some mechanism for optionally disregarding the observed part, though, perhaps via poutine.uncondition if not automatically in Predictive? Maybe the current version of uncondition is sufficient? At any rate, I guess we can leave higher-level interfaces to future issues/PRs.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fritzo
Copy link
Member Author

fritzo commented Feb 25, 2021

@eb8680 @fehiepsi thanks for reviewing!

@eb8680 I agree there is more to think about wrt uncondition and posterior predictive etc.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM too!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support masked sample statement for batched imputation

5 participants