From c615f3d39104b0bbb3436a314d14151509465d11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Wed, 21 Jul 2021 07:24:44 +0000 Subject: [PATCH 01/12] add state_dict and load_state_dict and unittest for class GradScaler --- python/paddle/amp/grad_scaler.py | 47 +++++++++++++++++++ .../paddle/fluid/dygraph/amp/loss_scaler.py | 39 +++++++++++++++ .../test_imperative_auto_mixed_precision.py | 19 ++++++++ 3 files changed, 105 insertions(+) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 827a320b2cc9c4..66f5f98a9aa0f7 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -432,3 +432,50 @@ def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf): """ super(GradScaler, self).set_decr_every_n_nan_or_inf(new_decr_every_n_nan_or_inf) + + def state_dict(self): + """ + Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict. + Reurns: + A dict of scaler includes: + init_loss_scaling (float, optional): The initial loss scaling factor. + incr_ratio(float, optional): The multiplier to use when increasing the loss scaling. + decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing the loss scaling. + incr_every_n_steps(int, optional): Increases loss scaling every n consecutive steps with finite gradients. + decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n accumulated steps with nan or inf gradients. + + Examples: + .. code-block:: python + import paddle + scaler = paddle.amp.GradScaler(enable=True, + init_loss_scaling=1024, + incr_ratio=2.0, + decr_ratio=0.5, + incr_every_n_steps=1000, + decr_every_n_nan_or_inf=2, + use_dynamic_loss_scaling=True) + scaler_state = scaler.state_dict() + scaler.load_state_dict(scaler_state) + """ + return super(GradScaler, self).state_dict() + + def load_state_dict(self, state_dict): + """ + Loads the scaler state. + Args: + state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`. + + Examples: + .. code-block:: python + import paddle + scaler = paddle.amp.GradScaler(enable=True, + init_loss_scaling=1024, + incr_ratio=2.0, + decr_ratio=0.5, + incr_every_n_steps=1000, + decr_every_n_nan_or_inf=2, + use_dynamic_loss_scaling=True) + scaler_state = scaler.state_dict() + scaler.load_state_dict(scaler_state) + """ + super(GradScaler, self).load_state_dict(state_dict) diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index 96ee4514ac2b93..101d86369dca97 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -357,3 +357,42 @@ def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf): new_decr_every_n_nan_or_inf(int): The new_decr_every_n_nan_or_inf used to update the num `n`, `n` represent decreases loss scaling every `n` accumulated steps with nan or inf gradients. """ self._decr_every_n_nan_or_inf = new_decr_every_n_nan_or_inf + + def state_dict(self): + """ + Returns state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict. + Reurns: + A dict of scaler includes: + init_loss_scaling (float, optional): The initial loss scaling factor. + incr_ratio(float, optional): The multiplier to use when increasing the loss scaling. + decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing the loss scaling. + incr_every_n_steps(int, optional): Increases loss scaling every n consecutive steps with finite gradients. + decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n accumulated steps with nan or inf gradients. + """ + return { + "init_loss_scaling": self._init_loss_scaling, + "incr_ratio": self._incr_ratio, + "decr_ratio": self._decr_ratio, + "incr_every_n_steps": self._incr_every_n_steps, + "decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf + } if self._enable else {} + + def load_state_dict(self, state_dict): + """ + Loads the scaler state. + Args: + state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`. + """ + if not self._enable: + return + + if len(state_dict) == 0: + raise RuntimeError( + "The input state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_loss_scaling = state_dict["init_loss_scaling"] + self._incr_ratio = state_dict["incr_ratio"] + self._decr_ratio = state_dict["decr_ratio"] + self._incr_every_n_steps = state_dict["incr_every_n_steps"] + self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"] diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index e3d2bda8921287..5b53b69b2590fd 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -237,6 +237,25 @@ def test_get_and_set(self): scaler.set_init_loss_scaling(100) self.assertEqual(scaler.get_init_loss_scaling() == 100, True) + def test_state_dict_and_load_state_dict(self): + with fluid.dygraph.guard(): + scaler1 = paddle.amp.GradScaler( + enable=True, + init_loss_scaling=14, + incr_ratio=233.0, + decr_ratio=0.523, + incr_every_n_steps=1090, + decr_every_n_nan_or_inf=20, + use_dynamic_loss_scaling=True) + scaler_state = scaler1.state_dict() + scaler2 = paddle.amp.GradScaler(enable=True) + scaler2.load_state_dict(scaler_state) + self.assertEqual(scaler2.get_init_loss_scaling() == 14, True) + self.assertEqual(scaler2.get_incr_ratio() == 233.0, True) + self.assertEqual(scaler2.get_decr_ratio() == 0.523, True) + self.assertEqual(scaler2.get_incr_every_n_steps() == 1090, True) + self.assertEqual(scaler2.get_decr_every_n_nan_or_inf() == 20, True) + def reader_decorator(reader): def __reader__(): From 1dce33afcce5641943aac9a2ef413a2aedd1701a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Mon, 26 Jul 2021 12:36:10 +0000 Subject: [PATCH 02/12] refine unittest for coverage of load_state_dict --- .../test_imperative_auto_mixed_precision.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 5b53b69b2590fd..d57efa3217bdc4 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -256,6 +256,18 @@ def test_state_dict_and_load_state_dict(self): self.assertEqual(scaler2.get_incr_every_n_steps() == 1090, True) self.assertEqual(scaler2.get_decr_every_n_nan_or_inf() == 20, True) + scaler3 = paddle.amp.GradScaler(enable=False) + scaler3.load_state_dict(scaler_state) + self.assertEqual(scaler3.is_enable() == False, True) + + def test_state_dict_and_load_state_dict_error(self): + def test_error(): + state_empty = {} + scaler = paddle.amp.GradScaler(enable=True) + scaler.load_state_dict(state_empty) + + self.assertRaises(RuntimeError, test_error) + def reader_decorator(reader): def __reader__(): From f917dff0f7bece061b3a01f822ce82486c4b58e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Tue, 27 Jul 2021 01:24:17 +0000 Subject: [PATCH 03/12] refine comments of code-block --- python/paddle/amp/grad_scaler.py | 8 ++++++++ python/paddle/fluid/dygraph/amp/loss_scaler.py | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 66f5f98a9aa0f7..80349857b01361 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -436,6 +436,7 @@ def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf): def state_dict(self): """ Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict. + Reurns: A dict of scaler includes: init_loss_scaling (float, optional): The initial loss scaling factor. @@ -445,8 +446,11 @@ def state_dict(self): decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n accumulated steps with nan or inf gradients. Examples: + .. code-block:: python + import paddle + scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, incr_ratio=2.0, @@ -462,12 +466,16 @@ def state_dict(self): def load_state_dict(self, state_dict): """ Loads the scaler state. + Args: state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`. Examples: + .. code-block:: python + import paddle + scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, incr_ratio=2.0, diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index 101d86369dca97..605e66f5f6c35f 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -360,7 +360,8 @@ def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf): def state_dict(self): """ - Returns state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict. + Returns the state of the scaler as a `dict`, If this instance is not enabled, returns an empty dict. + Reurns: A dict of scaler includes: init_loss_scaling (float, optional): The initial loss scaling factor. @@ -380,6 +381,7 @@ def state_dict(self): def load_state_dict(self, state_dict): """ Loads the scaler state. + Args: state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`. """ From eb96c243b01d75be60d516256a87eb7b37ff02e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Tue, 27 Jul 2021 03:25:54 +0000 Subject: [PATCH 04/12] refine some comments --- python/paddle/amp/grad_scaler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 80349857b01361..3b6fdae39e64e0 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -220,7 +220,7 @@ def set_init_loss_scaling(self, new_init_loss_scaling): Set the initial loss scaling factor by `new_init_loss_scaling`. Args: - new_init_loss_scaling(int): The new_init_loss_scaling used to update initial loss scaling factor. + new_init_loss_scaling(float): The new_init_loss_scaling used to update initial loss scaling factor. Examples: .. code-block:: python @@ -459,7 +459,6 @@ def state_dict(self): decr_every_n_nan_or_inf=2, use_dynamic_loss_scaling=True) scaler_state = scaler.state_dict() - scaler.load_state_dict(scaler_state) """ return super(GradScaler, self).state_dict() @@ -468,7 +467,7 @@ def load_state_dict(self, state_dict): Loads the scaler state. Args: - state_dict(dict): scaler state. Should be an object returned from a call to `AmpScaler.state_dict()`. + state_dict(dict): scaler state. Should be an object returned from a call to `GradScaler.state_dict()`. Examples: From e2f855a627e19e3648f4fe7d729476777692238f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Mon, 2 Aug 2021 02:47:51 +0000 Subject: [PATCH 05/12] refine state_dict code and unittest --- .../paddle/fluid/dygraph/amp/loss_scaler.py | 27 +++-- .../test_imperative_auto_mixed_precision.py | 108 ++++++++++++++++++ 2 files changed, 127 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index 605e66f5f6c35f..2065bec8af3bc4 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -364,18 +364,24 @@ def state_dict(self): Reurns: A dict of scaler includes: - init_loss_scaling (float, optional): The initial loss scaling factor. - incr_ratio(float, optional): The multiplier to use when increasing the loss scaling. - decr_ratio(float, optional): The less-than-one-multiplier to use when decreasing the loss scaling. - incr_every_n_steps(int, optional): Increases loss scaling every n consecutive steps with finite gradients. - decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n accumulated steps with nan or inf gradients. + scale (tensor): The loss scaling factor. + incr_ratio(float): The multiplier to use when increasing the loss scaling. + decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. + incr_every_n_steps(int): Increases loss scaling every n consecutive steps with finite gradients. + decr_every_n_nan_or_inf(int): Decreases loss scaling every n accumulated steps with nan or inf gradients. + incr_count(int): The number of recent consecutive unskipped steps. + decr_count(int): The number of recent consecutive skipped steps. + use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. If False, fixed loss_scaling is used. If True, the loss scaling is updated dynamicly. Default is True. """ return { - "init_loss_scaling": self._init_loss_scaling, + "scale": self._scale.numpy(), "incr_ratio": self._incr_ratio, "decr_ratio": self._decr_ratio, "incr_every_n_steps": self._incr_every_n_steps, - "decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf + "decr_every_n_nan_or_inf": self._decr_every_n_nan_or_inf, + "incr_count": self._incr_count, + "decr_count": self._decr_count, + "use_dynamic_loss_scaling": self._use_dynamic_loss_scaling } if self._enable else {} def load_state_dict(self, state_dict): @@ -393,8 +399,13 @@ def load_state_dict(self, state_dict): "The input state dict is empty, possibly because it was saved " "from a disabled instance of GradScaler.") - self._init_loss_scaling = state_dict["init_loss_scaling"] + self._init_loss_scaling = state_dict["scale"][0] + self._scale = to_variable( + np.array([self._init_loss_scaling]).astype(np.float32)) self._incr_ratio = state_dict["incr_ratio"] self._decr_ratio = state_dict["decr_ratio"] self._incr_every_n_steps = state_dict["incr_every_n_steps"] self._decr_every_n_nan_or_inf = state_dict["decr_every_n_nan_or_inf"] + self._incr_count = state_dict["incr_count"] + self._decr_count = state_dict["decr_count"] + self._use_dynamic_loss_scaling = state_dict["use_dynamic_loss_scaling"] diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index d57efa3217bdc4..006c39619e44e4 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -279,6 +279,114 @@ def __reader__(): return __reader__ +class TestGradScalerStateDict(unittest.TestCase): + def train_resnet(self, + enable_amp=True, + use_data_loader=True, + use_save_load=True): + seed = 90 + + EPOCH_NUM = 4 # 设置外层循环次数 + + batch_size = train_parameters["batch_size"] + batch_num = 1 + + paddle.seed(seed) + paddle.framework.random._manual_program_seed(seed) + + resnet = ResNet(use_cudnn=True) + optimizer = optimizer_setting( + train_parameters, parameter_list=resnet.parameters()) + np.random.seed(seed) + train_reader = paddle.batch( + paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size) + + dy_param_init_value = {} + for param in resnet.parameters(): + dy_param_init_value[param.name] = param.numpy() + + program = None + scaler = paddle.amp.GradScaler( + enable=enable_amp, init_loss_scaling=2.**10) + + if use_data_loader: + train_reader = paddle.batch( + reader_decorator(paddle.dataset.flowers.train(use_xmap=False)), + batch_size=batch_size, + drop_last=True) + train_loader = fluid.io.DataLoader.from_generator( + capacity=4, + use_double_buffer=True, + iterable=True, + return_list=True) + train_loader.set_sample_list_generator(train_reader) + train_reader = train_loader + + for epoch_id in range(EPOCH_NUM): + for batch_id, data in enumerate(train_reader()): + if batch_id >= batch_num: + break + if use_data_loader: + img, label = data + else: + dy_x_data = np.array( + [x[0].reshape(3, 224, 224) + for x in data]).astype('float32') + if len(np.array([x[1] for x in data]).astype( + 'int64')) != batch_size: + continue + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = paddle.to_tensor(dy_x_data) + label = paddle.to_tensor(y_data) + label.stop_gradient = True + + with paddle.amp.auto_cast(enable=enable_amp): + out = resnet(img) + + loss = paddle.nn.functional.cross_entropy( + input=out, label=label) + avg_loss = paddle.mean(x=loss) + + dy_out = avg_loss.numpy() + + scaled_loss = scaler.scale(avg_loss) + scaled_loss.backward() + + scaler.minimize(optimizer, scaled_loss) + + dy_grad_value = {} + for param in resnet.parameters(): + if param.trainable: + np_array = np.array(param._grad_ivar().value() + .get_tensor()) + dy_grad_value[param.name + fluid.core.grad_var_suffix( + )] = np_array + + resnet.clear_gradients() + + dy_param_value = {} + for param in resnet.parameters(): + dy_param_value[param.name] = param.numpy() + if use_save_load and epoch_id == 2: + paddle.save(scaler.state_dict(), 'ResNet_model.pdparams') + dict_load = paddle.load('ResNet_model.pdparams') + scaler.load_state_dict(dict_load) + return dy_out, dy_param_value, dy_grad_value + + def test_with_state_dict(self): + with fluid.dygraph.guard(): + out_use_state_dict = self.train_resnet( + enable_amp=True, use_data_loader=True, use_save_load=True) + out_no_state_dict = self.train_resnet( + enable_amp=True, use_data_loader=True, use_save_load=False) + print('save_load:', out_use_state_dict[0], out_no_state_dict[0]) + self.assertTrue( + np.allclose( + out_use_state_dict[0], out_no_state_dict[0], atol=1.e-2)) + + class TestResnet2(unittest.TestCase): """ Use paddle-2.0 API From b74be1bdb86b8966690ac63bc19b3e7355f6119a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Mon, 2 Aug 2021 03:33:10 +0000 Subject: [PATCH 06/12] add #require gpu, xpu for GradScaler get/set example code --- python/paddle/amp/grad_scaler.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 3b6fdae39e64e0..4c6a0893ad2731 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -47,7 +47,8 @@ class GradScaler(AmpScaler): Examples: .. code-block:: python - + + # required: gpu, xpu import paddle model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) @@ -91,7 +92,8 @@ def scale(self, var): Examples: .. code-block:: python - + + # required: gpu, xpu import paddle model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) @@ -128,6 +130,7 @@ def minimize(self, optimizer, *args, **kwargs): .. code-block:: python + # required: gpu, xpu import paddle model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) @@ -156,6 +159,7 @@ def is_enable(self): Examples: .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -178,7 +182,8 @@ def is_use_dynamic_loss_scaling(self): Examples: .. code-block:: python - + + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -202,7 +207,9 @@ def get_init_loss_scaling(self): Examples: .. code-block:: python + # required: gpu, xpu import paddle + paddle.set_device('gpu') scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, incr_ratio=2.0, @@ -224,7 +231,8 @@ def set_init_loss_scaling(self, new_init_loss_scaling): Examples: .. code-block:: python - + + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -250,6 +258,7 @@ def get_incr_ratio(self): Examples: .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -273,6 +282,7 @@ def set_incr_ratio(self, new_incr_ratio): Examples: .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -298,6 +308,7 @@ def get_decr_ratio(self): Examples: .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -321,6 +332,7 @@ def set_decr_ratio(self, new_decr_ratio): Examples: .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -346,6 +358,7 @@ def get_incr_every_n_steps(self): Examples: .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -369,6 +382,7 @@ def set_incr_every_n_steps(self, new_incr_every_n_steps): Examples: .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -394,6 +408,7 @@ def get_decr_every_n_nan_or_inf(self): Examples: .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -417,6 +432,7 @@ def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf): Examples: .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -449,6 +465,7 @@ def state_dict(self): .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, @@ -473,6 +490,7 @@ def load_state_dict(self, state_dict): .. code-block:: python + # required: gpu, xpu import paddle scaler = paddle.amp.GradScaler(enable=True, From 81aa57b1916b7786ac7943a470b34e16fa321b65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Mon, 2 Aug 2021 03:50:08 +0000 Subject: [PATCH 07/12] add #require gpu, xpu for GradScaler get/set example code --- python/paddle/amp/grad_scaler.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index 4c6a0893ad2731..ddba7794e21d9d 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -48,7 +48,6 @@ class GradScaler(AmpScaler): .. code-block:: python - # required: gpu, xpu import paddle model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) @@ -93,7 +92,6 @@ def scale(self, var): .. code-block:: python - # required: gpu, xpu import paddle model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) @@ -130,7 +128,6 @@ def minimize(self, optimizer, *args, **kwargs): .. code-block:: python - # required: gpu, xpu import paddle model = paddle.nn.Conv2D(3, 2, 3, bias_attr=True) @@ -159,7 +156,7 @@ def is_enable(self): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -183,7 +180,7 @@ def is_use_dynamic_loss_scaling(self): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -207,7 +204,7 @@ def get_init_loss_scaling(self): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle paddle.set_device('gpu') scaler = paddle.amp.GradScaler(enable=True, @@ -232,7 +229,7 @@ def set_init_loss_scaling(self, new_init_loss_scaling): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -258,7 +255,7 @@ def get_incr_ratio(self): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -282,7 +279,7 @@ def set_incr_ratio(self, new_incr_ratio): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -308,7 +305,7 @@ def get_decr_ratio(self): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -332,7 +329,7 @@ def set_decr_ratio(self, new_decr_ratio): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -358,7 +355,7 @@ def get_incr_every_n_steps(self): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -382,7 +379,7 @@ def set_incr_every_n_steps(self, new_incr_every_n_steps): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -408,7 +405,7 @@ def get_decr_every_n_nan_or_inf(self): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -432,7 +429,7 @@ def set_decr_every_n_nan_or_inf(self, new_decr_every_n_nan_or_inf): Examples: .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, @@ -465,7 +462,7 @@ def state_dict(self): .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, @@ -490,7 +487,7 @@ def load_state_dict(self, state_dict): .. code-block:: python - # required: gpu, xpu + # required: gpu,xpu import paddle scaler = paddle.amp.GradScaler(enable=True, From 50855472cee40f4bc01d44ae32b09ebeb7937d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Mon, 2 Aug 2021 07:27:22 +0000 Subject: [PATCH 08/12] refine example code --- python/paddle/amp/grad_scaler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/amp/grad_scaler.py b/python/paddle/amp/grad_scaler.py index ddba7794e21d9d..18c436a0bb95f7 100644 --- a/python/paddle/amp/grad_scaler.py +++ b/python/paddle/amp/grad_scaler.py @@ -206,7 +206,6 @@ def get_init_loss_scaling(self): # required: gpu,xpu import paddle - paddle.set_device('gpu') scaler = paddle.amp.GradScaler(enable=True, init_loss_scaling=1024, incr_ratio=2.0, From c039e7ef581dc3372ac65e006da6275280d9300e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Fri, 6 Aug 2021 02:42:06 +0000 Subject: [PATCH 09/12] refine unittest for state_dict --- .../tests/unittests/test_imperative_auto_mixed_precision.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 006c39619e44e4..1b27ecfffe60cd 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -378,9 +378,9 @@ def train_resnet(self, def test_with_state_dict(self): with fluid.dygraph.guard(): out_use_state_dict = self.train_resnet( - enable_amp=True, use_data_loader=True, use_save_load=True) + enable_amp=True, use_data_loader=False, use_save_load=True) out_no_state_dict = self.train_resnet( - enable_amp=True, use_data_loader=True, use_save_load=False) + enable_amp=True, use_data_loader=False, use_save_load=False) print('save_load:', out_use_state_dict[0], out_no_state_dict[0]) self.assertTrue( np.allclose( From 97453ed3a7f9e4dbd3f1f5c367e2d27cfa958131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Fri, 6 Aug 2021 02:58:29 +0000 Subject: [PATCH 10/12] refine unittest for state_dict --- .../tests/unittests/test_imperative_auto_mixed_precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 1b27ecfffe60cd..f9496b22288a19 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -381,7 +381,7 @@ def test_with_state_dict(self): enable_amp=True, use_data_loader=False, use_save_load=True) out_no_state_dict = self.train_resnet( enable_amp=True, use_data_loader=False, use_save_load=False) - print('save_load:', out_use_state_dict[0], out_no_state_dict[0]) + print(out_use_state_dict[0], out_no_state_dict[0]) self.assertTrue( np.allclose( out_use_state_dict[0], out_no_state_dict[0], atol=1.e-2)) From f85b0e4686ecf2be8d90dd1fddf3a8b3d1afb635 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Fri, 6 Aug 2021 10:57:29 +0000 Subject: [PATCH 11/12] fix bug of DataLoader in TestGradScalerStateDict --- .../unittests/test_imperative_auto_mixed_precision.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index f9496b22288a19..1f0b70951bd958 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -373,15 +373,17 @@ def train_resnet(self, paddle.save(scaler.state_dict(), 'ResNet_model.pdparams') dict_load = paddle.load('ResNet_model.pdparams') scaler.load_state_dict(dict_load) + if use_data_loader: + train_reader._reset() return dy_out, dy_param_value, dy_grad_value def test_with_state_dict(self): with fluid.dygraph.guard(): out_use_state_dict = self.train_resnet( - enable_amp=True, use_data_loader=False, use_save_load=True) + enable_amp=True, use_data_loader=True, use_save_load=True) out_no_state_dict = self.train_resnet( - enable_amp=True, use_data_loader=False, use_save_load=False) - print(out_use_state_dict[0], out_no_state_dict[0]) + enable_amp=True, use_data_loader=True, use_save_load=False) + print('save_load:', out_use_state_dict[0], out_no_state_dict[0]) self.assertTrue( np.allclose( out_use_state_dict[0], out_no_state_dict[0], atol=1.e-2)) From 42561fd578526a44309cd84ed8f191f53a1e310a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Czhangbo9674=E2=80=9D?= Date: Tue, 10 Aug 2021 12:41:40 +0000 Subject: [PATCH 12/12] add flag FLAGS_cudnn_deterministic --- .../test_imperative_auto_mixed_precision.py | 92 ++++++++++--------- 1 file changed, 47 insertions(+), 45 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 1f0b70951bd958..17d50ed8c19de0 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -286,10 +286,8 @@ def train_resnet(self, use_save_load=True): seed = 90 - EPOCH_NUM = 4 # 设置外层循环次数 - batch_size = train_parameters["batch_size"] - batch_num = 1 + batch_num = 4 paddle.seed(seed) paddle.framework.random._manual_program_seed(seed) @@ -322,54 +320,51 @@ def train_resnet(self, train_loader.set_sample_list_generator(train_reader) train_reader = train_loader - for epoch_id in range(EPOCH_NUM): - for batch_id, data in enumerate(train_reader()): - if batch_id >= batch_num: - break - if use_data_loader: - img, label = data - else: - dy_x_data = np.array( - [x[0].reshape(3, 224, 224) - for x in data]).astype('float32') - if len(np.array([x[1] for x in data]).astype( - 'int64')) != batch_size: - continue - y_data = np.array( - [x[1] for x in data]).astype('int64').reshape(-1, 1) - - img = paddle.to_tensor(dy_x_data) - label = paddle.to_tensor(y_data) - label.stop_gradient = True + for batch_id, data in enumerate(train_reader()): + if batch_id >= batch_num: + break + if use_data_loader: + img, label = data + else: + dy_x_data = np.array([x[0].reshape(3, 224, 224) + for x in data]).astype('float32') + if len(np.array([x[1] + for x in data]).astype('int64')) != batch_size: + continue + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) - with paddle.amp.auto_cast(enable=enable_amp): - out = resnet(img) + img = paddle.to_tensor(dy_x_data) + label = paddle.to_tensor(y_data) + label.stop_gradient = True - loss = paddle.nn.functional.cross_entropy( - input=out, label=label) - avg_loss = paddle.mean(x=loss) + with paddle.amp.auto_cast(enable=enable_amp): + out = resnet(img) - dy_out = avg_loss.numpy() + loss = paddle.nn.functional.cross_entropy(input=out, label=label) + avg_loss = paddle.mean(x=loss) - scaled_loss = scaler.scale(avg_loss) - scaled_loss.backward() + dy_out = avg_loss.numpy() - scaler.minimize(optimizer, scaled_loss) + scaled_loss = scaler.scale(avg_loss) + scaled_loss.backward() - dy_grad_value = {} - for param in resnet.parameters(): - if param.trainable: - np_array = np.array(param._grad_ivar().value() - .get_tensor()) - dy_grad_value[param.name + fluid.core.grad_var_suffix( - )] = np_array + scaler.minimize(optimizer, scaled_loss) - resnet.clear_gradients() + dy_grad_value = {} + for param in resnet.parameters(): + if param.trainable: + np_array = np.array(param._grad_ivar().value().get_tensor()) + dy_grad_value[param.name + fluid.core.grad_var_suffix( + )] = np_array - dy_param_value = {} - for param in resnet.parameters(): - dy_param_value[param.name] = param.numpy() - if use_save_load and epoch_id == 2: + resnet.clear_gradients() + + dy_param_value = {} + for param in resnet.parameters(): + dy_param_value[param.name] = param.numpy() + + if use_save_load and batch_id == 2: paddle.save(scaler.state_dict(), 'ResNet_model.pdparams') dict_load = paddle.load('ResNet_model.pdparams') scaler.load_state_dict(dict_load) @@ -378,6 +373,8 @@ def train_resnet(self, return dy_out, dy_param_value, dy_grad_value def test_with_state_dict(self): + if fluid.core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) with fluid.dygraph.guard(): out_use_state_dict = self.train_resnet( enable_amp=True, use_data_loader=True, use_save_load=True) @@ -385,8 +382,7 @@ def test_with_state_dict(self): enable_amp=True, use_data_loader=True, use_save_load=False) print('save_load:', out_use_state_dict[0], out_no_state_dict[0]) self.assertTrue( - np.allclose( - out_use_state_dict[0], out_no_state_dict[0], atol=1.e-2)) + np.allclose(out_use_state_dict[0], out_no_state_dict[0])) class TestResnet2(unittest.TestCase): @@ -479,6 +475,8 @@ def train_resnet(self, enable_amp=True, use_data_loader=False): return dy_out, dy_param_value, dy_grad_value def test_resnet(self): + if fluid.core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) with fluid.dygraph.guard(): out_fp32 = self.train_resnet(enable_amp=False) out_amp = self.train_resnet(enable_amp=True) @@ -486,6 +484,8 @@ def test_resnet(self): self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2)) def test_with_data_loader(self): + if fluid.core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) with fluid.dygraph.guard(): out_fp32 = self.train_resnet(enable_amp=False, use_data_loader=True) out_amp = self.train_resnet(enable_amp=True, use_data_loader=True) @@ -566,6 +566,8 @@ def train_resnet(self, enable_amp=True): return dy_out, dy_param_value, dy_grad_value def test_resnet(self): + if fluid.core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) out_fp32 = self.train_resnet(enable_amp=False) out_amp = self.train_resnet(enable_amp=True) print(out_fp32[0], out_amp[0])