Skip to content

Commit 24b288f

Browse files
authored
Merge pull request #16 from ROCmSoftwarePlatform/adabeyta_update_hf_training
Removed hardcoded warmup steps.
2 parents 8097220 + 8cc1f10 commit 24b288f

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/transformers/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,8 +1876,8 @@ def _inner_training_loop(
18761876

18771877
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
18781878

1879-
total_samples = args.max_steps*total_train_batch_size if args.max_steps > 0 else num_examples*num_train_epochs
1880-
perf_samples = total_samples - 10*total_train_batch_size
1879+
total_samples = self.state.global_step*total_train_batch_size if args.max_steps > 0 else num_examples*num_train_epochs
1880+
perf_samples = total_samples - self.args.warmup_steps*total_train_batch_size
18811881
stable_train_metrics = speed_metrics("stable_train", start_train_stable_time, perf_samples)
18821882

18831883
self.store_flos()

src/transformers/training_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ class TrainingArguments:
568568
warmup_ratio: float = field(
569569
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
570570
)
571-
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
571+
warmup_steps: int = field(default=10, metadata={"help": "Linear warmup over warmup_steps."})
572572

573573
log_level: Optional[str] = field(
574574
default="passive",

0 commit comments

Comments
 (0)