@@ -79,6 +79,8 @@ def gptmodel_forward_qwen2_5_vl(
7979 assert mpu .get_context_parallel_world_size () == 1 , "qwen2_5_vl's context parallel is not accurate yet"
8080 pre_process = unwrap_model (model ).pre_process
8181 post_process = unwrap_model (model ).post_process
82+ pixel_values = multi_modal_inputs ["pixel_values" ].to (input_ids .device ) if "pixel_values" in multi_modal_inputs else None
83+ image_grid_thw = multi_modal_inputs ["image_grid_thw" ].to (input_ids .device ) if "image_grid_thw" in multi_modal_inputs else None
8284 if pack_seqs :
8385 batch_size , seq_len = attention_mask .shape [:2 ]
8486 input_ids_rmpad , packed_seq_params = preprocess_packed_seqs (input_ids , attention_mask , pre_process = True )
@@ -88,8 +90,8 @@ def gptmodel_forward_qwen2_5_vl(
8890 attention_mask = None ,
8991 position_ids = position_ids ,
9092 packed_seq_params = packed_seq_params ,
91- pixel_values = multi_modal_inputs [ " pixel_values" ]. to ( input_ids . device ) ,
92- image_grid_thw = multi_modal_inputs [ " image_grid_thw" ]. to ( input_ids . device ) ,
93+ pixel_values = pixel_values ,
94+ image_grid_thw = image_grid_thw ,
9395 )
9496
9597 if post_process and logits_processor is not None :
@@ -105,8 +107,8 @@ def gptmodel_forward_qwen2_5_vl(
105107 input_ids = new_input_ids ,
106108 position_ids = new_position_ids ,
107109 attention_mask = new_attention_mask ,
108- pixel_values = multi_modal_inputs [ " pixel_values" ]. to ( input_ids . device ) ,
109- image_grid_thw = multi_modal_inputs [ " image_grid_thw" ]. to ( input_ids . device ) ,
110+ pixel_values = pixel_values ,
111+ image_grid_thw = image_grid_thw ,
110112 )
111113 output = recover_left_padding (output , new_attention_mask , attention_mask , sequence_length , post_process = post_process )
112114 if value_model and post_process :
0 commit comments