Skip to content

Commit 9d658a9

Browse files
authored
Add check for sharding stage1-v2 using amp master grad (#9333)
1 parent ab62e47 commit 9d658a9

1 file changed

Lines changed: 1 addition & 0 deletions

File tree

paddlenlp/trainer/training_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,6 +1350,7 @@ def is_segment_parallel_supported():
13501350

13511351
if "split_param" in sharding_parallel_config:
13521352
strategy.hybrid_configs["sharding_configs"].split_param = True
1353+
assert self.amp_master_grad, "Currently sharding stage1 v2 only support amp_master_grad"
13531354

13541355
if "enable_release_grads" in sharding_parallel_config:
13551356
strategy.hybrid_configs["sharding_configs"].release_gradients = True

0 commit comments

Comments
 (0)