Skip to content

Commit 2e4fbab

Browse files
committed
update data protocol
1 parent b925fd8 commit 2e4fbab

File tree

3 files changed

+271
-40
lines changed

3 files changed

+271
-40
lines changed

.github/requirements-test.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ codetiming
22
datasets
33
pillow
44
pytest
5+
ray[default]
56
ruff
7+
tensordict
68
torch
79
torchvision
810
transformers

tests/test_dataproto.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import os
17+
from typing import Any, Dict, List, Optional
18+
19+
import numpy as np
20+
import pytest
21+
import torch
22+
23+
from verl.protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto
24+
25+
26+
def _get_data_proto(
27+
tensors: Optional[Dict[str, List[Any]]] = None,
28+
non_tensors: Optional[Dict[str, List[Any]]] = None,
29+
meta_info: Optional[Dict[str, Any]] = None,
30+
) -> DataProto:
31+
if tensors is None and non_tensors is None:
32+
tensors = {"obs": [1, 2, 3, 4, 5, 6]}
33+
non_tensors = {"labels": ["a", "b", "c", "d", "e", "f"]}
34+
35+
if tensors is not None:
36+
tensors = {k: torch.tensor(v) if not isinstance(v, torch.Tensor) else v for k, v in tensors.items()}
37+
38+
if non_tensors is not None:
39+
non_tensors = {
40+
k: np.array(v, dtype=object) if not isinstance(v, np.ndarray) else v for k, v in non_tensors.items()
41+
}
42+
43+
meta_info = meta_info or {"info": "test_info"}
44+
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info)
45+
46+
47+
def _assert_equal(data1: DataProto, data2: Optional[DataProto] = None):
48+
data2 = data2 or _get_data_proto()
49+
if data1.batch is not None:
50+
assert data1.batch.keys() == data2.batch.keys()
51+
for key in data1.batch.keys():
52+
assert torch.all(data1.batch[key] == data2.batch[key])
53+
else:
54+
assert data2.batch is None
55+
56+
if data1.non_tensor_batch is not None:
57+
assert data1.non_tensor_batch.keys() == data2.non_tensor_batch.keys()
58+
for key in data1.non_tensor_batch.keys():
59+
assert np.all(data1.non_tensor_batch[key] == data2.non_tensor_batch[key])
60+
else:
61+
assert data2.non_tensor_batch is None
62+
63+
assert data1.meta_info == data2.meta_info
64+
65+
66+
def test_tensor_dict_constructor():
67+
obs = torch.randn(100, 10)
68+
act = torch.randn(100, 10, 3)
69+
data = DataProto.from_dict(tensors={"obs": obs, "act": act})
70+
assert len(data) == 100
71+
72+
with pytest.raises(AssertionError):
73+
data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=2)
74+
75+
with pytest.raises(AssertionError):
76+
data = DataProto.from_dict(tensors={"obs": obs, "act": act}, num_batch_dims=3)
77+
78+
labels = np.array(["a", "b", "c"], dtype=object)
79+
data = DataProto.from_dict(non_tensors={"labels": labels})
80+
assert len(data) == 3
81+
82+
83+
def test_getitem():
84+
data = _get_data_proto()
85+
assert data[0].batch["obs"] == torch.tensor(1)
86+
assert data[0].non_tensor_batch["labels"] == "a"
87+
_assert_equal(data[1:3], _get_data_proto({"obs": [2, 3]}, {"labels": ["b", "c"]}))
88+
_assert_equal(data[[0, 2]], _get_data_proto({"obs": [1, 3]}, {"labels": ["a", "c"]}))
89+
_assert_equal(data[torch.tensor([1])], _get_data_proto({"obs": [2]}, {"labels": ["b"]}))
90+
91+
92+
def test_select_pop():
93+
obs = torch.randn(100, 10)
94+
act = torch.randn(100, 3)
95+
dataset = _get_data_proto(tensors={"obs": obs, "act": act}, meta_info={"p": 1, "q": 2})
96+
selected_dataset = dataset.select(batch_keys=["obs"], meta_info_keys=["p"])
97+
98+
assert selected_dataset.batch.keys() == {"obs"}
99+
assert selected_dataset.meta_info.keys() == {"p"}
100+
assert dataset.batch.keys() == {"obs", "act"}
101+
assert dataset.meta_info.keys() == {"p", "q"}
102+
103+
popped_dataset = dataset.pop(batch_keys=["obs"], meta_info_keys=["p"])
104+
assert popped_dataset.batch.keys() == {"obs"}
105+
assert popped_dataset.meta_info.keys() == {"p"}
106+
assert dataset.batch.keys() == {"act"}
107+
assert dataset.meta_info.keys() == {"q"}
108+
109+
110+
def test_chunk_concat_split():
111+
data = _get_data_proto()
112+
with pytest.raises(AssertionError):
113+
data.chunk(5)
114+
115+
chunked_data = data.chunk(2)
116+
117+
assert len(chunked_data) == 2
118+
expected_data = _get_data_proto({"obs": [1, 2, 3]}, {"labels": ["a", "b", "c"]})
119+
_assert_equal(chunked_data[0], expected_data)
120+
121+
concat_data = DataProto.concat(chunked_data)
122+
_assert_equal(concat_data, data)
123+
124+
splitted_data = data.split(2)
125+
assert len(splitted_data) == 3
126+
expected_data = _get_data_proto({"obs": [1, 2]}, {"labels": ["a", "b"]})
127+
_assert_equal(splitted_data[0], expected_data)
128+
129+
130+
def test_reorder():
131+
data = _get_data_proto()
132+
data.reorder(torch.tensor([3, 4, 2, 0, 1, 5]))
133+
expected_data = _get_data_proto({"obs": [4, 5, 3, 1, 2, 6]}, {"labels": ["d", "e", "c", "a", "b", "f"]})
134+
_assert_equal(data, expected_data)
135+
136+
137+
@pytest.mark.parametrize("interleave", [True, False])
138+
def test_repeat(interleave: bool):
139+
data = _get_data_proto({"obs": [1, 2]}, {"labels": ["a", "b"]})
140+
repeated_data = data.repeat(repeat_times=2, interleave=interleave)
141+
expected_tensors = {"obs": [1, 1, 2, 2] if interleave else [1, 2, 1, 2]}
142+
expected_non_tensors = {"labels": ["a", "a", "b", "b"] if interleave else ["a", "b", "a", "b"]}
143+
_assert_equal(repeated_data, _get_data_proto(expected_tensors, expected_non_tensors))
144+
145+
146+
@pytest.mark.parametrize("size_divisor", [2, 3])
147+
def test_dataproto_pad_unpad(size_divisor: int):
148+
data = _get_data_proto({"obs": [1, 2, 3]}, {"labels": ["a", "b", "c"]})
149+
# test size_divisor=2
150+
padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=size_divisor)
151+
unpadded_data = unpad_dataproto(padded_data, pad_size=pad_size)
152+
153+
if size_divisor == 2:
154+
assert pad_size == 1
155+
expected_tensors = {"obs": [1, 2, 3, 1]}
156+
expected_non_tensors = {"labels": ["a", "b", "c", "a"]}
157+
expected_data = _get_data_proto(expected_tensors, expected_non_tensors)
158+
else:
159+
assert pad_size == 0
160+
expected_data = data
161+
162+
_assert_equal(padded_data, expected_data)
163+
_assert_equal(unpadded_data, data)
164+
165+
166+
def test_data_proto_save_load():
167+
data = _get_data_proto()
168+
data.save_to_disk("test_data.pt")
169+
loaded_data = DataProto.load_from_disk("test_data.pt")
170+
os.remove("test_data.pt")
171+
_assert_equal(data, loaded_data)
172+
173+
174+
def test_union_tensor_dict():
175+
obs = torch.randn(100, 10)
176+
data1 = _get_data_proto({"obs": obs, "act": torch.randn(100, 3)})
177+
data2 = _get_data_proto({"obs": obs, "rew": torch.randn(100)})
178+
data1.union(data2)
179+
180+
data1 = _get_data_proto({"obs": obs, "act": torch.randn(100, 3)})
181+
data2 = _get_data_proto({"obs": obs + 1, "rew": torch.randn(100)})
182+
with pytest.raises(ValueError):
183+
data1.union(data2)

0 commit comments

Comments
 (0)