Skip to content

Commit 0809d83

Browse files
committed
add ut
1 parent cc74f4e commit 0809d83

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

test/auto_parallel/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
212212
NVIDIA_TF32_OVERRIDE=0)
213213
# End of unittests WITH single card WITHOUT timeout
214214

215+
py_test_modules(test_clear_param_storage_api MODULES
216+
test_clear_param_storage_api)
217+
215218
endif()
216219

217220
py_test_modules(test_job_schedule_profiler_range MODULES
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle
18+
19+
20+
class TestClearParamStorage(unittest.TestCase):
21+
def test_clear_param_storage(self):
22+
class TestLayer(paddle.nn.Layer):
23+
def __init__(self, dtype):
24+
super().__init__()
25+
self._w = self.create_parameter([2, 3], dtype=dtype)
26+
self._b = self.create_parameter([2, 3], dtype=dtype)
27+
self._w.color = {"color": "_w"}
28+
self._b.color = {"color": "_b"}
29+
30+
@paddle.amp.debugging.check_layer_numerics
31+
def forward(self, x):
32+
return x * self._w + self._b
33+
34+
dtype = 'float32'
35+
model = TestLayer(dtype)
36+
adam = paddle.optimizer.Adam(parameters=model.parameters())
37+
adam.clear_param_storage("_w")
38+
adam.clear_param_storage("_b")
39+
adam.reset_param_storage()
40+
41+
42+
if __name__ == '__main__':
43+
unittest.main()

0 commit comments

Comments
 (0)