Skip to content

Commit bc299b8

Browse files
[sglang] fix: Qwen VLM Baseline (volcengine#3083)
### What does this PR do? This PR fix the script in https://github.com/volcengine/verl/blob/main/examples/grpo_trainer/run_qwen2_5_vl-7b.sh The core issue was `TypeError: 'NoneType'` object is not callable which occurred because the variable flash_attn_varlen_func was assigned None. This happened when the primary import from `transformers.modeling_flash_attention_utils` failed. I add a nested try...except block to first attempt the import from transformers, and if that fails, to then try importing `flash_attn_varlen_func` directly from the `flash_attn` package as a solution. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. I added a new test script here: `examples/grpo_trainer/run_qwen2_5_vl-7b-sglang.sh` ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: zhaochenyang20 <[email protected]>
1 parent ce26a7b commit bc299b8

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
set -x
2+
3+
# python examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k
4+
5+
python3 -m verl.trainer.main_ppo \
6+
algorithm.adv_estimator=grpo \
7+
data.train_files=$HOME/data/geo3k/train.parquet \
8+
data.val_files=$HOME/data/geo3k/test.parquet \
9+
data.train_batch_size=512 \
10+
data.max_prompt_length=1024 \
11+
data.max_response_length=2048 \
12+
data.filter_overlong_prompts=True \
13+
data.truncation='error' \
14+
data.image_key=images \
15+
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
16+
actor_rollout_ref.actor.optim.lr=1e-6 \
17+
actor_rollout_ref.model.use_remove_padding=True \
18+
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
19+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \
20+
actor_rollout_ref.actor.use_kl_loss=True \
21+
actor_rollout_ref.actor.kl_loss_coef=0.01 \
22+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
23+
actor_rollout_ref.actor.entropy_coeff=0 \
24+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
25+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
26+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
27+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \
28+
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
29+
actor_rollout_ref.rollout.name=sglang \
30+
+actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
31+
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
32+
actor_rollout_ref.rollout.multi_stage_wake_up=True \
33+
actor_rollout_ref.rollout.enable_chunked_prefill=False \
34+
actor_rollout_ref.rollout.enforce_eager=False \
35+
actor_rollout_ref.rollout.free_cache_engine=True \
36+
actor_rollout_ref.rollout.n=5 \
37+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \
38+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
39+
algorithm.use_kl_in_reward=False \
40+
trainer.critic_warmup=0 \
41+
trainer.logger='["console","wandb"]' \
42+
trainer.project_name='verl_grpo_example_geo3k' \
43+
trainer.experiment_name='qwen2_5_vl_7b_function_rm' \
44+
trainer.n_gpus_per_node=8 \
45+
trainer.nnodes=1 \
46+
trainer.save_freq=20 \
47+
trainer.test_freq=5 \
48+
trainer.total_epochs=15 $@

verl/models/transformers/qwen2_vl.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
import logging
1617
import os
1718
from dataclasses import dataclass
1819
from typing import Optional
@@ -36,13 +37,18 @@
3637
validate_ulysses_config,
3738
)
3839

40+
logger = logging.getLogger(__file__)
41+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
42+
3943
try:
4044
from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func
4145

4246
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
4347
except ImportError:
44-
flash_attn_varlen_func = None
48+
# Fallback: try to import from flash_attn package directly
49+
from flash_attn import flash_attn_varlen_func
4550

51+
flash_attn_func = None
4652
_flash_supports_window_size = None
4753

4854

@@ -193,7 +199,12 @@ def flash_attention_forward(
193199
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
194200
flash_kwargs["deterministic"] = deterministic
195201

196-
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all():
202+
if (
203+
flash_attn_varlen_func is not None
204+
and position_ids is not None
205+
and query_length != 1
206+
and not (torch.diff(position_ids[0], dim=-1) >= 0).all()
207+
):
197208
batch_size = query_states.size(0)
198209
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
199210
query_states, key_states, value_states, position_ids[0]
@@ -215,6 +226,16 @@ def flash_attention_forward(
215226
)
216227
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
217228
else:
229+
if (
230+
flash_attn_varlen_func is None
231+
and position_ids is not None
232+
and query_length != 1
233+
and not (torch.diff(position_ids[0], dim=-1) >= 0).all()
234+
):
235+
logger.warning_once(
236+
"flash_attn_varlen_func is not available; falling back to _flash_attention_forward."
237+
"This may be suboptimal for non-monotonic position_ids in VLM mRoPE."
238+
)
218239
attn_output = _flash_attention_forward(
219240
query_states,
220241
key_states,
@@ -226,7 +247,7 @@ def flash_attention_forward(
226247
use_top_left_mask=flash_attn_supports_top_left_mask(),
227248
deterministic=deterministic,
228249
**kwargs,
229-
) # do not pass position_ids to old flash_attention_forward
250+
)
230251

231252
return attn_output
232253

0 commit comments

Comments
 (0)