We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ae28570 commit 083da9aCopy full SHA for 083da9a
verl/protocol.py
@@ -332,9 +332,10 @@ def __getstate__(self):
332
333
buffer = io.BytesIO()
334
if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None:
335
- self.batch = self.batch.contiguous()
336
- self.batch = self.batch.consolidate()
337
- torch.save(self.batch, buffer)
+ batch_to_save = self.batch.contiguous().consolidate()
+ else:
+ batch_to_save = self.batch
338
+ torch.save(batch_to_save, buffer)
339
buffer_bytes = buffer.getvalue()
340
return buffer_bytes, self.non_tensor_batch, self.meta_info
341
0 commit comments