Skip to content

Commit 6b347df

Browse files
committed
incorporating the review
1 parent a65a9d6 commit 6b347df

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

examples/trl/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 grpo.py \
4141
--num_generations 4 \
4242
--max_completion_length 64 \
4343
--use_peft True \
44-
--lora_target_modules q_proj, k_proj \
44+
--lora_target_modules q_proj k_proj \
4545
--num_train_epochs 1 \
4646
--save_strategy="epoch"
4747
```
@@ -66,7 +66,7 @@ PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py --wo
6666
--num_generations 4 \
6767
--max_completion_length 64 \
6868
--use_peft True \
69-
--lora_target_modules q_proj, k_proj \
69+
--lora_target_modules q_proj k_proj \
7070
--max_steps=500 \
7171
--logging_steps=10 \
7272
--save_steps=100

examples/trl/grpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class ScriptArguments:
7979
model_name_or_path: Optional[str] = field(default="Qwen/Qwen2-0.5B-Instruct", metadata={"help": "the model name"})
8080
dataset_name: Optional[str] = field(default=None, metadata={"help": "the dataset name"})
8181
use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"})
82-
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
82+
num_workers: Optional[int] = field(default=1, metadata={"help": "the number of workers"})
8383
subset: Optional[str] = field(default=None, metadata={"help": "the subset to use"})
8484
streaming: Optional[bool] = field(default=False, metadata={"help": "whether to stream the dataset"})
8585
dataset_train_split: str = field(default="train[:5%]", metadata={"help": "Dataset split to use for training."})

examples/trl/requirements_grpo.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
trl == 0.17.0
22
peft == 0.12.0
3-
datasets == 3.0.0
3+
datasets
44
tyro
55
evaluate
66
scikit-learn == 1.5.2

optimum/habana/trl/trainer/grpo_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
selective_log_softmax,
5050
)
5151

52-
from optimum.habana.transformers import trainer as habana_trainer
53-
from optimum.habana.transformers.trainer import _get_input_update_settings
52+
from ...transformers import trainer as habana_trainer
53+
from ...transformers.trainer import _get_input_update_settings
5454
from optimum.utils import logging
5555

5656
from ... import GaudiConfig, GaudiTrainer

0 commit comments

Comments
 (0)