Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,30 @@ def fit(self):
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)

if "rollout_log_probs" in batch.batch.keys():
# TODO: we may want to add diff of probs too.
rollout_old_log_probs = batch.batch["rollout_log_probs"]
actor_old_log_probs = batch.batch["old_log_probs"]
attention_mask = batch.batch["attention_mask"]
responses = batch.batch["responses"]
response_length = responses.size(1)
response_mask = attention_mask[:, -response_length:]

rollout_probs = torch.exp(rollout_old_log_probs)
actor_probs = torch.exp(actor_old_log_probs)
rollout_probs_diff = torch.abs(rollout_probs - actor_probs)
rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())
rollout_probs_diff_max = torch.max(rollout_probs_diff)
rollout_probs_diff_mean = torch.mean(rollout_probs_diff)
rollout_probs_diff_std = torch.std(rollout_probs_diff)
metrics.update(
{
"training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(),
"training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(),
"training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(),
}
)

if self.use_reference_policy:
# compute reference log_prob
with _timer("ref", timing_raw):
Expand Down
8 changes: 4 additions & 4 deletions verl/workers/rollout/sglang_rollout/async_sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:

# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
print(f"{self.sampling_params=}")
# print(f"{self.sampling_params=}")
if self._tp_rank == 0:
loop = asyncio.get_event_loop()
output = loop.run_until_complete(
Expand All @@ -390,11 +390,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
out = _post_process_outputs(self.tokenizer, output)

response = out[0].to(idx.device)
# log_probs = out[1].to(idx.device)
rollout_log_probs = out[1].to(idx.device)

if response.shape[1] < self.config.response_length:
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
# log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id)

# utilize current sampling params
if self.sampling_params.get("n", 1) > 1 and do_sample:
Expand Down Expand Up @@ -428,7 +428,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
"prompts": idx,
"responses": response,
"input_ids": seq, # here input_ids become the whole sentences
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor
"attention_mask": attention_mask,
"position_ids": position_ids,
},
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/rollout/sglang_rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:

# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
print(f"{self.sampling_params=}")
# print(f"{self.sampling_params=}")
output = self.inference_engine.generate(
prompt=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
Expand Down
6 changes: 3 additions & 3 deletions verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
# TODO(sgm): disable logprob when recompute_log_prob is enable
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
response = output[0].to(idx.device)
# log_probs = output[1].to(idx.device)
log_probs = output[1].to(idx.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check if logprob is >0 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logprob will never be > 0?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I refer to is here:

kwargs = dict(
            n=1,
            logprobs=0,  # can be set to 0 and let actor to recompute
            max_tokens=config.response_length,
        )

We may need to set logprobs > 0 to get logprob returns in vllm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logprobs == 0 will return the highest logprob. So it would be fine here


if response.shape[1] < self.config.response_length:
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
# log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)

# utilize current sampling params
if self.sampling_params.n > 1 and do_sample:
Expand Down Expand Up @@ -262,7 +262,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
"prompts": idx,
"responses": response,
"input_ids": seq, # here input_ids become the whole sentences
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
'rollout_log_probs': log_probs, # we will recompute old log prob with actor
"attention_mask": attention_mask,
"position_ids": position_ids,
},
Expand Down
12 changes: 10 additions & 2 deletions verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,19 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)

response = []
rollout_log_probs = []
for output in outputs:
for sample_id in range(len(output.outputs)):
response.append(output.outputs[sample_id].token_ids)
response_ids = output.outputs[sample_id].token_ids
response.append(response_ids)
curr_log_prob = []
for i, logprob in enumerate(output.outputs[sample_id].logprobs):
curr_log_prob.append(logprob[response_ids[i]].logprob)
rollout_log_probs.append(curr_log_prob)

response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device)
rollout_log_probs = pad_2d_list_to_length(rollout_log_probs, -1, max_length=self.config.response_length).to(idx.device)
rollout_log_probs = rollout_log_probs.to(torch.float32)

if self.sampling_params.n > 1 and do_sample:
idx = _repeat_interleave(idx, self.sampling_params.n)
Expand Down Expand Up @@ -322,7 +330,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
"prompts": idx,
"responses": response,
"input_ids": seq, # here input_ids become the whole sentences
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor
"attention_mask": attention_mask,
"position_ids": position_ids,
},
Expand Down
Loading