Skip to content

Commit 083da9a

Browse files
authored
[misc] fix: fix DataProto __getstate__ bug (#2962)
1 parent ae28570 commit 083da9a

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

verl/protocol.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,10 @@ def __getstate__(self):
332332

333333
buffer = io.BytesIO()
334334
if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None:
335-
self.batch = self.batch.contiguous()
336-
self.batch = self.batch.consolidate()
337-
torch.save(self.batch, buffer)
335+
batch_to_save = self.batch.contiguous().consolidate()
336+
else:
337+
batch_to_save = self.batch
338+
torch.save(batch_to_save, buffer)
338339
buffer_bytes = buffer.getvalue()
339340
return buffer_bytes, self.non_tensor_batch, self.meta_info
340341

0 commit comments

Comments
 (0)