Skip to content

Commit c9b6af1

Browse files
authored
Merge pull request #524 from shenoynikhil/attn_bug_fix
Bug fix in bias repeat when multiplicity > 1
2 parents 0a3d03d + 170746f commit c9b6af1

2 files changed

Lines changed: 4 additions & 2 deletions

File tree

src/boltz/model/modules/transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def forward(
299299
c = c.view((B * NW, W, -1))
300300
if mask is not None:
301301
mask = mask.view(B * NW, W)
302+
p = p.repeat_interleave(multiplicity, 0)
302303
p = p.view((p.shape[0] * NW, W, H, -1))
303304

304305
to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1)
@@ -311,7 +312,7 @@ def forward(
311312
s=c,
312313
z=p,
313314
mask=mask.float(),
314-
multiplicity=multiplicity,
315+
multiplicity=1, # bias term already expanded with multiplicity
315316
to_keys=to_keys_new,
316317
model_cache=model_cache,
317318
)

src/boltz/model/modules/transformersv2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def forward(
243243
q = q.view((B * NW, W, -1))
244244
c = c.view((B * NW, W, -1))
245245
mask = mask.view(B * NW, W)
246+
bias = bias.repeat_interleave(multiplicity, 0)
246247
bias = bias.view((bias.shape[0] * NW, W, H, -1))
247248

248249
to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1)
@@ -253,7 +254,7 @@ def forward(
253254
s=c,
254255
bias=bias,
255256
mask=mask.float(),
256-
multiplicity=multiplicity,
257+
multiplicity=1, # bias term already expanded with multiplicity
257258
to_keys=to_keys_new,
258259
)
259260

0 commit comments

Comments
 (0)