Skip to content
31 changes: 12 additions & 19 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,14 @@ def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str
return tensor_dict1


def numpy_dict_to_tensor_dict(numpy_dict: dict[str, np.ndarray]) -> TensorDict:
"""Convert a dictionary of numpy arrays to a tensordict"""
tensor_dict = tensordict.TensorDict()
for key, val in numpy_dict.items():
tensor_dict[key] = torch.from_numpy(val)
return tensor_dict


def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
if len(list_of_dict) == 0:
return {}
Expand Down Expand Up @@ -331,28 +339,13 @@ def __getitem__(self, item):
raise TypeError(f"Indexing with {type(item)} is not supported")

def __getstate__(self):
import io

buffer = io.BytesIO()
if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None:
batch_to_save = self.batch.contiguous().consolidate()
else:
batch_to_save = self.batch
torch.save(batch_to_save, buffer)
buffer_bytes = buffer.getvalue()
return buffer_bytes, self.non_tensor_batch, self.meta_info
return pickle.dumps(self.batch.numpy()), self.non_tensor_batch, self.meta_info
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of __getstate__ will raise an AttributeError if self.batch is None, as it would attempt to call .numpy() on None. DataProto objects can be initialized with batch=None, so this case must be handled to prevent crashes during serialization.

Suggested change
def __getstate__(self):
import io
buffer = io.BytesIO()
if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None:
batch_to_save = self.batch.contiguous().consolidate()
else:
batch_to_save = self.batch
torch.save(batch_to_save, buffer)
buffer_bytes = buffer.getvalue()
return buffer_bytes, self.non_tensor_batch, self.meta_info
return pickle.dumps(self.batch.numpy()), self.non_tensor_batch, self.meta_info
def __getstate__(self):
batch_bytes = pickle.dumps(self.batch.numpy()) if self.batch is not None else None
return batch_bytes, self.non_tensor_batch, self.meta_info


def __setstate__(self, data):
import io

batch_deserialized_bytes, non_tensor_batch, meta_info = data
batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)
batch = torch.load(
batch_deserialized,
weights_only=False,
map_location="cpu" if not get_torch_device().is_available() else None,
)
self.batch = batch
batch_deserialized_bytes = pickle.loads(batch_deserialized_bytes)

self.batch = numpy_dict_to_tensor_dict(batch_deserialized_bytes)
self.non_tensor_batch = non_tensor_batch
self.meta_info = meta_info
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This implementation of __setstate__ does not correctly handle deserialization when the original self.batch was None. Assuming __getstate__ is fixed to handle the None case, batch_deserialized_bytes could be None, which would cause pickle.loads(None) to raise a TypeError. The logic should check for None before attempting to deserialize.

    def __setstate__(self, data):
        batch_deserialized_bytes, non_tensor_batch, meta_info = data
        if batch_deserialized_bytes is not None:
            numpy_dict = pickle.loads(batch_deserialized_bytes)
            self.batch = numpy_dict_to_tensor_dict(numpy_dict)
        else:
            self.batch = None
        self.non_tensor_batch = non_tensor_batch
        self.meta_info = meta_info

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new serialization mechanism loses the device information of the tensors in self.batch. The previous implementation using torch.load had a map_location argument that provided control over the device on which tensors were loaded. The new implementation always deserializes tensors to the CPU because torch.from_numpy creates CPU tensors. This is a functional regression that could lead to performance degradation from unnecessary device transfers (e.g., GPU -> CPU -> GPU) or errors if downstream code expects tensors on a specific device.

To address this, you should store the device in __getstate__ and use it in __setstate__ to restore the tensors to their original device.

Here is an example of how you could modify both methods to preserve device information (this also includes the fix for the None batch case):

In __getstate__:

def __getstate__(self):
    device = str(self.batch.device) if self.batch is not None else None
    batch_bytes = pickle.dumps(self.batch.numpy()) if self.batch is not None else None
    return batch_bytes, self.non_tensor_batch, self.meta_info, device

In __setstate__:

def __setstate__(self, data):
    # Handle both old and new format for backward compatibility
    if len(data) == 4:
        batch_bytes, non_tensor_batch, meta_info, device = data
    else:
        batch_bytes, non_tensor_batch, meta_info = data
        device = None

    if batch_bytes is not None:
        numpy_dict = pickle.loads(batch_bytes)
        self.batch = numpy_dict_to_tensor_dict(numpy_dict)
        if device and device != "cpu":
            self.batch = self.batch.to(device)
    else:
        self.batch = None
    
    self.non_tensor_batch = non_tensor_batch
    self.meta_info = meta_info


Expand Down
Loading