Skip to content

Commit e894627

Browse files
ellen-m1claude
andcommitted
[rollout] feat: add optional tags parameter to release()
Add tags parameter to release() across all rollout backends (SGLang, vLLM, TRT-LLM), matching the existing resume(tags) signature from verl-project#1911. Callers can now selectively release ["weights"], ["kv_cache"], or both. Co-Authored-By: Claude Opus 4.6 <[email protected]>
1 parent d9d94b4 commit e894627

File tree

6 files changed

+397
-15
lines changed

6 files changed

+397
-15
lines changed
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for the `tags` parameter on rollout release().
15+
16+
Each backend's release() accepts an optional `tags` argument that selects
17+
which GPU resources to release (["weights"], ["kv_cache"], or both).
18+
19+
The shared validation logic lives in `_tag_utils.validate_release_tags()`
20+
and is tested directly (no mocking needed). Backend-specific behavior
21+
(vLLM sleep-level mapping, TRT-LLM tag resolution) is tested via
22+
lightweight mock objects that exercise each backend's release() method
23+
without requiring GPU or distributed infrastructure.
24+
"""
25+
26+
from __future__ import annotations
27+
28+
from unittest.mock import AsyncMock, MagicMock
29+
30+
import pytest
31+
32+
from verl.workers.rollout._tag_utils import validate_release_tags
33+
34+
# ---------------------------------------------------------------------------
35+
# validate_release_tags — shared logic (real code, no mocks)
36+
# ---------------------------------------------------------------------------
37+
38+
39+
class TestValidateReleaseTags:
40+
def test_none_returns_both(self):
41+
assert validate_release_tags(None) == {"kv_cache", "weights"}
42+
43+
def test_weights_only(self):
44+
assert validate_release_tags(["weights"]) == {"weights"}
45+
46+
def test_kv_cache_only(self):
47+
assert validate_release_tags(["kv_cache"]) == {"kv_cache"}
48+
49+
def test_both_explicit(self):
50+
assert validate_release_tags(["kv_cache", "weights"]) == {"kv_cache", "weights"}
51+
52+
def test_duplicates_deduplicated(self):
53+
assert validate_release_tags(["weights", "weights"]) == {"weights"}
54+
55+
def test_unknown_tag_raises(self):
56+
with pytest.raises(ValueError, match="Unknown release tags"):
57+
validate_release_tags(["bogus"])
58+
59+
def test_mixed_valid_and_unknown_raises(self):
60+
with pytest.raises(ValueError, match="Unknown release tags"):
61+
validate_release_tags(["weights", "bogus"])
62+
63+
def test_empty_list_raises(self):
64+
with pytest.raises(ValueError, match="must not be empty"):
65+
validate_release_tags([])
66+
67+
68+
# ---------------------------------------------------------------------------
69+
# Backend-specific release() behavior via mock objects.
70+
#
71+
# These test the exact method logic from each backend's release() without
72+
# importing the actual classes (which require torch, sglang, ray, etc.).
73+
# The validate_release_tags() call is real code — only the async I/O
74+
# (engine calls, server adapters) is mocked.
75+
# ---------------------------------------------------------------------------
76+
77+
78+
async def _sglang_release(self, tags=None):
79+
"""Mirrors SglangRollout.release() — calls validate_release_tags."""
80+
tag_set = validate_release_tags(tags)
81+
await self._init_server_adapter()
82+
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.config.free_cache_engine:
83+
await self._engine.release_memory_occupation(tags=sorted(tag_set))
84+
85+
86+
async def _vllm_release(self, tags=None):
87+
"""Mirrors VllmRollout.release() — calls validate_release_tags."""
88+
tag_set = validate_release_tags(tags)
89+
if not self.config.free_cache_engine:
90+
return
91+
if tag_set == {"kv_cache", "weights"}:
92+
level = self.sleep_level
93+
elif tag_set == {"kv_cache"}:
94+
level = 1
95+
else:
96+
raise NotImplementedError(
97+
f"vLLM release does not support tags={tags!r}; only ['kv_cache', 'weights'] or ['kv_cache'] are supported"
98+
)
99+
await self._execute_method("sleep", kwargs={"level": level})
100+
101+
102+
# TRT-LLM weight tags (from ServerAdapter._WEIGHTS_TAGS)
103+
_TRTLLM_WEIGHTS_TAGS = [
104+
"sampler",
105+
"drafter",
106+
"guided_decoder",
107+
"spec_resource_manager",
108+
"model_extra",
109+
"executor_extra",
110+
"model",
111+
"draft_model",
112+
]
113+
114+
115+
async def _trtllm_release(self, tags=None):
116+
"""Mirrors TrtllmRollout.release() — calls validate_release_tags."""
117+
tag_set = validate_release_tags(tags)
118+
if not self.is_leader_rank or not self.config.free_cache_engine:
119+
return
120+
await self._init_server_adapter()
121+
resolved_tags = []
122+
if "weights" in tag_set:
123+
resolved_tags.extend(_TRTLLM_WEIGHTS_TAGS)
124+
if "kv_cache" in tag_set:
125+
resolved_tags.append("kv_cache")
126+
await self._adapter.release_memory_occupation(tags=resolved_tags)
127+
128+
129+
# ---------------------------------------------------------------------------
130+
# Mock factories
131+
# ---------------------------------------------------------------------------
132+
133+
134+
def _make_sglang_mock():
135+
mock = MagicMock()
136+
mock._init_server_adapter = AsyncMock()
137+
mock._engine = AsyncMock()
138+
mock._engine.release_memory_occupation = AsyncMock(return_value={"status": "ok"})
139+
mock.device_mesh = {"infer_tp": MagicMock(get_local_rank=MagicMock(return_value=0))}
140+
mock.config = MagicMock(free_cache_engine=True)
141+
return mock
142+
143+
144+
def _make_vllm_mock():
145+
mock = MagicMock()
146+
mock.config = MagicMock(free_cache_engine=True)
147+
mock.sleep_level = 2
148+
mock._execute_method = AsyncMock()
149+
return mock
150+
151+
152+
def _make_trtllm_mock():
153+
mock = MagicMock()
154+
mock.is_leader_rank = True
155+
mock.config = MagicMock(free_cache_engine=True)
156+
mock._init_server_adapter = AsyncMock()
157+
mock._adapter = AsyncMock()
158+
mock._adapter.release_memory_occupation = AsyncMock(return_value={"status": "ok"})
159+
return mock
160+
161+
162+
# ---------------------------------------------------------------------------
163+
# SGLang tests
164+
# ---------------------------------------------------------------------------
165+
166+
167+
class TestSglangReleaseTags:
168+
@pytest.mark.asyncio
169+
async def test_default_releases_both(self):
170+
mock = _make_sglang_mock()
171+
await _sglang_release(mock)
172+
mock._engine.release_memory_occupation.assert_called_once_with(tags=["kv_cache", "weights"])
173+
174+
@pytest.mark.asyncio
175+
async def test_weights_only(self):
176+
mock = _make_sglang_mock()
177+
await _sglang_release(mock, tags=["weights"])
178+
mock._engine.release_memory_occupation.assert_called_once_with(tags=["weights"])
179+
180+
@pytest.mark.asyncio
181+
async def test_kv_cache_only(self):
182+
mock = _make_sglang_mock()
183+
await _sglang_release(mock, tags=["kv_cache"])
184+
mock._engine.release_memory_occupation.assert_called_once_with(tags=["kv_cache"])
185+
186+
@pytest.mark.asyncio
187+
async def test_unknown_tag_raises(self):
188+
mock = _make_sglang_mock()
189+
with pytest.raises(ValueError, match="Unknown release tags"):
190+
await _sglang_release(mock, tags=["bogus"])
191+
192+
@pytest.mark.asyncio
193+
async def test_free_cache_disabled_is_noop(self):
194+
mock = _make_sglang_mock()
195+
mock.config.free_cache_engine = False
196+
await _sglang_release(mock)
197+
mock._engine.release_memory_occupation.assert_not_called()
198+
199+
200+
# ---------------------------------------------------------------------------
201+
# vLLM tests
202+
# ---------------------------------------------------------------------------
203+
204+
205+
class TestVllmReleaseTags:
206+
@pytest.mark.asyncio
207+
async def test_default_releases_both(self):
208+
mock = _make_vllm_mock()
209+
await _vllm_release(mock)
210+
mock._execute_method.assert_called_once_with("sleep", kwargs={"level": 2})
211+
212+
@pytest.mark.asyncio
213+
async def test_kv_cache_only(self):
214+
mock = _make_vllm_mock()
215+
await _vllm_release(mock, tags=["kv_cache"])
216+
mock._execute_method.assert_called_once_with("sleep", kwargs={"level": 1})
217+
218+
@pytest.mark.asyncio
219+
async def test_weights_only_not_supported(self):
220+
mock = _make_vllm_mock()
221+
with pytest.raises(NotImplementedError):
222+
await _vllm_release(mock, tags=["weights"])
223+
224+
@pytest.mark.asyncio
225+
async def test_unknown_tag_raises_value_error(self):
226+
mock = _make_vllm_mock()
227+
with pytest.raises(ValueError, match="Unknown release tags"):
228+
await _vllm_release(mock, tags=["bogus"])
229+
230+
@pytest.mark.asyncio
231+
async def test_free_cache_disabled_is_noop(self):
232+
mock = _make_vllm_mock()
233+
mock.config.free_cache_engine = False
234+
await _vllm_release(mock) # valid tags, but free_cache_engine=False → noop
235+
mock._execute_method.assert_not_called()
236+
237+
@pytest.mark.asyncio
238+
async def test_free_cache_disabled_still_validates(self):
239+
mock = _make_vllm_mock()
240+
mock.config.free_cache_engine = False
241+
with pytest.raises(ValueError, match="Unknown release tags"):
242+
await _vllm_release(mock, tags=["bogus"])
243+
244+
245+
# ---------------------------------------------------------------------------
246+
# TRT-LLM tests
247+
# ---------------------------------------------------------------------------
248+
249+
250+
class TestTrtllmReleaseTags:
251+
@pytest.mark.asyncio
252+
async def test_default_releases_both(self):
253+
mock = _make_trtllm_mock()
254+
await _trtllm_release(mock)
255+
call_tags = mock._adapter.release_memory_occupation.call_args.kwargs["tags"]
256+
assert "kv_cache" in call_tags
257+
for wt in _TRTLLM_WEIGHTS_TAGS:
258+
assert wt in call_tags
259+
260+
@pytest.mark.asyncio
261+
async def test_weights_only(self):
262+
mock = _make_trtllm_mock()
263+
await _trtllm_release(mock, tags=["weights"])
264+
call_tags = mock._adapter.release_memory_occupation.call_args.kwargs["tags"]
265+
assert "kv_cache" not in call_tags
266+
for wt in _TRTLLM_WEIGHTS_TAGS:
267+
assert wt in call_tags
268+
269+
@pytest.mark.asyncio
270+
async def test_kv_cache_only(self):
271+
mock = _make_trtllm_mock()
272+
await _trtllm_release(mock, tags=["kv_cache"])
273+
call_tags = mock._adapter.release_memory_occupation.call_args.kwargs["tags"]
274+
assert call_tags == ["kv_cache"]
275+
276+
@pytest.mark.asyncio
277+
async def test_unknown_tag_raises(self):
278+
mock = _make_trtllm_mock()
279+
with pytest.raises(ValueError, match="Unknown release tags"):
280+
await _trtllm_release(mock, tags=["bogus"])
281+
282+
@pytest.mark.asyncio
283+
async def test_non_leader_is_noop(self):
284+
mock = _make_trtllm_mock()
285+
mock.is_leader_rank = False
286+
await _trtllm_release(mock, tags=["weights"])
287+
mock._adapter.release_memory_occupation.assert_not_called()

verl/workers/rollout/_tag_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Lightweight tag validation utilities for rollout release/resume.
15+
16+
This module has zero heavy dependencies (no torch, ray, etc.) so it can
17+
be imported in unit tests without GPU or distributed infrastructure.
18+
"""
19+
20+
from __future__ import annotations
21+
22+
_VALID_RELEASE_TAGS = frozenset({"kv_cache", "weights"})
23+
_DEFAULT_RELEASE_TAGS = ("kv_cache", "weights")
24+
25+
26+
def validate_release_tags(tags: list[str] | None) -> set[str]:
27+
"""Normalize and validate release tags.
28+
29+
Args:
30+
tags: List of tags to release, or None for the default (both).
31+
32+
Returns:
33+
A set of validated tags.
34+
35+
Raises:
36+
ValueError: If any tag is not in {"kv_cache", "weights"}.
37+
"""
38+
if tags is None:
39+
return set(_DEFAULT_RELEASE_TAGS)
40+
tag_set = set(tags)
41+
if not tag_set:
42+
raise ValueError("release tags must not be empty; pass None to release all")
43+
unknown = tag_set - _VALID_RELEASE_TAGS
44+
if unknown:
45+
raise ValueError(f"Unknown release tags: {unknown!r}; expected subset of {sorted(_VALID_RELEASE_TAGS)}")
46+
return tag_set

verl/workers/rollout/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,13 @@ async def update_weights(
6464
pass
6565

6666
@abstractmethod
67-
async def release(self):
68-
"""Release weights and kv cache in GPU memory."""
67+
async def release(self, tags: list[str] | None = None):
68+
"""Release weights and/or kv cache in GPU memory.
69+
70+
Args:
71+
tags: List of tags to release, e.g. ["weights"], ["kv_cache"], or
72+
["kv_cache", "weights"]. If None (default), releases both.
73+
"""
6974
pass
7075

7176
def generate_sequences(self, prompts: DataProto) -> DataProto:

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,19 @@ async def resume(self, tags: list[str]):
174174
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.config.free_cache_engine:
175175
await self._engine.resume_memory_occupation(tags=tags)
176176

177-
async def release(self):
178-
"""Release weights and kv cache in GPU memory."""
177+
async def release(self, tags: list[str] | None = None):
178+
"""Release weights and/or kv cache in GPU memory.
179+
180+
Args:
181+
tags: List of tags to release, e.g. ["weights"], ["kv_cache"], or
182+
["kv_cache", "weights"]. If None (default), releases both.
183+
"""
184+
from verl.workers.rollout._tag_utils import validate_release_tags
185+
186+
tag_set = validate_release_tags(tags)
179187
await self._init_server_adapter()
180188
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.config.free_cache_engine:
181-
await self._engine.release_memory_occupation(tags=["kv_cache", "weights"])
189+
await self._engine.release_memory_occupation(tags=sorted(tag_set))
182190

183191
async def update_weights(
184192
self, weights: Generator[tuple[str, torch.Tensor], None, None], global_steps: int = None, **kwargs

0 commit comments

Comments
 (0)