Skip to content

Commit 55854c8

Browse files
Move tests of experimental GRPO with replay buffer to tests/experimental (#4329)
Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 4352074 commit 55854c8

File tree

2 files changed

+276
-269
lines changed

2 files changed

+276
-269
lines changed
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
import pytest
2+
import torch
3+
from datasets import load_dataset
4+
5+
from trl import GRPOTrainer
6+
from trl.experimental.grpo_with_replay_buffer import (
7+
GRPOWithReplayBufferConfig,
8+
GRPOWithReplayBufferTrainer,
9+
ReplayBuffer,
10+
)
11+
12+
from ..testing_utils import TrlTestCase
13+
14+
15+
@pytest.mark.low_priority
16+
class TestReplayBuffer:
17+
def setup_method(self):
18+
self.replay_buffer = ReplayBuffer(max_size=5)
19+
20+
def test_add(self):
21+
# Add elements to the replay buffer
22+
scores = [0.5, 0.8, 0.3, 0.9, 0.7]
23+
data = [
24+
{"id": 1},
25+
{"id": 2},
26+
{"id": 3},
27+
{"id": 4},
28+
{"id": 5},
29+
]
30+
self.replay_buffer.add(scores, data)
31+
32+
# Check if the buffer contains the correct number of elements
33+
assert len(self.replay_buffer.heap) == 5
34+
35+
# Check if the buffer maintains the min-heap property
36+
heap_scores = [item[0] for item in self.replay_buffer.heap]
37+
assert heap_scores[0] == min(heap_scores)
38+
assert heap_scores[0] == 0.3
39+
40+
def test_add_more_than_maxlen(self):
41+
# Add elements to the replay buffer
42+
scores = [0.5, 0.8, 0.3, 0.9, 0.7, 0.6, 0.4]
43+
data = [
44+
{"id": 1},
45+
{"id": 2},
46+
{"id": 3},
47+
{"id": 4},
48+
{"id": 5},
49+
{"id": 6},
50+
{"id": 7},
51+
]
52+
self.replay_buffer.add(scores, data)
53+
54+
# Check if the buffer contains the correct number of elements
55+
assert len(self.replay_buffer.heap) == 5
56+
57+
# Check if the buffer maintains the min-heap property
58+
heap_scores = [item[0] for item in self.replay_buffer.heap]
59+
assert heap_scores[0] == min(heap_scores)
60+
assert heap_scores[0] == 0.5 # 0.3 and 0.4 should be removed
61+
62+
def test_sample(self):
63+
# Add elements to the replay buffer
64+
scores = [0.5, 0.8, 0.3, 0.9, 0.7]
65+
data = [
66+
{"id": 1},
67+
{"id": 2},
68+
{"id": 3},
69+
{"id": 4},
70+
{"id": 5},
71+
]
72+
self.replay_buffer.add(scores, data)
73+
74+
# Sample elements from the buffer
75+
sampled = self.replay_buffer.sample(num_samples=3)
76+
77+
# Check if the sampled elements are from the buffer
78+
assert len(sampled) == 3
79+
for item in sampled:
80+
assert item in [entry[1] for entry in self.replay_buffer.heap]
81+
82+
83+
@pytest.mark.low_priority
84+
class TestUpdateWithReplayBuffer:
85+
def setup_method(self):
86+
config = GRPOWithReplayBufferConfig(
87+
replay_buffer_size=5,
88+
)
89+
self.trainer = GRPOWithReplayBufferTrainer(
90+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
91+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
92+
args=config,
93+
train_dataset=None,
94+
)
95+
self.trainer.replay_buffer = ReplayBuffer(max_size=5)
96+
self.trainer.num_generations = 2
97+
98+
def _prepopulate_buffer(self, with_pixels=False, with_logprobs=False):
99+
scores = [0.1, 0.9]
100+
data = [
101+
{
102+
"prompt_ids": torch.tensor([[100, 101], [102, 103]]),
103+
"prompt_mask": torch.ones(2, 2, dtype=torch.long),
104+
"completion_ids": torch.tensor([[5, 6], [7, 8]]),
105+
"completion_mask": torch.ones(2, 2, dtype=torch.long),
106+
"advantages": torch.tensor([[0.5, 0.6]]),
107+
**({"pixel_values": torch.randn(2, 3, 224, 224)} if with_pixels else {}),
108+
**({"old_per_token_logps": torch.randn(2, 2)} if with_logprobs else {}),
109+
},
110+
{
111+
"prompt_ids": torch.tensor([[104, 105], [106, 107]]),
112+
"prompt_mask": torch.ones(2, 2, dtype=torch.long),
113+
"completion_ids": torch.tensor([[13, 14], [15, 16]]),
114+
"completion_mask": torch.ones(2, 2, dtype=torch.long),
115+
"advantages": torch.tensor([[0.8, 0.85]]),
116+
**({"pixel_values": torch.randn(2, 3, 224, 224)} if with_pixels else {}),
117+
**({"old_per_token_logps": torch.randn(2, 2)} if with_logprobs else {}),
118+
},
119+
]
120+
self.trainer.replay_buffer.add(scores, data)
121+
122+
def _make_inputs(self, group_advantages, with_pixels=False, with_logprobs=False):
123+
inputs = {
124+
"group_advantages": group_advantages,
125+
"prompt_ids": torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]),
126+
"prompt_mask": torch.ones(4, 2, dtype=torch.long),
127+
"completion_ids": torch.tensor([[9, 10], [11, 12], [13, 14], [15, 16]]),
128+
"completion_mask": torch.ones(4, 2, dtype=torch.long),
129+
"prompt_inputs": {"pixel_values": torch.randn(4, 3, 224, 224)} if with_pixels else {},
130+
"old_per_token_logps": torch.randn(4, 2) if with_logprobs else None,
131+
}
132+
inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages)
133+
return inputs
134+
135+
def test_update_with_replay_buffer_no_variance(self):
136+
self._prepopulate_buffer(with_pixels=True, with_logprobs=True)
137+
group_advantages = torch.tensor([[0.5, 0.5], [0.8, 0.8]]) # no variance
138+
inputs = self._make_inputs(group_advantages, with_pixels=True, with_logprobs=True)
139+
original_prompt_ids = inputs["prompt_ids"].clone()
140+
141+
outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
142+
143+
assert outputs is not None
144+
assert "pixel_values" in outputs
145+
assert "old_per_token_logps" in outputs
146+
assert len(self.trainer.replay_buffer.heap) == 2
147+
for pid in outputs["prompt_ids"]:
148+
assert pid.tolist() not in original_prompt_ids.tolist()
149+
150+
def test_update_with_replay_buffer_with_variance(self):
151+
self._prepopulate_buffer()
152+
group_advantages = torch.tensor([[0.6, 0.4], [0.7, 1.2]]) # has variance
153+
inputs = self._make_inputs(group_advantages)
154+
155+
sampled = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
156+
157+
assert len(self.trainer.replay_buffer.heap) == 4 # grew
158+
assert sampled is None
159+
160+
def test_update_with_mixed_variance(self):
161+
self._prepopulate_buffer()
162+
group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance
163+
inputs = self._make_inputs(group_advantages)
164+
original_prompt_ids = inputs["prompt_ids"].clone().view(-1, self.trainer.num_generations, 2).tolist()
165+
166+
outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
167+
168+
assert len(self.trainer.replay_buffer.heap) == 3 # grew by 1
169+
output_prompt_ids = outputs["prompt_ids"].view(-1, self.trainer.num_generations, 2).tolist()
170+
171+
buffer_ids = [item[1]["prompt_ids"].tolist() for item in self.trainer.replay_buffer.heap]
172+
found_from_buffer = any(pid in buffer_ids for pid in output_prompt_ids)
173+
found_from_original = any(pid in original_prompt_ids for pid in output_prompt_ids)
174+
175+
assert found_from_buffer
176+
assert found_from_original
177+
assert [[1, 2], [3, 4]] not in output_prompt_ids # excluded no-variance group
178+
179+
def test_update_with_inputs_different_seq_len(self):
180+
"""
181+
Test with inputs where the sequence lengths are different from the prepopulated buffer.
182+
"""
183+
self._prepopulate_buffer()
184+
pad_token_id = self.trainer.processing_class.pad_token_id
185+
group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance
186+
inputs = {
187+
"group_advantages": group_advantages,
188+
"prompt_ids": torch.tensor(
189+
[
190+
[1, 2, pad_token_id],
191+
[1, 2, pad_token_id],
192+
[3, 4, 5],
193+
[3, 4, 5],
194+
]
195+
),
196+
"prompt_mask": torch.tensor([[1, 1, 0], [1, 1, 0], [1, 1, 1], [1, 1, 1]], dtype=torch.long),
197+
"completion_ids": torch.tensor(
198+
[
199+
[1009, 1010, pad_token_id],
200+
[1011, 1012, 1013],
201+
[1013, 1014, pad_token_id],
202+
[1015, 1016, 1017],
203+
]
204+
),
205+
"completion_mask": torch.tensor([[1, 1, 0], [1, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.long),
206+
"prompt_inputs": {},
207+
}
208+
inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages)
209+
210+
outputs_after_sampling = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
211+
# Seq length of current batch should be preserved
212+
assert outputs_after_sampling["prompt_ids"].shape[-1] == 3
213+
assert len(self.trainer.replay_buffer.heap) == 3
214+
output_prompt_ids = outputs_after_sampling["prompt_ids"].view(-1, self.trainer.num_generations, 3).tolist()
215+
216+
buffered_prompt_completion_ids = [
217+
(item[1]["prompt_ids"].tolist(), item[1]["completion_ids"].tolist())
218+
for item in self.trainer.replay_buffer.heap
219+
]
220+
buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids)
221+
222+
# Check for new entry with seq len 3 in buffer
223+
assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group
224+
assert [
225+
[1013, 1014, pad_token_id],
226+
[1015, 1016, 1017],
227+
] in buffered_completion_ids # excluded no-variance group
228+
229+
# Check that sampled outputs contain one group with prompt_ids starting with a pad token
230+
assert [
231+
[pad_token_id, 101, 102],
232+
[pad_token_id, 102, 103],
233+
] in output_prompt_ids or [
234+
[pad_token_id, 104, 105],
235+
[pad_token_id, 106, 107],
236+
] in output_prompt_ids
237+
238+
239+
@pytest.mark.low_priority
240+
class TestGRPOWithReplayBufferTrainer(TrlTestCase):
241+
def test_training_with_replay_buffer(self):
242+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
243+
244+
# Guarantee that some rewards have 0 std
245+
def custom_reward_func(completions, **kwargs):
246+
if torch.rand(1).item() < 0.25:
247+
return [0] * len(completions) # simulate some None rewards
248+
else:
249+
return torch.rand(len(completions)).tolist()
250+
251+
training_args = GRPOWithReplayBufferConfig(
252+
output_dir=self.tmp_dir,
253+
learning_rate=0.1, # increase the learning rate to speed up the test
254+
per_device_train_batch_size=4, # reduce the batch size to reduce memory usage
255+
num_generations=4, # reduce the number of generations to reduce memory usage
256+
max_completion_length=8, # reduce the completion length to reduce memory usage
257+
replay_buffer_size=8,
258+
report_to="none",
259+
)
260+
trainer = GRPOTrainer(
261+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
262+
reward_funcs=[custom_reward_func],
263+
args=training_args,
264+
train_dataset=dataset,
265+
)
266+
267+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
268+
269+
trainer.train()
270+
271+
assert trainer.state.log_history[-1]["train_loss"] is not None
272+
273+
# Check that the params have changed
274+
for n, param in previous_trainable_params.items():
275+
new_param = trainer.model.get_parameter(n)
276+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

0 commit comments

Comments
 (0)