@@ -206,7 +206,7 @@ def generate(
206206 top_k : Optional [int ] = None ,
207207 stop_tokens : Optional [list [int ]] = None ,
208208 rng : Optional [torch .Generator ] = None ,
209- custom_generate_next_token : Optional [Callable ] = None ,
209+ compiled_generate_next_token : Optional [Callable ] = None ,
210210) -> tuple [torch .Tensor , torch .Tensor ]:
211211 """
212212 Generates tokens from a model conditioned on a prompt, and also returns logits for the generations.
@@ -223,23 +223,48 @@ def generate(
223223 stop_tokens (Optional[list[int]]): If specified, generation is stopped when any of these tokens are generated,
224224 default None.
225225 rng (Optional[torch.Generator]): random number generator, default None.
226- custom_generate_next_token (Optional[Callable]): If specified, we'll use the
227- ``custom_generate_next_token function``. This is generally only useful if
228- you want to specify a ``torch.compile`` version of the generate next token for
229- performance reasons. If None, we use the default :func:`generate_next_token` .
226+ compiled_generate_next_token (Optional[Callable]): This argument is typically a reference to a compiled version of
227+ the :func:`generate_next_token` function. During autoregressive decoding, this function is called instead of the default
228+ :func:`generate_next_token` in order to accelerate generation. :func:`generate_next_token` will still be used for the
229+ first token generation - or "pre-fill" pass .
230230 Default is None.
231231
232232 Note:
233233 This function has only been tested with decoder-only models.
234234
235235 Examples:
236- >>> model = torchtune.models.llama3.llama3_8b()
237- >>> tokenizer = torchtune.models.llama3.llama3_tokenizer()
238- >>> prompt = tokenizer.encode("Hi my name is")
239- >>> rng.manual_seed(42)
240- >>> output, logits = generate(model, torch.tensor(prompt), max_generated_tokens=100, pad_id=0)
236+ >>> import torch
237+ >>> from torchtune.models.llama3 import llama3_tokenizer
238+ >>> from torchtune.models.llama3 import llama3_8b
239+ >>> from torchtune.generation import generate
240+ >>> from torchtune.training.checkpointing import FullModelHFCheckpointer
241+ >>> from torchtune.data import Message
242+
243+ >>> model = llama3_8b().cuda()
244+
245+ >>> checkpointer = FullModelHFCheckpointer(
246+ ... checkpoint_dir="/tmp/Meta-Llama-3-8B-Instruct",
247+ ... checkpoint_files=[
248+ ... "model-00001-of-00004.safetensors",
249+ ... "model-00002-of-00004.safetensors",
250+ ... "model-00003-of-00004.safetensors",
251+ ... "model-00004-of-00004.safetensors",
252+ ... ],
253+ ... model_type="LLAMA3",
254+ ... output_dir="/tmp/torchtune/llama3_8b",
255+ ... )
256+ >>> checkpoint = checkpointer.load_checkpoint()
257+ >>> model.load_state_dict(checkpoint["model"])
258+
259+ >>> tokenizer = llama3_tokenizer("/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model")
260+ >>> messages = [
261+ ... Message(role="assistant", content="Hi my name is"),
262+ ... ]
263+ >>> prompt = tokenizer({"messages": messages}, inference=True)
264+ >>> output, logits = generate(model, torch.tensor(prompt["tokens"], device='cuda'), max_generated_tokens=100, pad_id=0)
241265 >>> print(tokenizer.decode(output[0].tolist()))
242- Hi my name is Jeremy and I'm a friendly language model assistant!
266+
267+ >>> Hi my name is Marley. Nice to meet you, Marley! How are you doing today?... [truncated]
243268
244269 Returns:
245270 tuple[torch.Tensor, torch.Tensor]: tuple of two tensors:
@@ -251,9 +276,6 @@ def generate(
251276 """
252277 prompt = prompt .view (1 , - 1 ) if prompt .ndim == 1 else prompt
253278
254- if custom_generate_next_token is None :
255- custom_generate_next_token = generate_next_token
256-
257279 bsz , prompt_length = prompt .size ()
258280 total_response_length = prompt_length + max_generated_tokens
259281
@@ -356,6 +378,12 @@ def generate(
356378 if stop_token_reached .all ().item ():
357379 return generated_tokens , generated_logits
358380
381+ next_token_fn = (
382+ compiled_generate_next_token
383+ if compiled_generate_next_token is not None
384+ else generate_next_token
385+ )
386+
359387 for _ in range (max_generated_tokens - 1 ):
360388 # update stop_token_mask if we reached a stop token in a previous step
361389 # by appending the logical not of stop_token_reached to the end of the mask
@@ -387,7 +415,7 @@ def generate(
387415 condition = uniform_val >= 1.0 - epsilon
388416 q = - torch .where (condition , - epsilon , torch .log (uniform_val ))
389417
390- tokens , logits = custom_generate_next_token (
418+ tokens , logits = next_token_fn (
391419 model ,
392420 input_pos = curr_input_pos ,
393421 x = tokens .clone (),
@@ -409,7 +437,7 @@ def generate(
409437
410438 # mask out generated tokens in seqs that already hit a stop token
411439 if stop_tokens is not None :
412- generated_tokens *= stop_token_mask
440+ generated_tokens . masked_fill_ ( ~ stop_token_mask . bool (), pad_id )
413441 generated_logits *= stop_token_mask [:, - generated_logits .shape [1 ] :, None ]
414442
415443 return generated_tokens , generated_logits
0 commit comments