Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions recipe/prime/prime_fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,6 @@ def init_model(self):
lr_scheduler=self.reward_lr_scheduler,
tokenizer=self.tokenizer)

torch.cuda.empty_cache()

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
data = data.to('cuda')
Expand Down Expand Up @@ -321,7 +319,6 @@ def update_rm(self, data: DataProto):
offload_fsdp_model_to_cpu(self.ref_module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.reward_optimizer)
torch.cuda.empty_cache()
output = output.to('cpu')
return output

Expand Down
8 changes: 0 additions & 8 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,6 @@ def init_model(self):
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer)

torch.cuda.empty_cache()

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
# Support all hardwares
Expand Down Expand Up @@ -498,7 +496,6 @@ def generate_sequences(self, prompts: DataProto):
output = output.to('cpu')

# clear kv cache
torch.cuda.empty_cache()
log_gpu_memory_usage('After recompute log prob', logger=logger)
return output

Expand Down Expand Up @@ -561,7 +558,6 @@ def compute_ref_log_prob(self, data: DataProto):
if self.world_size > 1:
self.ref_policy.actor_module._handle.reshard(True)

torch.cuda.empty_cache()
return output

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
Expand Down Expand Up @@ -784,8 +780,6 @@ def init_model(self):
lr_scheduler=self.critic_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer)

torch.cuda.empty_cache()

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):

Expand Down Expand Up @@ -981,7 +975,6 @@ def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))
self.reward_module = self._build_model(config=self.config)
torch.cuda.empty_cache()

def _forward_micro_batch(self, micro_batch):
from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange
Expand Down Expand Up @@ -1154,5 +1147,4 @@ def compute_rm_score(self, data: DataProto):
self.reward_module._handle.reshard(True)

output = output.to('cpu')
torch.cuda.empty_cache()
return output
10 changes: 9 additions & 1 deletion verl/workers/sharding_manager/fsdp_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def __init__(self,
self.gen_random_states = None

def __enter__(self):
# NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
# Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
# to speed up memory allocations.
#
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch.cuda.empty_cache()

log_gpu_memory_usage('Before state_dict() in sharding manager memory', logger=logger)
params = self.module.state_dict()
log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger)
Expand All @@ -89,7 +98,6 @@ def __enter__(self):
log_gpu_memory_usage('After sync model weights in sharding manager', logger=logger)

del params
torch.cuda.empty_cache()
log_gpu_memory_usage('After del state_dict and empty_cache in sharding manager', logger=logger)

# TODO: offload FSDP model weights
Expand Down