Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def init_model(self):
processing_class=self.processor if self.processor is not None else self.tokenizer,
)

torch.cuda.empty_cache()

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def save_checkpoint(self, path: str, global_step: int = 0, remove_previous_ckpt: bool = False):
Expand All @@ -415,6 +414,9 @@ def load_checkpoint(self, path: str, remove_ckpt_after_load: bool = False):
dist.barrier()
if self._use_param_offload:
offload_fsdp_model(self.fsdp_module)

if self._use_optimizer_offload:
offload_fsdp_optimizer(self.actor_optimizer)

"""ActorRolloutRefWorker"""

Expand Down Expand Up @@ -456,7 +458,6 @@ def update_actor(self, data: DataProto):
offload_fsdp_optimizer(optimizer=self.optimizer)

output = output.to("cpu")
torch.cuda.empty_cache()
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Expand Down Expand Up @@ -491,7 +492,6 @@ def generate_sequences(self, prompts: DataProto):
print_gpu_memory_usage("After rollout generation")

output = output.to("cpu")
torch.cuda.empty_cache() # clear kv cache
print_gpu_memory_usage("After recompute log prob")
return output

Expand Down Expand Up @@ -522,7 +522,6 @@ def compute_log_prob(self, data: DataProto):
offload_fsdp_model(self.fsdp_module)

output = output.to("cpu")
torch.cuda.empty_cache()
print_gpu_memory_usage("After compute_log_prob")
return output

Expand All @@ -549,7 +548,6 @@ def compute_ref_log_prob(self, data: DataProto):
offload_fsdp_model(self.fsdp_module)

output = output.to("cpu")
torch.cuda.empty_cache()
print_gpu_memory_usage("After compute_ref_log_prob")
return output

Expand All @@ -572,7 +570,6 @@ def compute_values(self, data: DataProto):
offload_fsdp_model(self.fsdp_module)

output = output.to("cpu")
torch.cuda.empty_cache()
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Expand Down Expand Up @@ -608,5 +605,4 @@ def update_critic(self, data: DataProto):
offload_fsdp_optimizer(optimizer=self.optimizer)

output = output.to("cpu")
torch.cuda.empty_cache()
return output
9 changes: 8 additions & 1 deletion verl/workers/sharding_manager/fsdp_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def __init__(
self.gen_random_states = None

def __enter__(self):
# NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
# Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
# to speed up memory allocations.
#
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch.cuda.empty_cache()
print_gpu_memory_usage("Before state_dict() in sharding manager")
actor_weights = self.module.state_dict()
print_gpu_memory_usage("After state_dict() in sharding manager")
Expand All @@ -71,7 +79,6 @@ def __enter__(self):
print_gpu_memory_usage("After sync model weights in sharding manager")

del actor_weights
torch.cuda.empty_cache()
print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
Expand Down