Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ EasyR1 is efficient and scalable due to the design of **[HybirdEngine](https://a
### Software Requirements

- Python 3.9+
- transformers>=4.49.0
- transformers>=4.51.0
- flash-attn>=2.4.3
- vllm>=0.8.3

Expand Down
2 changes: 1 addition & 1 deletion examples/baselines/qwen2_5_vl_3b_clevr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ python3 -m verl.trainer.main \
data.format_prompt="${FORMAT_PROMPT}" \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \
worker.reward.score_function=r1v \
worker.reward.score_function=./examples/score_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_clevr \
trainer.n_gpus_per_node=2
2 changes: 1 addition & 1 deletion examples/baselines/qwen2_5_vl_3b_geoqa8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ python3 -m verl.trainer.main \
data.format_prompt="${FORMAT_PROMPT}" \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \
worker.reward.score_function=r1v \
worker.reward.score_function=./examples/score_function/r1v.py:compute_score \
trainer.experiment_name=qwen2_5_vl_3b_geoqa8k \
trainer.n_gpus_per_node=8
6 changes: 3 additions & 3 deletions examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ worker:
offload_optimizer: true # true: more CPU memory; false: more GPU memory

rollout:
temperature: 1.0
n: 5
temperature: 1.0
top_p: 0.99
gpu_memory_utilization: 0.6
enforce_eager: false
enable_chunked_prefill: false
Expand All @@ -68,8 +69,7 @@ worker:

reward:
reward_type: function
score_function: math
skip_special_tokens: true
score_function: ./examples/score_function/math.py:compute_score

trainer:
total_episodes: 15
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,23 @@
from mathruler.grader import extract_boxed_content, grade_answer


def math_format_reward(predict_str: str) -> float:
def format_reward(predict_str: str) -> float:
pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0


def math_acc_reward(predict_str: str, ground_truth: str) -> float:
def accuracy_reward(predict_str: str, ground_truth: str) -> float:
answer = extract_boxed_content(predict_str)
return 1.0 if grade_answer(answer, ground_truth) else 0.0


def math_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.1) -> Dict[str, float]:
predict_str = re.sub(r"\s*(<|>|/)\s*", r"\1", predict_str) # handle qwen2.5vl-32b format
format = math_format_reward(predict_str)
accuracy = math_acc_reward(predict_str, ground_truth)
format_score = format_reward(predict_str)
accuracy_score = accuracy_reward(predict_str, ground_truth)
return {
"overall": 0.9 * accuracy + 0.1 * format,
"format": format,
"accuracy": accuracy,
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
"format": format_score,
"accuracy": accuracy_score,
}
19 changes: 9 additions & 10 deletions verl/utils/reward_score/r1v.py → examples/score_function/r1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@
from mathruler.grader import grade_answer


def r1v_format_reward(predict_str: str) -> float:
def format_reward(predict_str: str) -> float:
pattern = re.compile(r"<think>.*?</think>\s*<answer>.*?</answer>", re.DOTALL)
format_match = re.fullmatch(pattern, predict_str)
return 1.0 if format_match else 0.0


def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float:
def accuracy_reward(predict_str: str, ground_truth: str) -> float:
try:
ground_truth = ground_truth.strip()
content_match = re.search(r"<answer>(.*?)</answer>", predict_str)
given_answer = content_match.group(1).strip() if content_match else predict_str.strip()
if grade_answer(given_answer, ground_truth):
if grade_answer(given_answer, ground_truth.strip()):
return 1.0

except Exception:
Expand All @@ -38,11 +37,11 @@ def r1v_accuracy_reward(predict_str: str, ground_truth: str) -> float:
return 0.0


def r1v_compute_score(predict_str: str, ground_truth: str) -> Dict[str, float]:
format = r1v_format_reward(predict_str)
accuracy = r1v_accuracy_reward(predict_str, ground_truth)
def compute_score(predict_str: str, ground_truth: str, format_weight: float = 0.5) -> Dict[str, float]:
format_score = format_reward(predict_str)
accuracy_score = accuracy_reward(predict_str, ground_truth)
return {
"overall": 0.5 * accuracy + 0.5 * format,
"format": format,
"accuracy": accuracy,
"overall": (1 - format_weight) * accuracy_score + format_weight * format_score,
"format": format_score,
"accuracy": accuracy_score,
}
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ qwen-vl-utils
ray[default]
tensordict
torchdata
transformers>=4.49.0
transformers>=4.51.0
vllm>=0.7.3
wandb
32 changes: 16 additions & 16 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def fold_batch_dim(data: "DataProto", new_batch_size: int):
tensor = tensor.view(new_batch_size, -1)
tensor.auto_batch_size_(batch_dims=1)

for key, val in non_tensor.items():
non_tensor[key] = np.reshape(val, newshape=(new_batch_size, -1, *val.shape[1:]))
for key, value in non_tensor.items():
non_tensor[key] = np.reshape(value, newshape=(new_batch_size, -1, *value.shape[1:]))

return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info)

Expand Down Expand Up @@ -182,14 +182,14 @@ def __len__(self) -> int:
if self.batch is not None:
return self.batch.batch_size[0]
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
random_key = list(self.non_tensor_batch.keys())[0]
return self.non_tensor_batch[random_key].shape[0]
pivot_key = list(self.non_tensor_batch.keys())[0]
return self.non_tensor_batch[pivot_key].shape[0]
else:
return 0

def __getitem__(self, item: Union[int, slice]) -> Union["DataProto", "DataProtoItem"]:
tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
non_tensor_data = {key: value[item] for key, value in self.non_tensor_batch.items()}
return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)

Expand Down Expand Up @@ -223,9 +223,10 @@ def load_from_disk(filepath: str) -> "DataProto":

def print_size(self, prefix: str = "") -> None:
size_of_tensordict = 0
for tensor in self.batch.values():
if isinstance(tensor, torch.Tensor):
size_of_tensordict += tensor.element_size() * tensor.numel()
if self.batch is not None:
for tensor in self.batch.values():
if isinstance(tensor, torch.Tensor):
size_of_tensordict += tensor.element_size() * tensor.numel()

size_of_numpy_array = 0
for value in self.non_tensor_batch.values():
Expand All @@ -249,17 +250,16 @@ def check_consistency(self):
assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty."

batch_size = self.batch.batch_size[0]
for key, val in self.non_tensor_batch.items():
assert len(val) == batch_size, f"key {key} length {len(val)} is not equal to batch size {batch_size}."
for key, value in self.non_tensor_batch.items():
assert len(value) == batch_size, f"key {key} length {len(value)} is not equal to bsz {batch_size}."

@classmethod
def from_single_dict(
cls,
data: Dict[str, Union[torch.Tensor, NDArray]],
meta_info: Optional[Dict[str, Any]] = None,
) -> "DataProto":
tensors = {}
non_tensors = {}
tensors, non_tensors = {}, {}
for key, value in data.items():
if isinstance(value, torch.Tensor):
tensors[key] = value
Expand Down Expand Up @@ -551,7 +551,7 @@ def reorder(self, indices: torch.Tensor) -> None:
"""
indices_np = indices.detach().numpy()
self.batch = self.batch[indices]
self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}
self.non_tensor_batch = {key: value[indices_np] for key, value in self.non_tensor_batch.items()}

def repeat(self, repeat_times: int = 2, interleave: bool = True) -> "DataProto":
"""
Expand Down Expand Up @@ -666,9 +666,9 @@ def allgather_dict_tensors(
output = {}
sorted_keys = sorted(tensors_as_dict.keys())
for key in sorted_keys:
val = tensors_as_dict[key]
output[key] = [torch.empty_like(val) for _ in range(size)]
torch.distributed.all_gather(output[key], val, group=group, async_op=False)
value = tensors_as_dict[key]
output[key] = [torch.empty_like(value) for _ in range(size)]
torch.distributed.all_gather(output[key], value, group=group, async_op=False)
output[key] = torch.cat(output[key], dim=dim)

if is_tensor_dict:
Expand Down
4 changes: 4 additions & 0 deletions verl/trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def post_init(self):
if self.save_checkpoint_path is None:
self.save_checkpoint_path = os.path.join("checkpoints", self.project_name, self.experiment_name)

self.save_checkpoint_path = os.path.abspath(self.save_checkpoint_path)
if self.load_checkpoint_path is not None:
self.load_checkpoint_path = os.path.abspath(self.load_checkpoint_path)


@dataclass
class PPOConfig:
Expand Down
23 changes: 13 additions & 10 deletions verl/trainer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""

import json

Expand All @@ -23,7 +20,7 @@
from ..single_controller.ray import RayWorkerGroup
from ..utils.tokenizer import get_processor, get_tokenizer
from ..workers.fsdp_workers import FSDPWorker
from ..workers.reward import CustomRewardManager
from ..workers.reward import FunctionRewardManager
from .config import PPOConfig
from .ray_trainer import RayPPOTrainer, ResourcePoolManager, Role

Expand All @@ -35,7 +32,6 @@ class Runner:

def run(self, config: PPOConfig):
# print config
config.deep_post_init()
print(json.dumps(config.to_dict(), indent=2))

# instantiate tokenizer
Expand Down Expand Up @@ -68,8 +64,8 @@ def run(self, config: PPOConfig):
}
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward)
val_reward_fn = CustomRewardManager(tokenizer=tokenizer, config=config.worker.reward)
reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)
val_reward_fn = FunctionRewardManager(config=config.worker.reward, tokenizer=tokenizer)

trainer = RayPPOTrainer(
config=config,
Expand All @@ -95,11 +91,18 @@ def main():
default_config = OmegaConf.merge(default_config, file_config)

ppo_config = OmegaConf.merge(default_config, cli_args)
ppo_config = OmegaConf.to_object(ppo_config)
ppo_config: PPOConfig = OmegaConf.to_object(ppo_config)
ppo_config.deep_post_init()

if not ray.is_initialized():
# this is for local ray cluster
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}})
runtime_env = {
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "INFO",
}
}
ray.init(runtime_env=runtime_env)

runner = Runner.remote()
ray.get(runner.run.remote(ppo_config))
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def compute_timing_metrics(batch: DataProto, timing_raw: Dict[str, float]) -> Di
}


def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], n_gpus: int) -> Dict[str, Any]:
def compute_throughout_metrics(batch: DataProto, timing_raw: Dict[str, float], num_gpus: int) -> Dict[str, Any]:
total_num_tokens = sum(batch.meta_info["global_token_num"])
time = timing_raw["step"]
return {
"perf/total_num_tokens": total_num_tokens,
"perf/time_per_step": time,
"perf/throughput": total_num_tokens / (time * n_gpus),
"perf/throughput": total_num_tokens / (time * num_gpus),
}
38 changes: 14 additions & 24 deletions verl/trainer/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,16 @@ def get_resource_pool(self, role: Role) -> RayResourcePool:
"""Get the resource pool of the worker."""
return self.resource_pool_dict[self.mapping[role]]

def get_n_gpus(self) -> int:
def get_num_gpus(self) -> int:
"""Get the number of gpus in this cluster."""
return sum([n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes])

def _check_resource_available(self):
"""Check if the resource pool can be satisfied in this ray cluster."""
node_available_resources = ray.state.available_resources_per_node()
node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()}

# check total required gpus can be satisfied
total_available_gpus = sum(node_available_gpus.values())
total_required_gpus = sum(
[n_gpus for process_on_nodes in self.resource_pool_spec.values() for n_gpus in process_on_nodes]
)
if total_available_gpus < total_required_gpus:
raise ValueError(
f"Total available GPUs {total_available_gpus} is less than total desired GPUs {total_required_gpus}."
)
gpus_available = ray.available_resources().get("GPU", 0)
gpus_required = self.get_num_gpus()
if gpus_available < gpus_required:
raise ValueError(f"Total available GPUs {gpus_available} is less than total desired GPUs {gpus_required}.")


def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penalty="kl"):
Expand All @@ -128,11 +120,8 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.KLController, kl_penal
response_mask = data.batch["response_mask"]

# compute kl between ref_policy and current policy
if "ref_log_probs" in data.batch.keys():
kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty)
kld = kld * response_mask # (batch_size, response_length)
else:
kld = torch.zeros_like(response_mask, dtype=torch.float32)
kld = core_algos.compute_kl(data.batch["old_log_probs"], data.batch["ref_log_probs"], kl_penalty=kl_penalty)
kld = kld * response_mask # (batch_size, response_length)

data.batch["token_level_rewards"] = token_level_scores - kl_ctrl.kl_coef * kld

Expand Down Expand Up @@ -366,10 +355,10 @@ def _validate(self) -> Dict[str, Any]:
input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
sample_inputs.extend(input_texts)

if "multi_modal_inputs" in test_batch.non_tensor_batch.keys():
if "multi_modal_data" in test_batch.non_tensor_batch.keys():
test_gen_batch = test_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
)
else:
test_gen_batch = test_batch.pop(
Expand Down Expand Up @@ -567,10 +556,10 @@ def fit(self):
batch: DataProto = DataProto.from_single_dict(batch_dict)

# pop those keys for generation
if "multi_modal_inputs" in batch.non_tensor_batch.keys():
if "multi_modal_data" in batch.non_tensor_batch.keys():
gen_batch = batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
)
else:
gen_batch = batch.pop(
Expand Down Expand Up @@ -604,6 +593,7 @@ def fit(self):
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.worker.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
batch.non_tensor_batch.pop("multi_modal_data", None)

# compute reward
with _timer("reward", timing_raw):
Expand Down Expand Up @@ -694,10 +684,10 @@ def fit(self):
self._save_checkpoint()

# collect metrics
n_gpus = self.resource_pool_manager.get_n_gpus()
num_gpus = self.resource_pool_manager.get_num_gpus()
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, num_gpus=num_gpus))

self.logger.log(data=metrics, step=self.global_step)

Expand Down
Loading