Skip to content

Commit 99d4e53

Browse files
Fix dtype issue with valid sequence length in torch.compile bs=1 (#1532)
1 parent 1553e83 commit 99d4e53

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

examples/text-generation/run_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def generate(size=None, reduce_recompile=False):
462462

463463
def compute_valid_sequence_lengths_tensor(input_tokens):
464464
attn_mask = input_tokens["attention_mask"]
465-
return torch.sum(attn_mask, dim=1)
465+
return torch.sum(attn_mask, dim=1, dtype=torch.int32)
466466

467467
valid_sequence_lengths = compute_valid_sequence_lengths_tensor(input_tokens).to(args.device)
468468
generation_config.valid_sequence_lengths = valid_sequence_lengths

0 commit comments

Comments
 (0)