-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[AMP] add state_dict and load_state_dict and unittest for class GradScaler #34300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
c615f3d
1dce33a
f917dff
eb96c24
e2f855a
b74be1b
81aa57b
5085547
c039e7e
5b699d3
97453ed
f85b0e4
d171365
42561fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -237,6 +237,37 @@ def test_get_and_set(self): | |
| scaler.set_init_loss_scaling(100) | ||
| self.assertEqual(scaler.get_init_loss_scaling() == 100, True) | ||
|
|
||
| def test_state_dict_and_load_state_dict(self): | ||
| with fluid.dygraph.guard(): | ||
| scaler1 = paddle.amp.GradScaler( | ||
| enable=True, | ||
| init_loss_scaling=14, | ||
| incr_ratio=233.0, | ||
| decr_ratio=0.523, | ||
| incr_every_n_steps=1090, | ||
| decr_every_n_nan_or_inf=20, | ||
| use_dynamic_loss_scaling=True) | ||
| scaler_state = scaler1.state_dict() | ||
| scaler2 = paddle.amp.GradScaler(enable=True) | ||
| scaler2.load_state_dict(scaler_state) | ||
| self.assertEqual(scaler2.get_init_loss_scaling() == 14, True) | ||
| self.assertEqual(scaler2.get_incr_ratio() == 233.0, True) | ||
| self.assertEqual(scaler2.get_decr_ratio() == 0.523, True) | ||
| self.assertEqual(scaler2.get_incr_every_n_steps() == 1090, True) | ||
| self.assertEqual(scaler2.get_decr_every_n_nan_or_inf() == 20, True) | ||
|
|
||
| scaler3 = paddle.amp.GradScaler(enable=False) | ||
| scaler3.load_state_dict(scaler_state) | ||
| self.assertEqual(scaler3.is_enable() == False, True) | ||
|
|
||
| def test_state_dict_and_load_state_dict_error(self): | ||
| def test_error(): | ||
| state_empty = {} | ||
| scaler = paddle.amp.GradScaler(enable=True) | ||
| scaler.load_state_dict(state_empty) | ||
|
|
||
| self.assertRaises(RuntimeError, test_error) | ||
|
|
||
|
|
||
| def reader_decorator(reader): | ||
| def __reader__(): | ||
|
|
@@ -248,6 +279,114 @@ def __reader__(): | |
| return __reader__ | ||
|
|
||
|
|
||
| class TestGradScalerStateDict(unittest.TestCase): | ||
| def train_resnet(self, | ||
| enable_amp=True, | ||
| use_data_loader=True, | ||
| use_save_load=True): | ||
| seed = 90 | ||
|
|
||
| EPOCH_NUM = 4 # 设置外层循环次数 | ||
|
|
||
| batch_size = train_parameters["batch_size"] | ||
| batch_num = 1 | ||
|
|
||
| paddle.seed(seed) | ||
| paddle.framework.random._manual_program_seed(seed) | ||
|
|
||
| resnet = ResNet(use_cudnn=True) | ||
| optimizer = optimizer_setting( | ||
| train_parameters, parameter_list=resnet.parameters()) | ||
| np.random.seed(seed) | ||
| train_reader = paddle.batch( | ||
| paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size) | ||
|
|
||
| dy_param_init_value = {} | ||
| for param in resnet.parameters(): | ||
| dy_param_init_value[param.name] = param.numpy() | ||
|
|
||
| program = None | ||
| scaler = paddle.amp.GradScaler( | ||
| enable=enable_amp, init_loss_scaling=2.**10) | ||
|
|
||
| if use_data_loader: | ||
| train_reader = paddle.batch( | ||
| reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), | ||
| batch_size=batch_size, | ||
| drop_last=True) | ||
| train_loader = fluid.io.DataLoader.from_generator( | ||
| capacity=4, | ||
| use_double_buffer=True, | ||
| iterable=True, | ||
| return_list=True) | ||
| train_loader.set_sample_list_generator(train_reader) | ||
| train_reader = train_loader | ||
|
|
||
| for epoch_id in range(EPOCH_NUM): | ||
| for batch_id, data in enumerate(train_reader()): | ||
| if batch_id >= batch_num: | ||
| break | ||
| if use_data_loader: | ||
| img, label = data | ||
| else: | ||
| dy_x_data = np.array( | ||
| [x[0].reshape(3, 224, 224) | ||
| for x in data]).astype('float32') | ||
| if len(np.array([x[1] for x in data]).astype( | ||
| 'int64')) != batch_size: | ||
| continue | ||
| y_data = np.array( | ||
| [x[1] for x in data]).astype('int64').reshape(-1, 1) | ||
|
|
||
| img = paddle.to_tensor(dy_x_data) | ||
| label = paddle.to_tensor(y_data) | ||
| label.stop_gradient = True | ||
|
|
||
| with paddle.amp.auto_cast(enable=enable_amp): | ||
| out = resnet(img) | ||
|
|
||
| loss = paddle.nn.functional.cross_entropy( | ||
| input=out, label=label) | ||
| avg_loss = paddle.mean(x=loss) | ||
|
|
||
| dy_out = avg_loss.numpy() | ||
|
|
||
| scaled_loss = scaler.scale(avg_loss) | ||
| scaled_loss.backward() | ||
|
|
||
| scaler.minimize(optimizer, scaled_loss) | ||
|
|
||
| dy_grad_value = {} | ||
| for param in resnet.parameters(): | ||
| if param.trainable: | ||
| np_array = np.array(param._grad_ivar().value() | ||
| .get_tensor()) | ||
| dy_grad_value[param.name + fluid.core.grad_var_suffix( | ||
| )] = np_array | ||
|
|
||
| resnet.clear_gradients() | ||
|
|
||
| dy_param_value = {} | ||
| for param in resnet.parameters(): | ||
| dy_param_value[param.name] = param.numpy() | ||
| if use_save_load and epoch_id == 2: | ||
| paddle.save(scaler.state_dict(), 'ResNet_model.pdparams') | ||
| dict_load = paddle.load('ResNet_model.pdparams') | ||
| scaler.load_state_dict(dict_load) | ||
|
Comment on lines
+368
to
+370
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check if the state value are equal
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks,the state values are euqal. |
||
| return dy_out, dy_param_value, dy_grad_value | ||
|
|
||
| def test_with_state_dict(self): | ||
| with fluid.dygraph.guard(): | ||
| out_use_state_dict = self.train_resnet( | ||
| enable_amp=True, use_data_loader=True, use_save_load=True) | ||
| out_no_state_dict = self.train_resnet( | ||
| enable_amp=True, use_data_loader=True, use_save_load=False) | ||
| print('save_load:', out_use_state_dict[0], out_no_state_dict[0]) | ||
| self.assertTrue( | ||
| np.allclose( | ||
| out_use_state_dict[0], out_no_state_dict[0], atol=1.e-2)) | ||
|
||
|
|
||
|
|
||
| class TestResnet2(unittest.TestCase): | ||
| """ | ||
| Use paddle-2.0 API | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确认一下类型变更之后对原来的情况是否完全兼容?比如参数类型检查是否有相关设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks,原本数据类型是float,初次添加注释的时候错误提交了int,这里修改为了正确的数据类型float。