diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 8a31e499c5843c..07d41e5bb5fb13 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -432,6 +432,7 @@ def __init__(self, layers, hcg, strategy): self.loss_fn_idx = 0 self._compute_loss = True + self._return_host_tensor = False self.callbacks = pipeline_parallel_callbacks_ logger.info( @@ -1026,13 +1027,18 @@ def train_batch( return train_loss - def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): + def eval_batch( + self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False + ): self.user_hooks_enabled = False # reset the virtual pp rank for each run self.set_virtual_pipeline_rank(0) self._layers.eval() + origin_compute_loss = self._compute_loss self._compute_loss = compute_loss + origin_return_host_tensor = self._return_host_tensor + self._return_host_tensor = return_host_tensor # store data id for micro_batch self.micro_batch_id = 0 @@ -1051,7 +1057,6 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): startup_steps = min(startup_steps, self.accumulate_steps) steady_steps = self.accumulate_steps - startup_steps - input_buffers = [] output_buffers = [] # convert to micro dataset @@ -1072,8 +1077,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): skip_check_meta=True, batch_p2p_comm=self._use_batch_p2p_comm, ) + if not self.is_pipeline_last_stage(): + self._release_output(output_tensor) + else: + self._offload_tensors(output_tensor) - input_buffers.append(input_tensor) output_buffers.append(output_tensor) if steady_steps > 0: @@ -1094,8 +1102,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): skip_check_meta=True, batch_p2p_comm=self._use_batch_p2p_comm, ) + if not self.is_pipeline_last_stage(): + self._release_output(output_tensor) + else: + self._offload_tensors(output_tensor) - input_buffers.append(input_tensor) output_buffers.append(output_tensor) if not last_iter: @@ -1105,11 +1116,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): ) if self._compute_loss: - self.train_loss = self._broadcast_final_loss() + train_loss = self._broadcast_final_loss() else: - self.train_loss = output_buffers + train_loss = output_buffers - return self.train_loss + self._compute_loss = origin_compute_loss + self._return_host_tensor = origin_return_host_tensor + return train_loss def _maybe_loss_compute( self, output_tensor, micro_dataset, overlap_schedule_mode=False @@ -1424,6 +1437,23 @@ def _optimizer_step(self): if self.lr_scheduler: self.lr_scheduler.step() + def _offload_tensors(self, output_tensor): + if not self._return_host_tensor: + return + if isinstance(output_tensor, (tuple, list)): + for t in output_tensor: + host_tensor = ( + t.pin_memory() if hasattr(t, "pin_memory") else t.cpu() + ) + host_tensor._share_buffer_to(t) + else: + host_tensor = ( + output_tensor.pin_memory() + if hasattr(output_tensor, "pin_memory") + else output_tensor.cpu() + ) + host_tensor._share_buffer_to(output_tensor) + def _release_output(self, output): def can_free(t): return ( @@ -1694,10 +1724,12 @@ def _get_forward_input(self, virtual_pp_rank): assert hasattr(self, 'output_tensors') if not self._forward_only: assert hasattr(self, 'output_tensor_grads') - assert len(self.input_tensors[virtual_pp_rank]) == ( - len(self.output_tensors[virtual_pp_rank]) + 1 - ) - input_tensor = self.input_tensors[virtual_pp_rank][-1] + assert len(self.input_tensors[virtual_pp_rank]) == ( + len(self.output_tensors[virtual_pp_rank]) + 1 + ) + input_tensor = self.input_tensors[virtual_pp_rank][-1] + else: + input_tensor = self.input_tensors[virtual_pp_rank].pop() return input_tensor def _store_forward_outputs( @@ -1712,11 +1744,17 @@ def _store_forward_outputs( self.schedule_chunks[virtual_pp_rank].append(schedule_chunk) if self.is_pipeline_last_stage(): self.loss_fn_chunks.append(loss_fn_node) - - if self._forward_only: + if self._forward_only: + # no need to store tensor for backward + if self._compute_loss: + self.output_tensors[virtual_pp_rank].pop() + # save output_tensors for return value of eval batch + else: + self._offload_tensors(output_tensor) + else: # no need to store tensor for backward - self.input_tensors[virtual_pp_rank].pop() - self.output_tensors[virtual_pp_rank].pop() + if self._forward_only: + self.output_tensors[virtual_pp_rank].pop() def _forward_step_helper( self, @@ -2022,7 +2060,7 @@ def forward_backward_pipeline( # this strategy is inspired by: # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py if not compute_loss: - assert not forward_only, ( + assert forward_only, ( "compute_loss can only be set to False when forward_only is set to True" ) @@ -2669,7 +2707,7 @@ def backward_async_comm( # no steady steps, which only occurs when accumulate_step == num_stage if not steady_steps: - output_tensor_grad = p2p.recv_backward( + output_tensor_grad = self._p2p_helper.recv_backward( self.is_pipeline_last_stage(), batch_p2p_comm=self._use_batch_p2p_comm, ) @@ -2800,12 +2838,14 @@ def backward_async_comm( if self._enable_timer: self.timers("broadcast_final_loss").start() with paddle.amp.auto_cast(enable=False): - train_loss = self._broadcast_final_loss(return_micro_batch_loss) + train_loss_or_logits = self._broadcast_final_loss( + return_micro_batch_loss + ) if self._enable_timer: self.timers("broadcast_final_loss").stop() else: - # else just return all intermediate output tensor for all micro steps - train_loss = self.output_tensors + # else just return logits without loss func calc + train_loss_or_logits = self.output_tensors.pop() if self._clear_every_step_cache: self._p2p_helper.clear_meta_cache() @@ -2823,7 +2863,7 @@ def backward_async_comm( ), "p2p dynamic_cnt should equal to send_recv_meta_list" self._p2p_helper._dynamic_cnt = 0 - return train_loss + return train_loss_or_logits def train_batch( self, @@ -2854,13 +2894,18 @@ def train_batch( return train_loss - def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): + def eval_batch( + self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False + ): self.user_hooks_enabled = False # reset the virtual pp rank for each run self.set_virtual_pipeline_rank(0) self._layers.eval() + origin_compute_loss = self._compute_loss self._compute_loss = compute_loss + origin_return_host_tensor = self._return_host_tensor + self._return_host_tensor = return_host_tensor # check loss_fn_idx is valid and loss_fn exists assert ( @@ -2869,7 +2914,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0): ), f"loss function {loss_fn_idx} should exist to compute loss" self.loss_fn_idx = loss_fn_idx - return self.forward_backward_pipeline(data, None, forward_only=True) + train_loss_or_logits = self.forward_backward_pipeline( + data, None, forward_only=True, compute_loss=compute_loss + ) + self._init_buffers() + self._compute_loss = origin_compute_loss + self._return_host_tensor = origin_return_host_tensor + return train_loss_or_logits def get_static_scheduler(self): return self.forward_backward_pipeline( @@ -2959,7 +3010,7 @@ def forward_backward_pipeline( if self.processed_steps < g_profile_pipeline_details_steps: get_sync_logger().info("start forward_backward_pipeline") if not compute_loss: - assert not forward_only, ( + assert forward_only, ( "compute_loss can only be set to False when forward_only is set to True" ) @@ -2977,7 +3028,7 @@ def forward_backward_pipeline( assert ( self.accumulate_steps == self.num_stages - or self.accumulate_steps % self.num_stages != 0 + or self.accumulate_steps % self.num_stages == 0 ), ( f"accumulate_steps({self.accumulate_steps}) and num_stages({self.num_stages}) should be a multiple or accumulate_steps % num_stages == 0" ) @@ -3108,12 +3159,14 @@ def forward_backward_pipeline( if self._enable_timer: self.timers("broadcast_final_loss").start() with paddle.amp.auto_cast(enable=False): - train_loss = self._broadcast_final_loss(return_micro_batch_loss) + train_loss_or_logits = self._broadcast_final_loss( + return_micro_batch_loss + ) if self._enable_timer: self.timers("broadcast_final_loss").stop() else: - # else just return all intermediate output tensor for all micro steps - train_loss = self.output_tensors + # else just return logits without loss func calc + train_loss_or_logits = self.output_tensors.pop() if self._clear_every_step_cache: self._p2p_helper.clear_meta_cache() @@ -3124,7 +3177,7 @@ def forward_backward_pipeline( get_sync_logger().info("end forward_backward_pipeline") self.processed_steps += 1 self._check_user_hooks_status_at_step_end() - return train_loss + return train_loss_or_logits class OffloadQueue(queue.Queue): @@ -3187,7 +3240,7 @@ def forward_backward_pipeline( ): self._reset_user_hooks_status() if not compute_loss: - assert not forward_only, ( + assert forward_only, ( "compute_loss can only be set to False when forward_only is set to True" ) assert self._using_cache, ( @@ -3462,12 +3515,14 @@ def forward_backward_pipeline( if self._enable_timer: self.timers("broadcast_final_loss").start() with paddle.amp.auto_cast(enable=False): - train_loss = self._broadcast_final_loss(return_micro_batch_loss) + train_loss_or_logits = self._broadcast_final_loss( + return_micro_batch_loss + ) if self._enable_timer: self.timers("broadcast_final_loss").stop() else: - # else just return all intermediate output tensor for all micro steps - train_loss = self.output_tensors + # else just return logits without loss func calc + train_loss_or_logits = self.output_tensors.pop() if self._clear_every_step_cache: self._p2p_helper.clear_meta_cache() @@ -3478,7 +3533,7 @@ def forward_backward_pipeline( get_sync_logger().info("end forward_backward_pipeline") self.processed_steps += 1 self._check_user_hooks_status_at_step_end() - return train_loss + return train_loss_or_logits def tuple_to_dict_helper(input_tensor):