Skip to content

Commit 64ae3b1

Browse files
wszczurekhabanaLiangyx2
authored andcommitted
Fix dtype issue with valid sequence length in torch.compile bs=1 (huggingface#1532)
1 parent 52f4cf4 commit 64ae3b1

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

examples/text-generation/run_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def generate(size=None, reduce_recompile=False):
462462

463463
def compute_valid_sequence_lengths_tensor(input_tokens):
464464
attn_mask = input_tokens["attention_mask"]
465-
return torch.sum(attn_mask, dim=1)
465+
return torch.sum(attn_mask, dim=1, dtype=torch.int32)
466466

467467
valid_sequence_lengths = compute_valid_sequence_lengths_tensor(input_tokens).to(args.device)
468468
generation_config.valid_sequence_lengths = valid_sequence_lengths

0 commit comments

Comments
 (0)