@@ -198,12 +198,16 @@ def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str
198198 return tensor_dict1
199199
200200
201- def numpy_dict_to_tensor_dict (numpy_dict : dict [str , np .ndarray ]) -> TensorDict :
201+ def numpy_dict_to_tensor_dict (numpy_dict : dict [str , np .ndarray ], batch_size = None ) -> TensorDict :
202202 """Convert a dictionary of numpy arrays to a tensordict"""
203- tensor_dict = tensordict . TensorDict ()
203+ tensor_dict = {}
204204 for key , val in numpy_dict .items ():
205- tensor_dict [key ] = torch .from_numpy (val )
206- return tensor_dict
205+ if isinstance (val , np .ndarray ):
206+ tensor_dict [key ] = torch .from_numpy (val )
207+ else :
208+ tensor_dict [key ] = val
209+
210+ return TensorDict (tensor_dict , batch_size = batch_size )
207211
208212
209213def list_of_dict_to_dict_of_list (list_of_dict : list [dict ]):
@@ -339,18 +343,13 @@ def __getitem__(self, item):
339343 raise TypeError (f"Indexing with { type (item )} is not supported" )
340344
341345 def __getstate__ (self ):
342- return pickle .dumps (self .batch .numpy ()), self .non_tensor_batch , self .meta_info
346+ return pickle .dumps (self .batch .numpy ()), self .batch . batch_size , self . non_tensor_batch , self .meta_info
343347
344348 def __setstate__ (self , data ):
345- batch_deserialized_bytes , non_tensor_batch , meta_info = data
349+ batch_deserialized_bytes , batch_size , non_tensor_batch , meta_info = data
346350 batch_deserialized = pickle .loads (batch_deserialized_bytes )
347-
348- tensor_dict = torch .utils ._pytree .tree_map (
349- lambda x : torch .from_numpy (x ) if isinstance (x , np .ndarray ) else x ,
350- batch_deserialized
351- )
352-
353- self .batch = TensorDict .from_dict (tensor_dict )
351+
352+ self .batch = numpy_dict_to_tensor_dict (batch_deserialized , batch_size = batch_size )
354353 self .non_tensor_batch = non_tensor_batch
355354 self .meta_info = meta_info
356355
0 commit comments