Value expectation and 1st order CKY#93
Conversation
|
Hmm, for some reason the tests are not running on this. Trying to figure out why. |
| .sum(2) | ||
| ) | ||
| assert torch.isclose( | ||
| E_val, log_probs.exp().unsqueeze(-1).mul(struct_vals).sum(0) |
There was a problem hiding this comment.
Just curious. Why not just make this the implementation of expected value? It seems just as good and perhaps more efficient.y
There was a problem hiding this comment.
Sorry, maybe I'm confused but isn't this enumerating over all possible structures explicitly?
There was a problem hiding this comment.
Oh sorry, my comment is confusing.
I think a valid way of computing an expectation over any "part-level value" is to first compute the marginals (.marginals()) and then doing an elementwise mul (.mul) and then summing. Doesn't that give you the same thing as the semiring?
There was a problem hiding this comment.
Oh wow, I didn't realize this! I just tested it out and it appears to be more efficient for larger structure sizes. I guess this is due to the fast log semiring implementation? I'll update things to use this approach instead.
There was a problem hiding this comment.
Yeah, I think that is right... I haven't thought about this too much, but my guess is that this is just better on GPU hardware since the expectation is batched at the end. But it seems worth understand when this works. I don't think you can compute Entropy this way? (but I might be wrong)
There was a problem hiding this comment.
Makes sense. I also don't think entropy can be done this way -- I just tested it out and the results didn't match the semiring. I will switch to this implementation in the latest commit and get rid of the value semiring.
Fwiw I ran a quick speed comparison you might be interested in:
B, N, C = 4, 200, 10
phis = torch.randn(B,N,C,C).cuda()
vals = torch.randn(B,N,C,C,10).cuda()
Results from running w/ genbmm
%%timeit
LinearChainCRF(phis).expected_value(vals)
>>> 100 loops, best of 3: 6.34 ms per loop
%%timeit
LinearChainCRF(phis).marginals.unsqueeze(-1).mul(vals).reshape(B,-1,vals.shape[-1]).sum(1)
>>> 100 loops, best of 3: 5.64 ms per loop
Results from running w/o genbmm
%%timeit
LinearChainCRF(phis).expected_value(vals)
>>> 100 loops, best of 3: 9.67 ms per loop
%%timeit
LinearChainCRF(phis).marginals.unsqueeze(-1).mul(vals).reshape(B,-1,vals.shape[-1]).sum(1)
>>> 100 loops, best of 3: 8.83 ms per loop
torch_struct/distributions.py
Outdated
| """ | ||
| Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. | ||
|
|
||
| Params: |
There was a problem hiding this comment.
This should be "Parameters:"
torch_struct/distributions.py
Outdated
| Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. | ||
|
|
||
| Params: | ||
| * values (*batch_shape x *event_shape, *value_shape): torch.FloatTensor that assigns a value to each part |
There was a problem hiding this comment.
Let's put the types in the first parens, and use :class:torch.FloatTensor
| samples = [] | ||
| for k in range(nsamples): | ||
| if k % 10 == 0: | ||
| if k % batch_size == 0: |
There was a problem hiding this comment.
Oh yeah, sorry this is my fault. 10 is a global constant. Let's put it on MultiSampledSemiring.
torch_struct/distributions.py
Outdated
| Implementation uses width-batched, forward-pass only | ||
|
|
||
| * Parallel Time: :math:`O(N)` parallel merges. | ||
| * Forward Memory: :math:`O(N^2)` |
There was a problem hiding this comment.
This can't be right... isn't the event shape O(N^3) alone?
There was a problem hiding this comment.
Oops yeah that's from modifying the CKYCRF class
torch_struct/full_cky_crf.py
Outdated
| @@ -0,0 +1,114 @@ | |||
| import torch | |||
| from .helpers import _Struct, Chart | |||
| from tqdm import tqdm | |||
There was a problem hiding this comment.
Be sure to run python setup.py style to run flake8 . It will catch these errors.
torch_struct/helpers.py
Outdated
|
|
||
| Returns: | ||
| v (torch.Tensor) : the resulting output of the dynammic program | ||
| edges (List[torch.Tensor]): the log edge potentials of the model. |
There was a problem hiding this comment.
changing this to logpotentials throughout.
torch_struct/helpers.py
Outdated
| [scores], as in `Alignment`, `LinearChain`, `SemiMarkov`, `CKY_CRF`. | ||
| An exceptional case is the `CKY` struct, which takes log potential parameters from production rules | ||
| for a PCFG, which are by definition independent of position in the sequence. | ||
| charts: Optional[List[Chart]] = None, the charts used in computing the dp. They are needed if we want to run the |
There was a problem hiding this comment.
Going to remove this for simplicity.
| for k in range(v.shape[0]): | ||
| obj = v[k].sum(dim=0) | ||
|
|
||
| with torch.autograd.enable_grad(): # in case input potentials don't have grads enabled. |
torch_struct/semirings/semirings.py
Outdated
| return xs | ||
|
|
||
|
|
||
| def ValueExpectationSemiring(k): |
There was a problem hiding this comment.
Are you sure we don't have this already? Could have sworn someone added it.
There was a problem hiding this comment.
I'm not 100% sure, I looked and hadn't seen it anywhere in master so I went ahead with it. Maybe it's in another branch? There's the entropy semiring which is very similar.
|
Thanks the PR. Lots of nice stuff in here. |
|
Quick dev question: when I try running |
|
Interesting, yeah not sure how to run those automatically, I will look into
it.
|
Changes are: