Add alignment CRF test. Fix missing fill_()#109
Add alignment CRF test. Fix missing fill_()#109JohnReid wants to merge 14 commits intoharvardnlp:masterfrom
Conversation
|
I realised I hadn't tested many of the distribution properties. I've tried tests for few more but it looks like there are at least two more issues to resolve. |
torch_struct/alignment.py
Outdated
| charta[1][:, b, point:, 1, ind, :, :, Mid] = semiring.one_( | ||
| charta[1][:, b, point:, 1, ind, :, :, Mid] | ||
| ) | ||
| charta[1][:, b, point:, 1, ind, :, :, Mid] = charta[1][:, b, point:, 1, ind, :, :, Mid].fill_(0) |
There was a problem hiding this comment.
Unfortunately this is not going to work.
We need to call
init = torch.zeros(charta[1].shape).bool()
init[:, b, point:, 1, ind, :, :, Mid].fill_(True)
charta[1] = semiring.fill(charta[1], init, semiring.one)
There was a problem hiding this comment.
(this should fix your other issues too)
There was a problem hiding this comment.
Great, thanks for that. I have to admit I had just copied the code from the one_() method before it was removed in #105. My assumption was that it was the correct code.
There was a problem hiding this comment.
I'm still facing a few issues though. I fixed a few of them in the commits below but some remain. The main sticking point seems to be that the BandedMatrixs are not correctly dispatched to multiply rather than matmul() in semirings.py. The matmul implementation only works for standard tensors. This affects dist.entropy, dist.sample(), dist.topk() but not the partition, argmax, marginals.
I tried to fix this rather naively by overloading the classmethod matmul in some of the semirings but this broke the existing tests. I backed that out and am trying to understand how the code relates to the description in the torch struct paper so that I can make the correct fix.
PR following up discussion here.
For the tests to pass I also had to update
genbmm. See PR here.Note that the tests only check the shape of the
argmaxandmarginals. The values are not checked.