Skip to content

Commit 42561fd

Browse files
committed
add flag FLAGS_cudnn_deterministic
1 parent d171365 commit 42561fd

File tree

1 file changed

+47
-45
lines changed

1 file changed

+47
-45
lines changed

python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,8 @@ def train_resnet(self,
286286
use_save_load=True):
287287
seed = 90
288288

289-
EPOCH_NUM = 4 # 设置外层循环次数
290-
291289
batch_size = train_parameters["batch_size"]
292-
batch_num = 1
290+
batch_num = 4
293291

294292
paddle.seed(seed)
295293
paddle.framework.random._manual_program_seed(seed)
@@ -322,54 +320,51 @@ def train_resnet(self,
322320
train_loader.set_sample_list_generator(train_reader)
323321
train_reader = train_loader
324322

325-
for epoch_id in range(EPOCH_NUM):
326-
for batch_id, data in enumerate(train_reader()):
327-
if batch_id >= batch_num:
328-
break
329-
if use_data_loader:
330-
img, label = data
331-
else:
332-
dy_x_data = np.array(
333-
[x[0].reshape(3, 224, 224)
334-
for x in data]).astype('float32')
335-
if len(np.array([x[1] for x in data]).astype(
336-
'int64')) != batch_size:
337-
continue
338-
y_data = np.array(
339-
[x[1] for x in data]).astype('int64').reshape(-1, 1)
340-
341-
img = paddle.to_tensor(dy_x_data)
342-
label = paddle.to_tensor(y_data)
343-
label.stop_gradient = True
323+
for batch_id, data in enumerate(train_reader()):
324+
if batch_id >= batch_num:
325+
break
326+
if use_data_loader:
327+
img, label = data
328+
else:
329+
dy_x_data = np.array([x[0].reshape(3, 224, 224)
330+
for x in data]).astype('float32')
331+
if len(np.array([x[1]
332+
for x in data]).astype('int64')) != batch_size:
333+
continue
334+
y_data = np.array(
335+
[x[1] for x in data]).astype('int64').reshape(-1, 1)
344336

345-
with paddle.amp.auto_cast(enable=enable_amp):
346-
out = resnet(img)
337+
img = paddle.to_tensor(dy_x_data)
338+
label = paddle.to_tensor(y_data)
339+
label.stop_gradient = True
347340

348-
loss = paddle.nn.functional.cross_entropy(
349-
input=out, label=label)
350-
avg_loss = paddle.mean(x=loss)
341+
with paddle.amp.auto_cast(enable=enable_amp):
342+
out = resnet(img)
351343

352-
dy_out = avg_loss.numpy()
344+
loss = paddle.nn.functional.cross_entropy(input=out, label=label)
345+
avg_loss = paddle.mean(x=loss)
353346

354-
scaled_loss = scaler.scale(avg_loss)
355-
scaled_loss.backward()
347+
dy_out = avg_loss.numpy()
356348

357-
scaler.minimize(optimizer, scaled_loss)
349+
scaled_loss = scaler.scale(avg_loss)
350+
scaled_loss.backward()
358351

359-
dy_grad_value = {}
360-
for param in resnet.parameters():
361-
if param.trainable:
362-
np_array = np.array(param._grad_ivar().value()
363-
.get_tensor())
364-
dy_grad_value[param.name + fluid.core.grad_var_suffix(
365-
)] = np_array
352+
scaler.minimize(optimizer, scaled_loss)
366353

367-
resnet.clear_gradients()
354+
dy_grad_value = {}
355+
for param in resnet.parameters():
356+
if param.trainable:
357+
np_array = np.array(param._grad_ivar().value().get_tensor())
358+
dy_grad_value[param.name + fluid.core.grad_var_suffix(
359+
)] = np_array
368360

369-
dy_param_value = {}
370-
for param in resnet.parameters():
371-
dy_param_value[param.name] = param.numpy()
372-
if use_save_load and epoch_id == 2:
361+
resnet.clear_gradients()
362+
363+
dy_param_value = {}
364+
for param in resnet.parameters():
365+
dy_param_value[param.name] = param.numpy()
366+
367+
if use_save_load and batch_id == 2:
373368
paddle.save(scaler.state_dict(), 'ResNet_model.pdparams')
374369
dict_load = paddle.load('ResNet_model.pdparams')
375370
scaler.load_state_dict(dict_load)
@@ -378,15 +373,16 @@ def train_resnet(self,
378373
return dy_out, dy_param_value, dy_grad_value
379374

380375
def test_with_state_dict(self):
376+
if fluid.core.is_compiled_with_cuda():
377+
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
381378
with fluid.dygraph.guard():
382379
out_use_state_dict = self.train_resnet(
383380
enable_amp=True, use_data_loader=True, use_save_load=True)
384381
out_no_state_dict = self.train_resnet(
385382
enable_amp=True, use_data_loader=True, use_save_load=False)
386383
print('save_load:', out_use_state_dict[0], out_no_state_dict[0])
387384
self.assertTrue(
388-
np.allclose(
389-
out_use_state_dict[0], out_no_state_dict[0], atol=1.e-2))
385+
np.allclose(out_use_state_dict[0], out_no_state_dict[0]))
390386

391387

392388
class TestResnet2(unittest.TestCase):
@@ -479,13 +475,17 @@ def train_resnet(self, enable_amp=True, use_data_loader=False):
479475
return dy_out, dy_param_value, dy_grad_value
480476

481477
def test_resnet(self):
478+
if fluid.core.is_compiled_with_cuda():
479+
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
482480
with fluid.dygraph.guard():
483481
out_fp32 = self.train_resnet(enable_amp=False)
484482
out_amp = self.train_resnet(enable_amp=True)
485483
print(out_fp32[0], out_amp[0])
486484
self.assertTrue(np.allclose(out_fp32[0], out_amp[0], atol=1.e-2))
487485

488486
def test_with_data_loader(self):
487+
if fluid.core.is_compiled_with_cuda():
488+
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
489489
with fluid.dygraph.guard():
490490
out_fp32 = self.train_resnet(enable_amp=False, use_data_loader=True)
491491
out_amp = self.train_resnet(enable_amp=True, use_data_loader=True)
@@ -566,6 +566,8 @@ def train_resnet(self, enable_amp=True):
566566
return dy_out, dy_param_value, dy_grad_value
567567

568568
def test_resnet(self):
569+
if fluid.core.is_compiled_with_cuda():
570+
fluid.set_flags({"FLAGS_cudnn_deterministic": True})
569571
out_fp32 = self.train_resnet(enable_amp=False)
570572
out_amp = self.train_resnet(enable_amp=True)
571573
print(out_fp32[0], out_amp[0])

0 commit comments

Comments
 (0)