Skip to content

Commit 9843859

Browse files
committed
fix bug of batch_size omission
1 parent cfa3bcd commit 9843859

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

verl/protocol.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,16 @@ def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str
198198
return tensor_dict1
199199

200200

201-
def numpy_dict_to_tensor_dict(numpy_dict: dict[str, np.ndarray]) -> TensorDict:
201+
def numpy_dict_to_tensor_dict(numpy_dict: dict[str, np.ndarray], batch_size=None) -> TensorDict:
202202
"""Convert a dictionary of numpy arrays to a tensordict"""
203-
tensor_dict = tensordict.TensorDict()
203+
tensor_dict = {}
204204
for key, val in numpy_dict.items():
205-
tensor_dict[key] = torch.from_numpy(val)
206-
return tensor_dict
205+
if isinstance(val, np.ndarray):
206+
tensor_dict[key] = torch.from_numpy(val)
207+
else:
208+
tensor_dict[key] = val
209+
210+
return TensorDict(tensor_dict, batch_size=batch_size)
207211

208212

209213
def list_of_dict_to_dict_of_list(list_of_dict: list[dict]):
@@ -339,18 +343,13 @@ def __getitem__(self, item):
339343
raise TypeError(f"Indexing with {type(item)} is not supported")
340344

341345
def __getstate__(self):
342-
return pickle.dumps(self.batch.numpy()), self.non_tensor_batch, self.meta_info
346+
return pickle.dumps(self.batch.numpy()), self.batch.batch_size, self.non_tensor_batch, self.meta_info
343347

344348
def __setstate__(self, data):
345-
batch_deserialized_bytes, non_tensor_batch, meta_info = data
349+
batch_deserialized_bytes, batch_size, non_tensor_batch, meta_info = data
346350
batch_deserialized = pickle.loads(batch_deserialized_bytes)
347-
348-
tensor_dict = torch.utils._pytree.tree_map(
349-
lambda x: torch.from_numpy(x) if isinstance(x, np.ndarray) else x,
350-
batch_deserialized
351-
)
352-
353-
self.batch = TensorDict.from_dict(tensor_dict)
351+
352+
self.batch = numpy_dict_to_tensor_dict(batch_deserialized, batch_size=batch_size)
354353
self.non_tensor_batch = non_tensor_batch
355354
self.meta_info = meta_info
356355

0 commit comments

Comments
 (0)