diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 717c1394a8f..00c7c41333f 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -772,7 +772,8 @@ def batch_fn(data_list: list[DataProto]) -> list[torch.Tensor]: self._init_agent_loop_workers() # Initially we're in sleep mode. - self.sleep() + if self.config.actor_rollout_ref.rollout.free_cache_engine: + self.sleep() def _initialize_llm_servers(self): self.rollout_tp_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 299355313d6..3a95c078407 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -635,13 +635,15 @@ async def rollout_mode(self): for name, param in params.items() ) - await self.rollout.resume(tags=["weights"]) + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["weights"]) log_gpu_memory_usage("After resume weights", logger=logger) await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done) log_gpu_memory_usage("After update_weights", logger=logger) del params, per_tensor_param aggressive_empty_cache(force_sync=True) - await self.rollout.resume(tags=["kv_cache"]) + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["kv_cache"]) log_gpu_memory_usage("After resume kv_cache", logger=logger) self.base_sync_done = True diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index c1fc6b7d9ef..42826731052 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -574,12 +574,14 @@ async def rollout_mode(self): set_expandable_segments(False) - await self.rollout.resume(tags=["weights"]) + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["weights"]) await self.rollout.update_weights(per_tensor_param) if self._is_offload_param: offload_megatron_model_to_cpu(self.actor.actor_module) aggressive_empty_cache(force_sync=True) - await self.rollout.resume(tags=["kv_cache"]) + if self.config.rollout.free_cache_engine: + await self.rollout.resume(tags=["kv_cache"]) # important: need to manually set the random states of each tp to be identical. self.torch_random_states = get_torch_device().get_rng_state() diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index 276d808020c..3414e5bae8d 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -403,6 +403,9 @@ async def resume(self, tags: list[str]): Args: tags: weights or kv_cache. """ + if not self.config.free_cache_engine: + return + if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=tags) else: @@ -411,6 +414,10 @@ async def resume(self, tags: list[str]): async def release(self): """Release weights and kv cache in GPU memory.""" self.inference_engine.reset_prefix_cache() + + if not self.config.free_cache_engine: + return + self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL) async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs): @@ -540,11 +547,13 @@ async def resume(self, tags: list[str]): Args: tags: weights or kv_cache. """ - self.inference_engine.wake_up(tags=tags) + if self.config.free_cache_engine: + self.inference_engine.wake_up(tags=tags) async def release(self): """Release weights and kv cache in GPU memory.""" - self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL) + if self.config.free_cache_engine: + self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL) async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs): """Update the weights of the rollout model.