Skip to content

Commit 76298ad

Browse files
authored
[recipe] feat: Add sleep/wakeup mode for gen rm vllm service and add tqdm showing process (#2739)
### What does this PR do? Add sleep/wakeup mode for gen rm vllm service and add tqdm showing process. This capability is particularly beneficial when the model server shares resources with a training workload on the same machine. It allows the reward model service to be temporarily offloaded (to free up GPU memory) during intensive training sessions and reloaded when the service is required again.
1 parent d640f99 commit 76298ad

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

recipe/genrm_remote/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,16 @@
77
Deploy the pretrained GenRM model using vLLM. Skip this step if you want to use an external api service.
88

99
```bash
10-
vllm serve verl-team/GenRM-CI-Test-1.5B --served-model-name genrm-demo
10+
VLLM_SERVER_DEV_MODE=1 vllm serve verl-team/GenRM-CI-Test-1.5B --served-model-name genrm-demo --enable-sleep-mode --dtype float32
1111
```
12+
Note that the wake_up and sleep operations for managing CUDA memory in vLLM are only available when both `VLLM_SERVER_DEV_MODE=1` and `enable_sleep_mode` are set. This capability is particularly beneficial when the model server shares resources with a training workload on the same machine. It allows the reward model service to be temporarily offloaded (to free up GPU memory) during intensive training sessions and reloaded when the service is required again. The relevant vllm code implementation can be found below:
13+
14+
[VLLM_SERVER_DEV_MODE](https://github.com/vllm-project/vllm/blob/5a19a6c6705fe83db2e3517a2d2f473586901743/vllm/entrypoints/openai/api_server.py#L971)
15+
16+
[sleep and wake_up mode](https://github.com/vllm-project/vllm/blob/5a19a6c6705fe83db2e3517a2d2f473586901743/vllm/entrypoints/openai/api_server.py#L994-L1003)
17+
18+
When the backend is configured as `SERVER_BACKEND`="VLLM", the `USE_OFFLOAD` flag can be toggled between True and False.(see `reward_function.py`)
19+
1220

1321
### Step 2: Perform RL using GenRM
1422

recipe/genrm_remote/reward_function.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from concurrent.futures import ThreadPoolExecutor
15+
import random
16+
import time
17+
from concurrent.futures import ThreadPoolExecutor, as_completed
1618
from time import sleep
1719

1820
import requests
21+
import tqdm
1922

2023
from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed
2124

25+
SERVER_BACKEND = "VLLM"
26+
USE_OFFLOAD = True
2227
BASE_URL = "http://localhost:30000"
2328
API_KEY = "EMPTY"
2429
MAX_RETRIES = 3
@@ -42,6 +47,13 @@
4247
""".strip()
4348

4449

50+
def vllm_execute_method(task="sleep"):
51+
assert task in ["sleep", "wake_up"], f"Invalid task: {task}"
52+
url_root = BASE_URL
53+
response = requests.post(url_root + "/" + task)
54+
assert response.status_code == 200
55+
56+
4557
def get_response(problem, solution_str, ground_truth):
4658
prompt = GENRM_PROMPT_TEMPLATE.format(problem=problem, solution=solution_str)
4759
messages = [{"role": "user", "content": prompt}]
@@ -77,14 +89,14 @@ def compute_reward(response):
7789
return reward_score
7890

7991

80-
def compute_score(data_source, solution_str, ground_truth, extra_info):
92+
def compute_score(data_source, solution_str, ground_truth, extra_info, index):
8193
split = extra_info["split"]
8294
from verl.utils.reward_score import default_compute_score
8395

8496
func_rm_score = default_compute_score(data_source, solution_str, ground_truth, extra_info)
8597

8698
if split == "test":
87-
return func_rm_score
99+
return func_rm_score, index
88100
else:
89101
problem = extra_info["question"]
90102
response = get_response(problem, solution_str, ground_truth)
@@ -93,18 +105,29 @@ def compute_score(data_source, solution_str, ground_truth, extra_info):
93105
else:
94106
reward_score = 0.0
95107

96-
return reward_score
108+
return reward_score, index
97109

98110

99111
def compute_score_batch(data_sources, solution_strs, ground_truths, extra_infos):
112+
results = []
113+
indexes = list(range(len(data_sources)))
114+
if SERVER_BACKEND == "VLLM" and USE_OFFLOAD:
115+
vllm_execute_method("wake_up")
116+
100117
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
101118
futures = []
102-
for data_source, solution_str, ground_truth, extra_info in zip(
103-
data_sources, solution_strs, ground_truths, extra_infos, strict=True
119+
for data_source, solution_str, ground_truth, extra_info, index in zip(
120+
data_sources, solution_strs, ground_truths, extra_infos, indexes, strict=True
104121
):
105-
future = executor.submit(compute_score, data_source, solution_str, ground_truth, extra_info)
122+
future = executor.submit(compute_score, data_source, solution_str, ground_truth, extra_info, index)
123+
time.sleep(0.001 * random.random())
106124
futures.append(future)
107125

108-
results = [future.result() for future in futures]
126+
for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
127+
results.append(future.result())
128+
results = sorted(results, key=lambda x: x[-1], reverse=False)
129+
results = [result[0] for result in results]
109130

131+
if SERVER_BACKEND == "VLLM" and USE_OFFLOAD:
132+
vllm_execute_method("sleep")
110133
return results

0 commit comments

Comments
 (0)