Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"c_reduce_min", {"Out"}},
{"c_reduce_prod", {"Out"}},
{"c_reduce", {"Out"}},
{"c_allgather", {"Out"}},
{"c_scatter", {"Out"}},
{"barrier", {"Out"}},
{"fake_quantize_dequantize_moving_average_abs_max",
Expand Down
11 changes: 5 additions & 6 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,14 +632,13 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
ring_id = 0 if group is None else group.id
nranks = _get_global_group().nranks if group is None else group.nranks

op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)

if in_dygraph_mode():
_C_ops.c_allgather(tensor, out, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks)
out = _C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks)
else:
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
if not isinstance(tensor_list, list):
raise ValueError("The type of 'tensor_list' for all_gather "
"should be list.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import paddle
from paddle.fluid.dygraph.layers import Layer
from ...utils.log_util import logger, layer_to_str
from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting

__all__ = []

Expand Down Expand Up @@ -134,7 +135,10 @@ def __init__(self,
num_stages=None,
topology=None,
loss_fn=None,
seg_method="uniform"):
seg_method="uniform",
recompute_interval=0,
recompute_offload=False,
recompute_partition=False):
super(PipelineLayer, self).__init__()
if num_stages is None and topology is None:
raise ValueError("should provide num_stages or topology")
Expand All @@ -147,6 +151,16 @@ def __init__(self,
self.layers = layers
self._loss_fn = loss_fn
self._topo = topology
self._recompute_interval = recompute_interval
self._recompute_offload = recompute_offload
self._recompute_partition = recompute_partition

if recompute_interval > 0:
logger.info(
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}".
format(recompute_offload, recompute_partition))
_initialize_recompute_setting(recompute_offload, recompute_partition)

world_size = dist.get_world_size()
self.global_rank = dist.get_rank()

Expand Down Expand Up @@ -312,11 +326,44 @@ def _build_layer(self):
else:
self.run_function.append(layer)

def forward_function(self, start, end):
def execute_func(*x):
if len(x) == 1:
x = x[0]
for idx, layer in enumerate(self.run_function[start:end]):
x = layer(x)
return x

return execute_func

def forward(self, input):
for layer in self.run_function:
input = layer(input)
if self._recompute_interval == 0:
input = self.forward_function(0, len(self.run_function))(input)
else:
num_layers = len(self.run_function)
for start_idx in range(0, num_layers, self._recompute_interval):
end_idx = min(start_idx + self._recompute_interval, num_layers)
funcs = self.run_function[start_idx:end_idx]

if not isinstance(input, tuple):
input = (input, )

if self._need_recompute(funcs, input):
input = _hp_recompute(
self.forward_function(start_idx, end_idx), *input)
else:
input = self.forward_function(start_idx, end_idx)(*input)

return input

def _need_recompute(self, funcs, inputs):
if not any(input_.stop_gradient == False for input_ in inputs
if isinstance(input_, paddle.Tensor)):
return False

params = [f.parameters() for f in funcs if isinstance(f, Layer)]
return any(len(list(p)) > 0 for p in params)

def save_state_dict(self, path):
if self._topo.get_coord(self.global_rank).data != 0:
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

MODEL_PARALLEL_RNG = 'model_parallel_rng'

# This file is inspired by Megatron to control random states for MP:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/random.py


class RNGStatesTracker:
"""
Expand All @@ -46,6 +49,15 @@ def add(self, name, seed):
self.states_[name] = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(orig_rng_state)

def get_states_tracker(self):
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states

def set_states_tracker(self, states):
self.states_ = states

@contextlib.contextmanager
def rng_state(self, name=MODEL_PARALLEL_RNG):
if name not in self.states_:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
from .pp_utils.utils import is_float_tensor
from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg
from .parallel_layers.pp_layers import PipelineLayer

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
Expand Down Expand Up @@ -48,6 +48,8 @@ def __init__(self, layers, hcg, strategy):

p2p.initialize_p2p_groups(hcg)

_initialize_recompute_hcg(hcg)

self.is_first_stage = self.stage_id == 0
self.is_last_stage = (self.stage_id == (self.num_stages - 1))
self.global_rank = self._hcg.get_global_rank()
Expand Down Expand Up @@ -213,6 +215,9 @@ def _load_micro_batch(self, cache_id):
if self.is_first_stage:
assert len(inputs) == 2, "length of input should be 2"
if isinstance(inputs[0], tuple):
assert len(
inputs[0]
) > 1, "If you use tuple for input data, it should have at least two inputs."
batch_size = inputs[0][0].shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size, (
"batch_size needs to be divisible by micro_batch_size. Currently, "
Expand Down
Loading