Skip to content

Commit 8f73030

Browse files
HollowMan6techkang
authored andcommitted
[worker] fix: respect free_cache_engine flag (volcengine#3442)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. Continuation of volcengine#1464 Now, recent changes have broken the `free_cache_engine` option again. ### Checklist Before Starting - [X] Search for similar PRs. Paste at least one query link here: ... - [X] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. Unit test cases might not be feasible as the `sleep`/`wake_up` call can happen anywhere in the codebase. An end-to-end test might be resource-consuming. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [X] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [X] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Signed-off-by: Hollow Man <[email protected]>
1 parent ddb1f2a commit 8f73030

File tree

4 files changed

+21
-7
lines changed

4 files changed

+21
-7
lines changed

verl/experimental/agent_loop/agent_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,8 @@ def batch_fn(data_list: list[DataProto]) -> list[torch.Tensor]:
772772
self._init_agent_loop_workers()
773773

774774
# Initially we're in sleep mode.
775-
self.sleep()
775+
if self.config.actor_rollout_ref.rollout.free_cache_engine:
776+
self.sleep()
776777

777778
def _initialize_llm_servers(self):
778779
self.rollout_tp_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size

verl/workers/fsdp_workers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,13 +662,15 @@ async def rollout_mode(self):
662662
for name, param in params.items()
663663
)
664664

665-
await self.rollout.resume(tags=["weights"])
665+
if self.config.rollout.free_cache_engine:
666+
await self.rollout.resume(tags=["weights"])
666667
log_gpu_memory_usage("After resume weights", logger=logger)
667668
await self.rollout.update_weights(per_tensor_param, peft_config=peft_config, base_sync_done=self.base_sync_done)
668669
log_gpu_memory_usage("After update_weights", logger=logger)
669670
del params, per_tensor_param
670671
aggressive_empty_cache(force_sync=True)
671-
await self.rollout.resume(tags=["kv_cache"])
672+
if self.config.rollout.free_cache_engine:
673+
await self.rollout.resume(tags=["kv_cache"])
672674
log_gpu_memory_usage("After resume kv_cache", logger=logger)
673675

674676
self.base_sync_done = True

verl/workers/megatron_workers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -574,12 +574,14 @@ async def rollout_mode(self):
574574

575575
set_expandable_segments(False)
576576

577-
await self.rollout.resume(tags=["weights"])
577+
if self.config.rollout.free_cache_engine:
578+
await self.rollout.resume(tags=["weights"])
578579
await self.rollout.update_weights(per_tensor_param)
579580
if self._is_offload_param:
580581
offload_megatron_model_to_cpu(self.actor.actor_module)
581582
aggressive_empty_cache(force_sync=True)
582-
await self.rollout.resume(tags=["kv_cache"])
583+
if self.config.rollout.free_cache_engine:
584+
await self.rollout.resume(tags=["kv_cache"])
583585

584586
# important: need to manually set the random states of each tp to be identical.
585587
self.torch_random_states = get_torch_device().get_rng_state()

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,9 @@ async def resume(self, tags: list[str]):
403403
Args:
404404
tags: weights or kv_cache.
405405
"""
406+
if not self.config.free_cache_engine:
407+
return
408+
406409
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
407410
self.inference_engine.wake_up(tags=tags)
408411
else:
@@ -411,6 +414,10 @@ async def resume(self, tags: list[str]):
411414
async def release(self):
412415
"""Release weights and kv cache in GPU memory."""
413416
self.inference_engine.reset_prefix_cache()
417+
418+
if not self.config.free_cache_engine:
419+
return
420+
414421
self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL)
415422

416423
async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs):
@@ -540,11 +547,13 @@ async def resume(self, tags: list[str]):
540547
Args:
541548
tags: weights or kv_cache.
542549
"""
543-
self.inference_engine.wake_up(tags=tags)
550+
if self.config.free_cache_engine:
551+
self.inference_engine.wake_up(tags=tags)
544552

545553
async def release(self):
546554
"""Release weights and kv cache in GPU memory."""
547-
self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL)
555+
if self.config.free_cache_engine:
556+
self.inference_engine.sleep(level=VLLM_SLEEP_LEVEL)
548557

549558
async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None], **kwargs):
550559
"""Update the weights of the rollout model.

0 commit comments

Comments
 (0)