Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 89 additions & 34 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def __init__(self, layers, hcg, strategy):
self.loss_fn_idx = 0

self._compute_loss = True
self._return_host_tensor = False
self.callbacks = pipeline_parallel_callbacks_

logger.info(
Expand Down Expand Up @@ -1026,13 +1027,18 @@ def train_batch(

return train_loss

def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
def eval_batch(
self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False
):
self.user_hooks_enabled = False
# reset the virtual pp rank for each run
self.set_virtual_pipeline_rank(0)

self._layers.eval()
origin_compute_loss = self._compute_loss
self._compute_loss = compute_loss
origin_return_host_tensor = self._return_host_tensor
self._return_host_tensor = return_host_tensor

# store data id for micro_batch
self.micro_batch_id = 0
Expand All @@ -1051,7 +1057,6 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
startup_steps = min(startup_steps, self.accumulate_steps)
steady_steps = self.accumulate_steps - startup_steps

input_buffers = []
output_buffers = []

# convert to micro dataset
Expand All @@ -1072,8 +1077,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
skip_check_meta=True,
batch_p2p_comm=self._use_batch_p2p_comm,
)
if not self.is_pipeline_last_stage():
self._release_output(output_tensor)
else:
self._offload_tensors(output_tensor)

input_buffers.append(input_tensor)
output_buffers.append(output_tensor)

if steady_steps > 0:
Expand All @@ -1094,8 +1102,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
skip_check_meta=True,
batch_p2p_comm=self._use_batch_p2p_comm,
)
if not self.is_pipeline_last_stage():
self._release_output(output_tensor)
else:
self._offload_tensors(output_tensor)

input_buffers.append(input_tensor)
output_buffers.append(output_tensor)

if not last_iter:
Expand All @@ -1105,11 +1116,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
)

if self._compute_loss:
self.train_loss = self._broadcast_final_loss()
train_loss = self._broadcast_final_loss()
else:
self.train_loss = output_buffers
train_loss = output_buffers

return self.train_loss
self._compute_loss = origin_compute_loss
self._return_host_tensor = origin_return_host_tensor
return train_loss

def _maybe_loss_compute(
self, output_tensor, micro_dataset, overlap_schedule_mode=False
Expand Down Expand Up @@ -1424,6 +1437,23 @@ def _optimizer_step(self):
if self.lr_scheduler:
self.lr_scheduler.step()

def _offload_tensors(self, output_tensor):
if not self._return_host_tensor:
return
if isinstance(output_tensor, (tuple, list)):
for t in output_tensor:
host_tensor = (
t.pin_memory() if hasattr(t, "pin_memory") else t.cpu()
)
host_tensor._share_buffer_to(t)
else:
host_tensor = (
output_tensor.pin_memory()
if hasattr(output_tensor, "pin_memory")
else output_tensor.cpu()
)
host_tensor._share_buffer_to(output_tensor)

def _release_output(self, output):
def can_free(t):
return (
Expand Down Expand Up @@ -1694,10 +1724,12 @@ def _get_forward_input(self, virtual_pp_rank):
assert hasattr(self, 'output_tensors')
if not self._forward_only:
assert hasattr(self, 'output_tensor_grads')
assert len(self.input_tensors[virtual_pp_rank]) == (
len(self.output_tensors[virtual_pp_rank]) + 1
)
input_tensor = self.input_tensors[virtual_pp_rank][-1]
assert len(self.input_tensors[virtual_pp_rank]) == (
len(self.output_tensors[virtual_pp_rank]) + 1
)
input_tensor = self.input_tensors[virtual_pp_rank][-1]
else:
input_tensor = self.input_tensors[virtual_pp_rank].pop()
return input_tensor

def _store_forward_outputs(
Expand All @@ -1712,11 +1744,17 @@ def _store_forward_outputs(
self.schedule_chunks[virtual_pp_rank].append(schedule_chunk)
if self.is_pipeline_last_stage():
self.loss_fn_chunks.append(loss_fn_node)

if self._forward_only:
if self._forward_only:
# no need to store tensor for backward
if self._compute_loss:
self.output_tensors[virtual_pp_rank].pop()
# save output_tensors for return value of eval batch
else:
self._offload_tensors(output_tensor)
else:
# no need to store tensor for backward
self.input_tensors[virtual_pp_rank].pop()
self.output_tensors[virtual_pp_rank].pop()
if self._forward_only:
self.output_tensors[virtual_pp_rank].pop()

def _forward_step_helper(
self,
Expand Down Expand Up @@ -2022,7 +2060,7 @@ def forward_backward_pipeline(
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
if not compute_loss:
assert not forward_only, (
assert forward_only, (
"compute_loss can only be set to False when forward_only is set to True"
)

Expand Down Expand Up @@ -2669,7 +2707,7 @@ def backward_async_comm(

# no steady steps, which only occurs when accumulate_step == num_stage
if not steady_steps:
output_tensor_grad = p2p.recv_backward(
output_tensor_grad = self._p2p_helper.recv_backward(
self.is_pipeline_last_stage(),
batch_p2p_comm=self._use_batch_p2p_comm,
)
Expand Down Expand Up @@ -2800,12 +2838,14 @@ def backward_async_comm(
if self._enable_timer:
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
train_loss_or_logits = self._broadcast_final_loss(
return_micro_batch_loss
)
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
else:
# else just return all intermediate output tensor for all micro steps
train_loss = self.output_tensors
# else just return logits without loss func calc
train_loss_or_logits = self.output_tensors.pop()

if self._clear_every_step_cache:
self._p2p_helper.clear_meta_cache()
Expand All @@ -2823,7 +2863,7 @@ def backward_async_comm(
), "p2p dynamic_cnt should equal to send_recv_meta_list"
self._p2p_helper._dynamic_cnt = 0

return train_loss
return train_loss_or_logits

def train_batch(
self,
Expand Down Expand Up @@ -2854,13 +2894,18 @@ def train_batch(

return train_loss

def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
def eval_batch(
self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False
):
self.user_hooks_enabled = False
# reset the virtual pp rank for each run
self.set_virtual_pipeline_rank(0)

self._layers.eval()
origin_compute_loss = self._compute_loss
self._compute_loss = compute_loss
origin_return_host_tensor = self._return_host_tensor
self._return_host_tensor = return_host_tensor

# check loss_fn_idx is valid and loss_fn exists
assert (
Expand All @@ -2869,7 +2914,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
), f"loss function {loss_fn_idx} should exist to compute loss"
self.loss_fn_idx = loss_fn_idx

return self.forward_backward_pipeline(data, None, forward_only=True)
train_loss_or_logits = self.forward_backward_pipeline(
data, None, forward_only=True, compute_loss=compute_loss
)
self._init_buffers()
self._compute_loss = origin_compute_loss
self._return_host_tensor = origin_return_host_tensor
return train_loss_or_logits

def get_static_scheduler(self):
return self.forward_backward_pipeline(
Expand Down Expand Up @@ -2959,7 +3010,7 @@ def forward_backward_pipeline(
if self.processed_steps < g_profile_pipeline_details_steps:
get_sync_logger().info("start forward_backward_pipeline")
if not compute_loss:
assert not forward_only, (
assert forward_only, (
"compute_loss can only be set to False when forward_only is set to True"
)

Expand All @@ -2977,7 +3028,7 @@ def forward_backward_pipeline(

assert (
self.accumulate_steps == self.num_stages
or self.accumulate_steps % self.num_stages != 0
or self.accumulate_steps % self.num_stages == 0
), (
f"accumulate_steps({self.accumulate_steps}) and num_stages({self.num_stages}) should be a multiple or accumulate_steps % num_stages == 0"
)
Expand Down Expand Up @@ -3108,12 +3159,14 @@ def forward_backward_pipeline(
if self._enable_timer:
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
train_loss_or_logits = self._broadcast_final_loss(
return_micro_batch_loss
)
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
else:
# else just return all intermediate output tensor for all micro steps
train_loss = self.output_tensors
# else just return logits without loss func calc
train_loss_or_logits = self.output_tensors.pop()

if self._clear_every_step_cache:
self._p2p_helper.clear_meta_cache()
Expand All @@ -3124,7 +3177,7 @@ def forward_backward_pipeline(
get_sync_logger().info("end forward_backward_pipeline")
self.processed_steps += 1
self._check_user_hooks_status_at_step_end()
return train_loss
return train_loss_or_logits


class OffloadQueue(queue.Queue):
Expand Down Expand Up @@ -3187,7 +3240,7 @@ def forward_backward_pipeline(
):
self._reset_user_hooks_status()
if not compute_loss:
assert not forward_only, (
assert forward_only, (
"compute_loss can only be set to False when forward_only is set to True"
)
assert self._using_cache, (
Expand Down Expand Up @@ -3462,12 +3515,14 @@ def forward_backward_pipeline(
if self._enable_timer:
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
train_loss_or_logits = self._broadcast_final_loss(
return_micro_batch_loss
)
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
else:
# else just return all intermediate output tensor for all micro steps
train_loss = self.output_tensors
# else just return logits without loss func calc
train_loss_or_logits = self.output_tensors.pop()

if self._clear_every_step_cache:
self._p2p_helper.clear_meta_cache()
Expand All @@ -3478,7 +3533,7 @@ def forward_backward_pipeline(
get_sync_logger().info("end forward_backward_pipeline")
self.processed_steps += 1
self._check_user_hooks_status_at_step_end()
return train_loss
return train_loss_or_logits


def tuple_to_dict_helper(input_tensor):
Expand Down