Skip to content

add WarmupStepLR#308

Merged
HiHippie merged 11 commits intomainfrom
dev_WarmupStepLR
Jun 29, 2022
Merged

add WarmupStepLR#308
HiHippie merged 11 commits intomainfrom
dev_WarmupStepLR

Conversation

@HiHippie
Copy link
Copy Markdown
Contributor

@HiHippie HiHippie commented Jun 23, 2022

该pr为LiBai添加WarmupStepLR支持,以及修复test_warmup_exponential测试代码error

  • WarmupStepLR代码
  • unittest代码
  • 测试(在vit上做了一些iter的训练,WarmupStepLR学习率衰减正常)
  • 修复test_warmup_exponential测试代码error

Auto-scaling the config to train.train_iter=100, train.warmup_iter=2 step_size=10

[06/27 08:45:41 lb.engine.trainer]: Starting training from iteration 0
[06/27 08:45:43 lb.utils.events]:  iteration: 0/100  consumed_samples: 16  total_loss: 2.344  data_time: 0.2230 s/iter  lr: 0.00e+00  
[06/27 08:45:43 lb.utils.events]:  eta: 0:00:42  iteration: 1/100  consumed_samples: 32  total_loss: 2.373  data_time: 0.2223 s/iter  lr: 5.00e-04  
[06/27 08:45:44 lb.utils.events]:  eta: 0:00:26  iteration: 2/100  consumed_samples: 48  total_loss: 2.401  time: 0.2740 s/iter  data_time: 0.2175 s/iter total_throughput: 58.38 samples/s lr: 1.00e-03  
[06/27 08:45:44 lb.utils.events]:  eta: 0:00:26  iteration: 3/100  consumed_samples: 64  total_loss: 2.706  time: 0.2739 s/iter  data_time: 0.2169 s/iter total_throughput: 58.41 samples/s lr: 1.00e-03  
[06/27 08:45:44 lb.utils.events]:  eta: 0:00:26  iteration: 4/100  consumed_samples: 80  total_loss: 2.693  time: 0.2773 s/iter  data_time: 0.2180 s/iter total_throughput: 57.70 samples/s lr: 1.00e-03  
[06/27 08:45:44 lb.utils.events]:  eta: 0:00:26  iteration: 5/100  consumed_samples: 96  total_loss: 2.648  time: 0.2784 s/iter  data_time: 0.2162 s/iter total_throughput: 57.47 samples/s lr: 1.00e-03  
[06/27 08:45:45 lb.utils.events]:  eta: 0:00:26  iteration: 6/100  consumed_samples: 112  total_loss: 2.602  time: 0.2788 s/iter  data_time: 0.2170 s/iter total_throughput: 57.39 samples/s lr: 1.00e-03  
[06/27 08:45:45 lb.utils.events]:  eta: 0:00:25  iteration: 7/100  consumed_samples: 128  total_loss: 2.57  time: 0.2788 s/iter  data_time: 0.2174 s/iter total_throughput: 57.38 samples/s lr: 1.00e-03  
[06/27 08:45:45 lb.utils.events]:  eta: 0:00:25  iteration: 8/100  consumed_samples: 144  total_loss: 2.538  time: 0.2749 s/iter  data_time: 0.2142 s/iter total_throughput: 58.21 samples/s lr: 1.00e-03  
[06/27 08:45:46 lb.evaluation.evaluator]: with eval_iter 10, reset total samples 10000 to 2560
[06/27 08:45:46 lb.evaluation.evaluator]: Start inference on 2560 samples
[06/27 08:46:02 lb.evaluation.evaluator]: Total valid samples: 2560
[06/27 08:46:02 lb.evaluation.evaluator]: Total inference time: 0:00:07.512320 (0.002940 s / iter per device, on 4 devices)
[06/27 08:46:02 lb.evaluation.evaluator]: Total inference pure compute time: 0:00:01 (0.000398 s / iter per device, on 4 devices)
[06/27 08:46:02 lb.engine.default]: Evaluation results for CIFAR10Dataset in csv format:
[06/27 08:46:02 lb.evaluation.utils]: copypaste: Acc@1=14.6875
[06/27 08:46:02 lb.evaluation.utils]: copypaste: Acc@5=49.84375
[06/27 08:46:02 lb.engine.hooks]: Saved first model at 14.68750 @ 9 steps
[06/27 08:46:03 lb.utils.checkpoint]: Saving checkpoint to output_unittest/test_vit/model_best
[06/27 08:46:09 lb.utils.events]:  eta: 0:00:25  iteration: 9/100  consumed_samples: 160  total_loss: 2.536  time: 0.2766 s/iter  data_time: 0.2146 s/iter total_throughput: 57.84 samples/s lr: 1.00e-03  
[06/27 08:46:09 lb.utils.events]:  eta: 0:00:24  iteration: 10/100  consumed_samples: 176  total_loss: 2.534  time: 0.2692 s/iter  data_time: 0.2082 s/iter total_throughput: 59.43 samples/s lr: 1.00e-04  
[06/27 08:46:10 lb.utils.events]:  eta: 0:00:24  iteration: 11/100  consumed_samples: 192  total_loss: 2.467  time: 0.2686 s/iter  data_time: 0.2083 s/iter total_throughput: 59.57 samples/s lr: 1.00e-04  
[06/27 08:46:10 lb.utils.events]:  eta: 0:00:24  iteration: 12/100  consumed_samples: 208  total_loss: 2.428  time: 0.2693 s/iter  data_time: 0.2085 s/iter total_throughput: 59.42 samples/s lr: 1.00e-04  
[06/27 08:46:10 lb.utils.events]:  eta: 0:00:23  iteration: 13/100  consumed_samples: 224  total_loss: 2.414  time: 0.2665 s/iter  data_time: 0.2071 s/iter total_throughput: 60.03 samples/s lr: 1.00e-04  
[06/27 08:46:10 lb.utils.events]:  eta: 0:00:23  iteration: 14/100  consumed_samples: 240  total_loss: 2.428  time: 0.2654 s/iter  data_time: 0.2069 s/iter total_throughput: 60.28 samples/s lr: 1.00e-04  
[06/27 08:46:11 lb.utils.events]:  eta: 0:00:23  iteration: 15/100  consumed_samples: 256  total_loss: 2.414  time: 0.2641 s/iter  data_time: 0.2061 s/iter total_throughput: 60.59 samples/s lr: 1.00e-04  
[06/27 08:46:11 lb.utils.events]:  eta: 0:00:22  iteration: 16/100  consumed_samples: 272  total_loss: 2.401  time: 0.2621 s/iter  data_time: 0.2048 s/iter total_throughput: 61.05 samples/s lr: 1.00e-04  
[06/27 08:46:11 lb.utils.events]:  eta: 0:00:22  iteration: 17/100  consumed_samples: 288  total_loss: 2.389  time: 0.2597 s/iter  data_time: 0.2030 s/iter total_throughput: 61.62 samples/s lr: 1.00e-04  
[06/27 08:46:11 lb.utils.events]:  eta: 0:00:21  iteration: 18/100  consumed_samples: 304  total_loss: 2.378  time: 0.2581 s/iter  data_time: 0.2020 s/iter total_throughput: 62.00 samples/s lr: 1.00e-04  
[06/27 08:46:12 lb.evaluation.evaluator]: with eval_iter 10, reset total samples 10000 to 2560
[06/27 08:46:12 lb.evaluation.evaluator]: Start inference on 2560 samples
[06/27 08:46:28 lb.evaluation.evaluator]: Total valid samples: 2560
[06/27 08:46:28 lb.evaluation.evaluator]: Total inference time: 0:00:07.433129 (0.002909 s / iter per device, on 4 devices)
[06/27 08:46:28 lb.evaluation.evaluator]: Total inference pure compute time: 0:00:00 (0.000322 s / iter per device, on 4 devices)
[06/27 08:46:28 lb.engine.default]: Evaluation results for CIFAR10Dataset in csv format:
[06/27 08:46:28 lb.evaluation.utils]: copypaste: Acc@1=14.4921875
[06/27 08:46:28 lb.evaluation.utils]: copypaste: Acc@5=48.515625
[06/27 08:46:28 lb.engine.hooks]: Not saving as latest eval score for Acc@1 is 14.49219, not better than best score 14.68750 @ iteration 9.
[06/27 08:46:28 lb.utils.events]:  eta: 0:00:20  iteration: 19/100  consumed_samples: 320  total_loss: 2.37  time: 0.2574 s/iter  data_time: 0.2016 s/iter total_throughput: 62.16 samples/s lr: 1.00e-04  
[06/27 08:46:29 lb.utils.events]:  eta: 0:00:19  iteration: 20/100  consumed_samples: 336  total_loss: 2.362  time: 0.2561 s/iter  data_time: 0.1996 s/iter total_throughput: 62.48 samples/s lr: 1.00e-05  
[06/27 08:46:29 lb.utils.events]:  eta: 0:00:19  iteration: 21/100  consumed_samples: 352  total_loss: 2.37  time: 0.2556 s/iter  data_time: 0.1984 s/iter total_throughput: 62.59 samples/s lr: 1.00e-05  
[06/27 08:46:29 lb.utils.events]:  eta: 0:00:19  iteration: 22/100  consumed_samples: 368  total_loss: 2.362  time: 0.2548 s/iter  data_time: 0.1975 s/iter total_throughput: 62.79 samples/s lr: 1.00e-05  
[06/27 08:46:29 lb.utils.events]:  eta: 0:00:18  iteration: 23/100  consumed_samples: 384  total_loss: 2.362  time: 0.2542 s/iter  data_time: 0.1964 s/iter total_throughput: 62.95 samples/s lr: 1.00e-05  
[06/27 08:46:30 lb.utils.events]:  eta: 0:00:18  iteration: 24/100  consumed_samples: 400  total_loss: 2.362  time: 0.2542 s/iter  data_time: 0.1955 s/iter total_throughput: 62.95 samples/s lr: 1.00e-05  
[06/27 08:46:30 lb.utils.events]:  eta: 0:00:18  iteration: 25/100  consumed_samples: 416  total_loss: 2.364  time: 0.2546 s/iter  data_time: 0.1956 s/iter total_throughput: 62.84 samples/s lr: 1.00e-05  
[06/27 08:46:30 lb.utils.events]:  eta: 0:00:18  iteration: 26/100  consumed_samples: 432  total_loss: 2.362  time: 0.2542 s/iter  data_time: 0.1944 s/iter total_throughput: 62.93 samples/s lr: 1.00e-05  
[06/27 08:46:30 lb.utils.events]:  eta: 0:00:17  iteration: 27/100  consumed_samples: 448  total_loss: 2.362  time: 0.2538 s/iter  data_time: 0.1930 s/iter total_throughput: 63.04 samples/s lr: 1.00e-05  
[06/27 08:46:31 lb.utils.events]:  eta: 0:00:17  iteration: 28/100  consumed_samples: 464  total_loss: 2.362  time: 0.2537 s/iter  data_time: 0.1935 s/iter total_throughput: 63.08 samples/s lr: 1.00e-05 

@HiHippie HiHippie requested review from CPFLAME and Ldpe2G June 27, 2022 08:54
@Ldpe2G Ldpe2G requested a review from oneflow-ci-bot June 28, 2022 06:58
def _get_exponential_lr(base_lr, gamma, max_iters, warmup_iters):
valid_values = []
for idx in range(max_iters - warmup_iters):
for idx in range(warmup_iters, max_iters+1):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改动test_warmup_exponential是因为做测试时候发现这里有bug测试不通过

136行改动是因为max_iters - warmup_iters会从0开始,但0-5是warmup区间,会导致seertEqual不通过。

unittest我之前没接触过,帮忙把把关,我不确定这样改是不是符合测试的目的~
@CPFLAME @Ldpe2G

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉应该是对的,这里是收集从warmup_iters到max_iter之间的lr

sched.step()
lrs.append(opt.param_groups[0]["lr"])
self.assertTrue(np.allclose(lrs[:5], [0.005, 1.004, 2.003, 3.002, 4.001]))
self.assertTrue(np.allclose(lrs[:5], [0.005, 0.00401, 0.0030199999999999997, 0.00203, 0.0010399999999999997]))
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不通过是因为这里的warmup中lr是下降的

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warmup 时候 lr 是下降的?这个是正常行为么?

Copy link
Copy Markdown
Contributor Author

@HiHippie HiHippie Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

image

应该是正常行为,end_factor < start_factor时候就是下降的

这段代码调用test_warmup_exponential更新的lr计算end_factor,导致其变小。

https://github.com/Oneflow-Inc/oneflow/blob/b17a9cd6b930b5817c63623fb682bd708377a93b/python/oneflow/nn/optimizer/warmup_lr.py#L149-L158

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

@HiHippie HiHippie requested review from oneflow-ci-bot and removed request for oneflow-ci-bot June 29, 2022 07:13
@HiHippie HiHippie requested review from oneflow-ci-bot and removed request for oneflow-ci-bot June 29, 2022 07:29
@HiHippie HiHippie requested review from oneflow-ci-bot and removed request for oneflow-ci-bot June 29, 2022 07:58
"""
polynomial_lr = flow.optim.lr_scheduler.PolynomialLR(
optimizer, steps=max_iter, end_learning_rate=end_learning_rate, power=power, cycle=cycle
optimizer, decay_batch=max_iter, end_learning_rate=end_learning_rate, power=power, cycle=cycle
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为我的oneflow版本不是最新的,所以本地测试通过,但准备merge的时候发现这里的key_arg有变动。
参考 Oneflow-Inc/models#349 这个pr

所以这里同步做了修改

@HiHippie HiHippie requested review from oneflow-ci-bot and removed request for oneflow-ci-bot June 29, 2022 08:03
@HiHippie HiHippie merged commit 9f433f0 into main Jun 29, 2022
@HiHippie HiHippie deleted the dev_WarmupStepLR branch June 29, 2022 08:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants