Skip to content

Commit 05c4364

Browse files
committed
support nested tensor
1 parent 6cce18c commit 05c4364

File tree

2 files changed

+258
-38
lines changed

2 files changed

+258
-38
lines changed

tests/test_protocol_on_cpu.py

Lines changed: 199 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@
2222
from tensordict import TensorDict
2323

2424
from verl import DataProto
25-
from verl.protocol import union_numpy_dict, union_tensor_dict
25+
from verl.protocol import (
26+
deserialize_single_tensor,
27+
deserialize_tensordict,
28+
serialize_single_tensor,
29+
serialize_tensordict,
30+
union_numpy_dict,
31+
union_tensor_dict,
32+
)
2633

2734

2835
def test_union_tensor_dict():
@@ -614,3 +621,194 @@ def test_to_tensordict():
614621
assert torch.all(torch.eq(output["obs"], obs)).item()
615622
assert output["labels"] == labels
616623
assert output["name"] == "abdce"
624+
625+
626+
def test_serialize_deserialize_single_tensor():
627+
"""Test serialization and deserialization of a single tensor"""
628+
# Create test tensor
629+
original_tensor = torch.randn(3, 4, 5)
630+
631+
# Serialize
632+
dtype, shape, data = serialize_single_tensor(original_tensor)
633+
634+
# Deserialize
635+
reconstructed_tensor = deserialize_single_tensor((dtype, shape, data))
636+
637+
# Verify results
638+
assert torch.allclose(original_tensor, reconstructed_tensor)
639+
assert original_tensor.shape == reconstructed_tensor.shape
640+
assert original_tensor.dtype == reconstructed_tensor.dtype
641+
642+
643+
def test_serialize_deserialize_tensordict_regular_tensors():
644+
"""Test serialization and deserialization of TensorDict with regular tensors"""
645+
# Create test data
646+
batch_size = (5, 3)
647+
tensor1 = torch.randn(*batch_size, 4)
648+
tensor2 = torch.randint(0, 10, (*batch_size, 2))
649+
650+
# Create TensorDict
651+
original_tensordict = TensorDict({"tensor1": tensor1, "tensor2": tensor2}, batch_size=batch_size)
652+
653+
# Serialize
654+
batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)
655+
656+
# Deserialize
657+
reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))
658+
659+
# Verify results
660+
assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
661+
assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())
662+
663+
for key in original_tensordict.keys():
664+
original_tensor = original_tensordict[key]
665+
reconstructed_tensor = reconstructed_tensordict[key]
666+
667+
assert torch.allclose(original_tensor, reconstructed_tensor)
668+
assert original_tensor.shape == reconstructed_tensor.shape
669+
assert original_tensor.dtype == reconstructed_tensor.dtype
670+
671+
672+
def test_serialize_deserialize_tensordict_nested_tensors():
673+
"""Test serialization and deserialization of TensorDict with nested tensors"""
674+
# Create nested tensor
675+
tensor_list = [torch.randn(2, 3), torch.randn(3, 4), torch.randn(1, 5)]
676+
nested_tensor = torch.nested.as_nested_tensor(tensor_list)
677+
678+
# Create regular tensor for comparison
679+
regular_tensor = torch.randn(3, 4, 5)
680+
681+
# Create TensorDict
682+
original_tensordict = TensorDict({"nested": nested_tensor, "regular": regular_tensor}, batch_size=(3,))
683+
684+
# Serialize
685+
batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)
686+
687+
# Deserialize
688+
reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))
689+
690+
# Verify results
691+
assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
692+
assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())
693+
694+
# Verify regular tensor
695+
original_regular = original_tensordict["regular"]
696+
reconstructed_regular = reconstructed_tensordict["regular"]
697+
698+
assert torch.allclose(original_regular, reconstructed_regular)
699+
assert original_regular.shape == reconstructed_regular.shape
700+
assert original_regular.dtype == reconstructed_regular.dtype
701+
702+
# Verify nested tensor
703+
original_nested = original_tensordict["nested"]
704+
reconstructed_nested = reconstructed_tensordict["nested"]
705+
706+
# Check if it's a nested tensor
707+
assert original_nested.is_nested
708+
assert reconstructed_nested.is_nested
709+
710+
# Check layout
711+
assert original_nested.layout == reconstructed_nested.layout
712+
713+
# Check each tensor after unbinding
714+
original_unbind = original_nested.unbind()
715+
reconstructed_unbind = reconstructed_nested.unbind()
716+
717+
assert len(original_unbind) == len(reconstructed_unbind)
718+
719+
for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False):
720+
assert torch.allclose(orig, recon)
721+
assert orig.shape == recon.shape
722+
assert orig.dtype == recon.dtype
723+
724+
725+
def test_serialize_deserialize_tensordict_mixed_types():
726+
"""Test serialization and deserialization of TensorDict with mixed tensor types"""
727+
# Create tensors with different data types
728+
float_tensor = torch.randn(2, 3).float()
729+
double_tensor = torch.randn(2, 3).double()
730+
int_tensor = torch.randint(0, 10, (2, 3)).int()
731+
long_tensor = torch.randint(0, 10, (2, 3)).long()
732+
bool_tensor = torch.tensor([[True, False], [False, True]])
733+
734+
# Create nested tensor
735+
tensor_list = [
736+
torch.randn(2, 3),
737+
torch.randn(3, 4),
738+
]
739+
nested_tensor = torch.nested.as_nested_tensor(tensor_list)
740+
741+
# Create TensorDict
742+
original_tensordict = TensorDict(
743+
{
744+
"float": float_tensor,
745+
"double": double_tensor,
746+
"int": int_tensor,
747+
"long": long_tensor,
748+
"bool": bool_tensor,
749+
"nested": nested_tensor,
750+
},
751+
batch_size=(2,),
752+
)
753+
754+
# Serialize
755+
batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)
756+
757+
# Deserialize
758+
reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))
759+
760+
# Verify results
761+
assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
762+
assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())
763+
764+
for key in original_tensordict.keys():
765+
original_tensor = original_tensordict[key]
766+
reconstructed_tensor = reconstructed_tensordict[key]
767+
768+
if original_tensor.is_nested:
769+
# For nested tensors, check each tensor after unbinding
770+
original_unbind = original_tensor.unbind()
771+
reconstructed_unbind = reconstructed_tensor.unbind()
772+
773+
assert len(original_unbind) == len(reconstructed_unbind)
774+
775+
for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False):
776+
assert torch.allclose(orig, recon, equal_nan=True)
777+
assert orig.shape == recon.shape
778+
assert orig.dtype == recon.dtype
779+
else:
780+
# For regular tensors, compare directly
781+
assert torch.allclose(original_tensor, reconstructed_tensor, equal_nan=True)
782+
assert original_tensor.shape == reconstructed_tensor.shape
783+
assert original_tensor.dtype == reconstructed_tensor.dtype
784+
785+
786+
def test_serialize_deserialize_tensordict_with_device():
787+
"""Test serialization and deserialization of TensorDict with device information"""
788+
# Create test data
789+
batch_size = (2, 3)
790+
tensor1 = torch.randn(*batch_size, 4)
791+
tensor2 = torch.randint(0, 10, (*batch_size, 2))
792+
793+
# Create TensorDict with device information
794+
device = "cuda" if torch.cuda.is_available() else "cpu"
795+
original_tensordict = TensorDict({"tensor1": tensor1, "tensor2": tensor2}, batch_size=batch_size, device=device)
796+
797+
# Serialize
798+
batch_size_serialized, device_serialized, encoded_items = serialize_tensordict(original_tensordict)
799+
800+
# Deserialize
801+
reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device_serialized, encoded_items))
802+
803+
# Verify results
804+
assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
805+
assert str(original_tensordict.device) == str(reconstructed_tensordict.device)
806+
assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())
807+
808+
for key in original_tensordict.keys():
809+
original_tensor = original_tensordict[key]
810+
reconstructed_tensor = reconstructed_tensordict[key]
811+
812+
assert torch.allclose(original_tensor.cpu(), reconstructed_tensor.cpu())
813+
assert original_tensor.shape == reconstructed_tensor.shape
814+
assert original_tensor.dtype == reconstructed_tensor.dtype

verl/protocol.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,61 @@ def unfold_batch_dim(data: "DataProto", batch_dims=2):
249249
return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info)
250250

251251

252+
def serialize_single_tensor(obj: torch.Tensor) -> tuple[str, tuple[int, ...], int | memoryview]:
253+
data = obj.flatten().contiguous().view(torch.uint8).numpy()
254+
dtype = str(obj.dtype).removeprefix("torch.")
255+
return dtype, obj.shape, data
256+
257+
258+
def serialize_tensordict(batch: TensorDict) -> tuple[tuple[int, ...], Optional[str], dict[str, tuple[str, Any]]]:
259+
encoded_items: dict[str, tuple[Any]] = {}
260+
for k, v in batch.items():
261+
if not v.is_nested:
262+
encoded_items[k] = serialize_single_tensor(v)
263+
else:
264+
layout = str(v.layout).removeprefix("torch.")
265+
data = [serialize_single_tensor(tensor) for tensor in v.unbind()]
266+
encoded_items[k] = (layout, data)
267+
268+
batch_size = tuple(batch.batch_size)
269+
device = str(batch.device) if batch.device is not None else None
270+
return batch_size, device, encoded_items
271+
272+
273+
def deserialize_single_tensor(arr: Any) -> torch.Tensor:
274+
dtype, shape, data = arr
275+
276+
torch_dtype = getattr(torch, dtype)
277+
assert isinstance(torch_dtype, torch.dtype)
278+
279+
buffer = bytearray(data)
280+
# Create uint8 array
281+
arr = torch.frombuffer(buffer, dtype=torch.uint8)
282+
# Convert back to proper shape & type
283+
return arr.view(torch_dtype).view(shape)
284+
285+
286+
def deserialize_tensordict(arr: Any) -> TensorDict:
287+
batch_size, device, encoded_items = arr
288+
decoded_items: dict[str, Any] = {}
289+
290+
for k, v in encoded_items.items():
291+
if len(v) == 3:
292+
# decode single tensor
293+
decoded_items[k] = deserialize_single_tensor(v)
294+
elif len(v) == 2:
295+
# decode nested tensor
296+
layout, data = v
297+
torch_layout = getattr(torch, layout)
298+
decoded_items[k] = torch.nested.as_nested_tensor(
299+
[deserialize_single_tensor(tensor) for tensor in data], layout=torch_layout
300+
)
301+
else:
302+
raise ValueError(f"Invalid tensor encoding format, expected length 2 or 3, got {len(v)}")
303+
304+
return TensorDict(source=decoded_items, batch_size=batch_size, device=device)
305+
306+
252307
def collate_fn(x: list["DataProtoItem"]):
253308
batch = []
254309
non_tensor_batch = []
@@ -338,28 +393,10 @@ def __getstate__(self):
338393

339394
if os.getenv("VERL_DATAPROTO_SERIALIZATION_METHOD") == "numpy":
340395
if batch is not None:
341-
dtypes = {}
342-
batch_to_serialize = {}
343-
for k, v in batch.items():
344-
dtypes[k] = str(v.dtype).removeprefix("torch.")
345-
if v.dtype == torch.bfloat16:
346-
batch_to_serialize[k] = v.view(torch.uint8).numpy()
347-
else:
348-
batch_to_serialize[k] = v.numpy()
349-
batch_size = batch.batch_size
350-
else:
351-
dtypes = None
352-
batch_to_serialize = None
353-
batch_size = None
396+
batch = serialize_tensordict(self.batch)
354397

355398
return (
356-
pickle.dumps(
357-
{
358-
"batch_size": batch_size,
359-
"dtypes": dtypes,
360-
"data": batch_to_serialize,
361-
}
362-
),
399+
batch,
363400
self.non_tensor_batch,
364401
self.meta_info,
365402
)
@@ -375,23 +412,8 @@ def __setstate__(self, data):
375412
batch_deserialized_bytes, non_tensor_batch, meta_info = data
376413

377414
if os.getenv("VERL_DATAPROTO_SERIALIZATION_METHOD") == "numpy":
378-
batch_deserialized = pickle.loads(batch_deserialized_bytes)
379-
380-
numpy_dict = batch_deserialized["data"]
381-
batch_size = batch_deserialized["batch_size"]
382-
dtypes = batch_deserialized["dtypes"]
383-
if numpy_dict is not None:
384-
tensor_dict = {}
385-
for k, v in numpy_dict.items():
386-
dtype = dtypes[k]
387-
if dtype == "bfloat16":
388-
tensor_dict[k] = torch.from_numpy(v).view(getattr(torch, dtype))
389-
else:
390-
tensor_dict[k] = torch.from_numpy(v)
391-
self.batch = TensorDict(
392-
tensor_dict,
393-
batch_size=batch_size,
394-
)
415+
if batch_deserialized_bytes is not None:
416+
self.batch = deserialize_tensordict(batch_deserialized_bytes)
395417
else:
396418
self.batch = None
397419
else:

0 commit comments

Comments
 (0)