Skip to content

Commit 727a952

Browse files
authored
Minor multi gpu doc improvement (#6649)
minor multi gpu doc improvement
1 parent 14d9afb commit 727a952

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

docs/source/process.mdx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ The [`~Dataset.map`] also works with the rank of the process if you set `with_ra
360360
>>> dataset = load_dataset("fka/awesome-chatgpt-prompts", split="train")
361361
>>>
362362
>>> # Get an example model and its tokenizer
363-
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat")
363+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B-Chat").eval()
364364
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat")
365365
>>>
366366
>>> def gpu_computation(batch, rank):
@@ -378,8 +378,9 @@ The [`~Dataset.map`] also works with the rank of the process if you set `with_ra
378378
... tokenize=False,
379379
... add_generation_prompt=True
380380
... ) for chat in chats]
381-
... model_inputs = tokenizer(texts, return_tensors="pt").to(device)
382-
... outputs = model.generate(**model_inputs, max_new_tokens=512)
381+
... model_inputs = tokenizer(texts, padding=True, return_tensors="pt").to(device)
382+
... with torch.no_grad():
383+
... outputs = model.generate(**model_inputs, max_new_tokens=512)
383384
... batch["output"] = tokenizer.batch_decode(outputs, skip_special_tokens=True)
384385
... return batch
385386
>>>

0 commit comments

Comments
 (0)