From 9ea98416a58b6e2f34d4f5ceda8f81c8e97d2926 Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Wed, 21 Jul 2021 20:18:55 +0800 Subject: [PATCH 01/14] support 1f1b --- .../fleet/meta_parallel/pipeline_parallel.py | 518 +++--------------- .../pp_utils/p2p_communication.py | 177 +++++- 2 files changed, 256 insertions(+), 439 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 9f2a4aaffb4745..7468230331c524 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -85,31 +85,7 @@ def __init__(self, layers, hcg, strategy): if self.use_data_parallel: logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) - - def _init_caches(self, num_caches): - if self.num_caches >= num_caches: - return - self.num_caches = num_caches - self.num_caches - for key in self.caches: - self.caches[key].extend([None] * self.num_caches) - - def _reduce_final_loss(self): - if self.is_last_stage: - assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" - loss = self.total_loss.clone() / self.accumulate_steps - paddle.distributed.broadcast( - loss, - src=self.global_rank, - use_calc_stream=True, - group=self.pp_group) - else: - loss = paddle.to_tensor(0.0) - paddle.distributed.broadcast( - loss, - src=self._hcg.get_rank_from_stage(self.num_stages - 1), - use_calc_stream=True, - group=self.pp_group) - return loss + self.data_id = 0 def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): assert isinstance(optimizer, HybridParallelOptimizer), ( @@ -131,113 +107,100 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): # store total loss of entire batch self.total_loss = None - self._init_caches(self.accumulate_steps) - startup_steps = self.num_stages - self.stage_id - 1 - forward_steps = 0 - backward_steps = 0 - - # forward - while (forward_steps < self.accumulate_steps): - self._forward(cache_id=forward_steps) - forward_steps += 1 - - # backward - while (backward_steps < self.accumulate_steps): - self._backward(cache_id=backward_steps) - backward_steps += 1 - - self._layers.allreduce_shared_weight_gradients() - - # optimizer - self.train_loss = self._reduce_final_loss() - self._step() - return self.train_loss - - def _forward(self, cache_id): - # load data - self._load_micro_batch(cache_id) - if self.stage_id != 0: - self._recv_activations(cache_id) - - if isinstance(self.caches['inputs'][cache_id], tuple): - inputs = tuple(t for t in self.caches['inputs'][cache_id]) - else: - inputs = self.caches['inputs'][cache_id] - self._clear_grads(inputs) - outputs = self._layers.forward(inputs) + self.micro_batch_size = self._strategy.pipeline_configs[ + 'micro_batch_size'] + self.accumulate_steps = self._strategy.pipeline_configs[ + 'accumulate_steps'] - self.caches['outputs'][cache_id] = outputs + self.num_stages = self._hcg.get_pipe_parallel_world_size() - if self.is_last_stage: - if self._layers._loss_fn is not None: - labels = self.caches['labels'][cache_id] - outputs = self._layers._loss_fn(outputs, labels) + # Compute number of warmup microbatches. + num_microbatches = self.accumulate_steps + num_warmup_microbatches = (self.num_stages - self.stage_id - 1) + num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) + num_microbatches_remaining = num_microbatches - num_warmup_microbatches - if self.is_last_stage: - self.current_loss = outputs - if isinstance(self.current_loss, paddle.Tensor): - if self.total_loss is None: - self.total_loss = paddle.zeros_like(self.current_loss) - self.total_loss += self.current_loss.detach() - else: - if self.total_loss is None: - self.total_loss = [ - paddle.zeros_like(v) for v in self.current_loss - ] - for idx, v in enumerate(self.current_loss): - self.total_loss[idx] += v.detach() + input_tensors = [] + output_tensors = [] + losses_reduced = [] - if self.accumulate_steps > 1: - self.current_loss = self.current_loss / self.accumulate_steps + for step_id in range(num_warmup_microbatches): + input_tensor = p2p.recv_forward() + output_tensor = self._forward_step(input_tensor) + p2p.send_forward(output_tensor) - self.caches['outputs'][cache_id] = self.current_loss.clone() + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) - else: - self._send_activations(cache_id) + #print("warmup is endding..") - def _backward(self, cache_id): - if self.is_last_stage: - if self.scaler: - paddle.autograd.backward( - self.scaler.scale(self.caches['outputs'][cache_id])) + if num_microbatches_remaining > 0: + input_tensor = p2p.recv_forward() + + for i in range(num_microbatches_remaining): + last_iteration = (i == (num_microbatches_remaining - 1)) + + if input_tensor is not None: + input_tensor.stop_gradient = False + output_tensor = self._forward_step(input_tensor) + + output_tensor_grad = p2p.send_forward_recv_backward(output_tensor) + + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + + input_tensor, output_tensor = input_tensors.pop( + 0), output_tensors.pop(0) + + #print("input_tensor: ", input_tensor, "output_tensor: ", output_tensor, "output_tensor_grad: ", output_tensor_grad) + input_tensor_grad = \ + self._backward_step(input_tensor, output_tensor, output_tensor_grad) + + #print("input_tensor_grad: ", input_tensor_grad) + if last_iteration: + input_tensor = None + #print("start send backward") + p2p.send_backward(input_tensor_grad) else: - paddle.autograd.backward(self.caches['outputs'][cache_id]) + input_tensor = \ + p2p.send_backward_recv_forward(input_tensor_grad) - self._send_gradients(cache_id) - return - self._recv_gradients(cache_id) + for i in range(num_warmup_microbatches): + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) - outputs = self.caches['outputs'][cache_id] + output_tensor_grad = p2p.recv_backward() - grad_tensors = self.grad_tensors - if isinstance(outputs, tuple): - out_tensors = [t for t in outputs if is_float_tensor(t)] - assert len(out_tensors) == len(grad_tensors) - paddle.autograd.backward( - tensors=out_tensors, grad_tensors=grad_tensors) - else: - paddle.autograd.backward( - tensors=[outputs], grad_tensors=[grad_tensors]) - - grad_tensors = None - if self.stage_id != 0: self._send_gradients(cache_id) - self.caches['outputs'][cache_id] = None - - def _broadcast_data(self, data): - if isinstance(data, paddle.Tensor): - paddle.distributed.broadcast( - data, - src=self._hcg.get_model_parallel_group_src_rank(), - group=self._hcg.get_model_parallel_group()) - else: - for d in data: - assert isinstance(d, paddle.Tensor) - paddle.distributed.broadcast( - d, - src=self._hcg.get_model_parallel_group_src_rank(), - group=self._hcg.get_model_parallel_group()) - return data + input_tensor_grad = \ + self._backward_step(input_tensor, output_tensor, output_tensor_grad) + p2p.send_backward(input_tensor_grad) + + self.data_id = 0 + return 10.0 + + def _forward_step(self, input_tensor): + if self.stage_id == 0: + input_tensor = self._load_micro_batch(self.data_id) + + output_tensor = self._layers.forward(input_tensor) + + if self.is_last_stage: + labels = self._load_micro_batch(self.data_id) + output_tensor = self._layers._loss_fn(output_tensor, labels) + #print("micro batch id: ", self.data_id, "output_tensor: ", output_tensor) + self.data_id += 1 + return output_tensor + + def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): + #print("output_tensor: ", output_tensor, "output_tensor_grad", output_tensor_grad) + paddle.autograd.backward( + tensors=[output_tensor], grad_tensors=[output_tensor_grad]) + input_tensor_grad = None + if input_tensor is not None: + input_tensor_grad = input_tensor.grad + + return input_tensor_grad def _load_micro_batch(self, cache_id): inputs = self.data @@ -246,8 +209,6 @@ def _load_micro_batch(self, cache_id): if self.is_first_stage: assert len(inputs) == 2, "length of input should be 2" - if self.use_model_parallel: - inputs[0] = self._broadcast_data(inputs[0]) if isinstance(inputs[0], tuple): batch_size = inputs[0][0].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size, ( @@ -258,329 +219,24 @@ def _load_micro_batch(self, cache_id): data = [ input[begin:end, :].clone().detach() for input in inputs[0] ] - self.caches['inputs'][cache_id] = tuple(data) + return tuple(data) else: batch_size = inputs[0].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size - self.caches['inputs'][cache_id] = inputs[0][begin:end, :].clone( - ).detach() + return inputs[0][begin:end, :].clone().detach() elif self.is_last_stage: assert len(inputs) == 2, "length of input should be 2" - if self.use_model_parallel: - inputs[1] = self._broadcast_data(inputs[1]) if isinstance(inputs[1], tuple): batch_size = inputs[1][0].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size data = [ input[begin:end, :].clone().detach() for input in inputs[1] ] - self.caches['labels'][cache_id] = tuple(data) + return tuple(data) else: batch_size = inputs[1].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size - self.caches['labels'][cache_id] = inputs[1][begin:end, :].clone( - ).detach() + return inputs[1][begin:end, :].clone().detach() else: # No data input is required for other stages inputs = None - - def _send_meta(self, data, peer): - if isinstance(data, paddle.Tensor): - tensor_type = paddle.to_tensor([0]) - # send tensor type - p2p.send(tensor_type, self.next_stage_id) - - # send len(shape) - dims = paddle.to_tensor(len(data.shape)) - p2p.send(dims, self.next_stage_id) - - # send shape - shape = paddle.to_tensor(data.shape) - p2p.send(shape, self.next_stage_id) - - # send dtype - dtype = paddle.to_tensor(paddle_2_number(data.dtype)) - p2p.send(dtype, self.next_stage_id) - - elif isinstance(data, tuple): - tensor_type = paddle.to_tensor([1]) - p2p.send(tensor_type, self.next_stage_id) - - nums = paddle.to_tensor(len(data)) - p2p.send(nums, self.next_stage_id) - - for idx, d in enumerate(data): - assert isinstance(d, paddle.Tensor) - # send len(shape) - dims = paddle.to_tensor(len(d.shape)) - p2p.send(dims, self.next_stage_id) - - # send shape - shape = paddle.to_tensor(d.shape) - p2p.send(shape, self.next_stage_id) - - # send dtype - dtype = paddle.to_tensor(paddle_2_number(d.dtype)) - p2p.send(dtype, self.next_stage_id) - - def _recv_meta(self, peer): - tensor_type = paddle.to_tensor([0]) - p2p.recv(tensor_type, self.prev_stage_id) - - tensor_type = tensor_type.item() - - if tensor_type == 0: - # recv len(shape) - dims = paddle.to_tensor([0]) - p2p.recv(dims, self.prev_stage_id) - - dims = dims.item() - - # recv shape - shape = paddle.to_tensor([0] * dims) - p2p.recv(shape, self.prev_stage_id) - - shape = shape.numpy().tolist() - - # recv dtype - dtype = paddle.to_tensor([0]) - p2p.recv(dtype, self.prev_stage_id) - - return self._allocate_cache( - shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0] - elif tensor_type == 1: - num = paddle.to_tensor([0]) - p2p.recv(num, self.prev_stage_id) - num = num.item() - shapes = [] - dtypes = [] - for i in range(num): - # recv len(shape) - dims = paddle.to_tensor([0]) - p2p.recv(dims, self.prev_stage_id) - - # recv shape - dims = dims.item() - shape = paddle.to_tensor([0] * dims) - p2p.recv(shape, self.prev_stage_id) - shapes.append(shape.numpy().tolist()) - - # recv dtype - dtype = paddle.to_tensor([0]) - p2p.recv(dtype, self.prev_stage_id) - dtypes.append(number_2_dtype(dtype.item())) - - caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0] - caches = tuple(caches) - return caches - - def _is_valid_send_recv(self, tensor): - tensor_numel = np.prod(tensor.shape) - assert tensor_numel != 0, "can't send/recv zero element" - return tensor_numel % self.mp_degree == 0 - - def _send_activations(self, cache_id): - outputs = self.caches['outputs'][cache_id] - - if self.send_meta: - self.send_meta = False - self._send_meta(outputs, self.next_stage_id) - - if isinstance(outputs, paddle.Tensor): - if self.is_pipe_partitioned and self._is_valid_send_recv(outputs): - p2p.send_partial( - outputs.detach(), - self.next_stage_id, - mp_degree=self.mp_degree, - mp_rank=self.mp_rank) - else: - p2p.send(outputs.detach(), self.next_stage_id) - - elif isinstance(outputs, tuple): - for output in outputs: - if self.is_pipe_partitioned and self._is_valid_send_recv( - output): - p2p.send_partial( - output.detach(), - self.next_stage_id, - mp_degree=self.mp_degree, - mp_rank=self.mp_rank) - else: - p2p.send(output.detach(), self.next_stage_id) - - def _send_gradients(self, cache_id): - inputs = self.caches['inputs'][cache_id] - if isinstance(inputs, paddle.Tensor): - assert inputs.grad is not None - if self.is_pipe_partitioned and self._is_valid_send_recv( - inputs.grad): - grad = p2p.send_partial( - inputs.grad, - self.prev_stage_id, - mp_degree=self.mp_degree, - mp_rank=self.mp_rank) - else: - p2p.send(inputs.grad, self.prev_stage_id) - else: - for idx, d in enumerate(inputs): - # Skip tensors that will not produce a grad - if not is_float_tensor(d): - assert d.grad is None - continue - - if self.is_pipe_partitioned and self._is_valid_send_recv( - d.grad): - grad = p2p.send_partial( - d.grad, - self.prev_stage_id, - mp_degree=self.mp_degree, - mp_rank=self.mp_rank) - else: - p2p.send(d.grad, self.prev_stage_id) - - self.caches['inputs'][cache_id] = None - - def _recv_activations(self, cache_id): - inputs = None - if self.recv_cache is None: - self.recv_cache = self._recv_meta(self.prev_stage_id) - - if isinstance(self.recv_cache, paddle.Tensor): - if self.is_pipe_partitioned and self._is_valid_send_recv( - self.recv_cache): - p2p.recv_partial(self.recv_cache, self.prev_stage_id, - self.mp_degree, self.mp_rank) - p2p.partial_allgather_operator( - self.recv_cache, - mp_ranks=self.mp_degree, - mp_rank_id=self.mp_rank, - group=self._hcg.get_model_parallel_group(), - use_calc_stream=True) - else: - p2p.recv(self.recv_cache, self.prev_stage_id) - - inputs = self.recv_cache.clone().detach() - inputs.stop_gradient = not is_float_tensor(inputs) - - else: - assert isinstance(self.recv_cache, tuple) - inputs = [None] * len(self.recv_cache) - for idx, d in enumerate(self.recv_cache): - if self.is_pipe_partitioned and self._is_valid_send_recv(d): - assert isinstance(d, paddle.Tensor) - p2p.recv_partial(d, self.prev_stage_id, self.mp_degree, - self.mp_rank) - p2p.partial_allgather_operator( - d, - mp_ranks=self.mp_degree, - mp_rank_id=self.mp_rank, - group=self._hcg.get_model_parallel_group(), - use_calc_stream=True) - else: - assert isinstance(d, paddle.Tensor) - p2p.recv(d, self.prev_stage_id) - inputs[idx] = d.clone().detach() - - inputs = tuple(inputs) - - for d in inputs: - d.stop_gradient = not is_float_tensor(d) - - self.caches['inputs'][cache_id] = inputs - - def _recv_gradients(self, cache_id): - outputs = self.caches['outputs'][cache_id] - if self.grad_tensors is None: - if isinstance(outputs, paddle.Tensor): - s = list(outputs.shape) - dtype = get_tensor_dtype(outputs.dtype) - self.grad_tensors = self._allocate_cache( - s, dtype, num_caches=1)[0] - else: - sizes = [list(d.shape) for d in outputs if is_float_tensor(d)] - dtypes = [ - get_tensor_dtype(d.dtype) for d in outputs - if is_float_tensor(d) - ] - self.grad_tensors = self._allocate_caches( - sizes, dtypes, num_caches=1)[0] - - if isinstance(self.grad_tensors, paddle.Tensor): - if self.is_pipe_partitioned and self._is_valid_send_recv( - self.grad_tensors): - p2p.recv_partial(self.grad_tensors, self.next_stage_id, - self.mp_degree, self.mp_rank) - p2p.partial_allgather_operator( - self.grad_tensors, - mp_ranks=self.mp_degree, - mp_rank_id=self.mp_rank, - group=self._hcg.get_model_parallel_group(), - use_calc_stream=True) - else: - p2p.recv(self.grad_tensors, self.next_stage_id) - - else: - assert isinstance(outputs, tuple) - for d in self.grad_tensors: - if self.is_pipe_partitioned and self._is_valid_send_recv(d): - p2p.recv_partial(d, self.next_stage_id, self.mp_degree, - self.mp_rank) - p2p.partial_allgather_operator( - d, - mp_ranks=self.mp_degree, - mp_rank_id=self.mp_rank, - group=self._hcg.get_model_parallel_group(), - use_calc_stream=True) - else: - p2p.recv(d, self.next_stage_id) - - def _step(self): - if self.scaler: - self.scaler.minimize(self.optimizer, self.train_loss) - else: - self.optimizer.step() - self.optimizer.clear_grad() - if self.lr_scheduler: - self.lr_scheduler.step() - - def _clear_grads(self, inputs): - if isinstance(inputs, paddle.Tensor): - if inputs.grad is not None: - inputs.clear_gradient() - else: - for d in inputs: - if d.grad is not None: - d.clear_gradient() - - def _allocate_zeros(self, shape, dtype): - return paddle.zeros(shape, dtype) - - def _allocate_cache(self, shape, dtype, num_caches=-1): - caches = [] - if num_caches == -1: - num_caches = self.num_caches - for count in range(num_caches): - caches.append(self._allocate_zeros(shape, dtype)) - return caches - - def _allocate_caches(self, shapes, dtypes, num_caches=-1): - caches = [] - if num_caches == -1: - num_caches = self.num_caches - for count in range(num_caches): - cache = [] - for shape, dtype in zip(shapes, dtypes): - cache.append(self._allocate_zeros(shape, dtype)) - caches.append(cache) - return caches - - def save_state_dict(self, model_path): - state_dict = self._layers.state_dict() - paddle.save(state_dict, model_path) - - def load_state_dict(self, model_path): - state_dict = paddle.load(self.model_path) - self._layers.set_state_dict(state_dict) - - def forward(self, *inputs, **kwargs): - raise RuntimeError("Call train_batch for pipeline instead of forward.") diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 44090be94f1a7d..612db17ad3be12 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -16,6 +16,7 @@ _groups = None _hcg = None +_tensor_shape = (2, 1024, 768) def initialize_p2p_groups(hcg): @@ -132,12 +133,172 @@ def _get_send_recv_group(src_stage, dest_stage): stage_id = None first_stage = 0 last_stage = _hcg.get_pipe_parallel_world_size() - 1 - if (src_stage == first_stage and dest_stage == last_stage) or \ - (dest_stage == first_stage and src_stage == last_stage): - stage_id = last_stage - elif src_stage > dest_stage: - stage_id = dest_stage - else: - stage_id = src_stage - group_id = _hcg.get_rank_from_stage(stage_id=stage_id) + #if (src_stage == first_stage and dest_stage == last_stage) or \ + # (dest_stage == first_stage and src_stage == last_stage): + # stage_id = last_stage + #if src_stage > dest_stage: + # stage_id = dest_stage + #else: + # stage_id = src_stage + #group_id = _hcg.get_rank_from_stage(stage_id=stage_id) + group_id = _hcg.get_rank_from_stage(stage_id=src_stage) return _groups[group_id] + + +def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): + + tensor_recv_prev = None + tensor_recv_next = None + + global _tensor_shape, _groups, _hcg + tensor_chunk_shape = _tensor_shape + dtype = "float32" + + current_stage = _hcg.get_stage_id() + prev_stage = current_stage - 1 + next_stage = current_stage + 1 + + if recv_prev: + tensor_recv_prev = paddle.empty(shape=tensor_chunk_shape, dtype=dtype) + if recv_next: + tensor_recv_next = paddle.empty(shape=tensor_chunk_shape, dtype=dtype) + + if tensor_send_prev is not None: + group = _get_send_recv_group( + src_stage=current_stage, dest_stage=prev_stage) + #print("group msg:", group) + paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) + paddle.distributed.send( + tensor_send_prev, dst=0, group=group, use_calc_stream=False) + if tensor_recv_prev is not None: + group = _get_send_recv_group( + src_stage=prev_stage, dest_stage=current_stage) + #print("group msg:", group) + paddle.distributed.recv( + tensor_recv_prev, src=0, group=group, use_calc_stream=True) + #print("tensor_recv_prev", tensor_recv_prev.numpy()) + + if tensor_send_next is not None: + group = _get_send_recv_group( + src_stage=current_stage, dest_stage=next_stage) + #print("group msg:", group) + paddle.distributed.wait(tensor_send_next, use_calc_stream=True) + paddle.distributed.send( + tensor_send_next, dst=1, group=group, use_calc_stream=False) + if tensor_recv_next is not None: + group = _get_send_recv_group( + src_stage=next_stage, dest_stage=current_stage) + #print("group msg:", group) + paddle.distributed.recv( + tensor_recv_next, src=1, group=group, use_calc_stream=True) + #print("tensor_recv_next", tensor_recv_next.numpy()) + + return tensor_recv_prev, tensor_recv_next + + +def recv_forward(): + if _hcg.is_first_stage: + input_tensor = None + else: + input_tensor, _ = _communicate( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=True, + recv_next=False) + return input_tensor + + +def recv_backward(): + if _hcg.is_last_stage: + output_tensor_grad = None + else: + _, output_tensor_grad = _communicate( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=False, + recv_next=True) + return output_tensor_grad + + +def send_forward(output_tensor): + if not _hcg.is_last_stage: + _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=False) + + +def send_backward(input_tensor_grad): + if not _hcg.is_first_stage: + _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=False) + + +def send_forward_recv_backward(output_tensor): + if _hcg.is_last_stage: + output_tensor_grad = None + else: + _, output_tensor_grad = _communicate( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=True) + return output_tensor_grad + + +def send_backward_recv_forward(input_tensor_grad): + if _hcg.is_first_stage: + input_tensor = None + else: + input_tensor, _ = _communicate( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=True, + recv_next=False) + return input_tensor + + +# def send_forward_recv_forward(output_tensor, recv_prev, timers=None): +# """Batched recv from previous rank and send to next rank in pipeline.""" +# if timers is not None: +# timers('forward-send-forward-recv').start() +# input_tensor, _ = _communicate( +# tensor_send_next=output_tensor, +# tensor_send_prev=None, +# recv_prev=recv_prev, +# recv_next=False) +# if timers is not None: +# timers('forward-send-forward-recv').stop() +# return input_tensor + +# def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): +# """Batched recv from next rank and send to previous rank in pipeline.""" +# if timers is not None: +# timers('backward-send-backward-recv').start() +# _, output_tensor_grad = _communicate( +# tensor_send_next=None, +# tensor_send_prev=input_tensor_grad, +# recv_prev=False, +# recv_next=recv_next) +# if timers is not None: +# timers('backward-send-backward-recv').stop() +# return output_tensor_grad + +# def send_forward_backward_recv_forward_backward( +# output_tensor, input_tensor_grad, recv_prev, +# recv_next, timers=None): +# """Batched send and recv with previous and next ranks in pipeline.""" +# if timers is not None: +# timers('forward-backward-send-forward-backward-recv').start() +# input_tensor, output_tensor_grad = _communicate( +# tensor_send_next=output_tensor, +# tensor_send_prev=input_tensor_grad, +# recv_prev=recv_prev, +# recv_next=recv_next) +# if timers is not None: +# timers('forward-backward-send-forward-backward-recv').stop() +# return input_tensor, output_tensor_grad From 9449f2801bbaec93f1c8c081f239509c272c81e9 Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Mon, 26 Jul 2021 22:03:49 +0800 Subject: [PATCH 02/14] support 1f1b for pipeline --- .../paddle/distributed/fleet/base/topology.py | 46 +++++++++++++++++-- .../fleet/meta_parallel/pipeline_parallel.py | 19 +++++++- .../pp_utils/p2p_communication.py | 44 +++++++++++------- 3 files changed, 88 insertions(+), 21 deletions(-) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 004b3fb0f666bc..f15b636969e254 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -156,6 +156,16 @@ def __init__(self, topology): self.is_first_stage = (self.stage_id == 0) self.is_last_stage = (self.stage_id == (self._pp_degree - 1)) + # create p2p_groups + self._p2p_groups = self._build_p2p_lists() + if self._pp_degree > 1: + self._set_p2p_group() + print("send_next_group: ", self.send_next_group) + print("send_prev_group: ", self.send_prev_group) + print("recv_next_group: ", self.recv_next_group) + print("recv_prev_group: ", self.recv_prev_group) + + debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \ "sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree, self._sharding_degree, self._pp_degree, self._dp_degree) @@ -164,9 +174,6 @@ def __init__(self, topology): self._dp_group, self._check_group) logger.info(debug_str) - # create p2p_groups and no new group - self._p2p_groups = self._build_p2p_lists() - global _HYBRID_PARALLEL_GROUP _HYBRID_PARALLEL_GROUP = self @@ -236,6 +243,39 @@ def _set_check_group(self, parallel_method="data"): return parallel_group, parallel_comm_group + def _set_p2p_group(self): + comm_lists = self._topo.get_comm_list('pipe') + + self.send_next_group = None + self.send_prev_group = None + self.recv_next_group = None + self.recv_prev_group = None + for comm_ranks in comm_lists: + assert len(comm_ranks) == self._pp_degree + for idx, rank in enumerate(comm_ranks): + curr_rank = rank + next_rank = comm_ranks[(idx + 1) % self._pp_degree] + prev_rank = comm_ranks[(idx - 1) % self._pp_degree] + next_group = paddle.distributed.new_group( + ranks=[curr_rank, next_rank]) + prev_group = paddle.distributed.new_group( + ranks=[prev_rank, curr_rank]) + + if self.global_rank == curr_rank: + self.send_next_group = next_group + self.send_prev_group = prev_group + elif self.global_rank == next_rank: + self.recv_prev_group = next_group + elif self.global_rank == prev_rank: + self.recv_next_group = prev_group + else: + pass + + assert self.send_next_group is not None + assert self.send_prev_group is not None + assert self.recv_next_group is not None + assert self.recv_prev_group is not None + def topology(self): return self._topo diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 7468230331c524..22a2533c87f6b9 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -121,13 +121,20 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = num_microbatches - num_warmup_microbatches + print("num_warmup_microbatches: ", num_warmup_microbatches, + "num_microbatches_remaining: ", num_microbatches_remaining) + input_tensors = [] output_tensors = [] losses_reduced = [] for step_id in range(num_warmup_microbatches): + logger("==recv F==") input_tensor = p2p.recv_forward() + if input_tensor is not None: + input_tensor.stop_gradient = False output_tensor = self._forward_step(input_tensor) + print("==send F==") p2p.send_forward(output_tensor) input_tensors.append(input_tensor) @@ -136,6 +143,7 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): #print("warmup is endding..") if num_microbatches_remaining > 0: + print("==recv F==") input_tensor = p2p.recv_forward() for i in range(num_microbatches_remaining): @@ -145,6 +153,7 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): input_tensor.stop_gradient = False output_tensor = self._forward_step(input_tensor) + print("==send F; recv B==") output_tensor_grad = p2p.send_forward_recv_backward(output_tensor) input_tensors.append(input_tensor) @@ -153,16 +162,20 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): input_tensor, output_tensor = input_tensors.pop( 0), output_tensors.pop(0) - #print("input_tensor: ", input_tensor, "output_tensor: ", output_tensor, "output_tensor_grad: ", output_tensor_grad) + print("input_tensor: ", input_tensor, "output_tensor: ", + output_tensor, "output_tensor_grad: ", output_tensor_grad) input_tensor_grad = \ self._backward_step(input_tensor, output_tensor, output_tensor_grad) - #print("input_tensor_grad: ", input_tensor_grad) + print("input_tensor_grad: ", input_tensor_grad) if last_iteration: input_tensor = None #print("start send backward") + print("==send B==") p2p.send_backward(input_tensor_grad) else: + print("==send B; recv F==") + print("input_tensor_grad: ", input_tensor_grad) input_tensor = \ p2p.send_backward_recv_forward(input_tensor_grad) @@ -170,10 +183,12 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) + print("==recv B==") output_tensor_grad = p2p.recv_backward() input_tensor_grad = \ self._backward_step(input_tensor, output_tensor, output_tensor_grad) + print("==send B==") p2p.send_backward(input_tensor_grad) self.data_id = 0 diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 612db17ad3be12..03c5b8c232dc3c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -21,10 +21,10 @@ def initialize_p2p_groups(hcg): global _groups, _hcg - _groups = [ - paddle.distributed.new_group(ranks=group) - for group in hcg.get_p2p_groups() - ] + # _groups = [ + # paddle.distributed.new_group(ranks=group) + # for group in hcg.get_p2p_groups() + # ] _hcg = hcg @@ -164,33 +164,45 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): tensor_recv_next = paddle.empty(shape=tensor_chunk_shape, dtype=dtype) if tensor_send_prev is not None: - group = _get_send_recv_group( - src_stage=current_stage, dest_stage=prev_stage) + # group = _get_send_recv_group( + # src_stage=current_stage, dest_stage=prev_stage) #print("group msg:", group) paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) paddle.distributed.send( - tensor_send_prev, dst=0, group=group, use_calc_stream=False) + tensor_send_prev, + dst=0, + group=_hcg.send_prev_group, + use_calc_stream=False) if tensor_recv_prev is not None: - group = _get_send_recv_group( - src_stage=prev_stage, dest_stage=current_stage) + # group = _get_send_recv_group( + # src_stage=prev_stage, dest_stage=current_stage) #print("group msg:", group) paddle.distributed.recv( - tensor_recv_prev, src=0, group=group, use_calc_stream=True) + tensor_recv_prev, + src=0, + group=_hcg.recv_prev_group, + use_calc_stream=True) #print("tensor_recv_prev", tensor_recv_prev.numpy()) if tensor_send_next is not None: - group = _get_send_recv_group( - src_stage=current_stage, dest_stage=next_stage) + # group = _get_send_recv_group( + # src_stage=current_stage, dest_stage=next_stage) #print("group msg:", group) paddle.distributed.wait(tensor_send_next, use_calc_stream=True) paddle.distributed.send( - tensor_send_next, dst=1, group=group, use_calc_stream=False) + tensor_send_next, + dst=1, + group=_hcg.send_next_group, + use_calc_stream=False) if tensor_recv_next is not None: - group = _get_send_recv_group( - src_stage=next_stage, dest_stage=current_stage) + # group = _get_send_recv_group( + # src_stage=next_stage, dest_stage=current_stage) #print("group msg:", group) paddle.distributed.recv( - tensor_recv_next, src=1, group=group, use_calc_stream=True) + tensor_recv_next, + src=1, + group=_hcg.recv_next_group, + use_calc_stream=True) #print("tensor_recv_next", tensor_recv_next.numpy()) return tensor_recv_prev, tensor_recv_next From e59aea2f7849ca754fb0cc8fbc8b4eb0eb025bd0 Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Mon, 26 Jul 2021 22:38:50 +0800 Subject: [PATCH 03/14] add train_loss for pipeline --- .../fleet/meta_parallel/pipeline_parallel.py | 71 +++++++++++++------ 1 file changed, 49 insertions(+), 22 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 22a2533c87f6b9..335c9e5e61efb3 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -108,6 +108,9 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): # store total loss of entire batch self.total_loss = None + # store data id for micro_batch + self.data_id = 0 + self.micro_batch_size = self._strategy.pipeline_configs[ 'micro_batch_size'] self.accumulate_steps = self._strategy.pipeline_configs[ @@ -121,29 +124,21 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = num_microbatches - num_warmup_microbatches - print("num_warmup_microbatches: ", num_warmup_microbatches, - "num_microbatches_remaining: ", num_microbatches_remaining) - input_tensors = [] output_tensors = [] losses_reduced = [] for step_id in range(num_warmup_microbatches): - logger("==recv F==") input_tensor = p2p.recv_forward() if input_tensor is not None: input_tensor.stop_gradient = False output_tensor = self._forward_step(input_tensor) - print("==send F==") p2p.send_forward(output_tensor) input_tensors.append(input_tensor) output_tensors.append(output_tensor) - #print("warmup is endding..") - if num_microbatches_remaining > 0: - print("==recv F==") input_tensor = p2p.recv_forward() for i in range(num_microbatches_remaining): @@ -153,7 +148,6 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): input_tensor.stop_gradient = False output_tensor = self._forward_step(input_tensor) - print("==send F; recv B==") output_tensor_grad = p2p.send_forward_recv_backward(output_tensor) input_tensors.append(input_tensor) @@ -162,20 +156,13 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): input_tensor, output_tensor = input_tensors.pop( 0), output_tensors.pop(0) - print("input_tensor: ", input_tensor, "output_tensor: ", - output_tensor, "output_tensor_grad: ", output_tensor_grad) input_tensor_grad = \ self._backward_step(input_tensor, output_tensor, output_tensor_grad) - print("input_tensor_grad: ", input_tensor_grad) if last_iteration: input_tensor = None - #print("start send backward") - print("==send B==") p2p.send_backward(input_tensor_grad) else: - print("==send B; recv F==") - print("input_tensor_grad: ", input_tensor_grad) input_tensor = \ p2p.send_backward_recv_forward(input_tensor_grad) @@ -183,16 +170,19 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - print("==recv B==") output_tensor_grad = p2p.recv_backward() input_tensor_grad = \ self._backward_step(input_tensor, output_tensor, output_tensor_grad) - print("==send B==") p2p.send_backward(input_tensor_grad) - self.data_id = 0 - return 10.0 + self._layers.allreduce_shared_weight_gradients() + + # optimizer + self.train_loss = self._reduce_final_loss() + + self._step() + return self.train_loss def _forward_step(self, input_tensor): if self.stage_id == 0: @@ -203,12 +193,21 @@ def _forward_step(self, input_tensor): if self.is_last_stage: labels = self._load_micro_batch(self.data_id) output_tensor = self._layers._loss_fn(output_tensor, labels) - #print("micro batch id: ", self.data_id, "output_tensor: ", output_tensor) + assert isinstance( + output_tensor, paddle. + Tensor), "Currently, loss_fn should obtain Paddle.Tensor dtype" + + if self.accumulate_steps > 1: + output_tensor = output_tensor / self.accumulate_steps + + if self.total_loss is None: + self.total_loss = paddle.zeros_like(output_tensor) + self.total_loss += output_tensor.detach() + self.data_id += 1 return output_tensor def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): - #print("output_tensor: ", output_tensor, "output_tensor_grad", output_tensor_grad) paddle.autograd.backward( tensors=[output_tensor], grad_tensors=[output_tensor_grad]) input_tensor_grad = None @@ -255,3 +254,31 @@ def _load_micro_batch(self, cache_id): else: # No data input is required for other stages inputs = None + + def _reduce_final_loss(self): + if self.is_last_stage: + assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" + loss = self.total_loss.clone() + paddle.distributed.broadcast( + loss, + src=self.global_rank, + use_calc_stream=True, + group=self.pp_group) + else: + loss = paddle.to_tensor(0.0) + paddle.distributed.broadcast( + loss, + src=self._hcg.get_rank_from_stage(self.num_stages - 1), + use_calc_stream=True, + group=self.pp_group) + return loss + + def _step(self): + if self.scaler: + self.scaler.minimize(self.optimizer, self.train_loss) + else: + self.optimizer.step() + + self.optimizer.clear_grad() + if self.lr_scheduler: + self.lr_scheduler.step() From 092854217f4545a307b2bbff20c99b993332ebda Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Tue, 27 Jul 2021 11:55:49 +0800 Subject: [PATCH 04/14] rm part of code --- .../paddle/distributed/fleet/base/topology.py | 12 ++-- .../fleet/meta_parallel/pipeline_parallel.py | 56 ++++++++++--------- 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index f15b636969e254..6cb5eb971fa31b 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -250,26 +250,28 @@ def _set_p2p_group(self): self.send_prev_group = None self.recv_next_group = None self.recv_prev_group = None + for comm_ranks in comm_lists: assert len(comm_ranks) == self._pp_degree for idx, rank in enumerate(comm_ranks): curr_rank = rank next_rank = comm_ranks[(idx + 1) % self._pp_degree] prev_rank = comm_ranks[(idx - 1) % self._pp_degree] + next_group = paddle.distributed.new_group( ranks=[curr_rank, next_rank]) + if self.global_rank == curr_rank: + self.send_next_group = next_group + elif self.global_rank == next_rank: + self.recv_prev_group = next_group + prev_group = paddle.distributed.new_group( ranks=[prev_rank, curr_rank]) if self.global_rank == curr_rank: - self.send_next_group = next_group self.send_prev_group = prev_group - elif self.global_rank == next_rank: - self.recv_prev_group = next_group elif self.global_rank == prev_rank: self.recv_next_group = prev_group - else: - pass assert self.send_next_group is not None assert self.send_prev_group is not None diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 335c9e5e61efb3..c395a20eb4843d 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -29,6 +29,22 @@ __all__ = [] +class SendRecvMeta: + def __init__(self): + # next meta + self.next_shape_message = None + self.next_dtype_messgae = None + + self.prev_shape_message = None + self.prev_dtype_message = None + + def ready_for_next(self): + return self.next_shape_message is not None and self.next_dtype_messgae is not None + + def ready_for_prev(self): + return self.next_shape_message is not None and self.next_dtype_messgae is not None + + class PipelineParallel(MetaParallelBase): def __init__(self, layers, hcg, strategy): if not isinstance(layers, PipelineLayer): @@ -41,21 +57,10 @@ def __init__(self, layers, hcg, strategy): self.is_pipe_partitioned = self.use_model_parallel - self.num_caches = 0 - self.caches = { - 'inputs': [], - 'labels': [], - 'outputs': [], - } - - self.recv_cache = None - self.grad_tensors = None - - self.send_meta = True - - self.current_loss = paddle.to_tensor(0.0) self.total_loss = None + self.send_recv_meta = SendRecvMeta() + self.micro_batch_size = self._strategy.pipeline_configs[ 'micro_batch_size'] self.accumulate_steps = self._strategy.pipeline_configs[ @@ -66,6 +71,7 @@ def __init__(self, layers, hcg, strategy): self.prev_stage_id = self.stage_id - 1 self.next_stage_id = self.stage_id + 1 self.pp_group = self._hcg.get_pipe_parallel_group() + p2p.initialize_p2p_groups(hcg) self.is_first_stage = self.stage_id == 0 @@ -90,19 +96,22 @@ def __init__(self, layers, hcg, strategy): def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): assert isinstance(optimizer, HybridParallelOptimizer), ( 'optimizer should be HybridParallelOptimizer subclass.') - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.scaler = scaler + assert isinstance(scaler, (None, HybridParallelGradScaler)), ( + 'scaler should be HybridParallelGradScaler subclass or None.') assert fluid.framework._dygraph_tracer()._has_grad, ( 'Please enable the generation of gradients.') if self.is_first_stage or self.is_last_stage: assert data is not None, ( - "For the first and the last stage, the data_iter must be set.") + "For the first and the last stage, the data must be set.") else: data = None + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.scaler = scaler self.data = data + self._layers.train() # store total loss of entire batch @@ -111,13 +120,6 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): # store data id for micro_batch self.data_id = 0 - self.micro_batch_size = self._strategy.pipeline_configs[ - 'micro_batch_size'] - self.accumulate_steps = self._strategy.pipeline_configs[ - 'accumulate_steps'] - - self.num_stages = self._hcg.get_pipe_parallel_world_size() - # Compute number of warmup microbatches. num_microbatches = self.accumulate_steps num_warmup_microbatches = (self.num_stages - self.stage_id - 1) @@ -178,10 +180,10 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): self._layers.allreduce_shared_weight_gradients() - # optimizer self.train_loss = self._reduce_final_loss() - self._step() + # optimizer + self._optimizer_step() return self.train_loss def _forward_step(self, input_tensor): @@ -273,7 +275,7 @@ def _reduce_final_loss(self): group=self.pp_group) return loss - def _step(self): + def _optimizer_step(self): if self.scaler: self.scaler.minimize(self.optimizer, self.train_loss) else: From f076ac01db952e6af8f0e0b30f2798269a64a9d6 Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Tue, 27 Jul 2021 18:32:24 +0800 Subject: [PATCH 05/14] add send_recv_meta --- .../fleet/meta_parallel/pipeline_parallel.py | 25 +-- .../pp_utils/p2p_communication.py | 206 +++++++----------- 2 files changed, 82 insertions(+), 149 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index c395a20eb4843d..b0162751cb8003 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -23,28 +23,12 @@ from ..utils.hybrid_parallel_util import broadcast_mp_parameters from ..utils.hybrid_parallel_util import broadcast_dp_parameters from ..utils.log_util import logger -from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer +from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler from .pp_utils import p2p_communication as p2p __all__ = [] -class SendRecvMeta: - def __init__(self): - # next meta - self.next_shape_message = None - self.next_dtype_messgae = None - - self.prev_shape_message = None - self.prev_dtype_message = None - - def ready_for_next(self): - return self.next_shape_message is not None and self.next_dtype_messgae is not None - - def ready_for_prev(self): - return self.next_shape_message is not None and self.next_dtype_messgae is not None - - class PipelineParallel(MetaParallelBase): def __init__(self, layers, hcg, strategy): if not isinstance(layers, PipelineLayer): @@ -59,8 +43,6 @@ def __init__(self, layers, hcg, strategy): self.total_loss = None - self.send_recv_meta = SendRecvMeta() - self.micro_batch_size = self._strategy.pipeline_configs[ 'micro_batch_size'] self.accumulate_steps = self._strategy.pipeline_configs[ @@ -96,8 +78,9 @@ def __init__(self, layers, hcg, strategy): def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): assert isinstance(optimizer, HybridParallelOptimizer), ( 'optimizer should be HybridParallelOptimizer subclass.') - assert isinstance(scaler, (None, HybridParallelGradScaler)), ( - 'scaler should be HybridParallelGradScaler subclass or None.') + if scaler is not None: + assert isinstance(scaler, HybridParallelGradScaler), ( + 'scaler should be HybridParallelGradScaler subclass or None.') assert fluid.framework._dygraph_tracer()._has_grad, ( 'Please enable the generation of gradients.') diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 03c5b8c232dc3c..8163d4ab173e9b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -13,73 +13,77 @@ # limitations under the License. import paddle +from .utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype _groups = None _hcg = None _tensor_shape = (2, 1024, 768) -def initialize_p2p_groups(hcg): - global _groups, _hcg - # _groups = [ - # paddle.distributed.new_group(ranks=group) - # for group in hcg.get_p2p_groups() - # ] - _hcg = hcg +class SendRecvMeta: + def __init__(self): + self.send_shape_message = None + self.send_dtype_messgae = None + self.recv_shape_message = None + self.recv_dtype_message = None -def _is_valid_communciate(src_stage, dest_stage): - first_stage = 0 - last_stage = _hcg.get_pipe_parallel_world_size() - 1 - assert abs(src_stage-dest_stage) == 1 or \ - (src_stage == first_stage and dest_stage == last_stage) or \ - (src_stage == last_stage and dest_stage == first_stage) + self.has_send_meta = False + self.has_recv_meta = False + def recv_meta(self, group): + tensor_type = paddle.to_tensor([0]) + paddle.distributed.recv(tensor_type, src=0, group=group) -def partial_send_operator(tensor, - dst=0, - mp_ranks=1, - mp_rank_id=0, - group=None, - use_calc_stream=True): + tensor_type = tensor_type.item() + if tensor_type == 0: + # recv len(shape) + dims = paddle.to_tensor([0]) + paddle.distributed.recv(dims, src=0, group=group) + dims = dims.item() - if group is not None and not group.is_member(): - return - ring_id = 0 if group is None else group.id - return paddle.fluid.core.ops.partial_send( - tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', - dst, 'num', mp_ranks, 'id', mp_rank_id) + # recv shape + shape = paddle.to_tensor([0] * dims) + paddle.distributed.recv(shape, src=0, group=group) + shape = shape.numpy().tolist() + # recv dtype + dtype = paddle.to_tensor([0]) + paddle.distributed.recv(dtype, src=0, group=group) -def partial_recv_operator(tensor, - src=0, - mp_ranks=1, - mp_rank_id=0, - group=None, - use_calc_stream=True): + self.recv_shape_message = shape + self.recv_dtype_message = dtype.item() - if group is not None and not group.is_member(): - return - ring_id = 0 if group is None else group.id + def send_meta(self, tensor, group): + if isinstance(tensor, paddle.Tensor): + tensor_type = paddle.to_tensor([0]) - return paddle.fluid.core.ops.partial_recv( - tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', - src, 'num', mp_ranks, 'id', mp_rank_id, 'dtype', tensor.dtype, - 'out_shape', tensor.shape) + # send tensor type + paddle.distributed.send(tensor_type, dst=1, group=group) + # send len(shape) + dims = paddle.to_tensor(len(tensor.shape)) + paddle.distributed.send(dims, dst=1, group=group) -def partial_allgather_operator(tensor, - mp_ranks=1, - mp_rank_id=0, - group=None, - use_calc_stream=True): - if group is not None and not group.is_member(): - return - ring_id = 0 if group is None else group.id + # send shape + shape = paddle.to_tensor(tensor.shape) + paddle.distributed.send(shape, dst=1, group=group) - return paddle.fluid.core.ops.partial_allgather_( - tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'nranks', mp_ranks, 'rank', mp_rank_id) + # send dtype + dtype = paddle.to_tensor(paddle_2_number(tensor.dtype)) + paddle.distributed.send(dtype, dst=1, group=group) + + def set_send_message(self, tensor): + self.send_shape_message = tensor.shape + self.send_dtype_message = paddle_2_number(tensor.dtype) + + +_send_recv_meta = SendRecvMeta() + + +def initialize_p2p_groups(hcg): + global _groups, _hcg + _hcg = hcg def send(tensor, dest_stage): @@ -101,46 +105,11 @@ def recv(tensor, src_stage): tensor, src=0 if dest_stage > src_stage else 1, group=group) -def send_partial(tensor, dest_stage, mp_degree, mp_rank): - global _groups, _hcg - src_stage = _hcg.get_stage_id() - _is_valid_communciate(src_stage, dest_stage) - group = _get_send_recv_group(src_stage, dest_stage) - return partial_send_operator( - tensor, - dst=1 if dest_stage > src_stage else 0, - mp_ranks=mp_degree, - mp_rank_id=mp_rank, - group=group) - - -def recv_partial(tensor, src_stage, mp_degree, mp_rank): - global _groups, _hcg - dest_stage = _hcg.get_stage_id() - - _is_valid_communciate(src_stage, dest_stage) - group = _get_send_recv_group(src_stage, dest_stage) - return partial_recv_operator( - tensor, - src=0 if dest_stage > src_stage else 1, - mp_ranks=mp_degree, - mp_rank_id=mp_rank, - group=group) - - def _get_send_recv_group(src_stage, dest_stage): global _groups, _hcg stage_id = None first_stage = 0 last_stage = _hcg.get_pipe_parallel_world_size() - 1 - #if (src_stage == first_stage and dest_stage == last_stage) or \ - # (dest_stage == first_stage and src_stage == last_stage): - # stage_id = last_stage - #if src_stage > dest_stage: - # stage_id = dest_stage - #else: - # stage_id = src_stage - #group_id = _hcg.get_rank_from_stage(stage_id=stage_id) group_id = _hcg.get_rank_from_stage(stage_id=src_stage) return _groups[group_id] @@ -159,9 +128,16 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): next_stage = current_stage + 1 if recv_prev: - tensor_recv_prev = paddle.empty(shape=tensor_chunk_shape, dtype=dtype) + # tensor_recv_prev = paddle.empty(shape=tensor_chunk_shape, dtype=dtype) + tensor_recv_prev = paddle.empty( + shape=_send_recv_meta.recv_shape_message, + dtype=number_2_dtype(_send_recv_meta.recv_dtype_message)) + if recv_next: - tensor_recv_next = paddle.empty(shape=tensor_chunk_shape, dtype=dtype) + # tensor_recv_next = paddle.empty(shape=tensor_chunk_shape, dtype=dtype) + tensor_recv_next = paddle.empty( + shape=_send_recv_meta.send_shape_message, + dtype=number_2_dtype(_send_recv_meta.send_dtype_message)) if tensor_send_prev is not None: # group = _get_send_recv_group( @@ -212,6 +188,11 @@ def recv_forward(): if _hcg.is_first_stage: input_tensor = None else: + # check recv forward + if not _send_recv_meta.has_recv_meta: + _send_recv_meta.recv_meta(_hcg.recv_prev_group) + _send_recv_meta.has_recv_meta = True + input_tensor, _ = _communicate( tensor_send_next=None, tensor_send_prev=None, @@ -234,6 +215,12 @@ def recv_backward(): def send_forward(output_tensor): if not _hcg.is_last_stage: + + if not _send_recv_meta.has_send_meta: + _send_recv_meta.set_send_message(output_tensor) + _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) + _send_recv_meta.has_send_meta = True + _communicate( tensor_send_next=output_tensor, tensor_send_prev=None, @@ -254,6 +241,11 @@ def send_forward_recv_backward(output_tensor): if _hcg.is_last_stage: output_tensor_grad = None else: + if not _send_recv_meta.has_send_meta: + _send_recv_meta.set_send_message(output_tensor) + _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) + _send_recv_meta.has_send_meta = True + _, output_tensor_grad = _communicate( tensor_send_next=output_tensor, tensor_send_prev=None, @@ -272,45 +264,3 @@ def send_backward_recv_forward(input_tensor_grad): recv_prev=True, recv_next=False) return input_tensor - - -# def send_forward_recv_forward(output_tensor, recv_prev, timers=None): -# """Batched recv from previous rank and send to next rank in pipeline.""" -# if timers is not None: -# timers('forward-send-forward-recv').start() -# input_tensor, _ = _communicate( -# tensor_send_next=output_tensor, -# tensor_send_prev=None, -# recv_prev=recv_prev, -# recv_next=False) -# if timers is not None: -# timers('forward-send-forward-recv').stop() -# return input_tensor - -# def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): -# """Batched recv from next rank and send to previous rank in pipeline.""" -# if timers is not None: -# timers('backward-send-backward-recv').start() -# _, output_tensor_grad = _communicate( -# tensor_send_next=None, -# tensor_send_prev=input_tensor_grad, -# recv_prev=False, -# recv_next=recv_next) -# if timers is not None: -# timers('backward-send-backward-recv').stop() -# return output_tensor_grad - -# def send_forward_backward_recv_forward_backward( -# output_tensor, input_tensor_grad, recv_prev, -# recv_next, timers=None): -# """Batched send and recv with previous and next ranks in pipeline.""" -# if timers is not None: -# timers('forward-backward-send-forward-backward-recv').start() -# input_tensor, output_tensor_grad = _communicate( -# tensor_send_next=output_tensor, -# tensor_send_prev=input_tensor_grad, -# recv_prev=recv_prev, -# recv_next=recv_next) -# if timers is not None: -# timers('forward-backward-send-forward-backward-recv').stop() -# return input_tensor, output_tensor_grad From c636a3199ac798901624483c033c413ed05895b2 Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Wed, 28 Jul 2021 11:09:49 +0800 Subject: [PATCH 06/14] support tuple --- .../pp_utils/p2p_communication.py | 247 ++++++++++-------- 1 file changed, 132 insertions(+), 115 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 8163d4ab173e9b..93b2de423380e3 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -23,7 +23,7 @@ class SendRecvMeta: def __init__(self): self.send_shape_message = None - self.send_dtype_messgae = None + self.send_dtype_message = None self.recv_shape_message = None self.recv_dtype_message = None @@ -31,51 +31,85 @@ def __init__(self): self.has_send_meta = False self.has_recv_meta = False + def _recv_shape_dtype(self, group): + # recv len(shape) + dims = paddle.to_tensor([0]) + paddle.distributed.recv(dims, src=0, group=group) + dims = dims.item() + + # recv shape + shape = paddle.to_tensor([0] * dims) + paddle.distributed.recv(shape, src=0, group=group) + + # recv dtype + dtype = paddle.to_tensor([0]) + paddle.distributed.recv(dtype, src=0, group=group) + return shape.numpy().tolist(), dtype.item() + def recv_meta(self, group): tensor_type = paddle.to_tensor([0]) paddle.distributed.recv(tensor_type, src=0, group=group) - tensor_type = tensor_type.item() - if tensor_type == 0: - # recv len(shape) - dims = paddle.to_tensor([0]) - paddle.distributed.recv(dims, src=0, group=group) - dims = dims.item() - - # recv shape - shape = paddle.to_tensor([0] * dims) - paddle.distributed.recv(shape, src=0, group=group) - shape = shape.numpy().tolist() - - # recv dtype - dtype = paddle.to_tensor([0]) - paddle.distributed.recv(dtype, src=0, group=group) + if tensor_type == 0: + shape, dtype = self._recv_shape_dtype(group) self.recv_shape_message = shape - self.recv_dtype_message = dtype.item() + self.recv_dtype_message = dtype + + elif tensor_type == 1: + num = paddle.to_tensor([0]) + paddle.distributed.recv(num, src=0, group=group) + num = num.item() + shapes = [] + dtypes = [] + for i in range(num): + shape, dtype = self._recv_shape_dtype() + shapes.append(shape) + dtypes.append(dtype) + + self.recv_shape_message = tuple(shapes) + self.recv_dtype_message = tuple(dtypes) + + def _send_dims_shape_dtype(self, tensor, group): + # send len(shape) + dims = paddle.to_tensor(len(tensor.shape)) + paddle.distributed.send(dims, dst=1, group=group) + + # send shape + shape = paddle.to_tensor(tensor.shape) + paddle.distributed.send(shape, dst=1, group=group) + + # send dtype + dtype = paddle.to_tensor(paddle_2_number(tensor.dtype)) + paddle.distributed.send(dtype, dst=1, group=group) def send_meta(self, tensor, group): if isinstance(tensor, paddle.Tensor): tensor_type = paddle.to_tensor([0]) - # send tensor type paddle.distributed.send(tensor_type, dst=1, group=group) - # send len(shape) - dims = paddle.to_tensor(len(tensor.shape)) - paddle.distributed.send(dims, dst=1, group=group) + self._send_dims_shape_dtype(tensor, group) + elif isinstance(tensor, tuple): + tensor_type = paddle.to_tensor([1]) + # send tensor type + paddle.distributed.send(tensor_type, dst=1, group=group) - # send shape - shape = paddle.to_tensor(tensor.shape) - paddle.distributed.send(shape, dst=1, group=group) + nums = paddle.to_tensor(len(tensor)) + paddle.distributed.send(nums, dst=1, group=group) - # send dtype - dtype = paddle.to_tensor(paddle_2_number(tensor.dtype)) - paddle.distributed.send(dtype, dst=1, group=group) + for d in tensor: + assert isinstance(d, paddle.Tensor) + self._send_dims_shape_dtype(d) def set_send_message(self, tensor): - self.send_shape_message = tensor.shape - self.send_dtype_message = paddle_2_number(tensor.dtype) + if isinstance(tensor, paddle.Tensor): + self.send_shape_message = tensor.shape + self.send_dtype_message = paddle_2_number(tensor.dtype) + elif isinstance(tensor, tuple): + self.send_shape_message = tuple([d.shape for d in tensor]) + self.send_dtype_message = tuple( + [paddle_2_number(d.dtype) for d in tensor]) _send_recv_meta = SendRecvMeta() @@ -86,100 +120,90 @@ def initialize_p2p_groups(hcg): _hcg = hcg -def send(tensor, dest_stage): - global _groups, _hcg - src_stage = _hcg.get_stage_id() - _is_valid_communciate(src_stage, dest_stage) - group = _get_send_recv_group(src_stage, dest_stage) - return paddle.distributed.send( - tensor, dst=1 if dest_stage > src_stage else 0, group=group) - - -def recv(tensor, src_stage): - global _groups, _hcg - dest_stage = _hcg.get_stage_id() - - _is_valid_communciate(src_stage, dest_stage) - group = _get_send_recv_group(src_stage, dest_stage) - return paddle.distributed.recv( - tensor, src=0 if dest_stage > src_stage else 1, group=group) - - -def _get_send_recv_group(src_stage, dest_stage): - global _groups, _hcg - stage_id = None - first_stage = 0 - last_stage = _hcg.get_pipe_parallel_world_size() - 1 - group_id = _hcg.get_rank_from_stage(stage_id=src_stage) - return _groups[group_id] - - def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): + global _groups, _hcg tensor_recv_prev = None tensor_recv_next = None - global _tensor_shape, _groups, _hcg - tensor_chunk_shape = _tensor_shape - dtype = "float32" - - current_stage = _hcg.get_stage_id() - prev_stage = current_stage - 1 - next_stage = current_stage + 1 + recv_shape_msg = _send_recv_meta.recv_shape_message + recv_dtype_msg = _send_recv_meta.recv_dtype_message + send_shape_msg = _send_recv_meta.send_shape_message + send_dtype_msg = _send_recv_meta.send_dtype_message if recv_prev: - # tensor_recv_prev = paddle.empty(shape=tensor_chunk_shape, dtype=dtype) - tensor_recv_prev = paddle.empty( - shape=_send_recv_meta.recv_shape_message, - dtype=number_2_dtype(_send_recv_meta.recv_dtype_message)) + if isinstance(recv_shape_msg, tuple): + for idx, shape in enumerate(recv_shape_msg): + tensor_recv_prev = tuple([ + paddle.empty( + shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx])) + ]) + else: + tensor_recv_prev = paddle.empty( + shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg)) if recv_next: - # tensor_recv_next = paddle.empty(shape=tensor_chunk_shape, dtype=dtype) - tensor_recv_next = paddle.empty( - shape=_send_recv_meta.send_shape_message, - dtype=number_2_dtype(_send_recv_meta.send_dtype_message)) + if isinstance(send_shape_msg, tuple): + for idx, shape in enumerate(send_shape_msg): + tensor_recv_next = tuple([ + paddle.empty( + shape=shape, dtype=number_2_dtype(send_dtype_msg[idx])) + ]) + else: + tensor_recv_next = paddle.empty( + shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg)) if tensor_send_prev is not None: - # group = _get_send_recv_group( - # src_stage=current_stage, dest_stage=prev_stage) - #print("group msg:", group) - paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) - paddle.distributed.send( - tensor_send_prev, - dst=0, - group=_hcg.send_prev_group, - use_calc_stream=False) + if isinstance(tensor_send_prev, tuple): + for d in tensor_send_prev: + paddle.distributed.wait(d, use_calc_stream=True) + paddle.distributed.send( + d, dst=0, group=_hcg.send_prev_group, use_calc_stream=False) + else: + paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) + paddle.distributed.send( + tensor_send_prev, + dst=0, + group=_hcg.send_prev_group, + use_calc_stream=False) + if tensor_recv_prev is not None: - # group = _get_send_recv_group( - # src_stage=prev_stage, dest_stage=current_stage) - #print("group msg:", group) - paddle.distributed.recv( - tensor_recv_prev, - src=0, - group=_hcg.recv_prev_group, - use_calc_stream=True) - #print("tensor_recv_prev", tensor_recv_prev.numpy()) + if isinstance(tensor_recv_prev, tuple): + for d in tensor_recv_prev: + paddle.distributed.recv( + d, src=0, group=_hcg.recv_prev_group, use_calc_stream=True) + else: + paddle.distributed.recv( + tensor_recv_prev, + src=0, + group=_hcg.recv_prev_group, + use_calc_stream=True) if tensor_send_next is not None: - # group = _get_send_recv_group( - # src_stage=current_stage, dest_stage=next_stage) - #print("group msg:", group) - paddle.distributed.wait(tensor_send_next, use_calc_stream=True) - paddle.distributed.send( - tensor_send_next, - dst=1, - group=_hcg.send_next_group, - use_calc_stream=False) + if isinstance(tensor_send_next, tuple): + for d in tensor_send_next: + paddle.distributed.wait(d, use_calc_stream=True) + paddle.distributed.send( + d, dst=1, group=_hcg.send_next_group, use_calc_stream=False) + else: + paddle.distributed.wait(tensor_send_next, use_calc_stream=True) + paddle.distributed.send( + tensor_send_next, + dst=1, + group=_hcg.send_next_group, + use_calc_stream=False) + if tensor_recv_next is not None: - # group = _get_send_recv_group( - # src_stage=next_stage, dest_stage=current_stage) - #print("group msg:", group) - paddle.distributed.recv( - tensor_recv_next, - src=1, - group=_hcg.recv_next_group, - use_calc_stream=True) - #print("tensor_recv_next", tensor_recv_next.numpy()) + if isinstance(tensor_recv_next, tuple): + for d in tensor_recv_next: + paddle.distributed.recv( + d, src=1, group=_hcg.recv_next_group, use_calc_stream=True) + else: + paddle.distributed.recv( + tensor_recv_next, + src=1, + group=_hcg.recv_next_group, + use_calc_stream=True) return tensor_recv_prev, tensor_recv_next @@ -188,7 +212,6 @@ def recv_forward(): if _hcg.is_first_stage: input_tensor = None else: - # check recv forward if not _send_recv_meta.has_recv_meta: _send_recv_meta.recv_meta(_hcg.recv_prev_group) _send_recv_meta.has_recv_meta = True @@ -215,7 +238,6 @@ def recv_backward(): def send_forward(output_tensor): if not _hcg.is_last_stage: - if not _send_recv_meta.has_send_meta: _send_recv_meta.set_send_message(output_tensor) _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) @@ -241,11 +263,6 @@ def send_forward_recv_backward(output_tensor): if _hcg.is_last_stage: output_tensor_grad = None else: - if not _send_recv_meta.has_send_meta: - _send_recv_meta.set_send_message(output_tensor) - _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) - _send_recv_meta.has_send_meta = True - _, output_tensor_grad = _communicate( tensor_send_next=output_tensor, tensor_send_prev=None, From 1a55ccb8350bf53884a7a5fd00c2415bd7f79bb2 Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Wed, 28 Jul 2021 22:29:17 +0800 Subject: [PATCH 07/14] add utest --- .../fleet/meta_parallel/pipeline_parallel.py | 47 ++++- .../pp_utils/p2p_communication.py | 22 ++- .../hybrid_parallel_pp_transformer.py | 187 ++++++++++++++++++ 3 files changed, 238 insertions(+), 18 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index b0162751cb8003..b0a938ba58aeb6 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -75,6 +75,18 @@ def __init__(self, layers, hcg, strategy): broadcast_dp_parameters(self._layers, self._hcg) self.data_id = 0 + def _set_tensor_trainable(self, tensor): + if tensor is None: + return + + if isinstance(tensor, tuple): + for t in tensor: + if is_float_tensor(t): + t.stop_gradient = False + else: + if is_float_tensor(tensor): + tensor.stop_gradient = False + def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): assert isinstance(optimizer, HybridParallelOptimizer), ( 'optimizer should be HybridParallelOptimizer subclass.') @@ -115,8 +127,8 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): for step_id in range(num_warmup_microbatches): input_tensor = p2p.recv_forward() - if input_tensor is not None: - input_tensor.stop_gradient = False + self._set_tensor_trainable(input_tensor) + output_tensor = self._forward_step(input_tensor) p2p.send_forward(output_tensor) @@ -129,8 +141,7 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) - if input_tensor is not None: - input_tensor.stop_gradient = False + self._set_tensor_trainable(input_tensor) output_tensor = self._forward_step(input_tensor) output_tensor_grad = p2p.send_forward_recv_backward(output_tensor) @@ -193,12 +204,32 @@ def _forward_step(self, input_tensor): return output_tensor def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): - paddle.autograd.backward( - tensors=[output_tensor], grad_tensors=[output_tensor_grad]) + if self.is_last_stage: + assert output_tensor_grad is None + paddle.autograd.backward( + tensors=[output_tensor], grad_tensors=[None]) + else: + if isinstance(output_tensor, tuple): + outputs = [t for t in output_tensor if not t.stop_gradient] + assert len(outputs) == len(output_tensor_grad) + print("outputs: ", type(outputs), len(outputs)) + print("output_tensor_grad: ", type(output_tensor_grad), + len(output_tensor_grad)) + print(output_tensor_grad) + paddle.autograd.backward( + tensors=outputs, + grad_tensors=[t for t in output_tensor_grad]) + else: + paddle.autograd.backward( + tensors=[output_tensor], grad_tensors=[output_tensor_grad]) + input_tensor_grad = None if input_tensor is not None: - input_tensor_grad = input_tensor.grad - + if isinstance(input_tensor, tuple): + input_tensor_grad = tuple( + [t.grad for t in input_tensor if not t.stop_gradient]) + else: + input_tensor_grad = input_tensor.grad return input_tensor_grad def _load_micro_batch(self, cache_id): diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 93b2de423380e3..6316b6cf2006c7 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -17,7 +17,6 @@ _groups = None _hcg = None -_tensor_shape = (2, 1024, 768) class SendRecvMeta: @@ -63,7 +62,7 @@ def recv_meta(self, group): shapes = [] dtypes = [] for i in range(num): - shape, dtype = self._recv_shape_dtype() + shape, dtype = self._recv_shape_dtype(group) shapes.append(shape) dtypes.append(dtype) @@ -100,14 +99,15 @@ def send_meta(self, tensor, group): for d in tensor: assert isinstance(d, paddle.Tensor) - self._send_dims_shape_dtype(d) + self._send_dims_shape_dtype(d, group=group) def set_send_message(self, tensor): if isinstance(tensor, paddle.Tensor): self.send_shape_message = tensor.shape self.send_dtype_message = paddle_2_number(tensor.dtype) elif isinstance(tensor, tuple): - self.send_shape_message = tuple([d.shape for d in tensor]) + self.send_shape_message = tuple( + [d.shape for d in tensor if not d.stop_gradient]) self.send_dtype_message = tuple( [paddle_2_number(d.dtype) for d in tensor]) @@ -133,22 +133,24 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if recv_prev: if isinstance(recv_shape_msg, tuple): + tensor_recv_prev = [] for idx, shape in enumerate(recv_shape_msg): - tensor_recv_prev = tuple([ + tensor_recv_prev.append( paddle.empty( - shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx])) - ]) + shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx]))) + tensor_recv_prev = tuple(tensor_recv_prev) else: tensor_recv_prev = paddle.empty( shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg)) if recv_next: if isinstance(send_shape_msg, tuple): + tensor_recv_next = [] for idx, shape in enumerate(send_shape_msg): - tensor_recv_next = tuple([ + tensor_recv_next.append( paddle.empty( - shape=shape, dtype=number_2_dtype(send_dtype_msg[idx])) - ]) + shape=shape, dtype=number_2_dtype(send_dtype_msg[idx]))) + tensor_recv_next = tuple(tensor_recv_next) else: tensor_recv_next = paddle.empty( shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg)) diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py new file mode 100644 index 00000000000000..e7510591578aed --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py @@ -0,0 +1,187 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from hybrid_parallel_pp_layer import AlexNetPipeDesc, AlexNet +from paddle.fluid.dygraph.layers import Layer +from paddle.fluid import layers +import paddle.nn.functional as F + +import paddle +import numpy as np +import random +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.fluid.dygraph.container import Sequential +from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc +from paddle.fluid.dygraph.layers import Layer +import paddle.nn as nn +import paddle.fluid as fluid + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 4 +length = 8 +micro_batch_size = 2 +vocab_size = 128 +hidden_size = 3 +d_model = hidden_size +dim_feedforward = 4 * d_model + + +class EmbeddingNet(Layer): + def __init__(self): + super(EmbeddingNet, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(vocab_size, hidden_size) + + def forward(self, x): + attention_mask = paddle.tensor.triu( + (paddle.ones( + (length, length), dtype="float32") * -1e9), 1) + attention_mask.stop_gradient = True + w_emb = self.word_embeddings(x) + p_emb = self.position_embeddings(x) + + return w_emb, attention_mask, p_emb.detach() + + +class TransformerNet(Layer): + def __init__(self): + super(TransformerNet, self).__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.q_proj = nn.Linear(d_model, d_model) + self.k_proj = nn.Linear(d_model, d_model) + self.v_proj = nn.Linear(d_model, d_model) + + self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, x, mask): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5) + + weights = F.softmax(product + mask) + weights = F.dropout(weights, 0.2) + tgt = layers.matmul(weights, v) + residual = tgt + tgt = self.norm1(tgt) + tgt = residual + tgt + + out = self.linear2(F.gelu(self.linear1(tgt), approximate=True)) + return out + + +class EmbeddingPipe(EmbeddingNet): + def forward(self, x): + return super().forward(x) + + +class TransformerNetPipe(TransformerNet): + def forward(self, args): + x, mask, p_emb = args[0], args[1], args[2] + + output = super().forward(x, mask) + output = output + p_emb + mask.stop_gradient = True + return output, mask, p_emb + + +class CriterionPipe(Layer): + def __init__(self): + super(CriterionPipe, self).__init__() + + def forward(self, out, label): + loss = out.mean() + return loss + + +class ModelPipe(PipelineLayer): + def __init__(self, topology): + self.descs = [] + self.descs.append(LayerDesc(EmbeddingPipe)) + + for x in range(4): + self.descs.append(LayerDesc(TransformerNetPipe)) + + self.descs.append(lambda x: x[0]) + + super().__init__( + layers=self.descs, loss_fn=CriterionPipe(), topology=topology) + + +class TestDistPPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + topology = hcg.topology() + set_random_seed(1024, dp_id, rank_id) + + model = ModelPipe(topology) + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer = paddle.optimizer.SGD(learning_rate=scheduler, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + for step_id in range(5): + x_data = np.random.randint(0, vocab_size, size=[batch_size, length]) + x = paddle.to_tensor(x_data) + x.stop_gradient = True + loss = model.train_batch([x, x], optimizer, scheduler) + print("loss: ", loss.numpy()) + + +if __name__ == "__main__": + unittest.main() From 7a2af576d5dd75962e025b0a9a171ef7338f85b9 Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Thu, 29 Jul 2021 16:44:08 +0800 Subject: [PATCH 08/14] add send_partial/recv_partial --- .../paddle/distributed/fleet/base/topology.py | 19 --- .../fleet/meta_parallel/pipeline_parallel.py | 17 +-- .../pp_utils/p2p_communication.py | 115 +++++++++++++++--- .../hybrid_parallel_pp_transformer.py | 30 ++--- 4 files changed, 116 insertions(+), 65 deletions(-) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 6cb5eb971fa31b..623a3efaf26304 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -157,7 +157,6 @@ def __init__(self, topology): self.is_last_stage = (self.stage_id == (self._pp_degree - 1)) # create p2p_groups - self._p2p_groups = self._build_p2p_lists() if self._pp_degree > 1: self._set_p2p_group() print("send_next_group: ", self.send_next_group) @@ -177,21 +176,6 @@ def __init__(self, topology): global _HYBRID_PARALLEL_GROUP _HYBRID_PARALLEL_GROUP = self - def _build_p2p_lists(self): - comm_lists = self._topo.get_comm_list('pipe') - p2p_lists = [] - for rank in range(self.nranks): - for comm_ranks in comm_lists: - assert len(comm_ranks) == self._pp_degree - if rank in comm_ranks: - idx = comm_ranks.index(rank) - next_rank = comm_ranks[(idx + 1) % self._pp_degree] - p2p_lists.append([rank, next_rank]) - break - assert len( - p2p_lists) == self.nranks, "len(p2p_lists) should be equal nranks" - return p2p_lists - def get_parallel_mode(self): # there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and @@ -346,9 +330,6 @@ def get_sharding_parallel_group_src_rank(self): # TODO should the src rank related to the shard rank for each parameter ? return self._sharding_comm_group.ranks[0] - def get_p2p_groups(self): - return self._p2p_groups - # check parallel group def get_check_parallel_group(self): return self._check_comm_group diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index b0a938ba58aeb6..7c1424d3194949 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -50,8 +50,6 @@ def __init__(self, layers, hcg, strategy): self.num_stages = self._hcg.get_pipe_parallel_world_size() self.stage_id = self._hcg.get_stage_id() - self.prev_stage_id = self.stage_id - 1 - self.next_stage_id = self.stage_id + 1 self.pp_group = self._hcg.get_pipe_parallel_group() p2p.initialize_p2p_groups(hcg) @@ -73,7 +71,8 @@ def __init__(self, layers, hcg, strategy): if self.use_data_parallel: logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) - self.data_id = 0 + + self.micro_batch_id = 0 def _set_tensor_trainable(self, tensor): if tensor is None: @@ -113,7 +112,7 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): self.total_loss = None # store data id for micro_batch - self.data_id = 0 + self.micro_batch_id = 0 # Compute number of warmup microbatches. num_microbatches = self.accumulate_steps @@ -182,12 +181,12 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): def _forward_step(self, input_tensor): if self.stage_id == 0: - input_tensor = self._load_micro_batch(self.data_id) + input_tensor = self._load_micro_batch(self.micro_batch_id) output_tensor = self._layers.forward(input_tensor) if self.is_last_stage: - labels = self._load_micro_batch(self.data_id) + labels = self._load_micro_batch(self.micro_batch_id) output_tensor = self._layers._loss_fn(output_tensor, labels) assert isinstance( output_tensor, paddle. @@ -200,7 +199,7 @@ def _forward_step(self, input_tensor): self.total_loss = paddle.zeros_like(output_tensor) self.total_loss += output_tensor.detach() - self.data_id += 1 + self.micro_batch_id += 1 return output_tensor def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): @@ -212,10 +211,6 @@ def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): if isinstance(output_tensor, tuple): outputs = [t for t in output_tensor if not t.stop_gradient] assert len(outputs) == len(output_tensor_grad) - print("outputs: ", type(outputs), len(outputs)) - print("output_tensor_grad: ", type(output_tensor_grad), - len(output_tensor_grad)) - print(output_tensor_grad) paddle.autograd.backward( tensors=outputs, grad_tensors=[t for t in output_tensor_grad]) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 6316b6cf2006c7..930f8203e8c8dd 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -15,7 +15,6 @@ import paddle from .utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype -_groups = None _hcg = None @@ -115,22 +114,78 @@ def set_send_message(self, tensor): _send_recv_meta = SendRecvMeta() +def send_partial(tensor, + dst=0, + nranks=1, + rank_id=0, + group=None, + use_calc_stream=True): + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + return paddle.fluid.core.ops.partial_send( + tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', + dst, 'num', nranks, 'id', rank_id) + + +def recv_partial(tensor, + src=0, + nranks=1, + rank_id=0, + group=None, + use_calc_stream=True): + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + + paddle.fluid.core.ops.partial_recv( + tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', + src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape', + tensor.shape) + + if nranks > 1: + pass + + +def allgather_partial(tensor, + nranks=1, + rank_id=0, + group=None, + use_calc_stream=True): + if nranks == 1: + return + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + + return paddle.fluid.core.ops.partial_allgather_( + tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, + 'nranks', nranks, 'rank', rank_id) + + def initialize_p2p_groups(hcg): - global _groups, _hcg + global _hcg _hcg = hcg -def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): - global _groups, _hcg +def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): + global _hcg tensor_recv_prev = None tensor_recv_next = None + # send / recv message recv_shape_msg = _send_recv_meta.recv_shape_message recv_dtype_msg = _send_recv_meta.recv_dtype_message send_shape_msg = _send_recv_meta.send_shape_message send_dtype_msg = _send_recv_meta.send_dtype_message + # model parallel message + + mp_group = _hcg.get_model_parallel_group() + mp_degree = self._hcg.get_model_parallel_world_size() + mp_rank = self._hcg.get_model_parallel_rank() + if recv_prev: if isinstance(recv_shape_msg, tuple): tensor_recv_prev = [] @@ -155,25 +210,54 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): tensor_recv_next = paddle.empty( shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg)) + # start to p2p communicate if tensor_send_prev is not None: if isinstance(tensor_send_prev, tuple): for d in tensor_send_prev: paddle.distributed.wait(d, use_calc_stream=True) - paddle.distributed.send( - d, dst=0, group=_hcg.send_prev_group, use_calc_stream=False) + send_partial( + d, + dst=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_prev_group, + use_calc_stream=False) + + # paddle.distributed.send( + # d, dst=0, group=_hcg.send_prev_group, use_calc_stream=False) else: paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) - paddle.distributed.send( + # paddle.distributed.send( + # tensor_send_prev, + # dst=0, + # group=_hcg.send_prev_group, + # use_calc_stream=False) + send_partial( tensor_send_prev, dst=0, + nranks=mp_degree, + rank_id=mp_rank, group=_hcg.send_prev_group, use_calc_stream=False) if tensor_recv_prev is not None: if isinstance(tensor_recv_prev, tuple): for d in tensor_recv_prev: - paddle.distributed.recv( - d, src=0, group=_hcg.recv_prev_group, use_calc_stream=True) + # paddle.distributed.recv( + # d, src=0, group=_hcg.recv_prev_group, use_calc_stream=True) + recv_partial( + d, + src=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_prev_group, + use_calc_stream=True) + allgather_partial( + d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) else: paddle.distributed.recv( tensor_recv_prev, @@ -185,6 +269,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if isinstance(tensor_send_next, tuple): for d in tensor_send_next: paddle.distributed.wait(d, use_calc_stream=True) + paddle.distributed.send( d, dst=1, group=_hcg.send_next_group, use_calc_stream=False) else: @@ -218,7 +303,7 @@ def recv_forward(): _send_recv_meta.recv_meta(_hcg.recv_prev_group) _send_recv_meta.has_recv_meta = True - input_tensor, _ = _communicate( + input_tensor, _ = _p2p_helper( tensor_send_next=None, tensor_send_prev=None, recv_prev=True, @@ -230,7 +315,7 @@ def recv_backward(): if _hcg.is_last_stage: output_tensor_grad = None else: - _, output_tensor_grad = _communicate( + _, output_tensor_grad = _p2p_helper( tensor_send_next=None, tensor_send_prev=None, recv_prev=False, @@ -245,7 +330,7 @@ def send_forward(output_tensor): _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) _send_recv_meta.has_send_meta = True - _communicate( + _p2p_helper( tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, @@ -254,7 +339,7 @@ def send_forward(output_tensor): def send_backward(input_tensor_grad): if not _hcg.is_first_stage: - _communicate( + _p2p_helper( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=False, @@ -265,7 +350,7 @@ def send_forward_recv_backward(output_tensor): if _hcg.is_last_stage: output_tensor_grad = None else: - _, output_tensor_grad = _communicate( + _, output_tensor_grad = _p2p_helper( tensor_send_next=output_tensor, tensor_send_prev=None, recv_prev=False, @@ -277,7 +362,7 @@ def send_backward_recv_forward(input_tensor_grad): if _hcg.is_first_stage: input_tensor = None else: - input_tensor, _ = _communicate( + input_tensor, _ = _p2p_helper( tensor_send_next=None, tensor_send_prev=input_tensor_grad, recv_prev=True, diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py index e7510591578aed..84971f2bc35571 100644 --- a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py @@ -19,25 +19,13 @@ import paddle import numpy as np import random -import paddle import paddle.distributed as dist import paddle.distributed.fleet as fleet -from hybrid_parallel_pp_layer import AlexNetPipeDesc, AlexNet -from paddle.fluid.dygraph.layers import Layer from paddle.fluid import layers import paddle.nn.functional as F - -import paddle -import numpy as np -import random -import paddle -import paddle.distributed as dist -import paddle.distributed.fleet as fleet -from paddle.fluid.dygraph.container import Sequential from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc from paddle.fluid.dygraph.layers import Layer import paddle.nn as nn -import paddle.fluid as fluid def set_random_seed(seed, dp_id, rank_id): @@ -47,11 +35,11 @@ def set_random_seed(seed, dp_id, rank_id): paddle.seed(seed + dp_id) -batch_size = 4 +batch_size = 8 length = 8 micro_batch_size = 2 vocab_size = 128 -hidden_size = 3 +hidden_size = 16 d_model = hidden_size dim_feedforward = 4 * d_model @@ -69,8 +57,10 @@ def forward(self, x): attention_mask.stop_gradient = True w_emb = self.word_embeddings(x) p_emb = self.position_embeddings(x) + w_emb = w_emb + p_emb - return w_emb, attention_mask, p_emb.detach() + # need to fix bug of backward() + return w_emb, attention_mask class TransformerNet(Layer): @@ -109,12 +99,12 @@ def forward(self, x): class TransformerNetPipe(TransformerNet): def forward(self, args): - x, mask, p_emb = args[0], args[1], args[2] + x, mask = args[0], args[1] output = super().forward(x, mask) - output = output + p_emb + output = output mask.stop_gradient = True - return output, mask, p_emb + return output, mask class CriterionPipe(Layer): @@ -131,7 +121,7 @@ def __init__(self, topology): self.descs = [] self.descs.append(LayerDesc(EmbeddingPipe)) - for x in range(4): + for x in range(5): self.descs.append(LayerDesc(TransformerNetPipe)) self.descs.append(lambda x: x[0]) @@ -180,7 +170,7 @@ def test_pp_model(self): x = paddle.to_tensor(x_data) x.stop_gradient = True loss = model.train_batch([x, x], optimizer, scheduler) - print("loss: ", loss.numpy()) + # TODO(shenliang03) add utest for loss if __name__ == "__main__": From 31d1b66042934279db4772f32ae6161b6f0d7ca9 Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Thu, 29 Jul 2021 20:19:35 +0800 Subject: [PATCH 09/14] send/recv --- .../fleet/meta_parallel/pipeline_parallel.py | 46 ++++++------- .../pp_utils/p2p_communication.py | 69 ++++++++++++------- 2 files changed, 65 insertions(+), 50 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 7c1424d3194949..cf0d45aa1f30a5 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -115,16 +115,15 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): self.micro_batch_id = 0 # Compute number of warmup microbatches. - num_microbatches = self.accumulate_steps - num_warmup_microbatches = (self.num_stages - self.stage_id - 1) - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining = num_microbatches - num_warmup_microbatches + # self.accumulate_steps = self.accumulate_steps + startup_steps = (self.num_stages - self.stage_id - 1) + startup_steps = min(startup_steps, self.accumulate_steps) + steady_steps = self.accumulate_steps - startup_steps input_tensors = [] output_tensors = [] - losses_reduced = [] - for step_id in range(num_warmup_microbatches): + for step_id in range(startup_steps): input_tensor = p2p.recv_forward() self._set_tensor_trainable(input_tensor) @@ -134,11 +133,11 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): input_tensors.append(input_tensor) output_tensors.append(output_tensor) - if num_microbatches_remaining > 0: + if steady_steps > 0: input_tensor = p2p.recv_forward() - for i in range(num_microbatches_remaining): - last_iteration = (i == (num_microbatches_remaining - 1)) + for i in range(steady_steps): + last_iteration = (i == (steady_steps - 1)) self._set_tensor_trainable(input_tensor) output_tensor = self._forward_step(input_tensor) @@ -151,24 +150,23 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): input_tensor, output_tensor = input_tensors.pop( 0), output_tensors.pop(0) - input_tensor_grad = \ - self._backward_step(input_tensor, output_tensor, output_tensor_grad) + input_tensor_grad = self._backward_step(input_tensor, output_tensor, + output_tensor_grad) if last_iteration: input_tensor = None p2p.send_backward(input_tensor_grad) else: - input_tensor = \ - p2p.send_backward_recv_forward(input_tensor_grad) + input_tensor = p2p.send_backward_recv_forward(input_tensor_grad) - for i in range(num_warmup_microbatches): + for i in range(startup_steps): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) output_tensor_grad = p2p.recv_backward() - input_tensor_grad = \ - self._backward_step(input_tensor, output_tensor, output_tensor_grad) + input_tensor_grad = self._backward_step(input_tensor, output_tensor, + output_tensor_grad) p2p.send_backward(input_tensor_grad) self._layers.allreduce_shared_weight_gradients() @@ -241,27 +239,23 @@ def _load_micro_batch(self, cache_id): "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." % (batch_size, self.micro_batch_size, self.accumulate_steps)) - data = [ - input[begin:end, :].clone().detach() for input in inputs[0] - ] + data = [input[begin:end, :].detach() for input in inputs[0]] return tuple(data) else: batch_size = inputs[0].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size - return inputs[0][begin:end, :].clone().detach() + return inputs[0][begin:end, :].detach() elif self.is_last_stage: assert len(inputs) == 2, "length of input should be 2" if isinstance(inputs[1], tuple): batch_size = inputs[1][0].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size - data = [ - input[begin:end, :].clone().detach() for input in inputs[1] - ] + data = [input[begin:end, :].detach() for input in inputs[1]] return tuple(data) else: batch_size = inputs[1].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size - return inputs[1][begin:end, :].clone().detach() + return inputs[1][begin:end, :].detach() else: # No data input is required for other stages inputs = None @@ -269,14 +263,14 @@ def _load_micro_batch(self, cache_id): def _reduce_final_loss(self): if self.is_last_stage: assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" - loss = self.total_loss.clone() + loss = self.total_loss.detach() paddle.distributed.broadcast( loss, src=self.global_rank, use_calc_stream=True, group=self.pp_group) else: - loss = paddle.to_tensor(0.0) + loss = paddle.zeros(shape=[1], dtype="float32") paddle.distributed.broadcast( loss, src=self._hcg.get_rank_from_stage(self.num_stages - 1), diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 930f8203e8c8dd..fffb0f6514f85d 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -143,9 +143,6 @@ def recv_partial(tensor, src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape', tensor.shape) - if nranks > 1: - pass - def allgather_partial(tensor, nranks=1, @@ -153,7 +150,7 @@ def allgather_partial(tensor, group=None, use_calc_stream=True): if nranks == 1: - return + return tensor if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id @@ -181,10 +178,9 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): send_dtype_msg = _send_recv_meta.send_dtype_message # model parallel message - mp_group = _hcg.get_model_parallel_group() - mp_degree = self._hcg.get_model_parallel_world_size() - mp_rank = self._hcg.get_model_parallel_rank() + mp_degree = _hcg.get_model_parallel_world_size() + mp_rank = _hcg.get_model_parallel_rank() if recv_prev: if isinstance(recv_shape_msg, tuple): @@ -222,16 +218,8 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): rank_id=mp_rank, group=_hcg.send_prev_group, use_calc_stream=False) - - # paddle.distributed.send( - # d, dst=0, group=_hcg.send_prev_group, use_calc_stream=False) else: paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) - # paddle.distributed.send( - # tensor_send_prev, - # dst=0, - # group=_hcg.send_prev_group, - # use_calc_stream=False) send_partial( tensor_send_prev, dst=0, @@ -243,8 +231,6 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if tensor_recv_prev is not None: if isinstance(tensor_recv_prev, tuple): for d in tensor_recv_prev: - # paddle.distributed.recv( - # d, src=0, group=_hcg.recv_prev_group, use_calc_stream=True) recv_partial( d, src=0, @@ -259,39 +245,74 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): group=mp_group, use_calc_stream=True) else: - paddle.distributed.recv( + recv_partial( tensor_recv_prev, src=0, + nranks=mp_degree, + rank_id=mp_rank, group=_hcg.recv_prev_group, use_calc_stream=True) + allgather_partial( + tensor_recv_prev, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) if tensor_send_next is not None: if isinstance(tensor_send_next, tuple): for d in tensor_send_next: paddle.distributed.wait(d, use_calc_stream=True) - paddle.distributed.send( - d, dst=1, group=_hcg.send_next_group, use_calc_stream=False) + send_partial( + d, + dst=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_next_group, + use_calc_stream=False) else: paddle.distributed.wait(tensor_send_next, use_calc_stream=True) - paddle.distributed.send( + send_partial( tensor_send_next, dst=1, + nranks=mp_degree, + rank_id=mp_rank, group=_hcg.send_next_group, use_calc_stream=False) if tensor_recv_next is not None: if isinstance(tensor_recv_next, tuple): for d in tensor_recv_next: - paddle.distributed.recv( - d, src=1, group=_hcg.recv_next_group, use_calc_stream=True) + recv_partial( + d, + src=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_next_group, + use_calc_stream=True) + allgather_partial( + d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + else: - paddle.distributed.recv( + recv_partial( tensor_recv_next, src=1, + nranks=mp_degree, + rank_id=mp_rank, group=_hcg.recv_next_group, use_calc_stream=True) + allgather_partial( + tensor_recv_next, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) return tensor_recv_prev, tensor_recv_next From d6abdb72a1a7cbc5b27cddd5a11f12f8d987adcf Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Thu, 29 Jul 2021 20:32:17 +0800 Subject: [PATCH 10/14] send/recv --- .../fleet/meta_parallel/pipeline_parallel.py | 44 ++++++++----------- .../pp_utils/p2p_communication.py | 5 ++- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index cf0d45aa1f30a5..b5447a3833ecda 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -11,13 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -import numpy as np - import paddle import paddle.fluid as fluid from .meta_parallel_base import MetaParallelBase -from .pp_utils.utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype -from .pp_utils import utils +from .pp_utils.utils import is_float_tensor from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters @@ -35,12 +32,9 @@ def __init__(self, layers, hcg, strategy): raise TypeError( "The Layer should be a derived class of PipelineLayer.") super(PipelineParallel, self).__init__(layers, hcg, strategy) - self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1 self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 - self.is_pipe_partitioned = self.use_model_parallel - self.total_loss = None self.micro_batch_size = self._strategy.pipeline_configs[ @@ -57,9 +51,7 @@ def __init__(self, layers, hcg, strategy): self.is_first_stage = self.stage_id == 0 self.is_last_stage = (self.stage_id == (self.num_stages - 1)) self.global_rank = self._hcg.get_global_rank() - - self.mp_degree = self._hcg.get_model_parallel_world_size() - self.mp_rank = self._hcg.get_model_parallel_rank() + self.micro_batch_id = 0 logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( self.num_stages, self.stage_id)) @@ -72,8 +64,6 @@ def __init__(self, layers, hcg, strategy): logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) - self.micro_batch_id = 0 - def _set_tensor_trainable(self, tensor): if tensor is None: return @@ -114,14 +104,16 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): # store data id for micro_batch self.micro_batch_id = 0 - # Compute number of warmup microbatches. - # self.accumulate_steps = self.accumulate_steps + # Next, use the 1f1b scheduling strategy. + # this strategy is inspired by: + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py + startup_steps = (self.num_stages - self.stage_id - 1) startup_steps = min(startup_steps, self.accumulate_steps) steady_steps = self.accumulate_steps - startup_steps - input_tensors = [] - output_tensors = [] + input_buffers = [] + output_buffers = [] for step_id in range(startup_steps): input_tensor = p2p.recv_forward() @@ -130,38 +122,38 @@ def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): output_tensor = self._forward_step(input_tensor) p2p.send_forward(output_tensor) - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) if steady_steps > 0: input_tensor = p2p.recv_forward() for i in range(steady_steps): - last_iteration = (i == (steady_steps - 1)) + last_iter = (i == (steady_steps - 1)) self._set_tensor_trainable(input_tensor) output_tensor = self._forward_step(input_tensor) output_tensor_grad = p2p.send_forward_recv_backward(output_tensor) - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) - input_tensor, output_tensor = input_tensors.pop( - 0), output_tensors.pop(0) + input_tensor, output_tensor = input_buffers.pop( + 0), output_buffers.pop(0) input_tensor_grad = self._backward_step(input_tensor, output_tensor, output_tensor_grad) - if last_iteration: + if last_iter: input_tensor = None p2p.send_backward(input_tensor_grad) else: input_tensor = p2p.send_backward_recv_forward(input_tensor_grad) for i in range(startup_steps): - input_tensor = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) + input_tensor = input_buffers.pop(0) + output_tensor = output_buffers.pop(0) output_tensor_grad = p2p.recv_backward() diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index fffb0f6514f85d..851d26b7ab5441 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -13,12 +13,14 @@ # limitations under the License. import paddle -from .utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype +from .utils import paddle_2_number, number_2_dtype _hcg = None class SendRecvMeta: + """Mainly used to help p2p communication context information""" + def __init__(self): self.send_shape_message = None self.send_dtype_message = None @@ -263,7 +265,6 @@ def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): if isinstance(tensor_send_next, tuple): for d in tensor_send_next: paddle.distributed.wait(d, use_calc_stream=True) - send_partial( d, dst=1, From d8fa25f549c6ce29f64db47d7c96fe50df96c43b Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Fri, 30 Jul 2021 13:16:27 +0800 Subject: [PATCH 11/14] support amp for pp --- .../distributed/fleet/meta_parallel/pipeline_parallel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index b5447a3833ecda..1cec106caec82b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -195,8 +195,10 @@ def _forward_step(self, input_tensor): def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): if self.is_last_stage: assert output_tensor_grad is None - paddle.autograd.backward( - tensors=[output_tensor], grad_tensors=[None]) + if self.scaler: + paddle.autograd.backward(self.scaler.scale(output_tensor)) + else: + paddle.autograd.backward(output_tensor) else: if isinstance(output_tensor, tuple): outputs = [t for t in output_tensor if not t.stop_gradient] From bff38395546af02cf4d3b2fd60de80aa95d23c8b Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Fri, 30 Jul 2021 15:33:47 +0800 Subject: [PATCH 12/14] support amp for pp --- .../paddle/distributed/fleet/base/topology.py | 8 +++---- .../pp_utils/p2p_communication.py | 22 ++++++++++++------- ...test_parallel_dygraph_pipeline_parallel.py | 3 +++ 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 623a3efaf26304..5b8d185212c23c 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -159,11 +159,6 @@ def __init__(self, topology): # create p2p_groups if self._pp_degree > 1: self._set_p2p_group() - print("send_next_group: ", self.send_next_group) - print("send_prev_group: ", self.send_prev_group) - print("recv_next_group: ", self.recv_next_group) - print("recv_prev_group: ", self.recv_prev_group) - debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \ "sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree, @@ -313,6 +308,9 @@ def get_pipe_parallel_world_size(self): def get_pipe_parallel_group(self): return self._pp_comm_group + def get_p2p_groups(self): + return self.send_next_group, self.send_prev_group, self.recv_next_group, self.recv_prev_group + # sharding parallel message: def _get_sharding_parallel_id(self): return self._topo.get_coord(self.global_rank).sharding diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 851d26b7ab5441..18d35351667592 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -14,8 +14,22 @@ import paddle from .utils import paddle_2_number, number_2_dtype +from ...utils import log_util as logger _hcg = None +_send_recv_meta = SendRecvMeta() + + +def initialize_p2p_groups(hcg): + global _hcg + _hcg = hcg + send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups( + ) + + debug_str = "P2pInfo: send_next_group: %s, send_prev_group: %s, " \ + "recv_next_group: %s, recv_prev_group: %s" % (repr(send_next_group), + repr(send_prev_group),repr(recv_next_group), repr(recv_prev_group)) + logger.info(debug_str) class SendRecvMeta: @@ -113,9 +127,6 @@ def set_send_message(self, tensor): [paddle_2_number(d.dtype) for d in tensor]) -_send_recv_meta = SendRecvMeta() - - def send_partial(tensor, dst=0, nranks=1, @@ -162,11 +173,6 @@ def allgather_partial(tensor, 'nranks', nranks, 'rank', rank_id) -def initialize_p2p_groups(hcg): - global _hcg - _hcg = hcg - - def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): global _hcg diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py index 73967782aea2da..d40d54a22c402a 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py @@ -33,6 +33,9 @@ def test_hybrid_parallel_pp_tuple_inputs(self): def test_pipeline_parallel(self): self.run_mnist_2gpu('hybrid_parallel_pp_amp.py') + def test_hybrid_parallel_transformer(self): + self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py') + if __name__ == "__main__": unittest.main() From c15c1412195e46636b9a5e105a6cce10f6ce8a22 Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Fri, 30 Jul 2021 16:16:07 +0800 Subject: [PATCH 13/14] fix bug --- .../fleet/meta_parallel/pp_utils/p2p_communication.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 18d35351667592..83db7744f87733 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -17,7 +17,6 @@ from ...utils import log_util as logger _hcg = None -_send_recv_meta = SendRecvMeta() def initialize_p2p_groups(hcg): @@ -127,6 +126,9 @@ def set_send_message(self, tensor): [paddle_2_number(d.dtype) for d in tensor]) +_send_recv_meta = SendRecvMeta() + + def send_partial(tensor, dst=0, nranks=1, From aa97008af7071869c7a4a779545363c335dde438 Mon Sep 17 00:00:00 2001 From: shenliang03 Date: Fri, 30 Jul 2021 18:50:47 +0800 Subject: [PATCH 14/14] fix logger --- .../fleet/meta_parallel/pp_utils/p2p_communication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 83db7744f87733..e533b2ef3f7a33 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -14,7 +14,7 @@ import paddle from .utils import paddle_2_number, number_2_dtype -from ...utils import log_util as logger +from ...utils.log_util import logger _hcg = None