Skip to content
Merged
11 changes: 7 additions & 4 deletions recipe/spin/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import warnings

import numpy as np
import psutil
import torch
import torch.distributed
Expand Down Expand Up @@ -483,11 +484,13 @@ def _switch_chat_template(self, data: DataProto):
rm_attention_mask = []

for i in range(data.batch.batch_size[0]):
if not isinstance(data.non_tensor_batch["raw_prompt"][i], list | np.ndarray):
raise TypeError(
f"raw_prompt must be a list or numpy array, got {type(data.non_tensor_batch['raw_prompt'][i])}"
)

# extract raw prompt
if isinstance(data.non_tensor_batch["raw_prompt"][i], list):
chat: list = data.non_tensor_batch["raw_prompt"][i]
else:
chat: list = data.non_tensor_batch["raw_prompt"][i].tolist()
chat: list = list(data.non_tensor_batch["raw_prompt"][i])

# extract response
response_ids = data.batch["responses"][i]
Expand Down
72 changes: 72 additions & 0 deletions tests/utils/dataset/test_rl_collate_fn_on_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


def test_rl_collate_fn():
from verl.utils.dataset.rl_dataset import collate_fn

max_prompt_length = 5

test_data = [
{
# test tensor
"input_ids": torch.randint(0, 10, (max_prompt_length,)),
# test fixed length (1) list within a batch
"messages": [{"role": "user", "content": "Hi."}],
# test variable length list within a batch
"raw_prompt_ids": [1, 2, 3, 4],
# test string
"ability": "math",
# test dict
"reward_model": {"ground_truth": 5, "style": "rule"},
# test empty dict
"tools_kwargs": {},
},
{
"input_ids": torch.randint(0, 10, (max_prompt_length,)),
"messages": [{"role": "user", "content": "Hello."}],
"raw_prompt_ids": [1, 2, 3],
"ability": "toolcall",
"reward_model": {
"ground_truth": '[{"name": "rgb_to_cmyk", "arguments": {"r": 0, "g": 0, "b": 255}}]',
"style": "rule",
},
"tools_kwargs": {},
},
]

batch_size = len(test_data)
batch = collate_fn(test_data)

# Tensor part
assert batch["input_ids"].shape == (batch_size, max_prompt_length)
assert isinstance(batch["input_ids"], torch.Tensor)

# Non-tensor parts
expected_types = {
"messages": list,
"raw_prompt_ids": list,
"ability": str,
"reward_model": dict,
"tools_kwargs": dict,
}

for key, dtype in expected_types.items():
assert batch[key].shape == (batch_size,), (
f"Expected shape {(batch_size,)} for '{key}', but got {batch[key].shape}"
)
assert isinstance(batch[key][0], dtype), (
f"'{key}' should contain elements of type {dtype}, but got {type(batch[key][0])}"
)
5 changes: 4 additions & 1 deletion verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,11 @@ async def generate_sequences(self, batch: DataProto) -> DataProto:
)

for agent_name, messages, trajectory in zip(agent_names, raw_prompts, trajectory_info, strict=True):
if not isinstance(messages, list | np.ndarray):
raise TypeError(f"messages must be a list or numpy array, got {type(messages)}")

tasks.append(
asyncio.create_task(self._run_agent_loop(agent_name, messages.tolist(), sampling_params, trajectory))
asyncio.create_task(self._run_agent_loop(agent_name, list(messages), sampling_params, trajectory))
)
outputs = await asyncio.gather(*tasks)

Expand Down
2 changes: 1 addition & 1 deletion verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def collate_fn(data_list: list[dict]) -> dict:
tensors[key] = torch.stack(val, dim=0)

for key, val in non_tensors.items():
non_tensors[key] = np.array(val, dtype=object)
non_tensors[key] = np.fromiter(val, dtype=object, count=len(val))

return {**tensors, **non_tensors}

Expand Down
11 changes: 7 additions & 4 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dataclasses import asdict
from typing import Any

import numpy as np
import psutil
import torch
import torch.distributed
Expand Down Expand Up @@ -1526,11 +1527,13 @@ def _switch_chat_template(self, data: DataProto):
rm_attention_mask = []

for i in range(data.batch.batch_size[0]):
if not isinstance(data.non_tensor_batch["raw_prompt"][i], list | np.ndarray):
raise TypeError(
f"raw_prompt must be a list or numpy array, got {type(data.non_tensor_batch['raw_prompt'][i])}"
)

# extract raw prompt
if isinstance(data.non_tensor_batch["raw_prompt"][i], list):
chat: list = data.non_tensor_batch["raw_prompt"][i]
else:
chat: list = data.non_tensor_batch["raw_prompt"][i].tolist()
chat: list = list(data.non_tensor_batch["raw_prompt"][i])

# extract response
response_ids = data.batch["responses"][i]
Expand Down
13 changes: 8 additions & 5 deletions verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,15 +660,15 @@ def _batch_level_generate_sequences(self, prompts: DataProto, **kwargs) -> DataP
{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
]

# Ensure token IDs are lists or numpy arrays
for input_data in sglang_inputs:
if isinstance(input_data["prompt_token_ids"], np.ndarray):
input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist()
elif not isinstance(input_data["prompt_token_ids"], list):
# Ensure token IDs are lists or numpy arrays
if not isinstance(input_data["prompt_token_ids"], list | np.ndarray):
raise TypeError(
f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}"
)

input_data["prompt_token_ids"] = list(input_data["prompt_token_ids"])

# Extract token IDs and image data for SGLang Engine
idx_list = [input_data["prompt_token_ids"] for input_data in sglang_inputs]
image_list = [input_data.get("image_data", None) for input_data in sglang_inputs]
Expand Down Expand Up @@ -1266,12 +1266,15 @@ def _preprocess_prompt_to_async_rollout_requests(self, prompts: DataProto, n: in
else:
_interaction_kwargs = {}

if not isinstance(raw_prompt, list | np.ndarray):
raise TypeError(f"raw_prompt must be a list or numpy array, got {type(raw_prompt)}")

req = AsyncRolloutRequest(
batch_data_id=data_idx,
rollout_offset=0,
request_id=str(uuid4()),
state=AsyncRolloutRequestStateEnum.PENDING,
messages=raw_prompt.tolist(),
messages=list(raw_prompt),
multi_modal_data=multi_modal_data,
tool_schemas=_tool_schemas,
tools_kwargs=_tools_kwargs,
Expand Down
9 changes: 4 additions & 5 deletions verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,15 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
{"prompt_token_ids": raw_prompt_ids} for raw_prompt_ids in non_tensor_batch.pop("raw_prompt_ids")
]

# ensure the type of `prompt_token_ids` passed to vllm is list[int]
# https://github.com/volcengine/verl/pull/772
for input_data in vllm_inputs:
if isinstance(input_data["prompt_token_ids"], np.ndarray):
input_data["prompt_token_ids"] = input_data["prompt_token_ids"].tolist()
elif not isinstance(input_data["prompt_token_ids"], list):
# Ensure token IDs are lists or numpy arrays
if not isinstance(input_data["prompt_token_ids"], list | np.ndarray):
raise TypeError(
f"prompt_token_ids must be a list or numpy array, got {type(input_data['prompt_token_ids'])}"
)

input_data["prompt_token_ids"] = list(input_data["prompt_token_ids"])

do_sample = prompts.meta_info.get("do_sample", True)
is_validate = prompts.meta_info.get("validate", False)
if not do_sample:
Expand Down
Loading