Skip to content

Commit c7a92e4

Browse files
Generation fixes (#2787)
1 parent 45326e3 commit c7a92e4

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

tests/torchtune/generation/test_generation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ def test_stop_tokens_batched_uneven_stopping(
406406
model = request.getfixturevalue(model)
407407
temperature = 0.6
408408
top_k = 100
409+
pad_id = -100
409410

410411
stop_tokens = [3991, 3987, 3969]
411412

@@ -418,13 +419,14 @@ def test_stop_tokens_batched_uneven_stopping(
418419
temperature=temperature,
419420
top_k=top_k,
420421
stop_tokens=stop_tokens,
422+
pad_id=pad_id,
421423
)
422424

423425
expected_output = torch.tensor(
424426
[
425427
[2, 3, 4, 5, 6, 7, 8, 9, 3954, 3920, 3991],
426-
[2, 3, 4, 5, 6, 7, 8, 9, 3983, 3987, 0],
427-
[2, 3, 4, 5, 6, 7, 8, 9, 3969, 0, 0],
428+
[2, 3, 4, 5, 6, 7, 8, 9, 3983, 3987, pad_id],
429+
[2, 3, 4, 5, 6, 7, 8, 9, 3969, pad_id, pad_id],
428430
]
429431
)
430432

@@ -445,7 +447,7 @@ def test_stop_tokens_batched_uneven_stopping_left_padded(
445447
model = request.getfixturevalue(model)
446448
temperature = 0.6
447449
top_k = 100
448-
450+
pad_id = -100
449451
# Updated stop tokens to match the new generated tokens
450452
stop_tokens = [3991, 3987, 3969]
451453

@@ -458,13 +460,14 @@ def test_stop_tokens_batched_uneven_stopping_left_padded(
458460
temperature=temperature,
459461
top_k=top_k,
460462
stop_tokens=stop_tokens,
463+
pad_id=pad_id,
461464
)
462465

463466
expected_output = torch.tensor(
464467
[
465468
[0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 3954, 3920, 3991],
466-
[0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 3983, 3987, 0],
467-
[0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 3969, 0, 0],
469+
[0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 3983, 3987, pad_id],
470+
[0, 0, 2, 3, 4, 5, 6, 7, 8, 9, 3969, pad_id, pad_id],
468471
]
469472
)
470473
assert torch.equal(outputs, expected_output)

torchtune/generation/_generation.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def generate(
206206
top_k: Optional[int] = None,
207207
stop_tokens: Optional[list[int]] = None,
208208
rng: Optional[torch.Generator] = None,
209-
custom_generate_next_token: Optional[Callable] = None,
209+
compiled_generate_next_token: Optional[Callable] = None,
210210
) -> tuple[torch.Tensor, torch.Tensor]:
211211
"""
212212
Generates tokens from a model conditioned on a prompt, and also returns logits for the generations.
@@ -223,23 +223,48 @@ def generate(
223223
stop_tokens (Optional[list[int]]): If specified, generation is stopped when any of these tokens are generated,
224224
default None.
225225
rng (Optional[torch.Generator]): random number generator, default None.
226-
custom_generate_next_token (Optional[Callable]): If specified, we'll use the
227-
``custom_generate_next_token function``. This is generally only useful if
228-
you want to specify a ``torch.compile`` version of the generate next token for
229-
performance reasons. If None, we use the default :func:`generate_next_token`.
226+
compiled_generate_next_token (Optional[Callable]): This argument is typically a reference to a compiled version of
227+
the :func:`generate_next_token` function. During autoregressive decoding, this function is called instead of the default
228+
:func:`generate_next_token` in order to accelerate generation. :func:`generate_next_token` will still be used for the
229+
first token generation - or "pre-fill" pass.
230230
Default is None.
231231
232232
Note:
233233
This function has only been tested with decoder-only models.
234234
235235
Examples:
236-
>>> model = torchtune.models.llama3.llama3_8b()
237-
>>> tokenizer = torchtune.models.llama3.llama3_tokenizer()
238-
>>> prompt = tokenizer.encode("Hi my name is")
239-
>>> rng.manual_seed(42)
240-
>>> output, logits = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0)
236+
>>> import torch
237+
>>> from torchtune.models.llama3 import llama3_tokenizer
238+
>>> from torchtune.models.llama3 import llama3_8b
239+
>>> from torchtune.generation import generate
240+
>>> from torchtune.training.checkpointing import FullModelHFCheckpointer
241+
>>> from torchtune.data import Message
242+
243+
>>> model = llama3_8b().cuda()
244+
245+
>>> checkpointer = FullModelHFCheckpointer(
246+
... checkpoint_dir="/tmp/Meta-Llama-3-8B-Instruct",
247+
... checkpoint_files=[
248+
... "model-00001-of-00004.safetensors",
249+
... "model-00002-of-00004.safetensors",
250+
... "model-00003-of-00004.safetensors",
251+
... "model-00004-of-00004.safetensors",
252+
... ],
253+
... model_type="LLAMA3",
254+
... output_dir="/tmp/torchtune/llama3_8b",
255+
... )
256+
>>> checkpoint = checkpointer.load_checkpoint()
257+
>>> model.load_state_dict(checkpoint["model"])
258+
259+
>>> tokenizer = llama3_tokenizer("/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model")
260+
>>> messages = [
261+
... Message(role="assistant", content="Hi my name is"),
262+
... ]
263+
>>> prompt = tokenizer({"messages": messages}, inference=True)
264+
>>> output, logits = generate(model, torch.tensor(prompt["tokens"], device='cuda'), max_generated_tokens=100, pad_id=0)
241265
>>> print(tokenizer.decode(output[0].tolist()))
242-
Hi my name is Jeremy and I'm a friendly language model assistant!
266+
267+
>>> Hi my name is Marley. Nice to meet you, Marley! How are you doing today?... [truncated]
243268
244269
Returns:
245270
tuple[torch.Tensor, torch.Tensor]: tuple of two tensors:
@@ -251,9 +276,6 @@ def generate(
251276
"""
252277
prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt
253278

254-
if custom_generate_next_token is None:
255-
custom_generate_next_token = generate_next_token
256-
257279
bsz, prompt_length = prompt.size()
258280
total_response_length = prompt_length + max_generated_tokens
259281

@@ -356,6 +378,12 @@ def generate(
356378
if stop_token_reached.all().item():
357379
return generated_tokens, generated_logits
358380

381+
next_token_fn = (
382+
compiled_generate_next_token
383+
if compiled_generate_next_token is not None
384+
else generate_next_token
385+
)
386+
359387
for _ in range(max_generated_tokens - 1):
360388
# update stop_token_mask if we reached a stop token in a previous step
361389
# by appending the logical not of stop_token_reached to the end of the mask
@@ -387,7 +415,7 @@ def generate(
387415
condition = uniform_val >= 1.0 - epsilon
388416
q = -torch.where(condition, -epsilon, torch.log(uniform_val))
389417

390-
tokens, logits = custom_generate_next_token(
418+
tokens, logits = next_token_fn(
391419
model,
392420
input_pos=curr_input_pos,
393421
x=tokens.clone(),
@@ -409,7 +437,7 @@ def generate(
409437

410438
# mask out generated tokens in seqs that already hit a stop token
411439
if stop_tokens is not None:
412-
generated_tokens *= stop_token_mask
440+
generated_tokens.masked_fill_(~stop_token_mask.bool(), pad_id)
413441
generated_logits *= stop_token_mask[:, -generated_logits.shape[1] :, None]
414442

415443
return generated_tokens, generated_logits

0 commit comments

Comments
 (0)