Skip to content

Commit 4a6d13a

Browse files
authored
[recipe] fix: fix issue when running split ppo (verl-project#2745)
1 parent e48433e commit 4a6d13a

3 files changed

Lines changed: 10 additions & 4 deletions

File tree

examples/split_placement/config/ppo_trainer_split.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
hydra:
44
searchpath:
5-
- file://verl/trainer/config
5+
- file://../../verl/trainer/config
66

77
defaults:
88
- ppo_trainer

examples/split_placement/split_monkey_patch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def fit(self):
169169
lam=self.config.algorithm.lam,
170170
num_repeat=self.config.actor_rollout_ref.rollout.n,
171171
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
172+
config=self.config.algorithm,
172173
)
173174

174175
# implement critic warmup

verl/workers/critic/dp_critic.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,11 @@ def compute_values(self, data: DataProto) -> torch.Tensor:
160160
micro_batch_size = data.meta_info["micro_batch_size"]
161161
use_dynamic_bsz = data.meta_info["use_dynamic_bsz"]
162162
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
163-
select_keys = ["responses", "input_ids", "response_mask", "attention_mask", "position_ids"]
163+
select_keys = (
164+
["responses", "input_ids", "response_mask", "attention_mask", "position_ids"]
165+
if "response_mask" in data.batch
166+
else ["responses", "input_ids", "attention_mask", "position_ids"]
167+
)
164168
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
165169

166170
data = data.select(batch_keys=select_keys, non_tensor_batch_keys=non_tensor_select_keys)
@@ -182,8 +186,9 @@ def compute_values(self, data: DataProto) -> torch.Tensor:
182186
if use_dynamic_bsz:
183187
values = restore_dynamic_batch(values, batch_idx_list)
184188

185-
response_mask = data.batch["response_mask"]
186-
values = values * response_mask # Only action tokens have values
189+
if "response_mask" in data.batch:
190+
response_mask = data.batch["response_mask"]
191+
values = values * response_mask # Only action tokens have values
187192
return values
188193

189194
@GPUMemoryLogger(role="dp critic", logger=logger)

0 commit comments

Comments
 (0)