-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling #2953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BREAKING][vllm, fsdp] feat: add Rollout-Training Mismatch Fix -- Truncated importance sampling #2953
Conversation
straight through trick for kl gradient estimation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces Truncated Importance Sampling (TIS) to correct for the mismatch between the rollout policy and the training policy in PPO. The changes look good overall, correctly implementing the TIS logic by re-weighting the PPO loss with a clipped importance ratio. The necessary configuration options and data plumbing have been added. There is one potential issue in a deprecated function that could lead to a runtime error if TIS is enabled but rollout log probs are not provided.
|
Hi VeRL Team, This feature has been incorporated in OpenRLHF: https://github.com/OpenRLHF/OpenRLHF/releases/tag/v0.8.9 Thanks! |
Why
|
|
Need vllm fix vllm-project/vllm#22387 to return |
| if self.config.use_kl_loss: | ||
| select_keys.append("ref_log_prob") | ||
| if self.config.imp_ratio_cap > 0: | ||
| select_keys.append("rollout_log_probs") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Server mode(agent loop) hasn't return rollout_log_probs for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I have add a check here before adding rollout_log_probs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
| # support logging rollout prob for debugging purpose | ||
| calculate_log_probs: False | ||
| # "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling | ||
| calculate_log_probs: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you turn the default value to False and add a running script for TIS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed it now and added one running scripts :)
verl/trainer/config/actor/actor.yaml
Outdated
| entropy_coeff: 0 | ||
|
|
||
| # whether to apply the truncated Importance Sampling (-1 for no importance sampling) | ||
| imp_ratio_cap: 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed :)
| loss_agg_mode: str = "token-mean", | ||
| rollout_log_probs=None, | ||
| imp_ratio_cap=-1, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function is deprecated and you should change verl.trainer.ppo.core_algos.compute_policy_loss_vanilla instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
icic, I have included change in change in verl.trainer.ppo.core_algos.compute_policy_loss_vanilla as well.
Let me delete the change in the compute_policy_loss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have fixed now :)
verl/workers/actor/dp_actor.py
Outdated
| loss_agg_mode=loss_agg_mode, | ||
| config=self.config, | ||
| rollout_log_probs=rollout_log_probs, | ||
| imp_ratio_cap=self.config.imp_ratio_cap, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to pass it as it's included in the config already
|
@jianfeng-Liu I do not quite understand your writing, but I guess you want to say
This is indeed correct; and one can further notice that |
| model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch} | ||
| response_mask = model_inputs["response_mask"] | ||
| old_log_prob = model_inputs["old_log_probs"] | ||
| rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A very minor problem: what if the user specify tis_imp_ratio_cap? but not specify calculate_log_probs = True? I would suggest directly checking whether "rollout_log_probs" is in model_inputs to avoid any risk of KeyNotFoundError (note this is already very deep in the whole verl codebase, so it could be hard for the user to debug)
|
|
||
| # Truncated Importance Sampling (TIS): https://fengyao.notion.site/off-policy-rl | ||
| # the truncation value C of truncated Importance Sampling (-1 for disable TIS) | ||
| tis_imp_ratio_cap: -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think there should be some joint checking on the configuration? Like, if the user specifies tis_imp_ratio_cap, calculate_log_probs must be True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Please verify in the config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, we have already addressed this issue inverl/workers/actor/dp_actor.py as follows:
if self.config.tis_imp_ratio_cap > 0:
assert "rollout_log_probs" in data.batch.keys(), (
"Truncated Importance Sampling (TIS) requires to configure "
"`actor_rollout_ref.rollout.calculate_log_probs=True` "
"and is not currently supported in Server mode (agent loop)."
)
select_keys.append("rollout_log_probs")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks! Looks good to me now.
@xihuai18 Great question! As to |
|
@yaof20 @LiyuanLucasLiu how do you deal with vLLM occasionally returning NaN logprobs for selected tokens? In my testing this is fairly common, have you observed this as well? |
@LeonEricsson Thanks for asking! After the patch, we haven't observed NaN logprobs. The log for our |
… functions (#3274) ### What does this PR do? <!-- > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. --> In the recent PR: - #2953, the file `workers/actor/dp_actor.py` was updated so that `rollout_log_probs` is passed to `policy_loss_fn`: https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456 In that PR, the "vanilla" policy loss function was modified to accept `rollout_log_probs` as an argument. However, other policy loss functions (e.g., "gspo") were not updated accordingly, which leads to an error such as: ``` TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs' ``` when setting `config.policy_loss.loss_mode` to one of these alternatives. Therefore, in this PR, `rollout_log_probs` is also added as an argument to the other policy loss functions. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
… functions (volcengine#3274) <!-- > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. --> In the recent PR: - volcengine#2953, the file `workers/actor/dp_actor.py` was updated so that `rollout_log_probs` is passed to `policy_loss_fn`: https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456 In that PR, the "vanilla" policy loss function was modified to accept `rollout_log_probs` as an argument. However, other policy loss functions (e.g., "gspo") were not updated accordingly, which leads to an error such as: ``` TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs' ``` when setting `config.policy_loss.loss_mode` to one of these alternatives. Therefore, in this PR, `rollout_log_probs` is also added as an argument to the other policy loss functions. - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python ``` > Demonstrate the high-level design if this PR is complex, and list the specific changes. > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
Looking forward if these get upstreamed into vllm :) |
I gather you haven't stumbled upon an easier workaround for the fact that vllm doesn't expose a quantization API for |
Thanks for asking! For the logprob patch, it has been upstreamed into vllm already, as in: vllm-project/vllm#22387 |
Yes, vllm dont have the load weight api yet, but we do plan to work with vllm team to have this api provided in the future. |
Wow, that was fast :) |
…ncated importance sampling (volcengine#2953) ### What does this PR do? Support [vLLM-FSDP off-policy importance sampling correction](https://fengyao.notion.site/off-policy-rl) using Truncated Importance Sampling (TIS): <img width="859" height="382" alt="TIS" src="https://github.com/user-attachments/assets/adc8f797-aa14-4b29-b265-a682c281d08e" /> ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=gae \ data.train_files="$train_files" \ data.val_files="$test_files" \ data.train_batch_size=1024 \ data.max_prompt_length=1024 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ actor_rollout_ref.model.enable_gradient_checkpointing=False \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=Qwen/Qwen2.5-32B-Instruct \ critic.model.enable_gradient_checkpointing=False \ critic.ppo_micro_batch_size_per_gpu=8 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=4 \ trainer.save_freq=20 \ trainer.test_freq=10 \ trainer.total_epochs=15 \ actor_rollout_ref.rollout.calculate_log_probs=True \ # add this config to return rollout prob +actor_rollout_ref.actor.behav_imp_weight_cap=10.0$@ # add this config to set up C value in TIS ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Narsil-Dinghuai Zhang 张鼎怀 <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: LiyuanLucasLiu <[email protected]>
… functions (volcengine#3274) ### What does this PR do? <!-- > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. --> In the recent PR: - volcengine#2953, the file `workers/actor/dp_actor.py` was updated so that `rollout_log_probs` is passed to `policy_loss_fn`: https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456 In that PR, the "vanilla" policy loss function was modified to accept `rollout_log_probs` as an argument. However, other policy loss functions (e.g., "gspo") were not updated accordingly, which leads to an error such as: ``` TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs' ``` when setting `config.policy_loss.loss_mode` to one of these alternatives. Therefore, in this PR, `rollout_log_probs` is also added as an argument to the other policy loss functions. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
… functions (volcengine#3274) <!-- > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. --> In the recent PR: - volcengine#2953, the file `workers/actor/dp_actor.py` was updated so that `rollout_log_probs` is passed to `policy_loss_fn`: https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456 In that PR, the "vanilla" policy loss function was modified to accept `rollout_log_probs` as an argument. However, other policy loss functions (e.g., "gspo") were not updated accordingly, which leads to an error such as: ``` TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs' ``` when setting `config.policy_loss.loss_mode` to one of these alternatives. Therefore, in this PR, `rollout_log_probs` is also added as an argument to the other policy loss functions. - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python ``` > Demonstrate the high-level design if this PR is complex, and list the specific changes. > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
…ncated importance sampling (volcengine#2953) ### What does this PR do? Support [vLLM-FSDP off-policy importance sampling correction](https://fengyao.notion.site/off-policy-rl) using Truncated Importance Sampling (TIS): <img width="859" height="382" alt="TIS" src="https://github.com/user-attachments/assets/adc8f797-aa14-4b29-b265-a682c281d08e" /> ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=gae \ data.train_files="$train_files" \ data.val_files="$test_files" \ data.train_batch_size=1024 \ data.max_prompt_length=1024 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ actor_rollout_ref.model.enable_gradient_checkpointing=False \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=Qwen/Qwen2.5-32B-Instruct \ critic.model.enable_gradient_checkpointing=False \ critic.ppo_micro_batch_size_per_gpu=8 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=4 \ trainer.save_freq=20 \ trainer.test_freq=10 \ trainer.total_epochs=15 \ actor_rollout_ref.rollout.calculate_log_probs=True \ # add this config to return rollout prob +actor_rollout_ref.actor.behav_imp_weight_cap=10.0$@ # add this config to set up C value in TIS ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Narsil-Dinghuai Zhang 张鼎怀 <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: LiyuanLucasLiu <[email protected]>
… functions (volcengine#3274) ### What does this PR do? <!-- > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. --> In the recent PR: - volcengine#2953, the file `workers/actor/dp_actor.py` was updated so that `rollout_log_probs` is passed to `policy_loss_fn`: https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456 In that PR, the "vanilla" policy loss function was modified to accept `rollout_log_probs` as an argument. However, other policy loss functions (e.g., "gspo") were not updated accordingly, which leads to an error such as: ``` TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs' ``` when setting `config.policy_loss.loss_mode` to one of these alternatives. Therefore, in this PR, `rollout_log_probs` is also added as an argument to the other policy loss functions. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
### What does this PR do? Add the TIS support from #2953 to megatron actor ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Co-authored-by: Shuang Yu <[email protected]>
…ncated importance sampling (volcengine#2953) Support [vLLM-FSDP off-policy importance sampling correction](https://fengyao.notion.site/off-policy-rl) using Truncated Importance Sampling (TIS): <img width="859" height="382" alt="TIS" src="https://github.com/user-attachments/assets/adc8f797-aa14-4b29-b265-a682c281d08e" /> - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=gae \ data.train_files="$train_files" \ data.val_files="$test_files" \ data.train_batch_size=1024 \ data.max_prompt_length=1024 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ actor_rollout_ref.model.enable_gradient_checkpointing=False \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=Qwen/Qwen2.5-32B-Instruct \ critic.model.enable_gradient_checkpointing=False \ critic.ppo_micro_batch_size_per_gpu=8 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=4 \ trainer.save_freq=20 \ trainer.test_freq=10 \ trainer.total_epochs=15 \ actor_rollout_ref.rollout.calculate_log_probs=True \ # add this config to return rollout prob +actor_rollout_ref.actor.behav_imp_weight_cap=10.0$@ # add this config to set up C value in TIS ``` > Demonstrate the high-level design if this PR is complex, and list the specific changes. > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Narsil-Dinghuai Zhang 张鼎怀 <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: LiyuanLucasLiu <[email protected]>
… functions (#3274) ### What does this PR do? <!-- > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. --> In the recent PR: - volcengine/verl#2953, the file `workers/actor/dp_actor.py` was updated so that `rollout_log_probs` is passed to `policy_loss_fn`: https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456 In that PR, the "vanilla" policy loss function was modified to accept `rollout_log_probs` as an argument. However, other policy loss functions (e.g., "gspo") were not updated accordingly, which leads to an error such as: ``` TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs' ``` when setting `config.policy_loss.loss_mode` to one of these alternatives. Therefore, in this PR, `rollout_log_probs` is also added as an argument to the other policy loss functions. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
### What does this PR do? Add the TIS support from volcengine/verl#2953 to megatron actor ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Co-authored-by: Shuang Yu <[email protected]>
…ncated importance sampling (volcengine#2953) Support [vLLM-FSDP off-policy importance sampling correction](https://fengyao.notion.site/off-policy-rl) using Truncated Importance Sampling (TIS): <img width="859" height="382" alt="TIS" src="https://github.com/user-attachments/assets/adc8f797-aa14-4b29-b265-a682c281d08e" /> - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=gae \ data.train_files="$train_files" \ data.val_files="$test_files" \ data.train_batch_size=1024 \ data.max_prompt_length=1024 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ actor_rollout_ref.model.enable_gradient_checkpointing=False \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=Qwen/Qwen2.5-32B-Instruct \ critic.model.enable_gradient_checkpointing=False \ critic.ppo_micro_batch_size_per_gpu=8 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=4 \ trainer.save_freq=20 \ trainer.test_freq=10 \ trainer.total_epochs=15 \ actor_rollout_ref.rollout.calculate_log_probs=True \ # add this config to return rollout prob +actor_rollout_ref.actor.behav_imp_weight_cap=10.0$@ # add this config to set up C value in TIS ``` > Demonstrate the high-level design if this PR is complex, and list the specific changes. > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Narsil-Dinghuai Zhang 张鼎怀 <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: LiyuanLucasLiu <[email protected]>
…ncated importance sampling (volcengine#2953) ### What does this PR do? Support [vLLM-FSDP off-policy importance sampling correction](https://fengyao.notion.site/off-policy-rl) using Truncated Importance Sampling (TIS): <img width="859" height="382" alt="TIS" src="https://github.com/user-attachments/assets/adc8f797-aa14-4b29-b265-a682c281d08e" /> ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=gae \ data.train_files="$train_files" \ data.val_files="$test_files" \ data.train_batch_size=1024 \ data.max_prompt_length=1024 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ actor_rollout_ref.model.enable_gradient_checkpointing=False \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=Qwen/Qwen2.5-32B-Instruct \ critic.model.enable_gradient_checkpointing=False \ critic.ppo_micro_batch_size_per_gpu=8 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=4 \ trainer.save_freq=20 \ trainer.test_freq=10 \ trainer.total_epochs=15 \ actor_rollout_ref.rollout.calculate_log_probs=True \ # add this config to return rollout prob +actor_rollout_ref.actor.behav_imp_weight_cap=10.0$@ # add this config to set up C value in TIS ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Narsil-Dinghuai Zhang 张鼎怀 <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: LiyuanLucasLiu <[email protected]>
@zdhNarsil Will the gradients of K2 exhibit high variance, and will this affect training stability? |
… functions (volcengine#3274) ### What does this PR do? <!-- > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. --> In the recent PR: - volcengine#2953, the file `workers/actor/dp_actor.py` was updated so that `rollout_log_probs` is passed to `policy_loss_fn`: https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456 In that PR, the "vanilla" policy loss function was modified to accept `rollout_log_probs` as an argument. However, other policy loss functions (e.g., "gspo") were not updated accordingly, which leads to an error such as: ``` TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs' ``` when setting `config.policy_loss.loss_mode` to one of these alternatives. Therefore, in this PR, `rollout_log_probs` is also added as an argument to the other policy loss functions. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
### What does this PR do? Add the TIS support from volcengine#2953 to megatron actor ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Co-authored-by: Shuang Yu <[email protected]>
…ncated importance sampling (volcengine#2953) ### What does this PR do? Support [vLLM-FSDP off-policy importance sampling correction](https://fengyao.notion.site/off-policy-rl) using Truncated Importance Sampling (TIS): <img width="859" height="382" alt="TIS" src="https://github.com/user-attachments/assets/adc8f797-aa14-4b29-b265-a682c281d08e" /> ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=gae \ data.train_files="$train_files" \ data.val_files="$test_files" \ data.train_batch_size=1024 \ data.max_prompt_length=1024 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \ actor_rollout_ref.model.enable_gradient_checkpointing=False \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=Qwen/Qwen2.5-32B-Instruct \ critic.model.enable_gradient_checkpointing=False \ critic.ppo_micro_batch_size_per_gpu=8 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger='["console","wandb"]' \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.nnodes=4 \ trainer.save_freq=20 \ trainer.test_freq=10 \ trainer.total_epochs=15 \ actor_rollout_ref.rollout.calculate_log_probs=True \ # add this config to return rollout prob +actor_rollout_ref.actor.behav_imp_weight_cap=10.0$@ # add this config to set up C value in TIS ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Narsil-Dinghuai Zhang 张鼎怀 <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: LiyuanLucasLiu <[email protected]>
… functions (volcengine#3274) ### What does this PR do? <!-- > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. --> In the recent PR: - volcengine#2953, the file `workers/actor/dp_actor.py` was updated so that `rollout_log_probs` is passed to `policy_loss_fn`: https://github.com/volcengine/verl/blob/38d23914ee512a125e00763fe3ddcc8df4319346/verl/workers/actor/dp_actor.py#L448-L456 In that PR, the "vanilla" policy loss function was modified to accept `rollout_log_probs` as an argument. However, other policy loss functions (e.g., "gspo") were not updated accordingly, which leads to an error such as: ``` TypeError: compute_policy_loss_gspo() got an unexpected keyword argument 'rollout_log_probs' ``` when setting `config.policy_loss.loss_mode` to one of these alternatives. Therefore, in this PR, `rollout_log_probs` is also added as an argument to the other policy loss functions. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
### What does this PR do? Add the TIS support from volcengine#2953 to megatron actor ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Co-authored-by: Shuang Yu <[email protected]>
### What does this PR do? Add the TIS support from volcengine#2953 to megatron actor ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) Co-authored-by: Shuang Yu <[email protected]>
What does this PR do?
Support vLLM-FSDP off-policy importance sampling correction using Truncated Importance Sampling (TIS):
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
Design & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)