Skip to content

Conversation

@PopSoda2002
Copy link
Contributor

@PopSoda2002 PopSoda2002 commented Jul 27, 2025

What does this PR do?

From issue here:
#2677

Try to pad the prompt, response & mask before batch post-processing to save time
Main idea:
image

# prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])
# response_ids: right padded with zeros (e.g., [5,6,7,8,0,0,0,0])
# input_ids: concatenation of prompt + response
# Mask:
# For example, if the prompt is [1,2,3,4] and the response is [5,6,7,(tool start)8,9(tool end),10,11,12]
# - prompt_attention_mask: 0s for padding, 1s for tokens
#   e.g., [0,0,0,0,1,1,1,1]
# - response_attention_mask: 0s for padding, 1s for tokens
#   e.g., [1,1,1,1,1,1,1,1,1,1,1,0,0,0,0]
# attention_mask: concatenation of prompt_attention_mask and response_attention_mask
#   e.g., [0,0,0,0,1,1,1,1(prompt),1,1,1,1,1,1,1,1,1,1,1,0,0,0,0(response)]
# - response_mask: 1s for LLM generated tokens, 0 for tool response/padding tokens
#   e.g., [1,1,1,1,1,1,1,(tool start),0,0(tool end),1,1,0,0,0,0]
# - position_ids: sequential positions for tokens, starting at 0
#   e.g., [0,0,0,0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,0,0,0,0]         

Test

Environment setup: follow this tutorial
Test config in 4 * H100

#!/bin/bash
# run on 8xH100 with optimizations for stability
# make sure your current working directory is the root of the project

set -x

ulimit -n 65535

# 增加网络稳定性环境变量
export CUDA_HOME=/usr/local/cuda
export CUDA_VISIBLE_DEVICES=4,5,6,7

PROJECT_DIR="$(pwd)"
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"

python3 -m verl.trainer.main_ppo \
    --config-path="$CONFIG_PATH" \
    --config-name='gsm8k_multiturn_grpo' \
    algorithm.adv_estimator=grpo \
    data.train_batch_size=256 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.return_raw_chat=True \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=sglang \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.n=16 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='gsm8k_async_rl' \
    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16-agent-loop-v1' \
    trainer.n_gpus_per_node=4 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=20 \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
    trainer.total_epochs=15 \
    actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \
    actor_rollout_ref.rollout.trace.backend=weave \
    actor_rollout_ref.rollout.trace.token2text=True \
    actor_rollout_ref.rollout.mode=async \
    actor_rollout_ref.rollout.multi_turn.enable=true

Before(v1) & After(v2)

image image

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

@CLAassistant
Copy link

CLAassistant commented Jul 27, 2025

CLA assistant check
All committers have signed the CLA.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces padding before batch post-processing in the agent loop to improve performance. It also addresses potential serialization issues by converting numpy int64 values to standard Python integers.

@PopSoda2002 PopSoda2002 changed the title [Feat.][perf] Padding before batch post-process in agent-loop [perf] feat: Padding before batch post-process in agent-loop to save time Jul 27, 2025
"""Number of chat turns, including user, assistant, tool."""
metrics: AgentLoopMetrics
"""Auxiliary performance metrics"""
processed_tensors: dict = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can save padded ids directly in prompt_ids, response_ids, response_mask.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what about tensor like attention_mask?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my fault, they all can be computed within the three one, should we move the calculation in the postprocess ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should move pad logic of prompt_ids, response_ids, response_mask in _postprocess to _run_agent_loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx! I will do this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

attention_mask = torch.cat([prompt_output["attention_mask"], response_output["attention_mask"]], dim=1)

# Overwrite with padded data, converted to lists for safe serialization.
output.prompt_ids = prompt_output["input_ids"].squeeze(0).tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep tensor as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_run_agent_loop can return other class instead of AgentLoopOutput

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed! Thx

@wuxibin89
Copy link
Collaborator

LGTM, let's merge it after CI passed.

@zhaochenyang20
Copy link
Collaborator

Great job!

@wuxibin89 wuxibin89 merged commit c3df0b5 into volcengine:main Jul 30, 2025
50 of 53 checks passed
@GuanxingLu
Copy link

Great job!

yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jul 31, 2025
…time (volcengine#2773)

### What does this PR do?

From issue here:
volcengine#2677

Try to pad the `prompt`, `response` & `mask` before batch
post-processing to save time
Main idea:
<img width="1978" height="916" alt="image"
src="https://github.com/user-attachments/assets/bf16d45b-9da8-4d07-aab4-d8773e5ab705"
/>

```python
# prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])
# response_ids: right padded with zeros (e.g., [5,6,7,8,0,0,0,0])
# input_ids: concatenation of prompt + response
# Mask:
# For example, if the prompt is [1,2,3,4] and the response is [5,6,7,(tool start)8,9(tool end),10,11,12]
# - prompt_attention_mask: 0s for padding, 1s for tokens
#   e.g., [0,0,0,0,1,1,1,1]
# - response_attention_mask: 0s for padding, 1s for tokens
#   e.g., [1,1,1,1,1,1,1,1,1,1,1,0,0,0,0]
# attention_mask: concatenation of prompt_attention_mask and response_attention_mask
#   e.g., [0,0,0,0,1,1,1,1(prompt),1,1,1,1,1,1,1,1,1,1,1,0,0,0,0(response)]
# - response_mask: 1s for LLM generated tokens, 0 for tool response/padding tokens
#   e.g., [1,1,1,1,1,1,1,(tool start),0,0(tool end),1,1,0,0,0,0]
# - position_ids: sequential positions for tokens, starting at 0
#   e.g., [0,0,0,0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,0,0,0,0]         
```

### Test

Environment setup: follow this
[tutorial](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/agent_loop.md)
Test config in 4 * H100
```bash
#!/bin/bash
# run on 8xH100 with optimizations for stability
# make sure your current working directory is the root of the project

set -x

ulimit -n 65535

# 增加网络稳定性环境变量
export CUDA_HOME=/usr/local/cuda
export CUDA_VISIBLE_DEVICES=4,5,6,7

PROJECT_DIR="$(pwd)"
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"

python3 -m verl.trainer.main_ppo \
    --config-path="$CONFIG_PATH" \
    --config-name='gsm8k_multiturn_grpo' \
    algorithm.adv_estimator=grpo \
    data.train_batch_size=256 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.return_raw_chat=True \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=sglang \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.n=16 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='gsm8k_async_rl' \
    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16-agent-loop-v1' \
    trainer.n_gpus_per_node=4 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=20 \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
    trainer.total_epochs=15 \
    actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \
    actor_rollout_ref.rollout.trace.backend=weave \
    actor_rollout_ref.rollout.trace.token2text=True \
    actor_rollout_ref.rollout.mode=async \
    actor_rollout_ref.rollout.multi_turn.enable=true
```
Before(v1) & After(v2)

<img width="831" height="632" alt="image"
src="https://github.com/user-attachments/assets/033737e2-1b63-4b25-8b26-ab593db28a90"
/>

<img width="1674" height="1272" alt="image"
src="https://github.com/user-attachments/assets/296fbb37-430f-4f45-84c1-e003930a1896"
/>

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
Juniper1021 pushed a commit to Juniper1021/verl that referenced this pull request Aug 7, 2025
…time (volcengine#2773)

### What does this PR do?

From issue here:
volcengine#2677

Try to pad the `prompt`, `response` & `mask` before batch
post-processing to save time
Main idea:
<img width="1978" height="916" alt="image"
src="https://github.com/user-attachments/assets/bf16d45b-9da8-4d07-aab4-d8773e5ab705"
/>

```python
# prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])
# response_ids: right padded with zeros (e.g., [5,6,7,8,0,0,0,0])
# input_ids: concatenation of prompt + response
# Mask:
# For example, if the prompt is [1,2,3,4] and the response is [5,6,7,(tool start)8,9(tool end),10,11,12]
# - prompt_attention_mask: 0s for padding, 1s for tokens
#   e.g., [0,0,0,0,1,1,1,1]
# - response_attention_mask: 0s for padding, 1s for tokens
#   e.g., [1,1,1,1,1,1,1,1,1,1,1,0,0,0,0]
# attention_mask: concatenation of prompt_attention_mask and response_attention_mask
#   e.g., [0,0,0,0,1,1,1,1(prompt),1,1,1,1,1,1,1,1,1,1,1,0,0,0,0(response)]
# - response_mask: 1s for LLM generated tokens, 0 for tool response/padding tokens
#   e.g., [1,1,1,1,1,1,1,(tool start),0,0(tool end),1,1,0,0,0,0]
# - position_ids: sequential positions for tokens, starting at 0
#   e.g., [0,0,0,0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,0,0,0,0]         
```

### Test

Environment setup: follow this
[tutorial](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/agent_loop.md)
Test config in 4 * H100
```bash
#!/bin/bash
# run on 8xH100 with optimizations for stability
# make sure your current working directory is the root of the project

set -x

ulimit -n 65535

# 增加网络稳定性环境变量
export CUDA_HOME=/usr/local/cuda
export CUDA_VISIBLE_DEVICES=4,5,6,7

PROJECT_DIR="$(pwd)"
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"

python3 -m verl.trainer.main_ppo \
    --config-path="$CONFIG_PATH" \
    --config-name='gsm8k_multiturn_grpo' \
    algorithm.adv_estimator=grpo \
    data.train_batch_size=256 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.return_raw_chat=True \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=sglang \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.n=16 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='gsm8k_async_rl' \
    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16-agent-loop-v1' \
    trainer.n_gpus_per_node=4 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=20 \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
    trainer.total_epochs=15 \
    actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \
    actor_rollout_ref.rollout.trace.backend=weave \
    actor_rollout_ref.rollout.trace.token2text=True \
    actor_rollout_ref.rollout.mode=async \
    actor_rollout_ref.rollout.multi_turn.enable=true
```
Before(v1) & After(v2)

<img width="831" height="632" alt="image"
src="https://github.com/user-attachments/assets/033737e2-1b63-4b25-8b26-ab593db28a90"
/>

<img width="1674" height="1272" alt="image"
src="https://github.com/user-attachments/assets/296fbb37-430f-4f45-84c1-e003930a1896"
/>

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
whatadayG pushed a commit to whatadayG/verl that referenced this pull request Sep 5, 2025
…time (volcengine#2773)

### What does this PR do?

From issue here:
volcengine#2677

Try to pad the `prompt`, `response` & `mask` before batch
post-processing to save time
Main idea:
<img width="1978" height="916" alt="image"
src="https://github.com/user-attachments/assets/bf16d45b-9da8-4d07-aab4-d8773e5ab705"
/>

```python
# prompt_ids: left padded with zeros (e.g., [0,0,0,0,1,2,3,4])
# response_ids: right padded with zeros (e.g., [5,6,7,8,0,0,0,0])
# input_ids: concatenation of prompt + response
# Mask:
# For example, if the prompt is [1,2,3,4] and the response is [5,6,7,(tool start)8,9(tool end),10,11,12]
# - prompt_attention_mask: 0s for padding, 1s for tokens
#   e.g., [0,0,0,0,1,1,1,1]
# - response_attention_mask: 0s for padding, 1s for tokens
#   e.g., [1,1,1,1,1,1,1,1,1,1,1,0,0,0,0]
# attention_mask: concatenation of prompt_attention_mask and response_attention_mask
#   e.g., [0,0,0,0,1,1,1,1(prompt),1,1,1,1,1,1,1,1,1,1,1,0,0,0,0(response)]
# - response_mask: 1s for LLM generated tokens, 0 for tool response/padding tokens
#   e.g., [1,1,1,1,1,1,1,(tool start),0,0(tool end),1,1,0,0,0,0]
# - position_ids: sequential positions for tokens, starting at 0
#   e.g., [0,0,0,0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,0,0,0,0]         
```

### Test

Environment setup: follow this
[tutorial](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/blob/main/rlhf/verl/multi-turn/tool_examples/agent_loop.md)
Test config in 4 * H100
```bash
#!/bin/bash
# run on 8xH100 with optimizations for stability
# make sure your current working directory is the root of the project

set -x

ulimit -n 65535

# 增加网络稳定性环境变量
export CUDA_HOME=/usr/local/cuda
export CUDA_VISIBLE_DEVICES=4,5,6,7

PROJECT_DIR="$(pwd)"
CONFIG_PATH="$PROJECT_DIR/examples/sglang_multiturn/config"

python3 -m verl.trainer.main_ppo \
    --config-path="$CONFIG_PATH" \
    --config-name='gsm8k_multiturn_grpo' \
    algorithm.adv_estimator=grpo \
    data.train_batch_size=256 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    data.return_raw_chat=True \
    actor_rollout_ref.model.path=Qwen/Qwen2.5-3B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=128 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=sglang \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
    actor_rollout_ref.rollout.n=16 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger='["console","wandb"]' \
    trainer.project_name='gsm8k_async_rl' \
    trainer.experiment_name='qwen2.5-3b_function_rm-gsm8k-sgl-multi-w-tool-verify-n16-agent-loop-v1' \
    trainer.n_gpus_per_node=4 \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=20 \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    actor_rollout_ref.rollout.multi_turn.tool_config_path="$PROJECT_DIR/examples/sglang_multiturn/config/tool_config/gsm8k_tool_config.yaml" \
    trainer.total_epochs=15 \
    actor_rollout_ref.rollout.update_weights_bucket_megabytes=128 \
    actor_rollout_ref.rollout.trace.backend=weave \
    actor_rollout_ref.rollout.trace.token2text=True \
    actor_rollout_ref.rollout.mode=async \
    actor_rollout_ref.rollout.multi_turn.enable=true
```
Before(v1) & After(v2)

<img width="831" height="632" alt="image"
src="https://github.com/user-attachments/assets/033737e2-1b63-4b25-8b26-ab593db28a90"
/>

<img width="1674" height="1272" alt="image"
src="https://github.com/user-attachments/assets/296fbb37-430f-4f45-84c1-e003930a1896"
/>

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants