@@ -66,10 +66,22 @@ def logpartition(self, arc_scores_in, lengths=None, force_grad=False):
6666 ]
6767 for _ in range (2 )
6868 ]
69- semiring .one_ (alpha [A ][C ][L ].data [:, :, :, 0 ].data )
70- semiring .one_ (alpha [A ][C ][R ].data [:, :, :, 0 ].data )
71- semiring .one_ (alpha [B ][C ][L ].data [:, :, :, - 1 ].data )
72- semiring .one_ (alpha [B ][C ][R ].data [:, :, :, - 1 ].data )
69+ mask = torch .zeros (alpha [A ][C ][L ].data .shape ).bool ()
70+ mask [:, :, :, 0 ].fill_ (True )
71+ alpha [A ][C ][L ].data [:] = semiring .fill (
72+ alpha [A ][C ][L ].data [:], mask , semiring .one
73+ )
74+ alpha [A ][C ][R ].data [:] = semiring .fill (
75+ alpha [A ][C ][R ].data [:], mask , semiring .one
76+ )
77+ mask = torch .zeros (alpha [B ][C ][L ].data [:].shape ).bool ()
78+ mask [:, :, :, - 1 ].fill_ (True )
79+ alpha [B ][C ][L ].data [:] = semiring .fill (
80+ alpha [B ][C ][L ].data [:], mask , semiring .one
81+ )
82+ alpha [B ][C ][R ].data [:] = semiring .fill (
83+ alpha [B ][C ][R ].data [:], mask , semiring .one
84+ )
7385
7486 if multiroot :
7587 start_idx = 0
@@ -119,10 +131,13 @@ def _check_potentials(self, arc_scores, lengths=None):
119131 lengths = torch .LongTensor ([N - 1 ] * batch ).to (arc_scores .device )
120132 assert max (lengths ) <= N , "Length longer than N"
121133 arc_scores = semiring .convert (arc_scores )
122- for b in range (batch ):
123- semiring .zero_ (arc_scores [:, b , lengths [b ] + 1 :, :])
124- semiring .zero_ (arc_scores [:, b , :, lengths [b ] + 1 :])
125134
135+ # Set the extra elements of the log-potentials to zero.
136+ keep = torch .ones_like (arc_scores ).bool ()
137+ for b in range (batch ):
138+ keep [:, b , lengths [b ] + 1 :, :].fill_ (0.0 )
139+ keep [:, b , :, lengths [b ] + 1 :].fill_ (0.0 )
140+ arc_scores = semiring .fill (arc_scores , ~ keep , semiring .zero )
126141 return arc_scores , batch , N , lengths
127142
128143 def _arrange_marginals (self , grads ):
0 commit comments