Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions examples/grpo_trainer/run_qwen2_5_vl-7b-sglang.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
set -x

# python examples/data_preprocess/geo3k.py --local_dir ~/data/geo3k

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/geo3k/train.parquet \
data.val_files=$HOME/data/geo3k/test.parquet \
data.train_batch_size=512 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.image_key=images \
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.01 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=sglang \
+actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
actor_rollout_ref.rollout.multi_stage_wake_up=True \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=True \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=20 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console","wandb"]' \
trainer.project_name='verl_grpo_example_geo3k' \
trainer.experiment_name='qwen2_5_vl_7b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
27 changes: 24 additions & 3 deletions verl/models/transformers/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import inspect
import logging
import os
from dataclasses import dataclass
from typing import Optional
Expand All @@ -36,13 +37,18 @@
validate_ulysses_config,
)

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

try:
from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func

_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
except ImportError:
flash_attn_varlen_func = None
# Fallback: try to import from flash_attn package directly
from flash_attn import flash_attn_varlen_func
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add another try except here


flash_attn_func = None
_flash_supports_window_size = None


Expand Down Expand Up @@ -193,7 +199,12 @@ def flash_attention_forward(
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
flash_kwargs["deterministic"] = deterministic

if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all():
if (
flash_attn_varlen_func is not None
and position_ids is not None
and query_length != 1
and not (torch.diff(position_ids[0], dim=-1) >= 0).all()
):
batch_size = query_states.size(0)
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids[0]
Expand All @@ -215,6 +226,16 @@ def flash_attention_forward(
)
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
if (
flash_attn_varlen_func is None
and position_ids is not None
and query_length != 1
and not (torch.diff(position_ids[0], dim=-1) >= 0).all()
):
logger.warning_once(
"flash_attn_varlen_func is not available; falling back to _flash_attention_forward."
"This may be suboptimal for non-monotonic position_ids in VLM mRoPE."
)
attn_output = _flash_attention_forward(
query_states,
key_states,
Expand All @@ -226,7 +247,7 @@ def flash_attention_forward(
use_top_left_mask=flash_attn_supports_top_left_mask(),
deterministic=deterministic,
**kwargs,
) # do not pass position_ids to old flash_attention_forward
)

return attn_output

Expand Down
Loading