Skip to content

Commit 843d101

Browse files
authored
Fix comments for PR #59644 (#59885)
* update * update
1 parent 90f966a commit 843d101

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import distutils.util
1516
import os
1617

1718
import paddle
@@ -46,6 +47,9 @@ def __init__(self, clip, hcg):
4647
self._clip = clip
4748
self._hcg = hcg
4849
self.not_sharding_stage1 = True
50+
self._force_align_vpp_grad_sum_order = distutils.util.strtobool(
51+
os.getenv('FLAGS_force_align_vpp_grad_sum_order', '0')
52+
)
4953

5054
def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
5155
# sharding first
@@ -99,6 +103,10 @@ def _global_norm(self, global_norm_var_dist, global_norm_var_not_dist):
99103

100104
@no_grad()
101105
def _dygraph_clip(self, params_grads):
106+
if self._force_align_vpp_grad_sum_order:
107+
chunk_num = self._get_vpp_chunk_num(params_grads)
108+
if chunk_num > 0:
109+
return self._vpp_dygraph_clip(params_grads, chunk_num)
102110
sum_square_dist_fp16 = []
103111
sum_square_dist_bf16 = []
104112
sum_square_dist_fp32 = []

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -921,8 +921,6 @@ def __init__(self, layers, hcg, strategy):
921921
self._virtual_pp_rank = 0
922922
self._reset_counter()
923923

924-
self._check_sanity()
925-
926924
def _check_sanity(self):
927925
assert (
928926
framework.in_dynamic_mode()

python/paddle/distributed/fleet/utils/tensor_fusion_helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,15 @@ def scale_grads(self):
461461

462462
self._reset_params_checked_in()
463463

464+
@imperative_base.no_grad
465+
def scale_and_split_grads(self):
466+
assert self._task is not None, "Task is not initialized. "
467+
self._task.wait()
468+
scale_factor = 1.0 / self._comm_group.nranks
469+
self.grad_storage.scale_(scale_factor)
470+
471+
self._reset_params_checked_in()
472+
464473

465474
def obtain_storage(
466475
parameters,

0 commit comments

Comments
 (0)