Skip to content
224 changes: 223 additions & 1 deletion tests/test_protocol_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
from tensordict import TensorDict

from verl import DataProto
from verl.protocol import union_numpy_dict, union_tensor_dict
from verl.protocol import (
deserialize_single_tensor,
deserialize_tensordict,
serialize_single_tensor,
serialize_tensordict,
union_numpy_dict,
union_tensor_dict,
)


def test_union_tensor_dict():
Expand Down Expand Up @@ -614,3 +621,218 @@ def test_to_tensordict():
assert torch.all(torch.eq(output["obs"], obs)).item()
assert output["labels"] == labels
assert output["name"] == "abdce"


def test_serialize_deserialize_single_tensor():
"""Test serialization and deserialization of a single tensor"""
# Create test tensor
original_tensor = torch.randn(3, 4, 5)

# Serialize
dtype, shape, data = serialize_single_tensor(original_tensor)

# Deserialize
reconstructed_tensor = deserialize_single_tensor((dtype, shape, data))

# Verify results
assert torch.allclose(original_tensor, reconstructed_tensor)
assert original_tensor.shape == reconstructed_tensor.shape
assert original_tensor.dtype == reconstructed_tensor.dtype


def test_serialize_deserialize_tensordict_regular_tensors():
"""Test serialization and deserialization of TensorDict with regular tensors"""
# Create test data
batch_size = (5, 3)
tensor1 = torch.randn(*batch_size, 4)
tensor2 = torch.randint(0, 10, (*batch_size, 2))

# Create TensorDict
original_tensordict = TensorDict({"tensor1": tensor1, "tensor2": tensor2}, batch_size=batch_size)

# Serialize
batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)

# Deserialize
reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))

# Verify results
assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())

for key in original_tensordict.keys():
original_tensor = original_tensordict[key]
reconstructed_tensor = reconstructed_tensordict[key]

assert torch.allclose(original_tensor, reconstructed_tensor)
assert original_tensor.shape == reconstructed_tensor.shape
assert original_tensor.dtype == reconstructed_tensor.dtype


def test_serialize_deserialize_tensordict_nested_tensors():
"""Test serialization and deserialization of TensorDict with nested tensors"""
# Create nested tensor
tensor_list = [torch.randn(2, 3), torch.randn(3, 4), torch.randn(1, 5)]
nested_tensor = torch.nested.as_nested_tensor(tensor_list)

# Create regular tensor for comparison
regular_tensor = torch.randn(3, 4, 5)

# Create TensorDict
original_tensordict = TensorDict({"nested": nested_tensor, "regular": regular_tensor}, batch_size=(3,))

# Serialize
batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)

# Deserialize
reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))

# Verify results
assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())

# Verify regular tensor
original_regular = original_tensordict["regular"]
reconstructed_regular = reconstructed_tensordict["regular"]

assert torch.allclose(original_regular, reconstructed_regular)
assert original_regular.shape == reconstructed_regular.shape
assert original_regular.dtype == reconstructed_regular.dtype

# Verify nested tensor
original_nested = original_tensordict["nested"]
reconstructed_nested = reconstructed_tensordict["nested"]

# Check if it's a nested tensor
assert original_nested.is_nested
assert reconstructed_nested.is_nested

# Check layout
assert original_nested.layout == reconstructed_nested.layout

# Check each tensor after unbinding
original_unbind = original_nested.unbind()
reconstructed_unbind = reconstructed_nested.unbind()

assert len(original_unbind) == len(reconstructed_unbind)

for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False):
assert torch.allclose(orig, recon)
assert orig.shape == recon.shape
assert orig.dtype == recon.dtype


def test_serialize_deserialize_tensordict_mixed_types():
"""Test serialization and deserialization of TensorDict with mixed tensor types"""
# Create tensors with different data types
float_tensor = torch.randn(2, 3).float()
double_tensor = torch.randn(2, 3).double()
int_tensor = torch.randint(0, 10, (2, 3)).int()
long_tensor = torch.randint(0, 10, (2, 3)).long()
bool_tensor = torch.tensor([[True, False], [False, True]])
bfloat16_tensor = torch.randn(2, 3).bfloat16()

# Add fp8 tensor (if available)
# Note: FP8 is not natively supported in all PyTorch versions
# We'll check if it's available and conditionally include it
has_fp8 = hasattr(torch, "float8_e5m2") or hasattr(torch, "float8_e4m3fn")
if has_fp8:
try:
# Try to create an FP8 tensor (implementation may vary)
# This is a placeholder - actual FP8 support might require specific hardware
fp8_tensor = torch.randn(2, 3)
if hasattr(torch, "float8_e5m2"):
fp8_tensor = fp8_tensor.to(torch.float8_e5m2)
elif hasattr(torch, "float8_e4m3fn"):
fp8_tensor = fp8_tensor.to(torch.float8_e4m3fn)
except Exception:
has_fp8 = False

# Create nested tensor
tensor_list = [
torch.randn(2, 3),
torch.randn(3, 4),
]
nested_tensor = torch.nested.as_nested_tensor(tensor_list)

# Create TensorDict with all available types
tensordict_data = {
"float": float_tensor,
"double": double_tensor,
"int": int_tensor,
"long": long_tensor,
"bool": bool_tensor,
"bfloat16": bfloat16_tensor,
"nested": nested_tensor,
}

# Conditionally add fp8 tensor if available
if has_fp8:
tensordict_data["fp8"] = fp8_tensor

original_tensordict = TensorDict(
tensordict_data,
batch_size=(2,),
)

# Serialize
batch_size_serialized, device, encoded_items = serialize_tensordict(original_tensordict)

# Deserialize
reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device, encoded_items))

# Verify results
assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())

for key in original_tensordict.keys():
original_tensor = original_tensordict[key]
reconstructed_tensor = reconstructed_tensordict[key]

if original_tensor.is_nested:
# For nested tensors, check each tensor after unbinding
original_unbind = original_tensor.unbind()
reconstructed_unbind = reconstructed_tensor.unbind()

assert len(original_unbind) == len(reconstructed_unbind)

for orig, recon in zip(original_unbind, reconstructed_unbind, strict=False):
assert torch.allclose(orig, recon, equal_nan=True)
assert orig.shape == recon.shape
assert orig.dtype == recon.dtype
else:
# For regular tensors, compare directly
assert torch.all(original_tensor == reconstructed_tensor)
assert original_tensor.shape == reconstructed_tensor.shape
assert original_tensor.dtype == reconstructed_tensor.dtype


def test_serialize_deserialize_tensordict_with_device():
"""Test serialization and deserialization of TensorDict with device information"""
# Create test data
batch_size = (2, 3)
tensor1 = torch.randn(*batch_size, 4)
tensor2 = torch.randint(0, 10, (*batch_size, 2))

# Create TensorDict with device information
device = "cuda" if torch.cuda.is_available() else "cpu"
original_tensordict = TensorDict({"tensor1": tensor1, "tensor2": tensor2}, batch_size=batch_size, device=device)

# Serialize
batch_size_serialized, device_serialized, encoded_items = serialize_tensordict(original_tensordict)

# Deserialize
reconstructed_tensordict = deserialize_tensordict((batch_size_serialized, device_serialized, encoded_items))

# Verify results
assert original_tensordict.batch_size == reconstructed_tensordict.batch_size
assert str(original_tensordict.device) == str(reconstructed_tensordict.device)
assert set(original_tensordict.keys()) == set(reconstructed_tensordict.keys())

for key in original_tensordict.keys():
original_tensor = original_tensordict[key]
reconstructed_tensor = reconstructed_tensordict[key]

assert torch.allclose(original_tensor.cpu(), reconstructed_tensor.cpu())
assert original_tensor.shape == reconstructed_tensor.shape
assert original_tensor.dtype == reconstructed_tensor.dtype
108 changes: 91 additions & 17 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,61 @@ def unfold_batch_dim(data: "DataProto", batch_dims=2):
return type(data)(batch=tensor, non_tensor_batch=non_tensor_new, meta_info=data.meta_info)


def serialize_single_tensor(obj: torch.Tensor) -> tuple[str, tuple[int, ...], int | memoryview]:
data = obj.flatten().contiguous().view(torch.uint8).numpy()
dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data


def serialize_tensordict(batch: TensorDict) -> tuple[tuple[int, ...], Optional[str], dict[str, tuple[str, Any]]]:
encoded_items: dict[str, tuple[Any]] = {}
for k, v in batch.items():
if not v.is_nested:
encoded_items[k] = serialize_single_tensor(v)
else:
layout = str(v.layout).removeprefix("torch.")
data = [serialize_single_tensor(tensor) for tensor in v.unbind()]
encoded_items[k] = (layout, data)

batch_size = tuple(batch.batch_size)
device = str(batch.device) if batch.device is not None else None
return batch_size, device, encoded_items


def deserialize_single_tensor(arr: Any) -> torch.Tensor:
dtype, shape, data = arr

torch_dtype = getattr(torch, dtype)
assert isinstance(torch_dtype, torch.dtype)

buffer = bytearray(data)
# Create uint8 array
arr = torch.frombuffer(buffer, dtype=torch.uint8)
# Convert back to proper shape & type
return arr.view(torch_dtype).view(shape)


def deserialize_tensordict(arr: Any) -> TensorDict:
batch_size, device, encoded_items = arr
decoded_items: dict[str, Any] = {}

for k, v in encoded_items.items():
if len(v) == 3:
# decode single tensor
decoded_items[k] = deserialize_single_tensor(v)
elif len(v) == 2:
# decode nested tensor
layout, data = v
torch_layout = getattr(torch, layout)
decoded_items[k] = torch.nested.as_nested_tensor(
[deserialize_single_tensor(tensor) for tensor in data], layout=torch_layout
)
else:
raise ValueError(f"Invalid tensor encoding format, expected length 2 or 3, got {len(v)}")

return TensorDict(source=decoded_items, batch_size=batch_size, device=device)


def collate_fn(x: list["DataProtoItem"]):
batch = []
non_tensor_batch = []
Expand Down Expand Up @@ -331,28 +386,47 @@ def __getitem__(self, item):
raise TypeError(f"Indexing with {type(item)} is not supported")

def __getstate__(self):
import io

buffer = io.BytesIO()
if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None:
batch_to_save = self.batch.contiguous().consolidate()
batch = self.batch.contiguous().consolidate()
else:
batch_to_save = self.batch
torch.save(batch_to_save, buffer)
buffer_bytes = buffer.getvalue()
return buffer_bytes, self.non_tensor_batch, self.meta_info
batch = self.batch

def __setstate__(self, data):
import io
if os.getenv("VERL_DATAPROTO_SERIALIZATION_METHOD") == "numpy":
if batch is not None:
batch = serialize_tensordict(self.batch)

return (
batch,
self.non_tensor_batch,
self.meta_info,
)
else:
import io

buffer = io.BytesIO()
torch.save(batch, buffer)
buffer_bytes = buffer.getvalue()
return buffer_bytes, self.non_tensor_batch, self.meta_info

def __setstate__(self, data):
batch_deserialized_bytes, non_tensor_batch, meta_info = data
batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)
batch = torch.load(
batch_deserialized,
weights_only=False,
map_location="cpu" if not get_torch_device().is_available() else None,
)
self.batch = batch

if os.getenv("VERL_DATAPROTO_SERIALIZATION_METHOD") == "numpy":
if batch_deserialized_bytes is not None:
self.batch = deserialize_tensordict(batch_deserialized_bytes)
else:
self.batch = None
else:
import io

batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)
batch = torch.load(
batch_deserialized,
weights_only=False,
map_location="cpu" if not get_torch_device().is_available() else None,
)
self.batch = batch

self.non_tensor_batch = non_tensor_batch
self.meta_info = meta_info

Expand Down