Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,8 @@ def batch_fn(data_list: list[DataProto]) -> list[torch.Tensor]:
self._init_agent_loop_workers()

# Initially we're in sleep mode.
self.sleep()
if self.config.actor_rollout_ref.rollout.free_cache_engine:
self.sleep()

def _initialize_llm_servers(self):
self.rollout_tp_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
Expand Down
6 changes: 4 additions & 2 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,13 +635,15 @@ async def rollout_mode(self):
for name, param in params.items()
)

await self.rollout.resume(tags=["weights"])
if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["weights"])
log_gpu_memory_usage("After resume weights", logger=logger)
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
log_gpu_memory_usage("After update_weights", logger=logger)
del params, per_tensor_param
aggressive_empty_cache(force_sync=True)
await self.rollout.resume(tags=["kv_cache"])
if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["kv_cache"])
log_gpu_memory_usage("After resume kv_cache", logger=logger)

self.base_sync_done = True
Expand Down
6 changes: 4 additions & 2 deletions verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,12 +574,14 @@ async def rollout_mode(self):

set_expandable_segments(False)

await self.rollout.resume(tags=["weights"])
if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["weights"])
await self.rollout.update_weights(per_tensor_param)
if self._is_offload_param:
offload_megatron_model_to_cpu(self.actor.actor_module)
aggressive_empty_cache(force_sync=True)
await self.rollout.resume(tags=["kv_cache"])
if self.config.rollout.free_cache_engine:
await self.rollout.resume(tags=["kv_cache"])

# important: need to manually set the random states of each tp to be identical.
self.torch_random_states = get_torch_device().get_rng_state()
Expand Down
13 changes: 11 additions & 2 deletions verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ async def resume(self, tags: list[str]):
Args:
tags: weights or kv_cache.
"""
if not self.config.free_cache_engine:
return

if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=tags)
else:
Expand All @@ -411,6 +414,10 @@ async def resume(self, tags: list[str]):
async def release(self):
"""Release weights and kv cache in GPU memory."""
self.inference_engine.reset_prefix_cache()

if not self.config.free_cache_engine:
return

self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL)

async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs):
Expand Down Expand Up @@ -540,11 +547,13 @@ async def resume(self, tags: list[str]):
Args:
tags: weights or kv_cache.
"""
self.inference_engine.wake_up(tags=tags)
if self.config.free_cache_engine:
self.inference_engine.wake_up(tags=tags)

async def release(self):
"""Release weights and kv cache in GPU memory."""
self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL)
if self.config.free_cache_engine:
self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL)

async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs):
"""Update the weights of the rollout model.
Expand Down
Loading