-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[ray] refactor: Accelerate Tensor serialization by converting to np.ndarray #3425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
0142225
cfa3bcd
9843859
70d84c8
183b9f1
fcd8e92
67f6697
ffadb28
a7f1389
9252ab3
bece1a1
6cce18c
05c4364
1317efc
16e793d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -198,6 +198,14 @@ def union_numpy_dict(tensor_dict1: dict[str, np.ndarray], tensor_dict2: dict[str | |
| return tensor_dict1 | ||
|
|
||
|
|
||
| def numpy_dict_to_tensor_dict(numpy_dict: dict[str, np.ndarray]) -> TensorDict: | ||
| """Convert a dictionary of numpy arrays to a tensordict""" | ||
| tensor_dict = tensordict.TensorDict() | ||
| for key, val in numpy_dict.items(): | ||
| tensor_dict[key] = torch.from_numpy(val) | ||
| return tensor_dict | ||
|
|
||
|
|
||
| def list_of_dict_to_dict_of_list(list_of_dict: list[dict]): | ||
| if len(list_of_dict) == 0: | ||
| return {} | ||
|
|
@@ -331,28 +339,13 @@ def __getitem__(self, item): | |
| raise TypeError(f"Indexing with {type(item)} is not supported") | ||
|
|
||
| def __getstate__(self): | ||
| import io | ||
|
|
||
| buffer = io.BytesIO() | ||
| if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: | ||
| batch_to_save = self.batch.contiguous().consolidate() | ||
| else: | ||
| batch_to_save = self.batch | ||
| torch.save(batch_to_save, buffer) | ||
| buffer_bytes = buffer.getvalue() | ||
| return buffer_bytes, self.non_tensor_batch, self.meta_info | ||
| return pickle.dumps(self.batch.numpy()), self.non_tensor_batch, self.meta_info | ||
|
|
||
| def __setstate__(self, data): | ||
| import io | ||
|
|
||
| batch_deserialized_bytes, non_tensor_batch, meta_info = data | ||
| batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) | ||
| batch = torch.load( | ||
| batch_deserialized, | ||
| weights_only=False, | ||
| map_location="cpu" if not get_torch_device().is_available() else None, | ||
| ) | ||
| self.batch = batch | ||
| batch_deserialized_bytes = pickle.loads(batch_deserialized_bytes) | ||
|
|
||
| self.batch = numpy_dict_to_tensor_dict(batch_deserialized_bytes) | ||
| self.non_tensor_batch = non_tensor_batch | ||
| self.meta_info = meta_info | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation of def __setstate__(self, data):
batch_deserialized_bytes, non_tensor_batch, meta_info = data
if batch_deserialized_bytes is not None:
numpy_dict = pickle.loads(batch_deserialized_bytes)
self.batch = numpy_dict_to_tensor_dict(numpy_dict)
else:
self.batch = None
self.non_tensor_batch = non_tensor_batch
self.meta_info = meta_infoThere was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new serialization mechanism loses the device information of the tensors in To address this, you should store the device in Here is an example of how you could modify both methods to preserve device information (this also includes the fix for the In def __getstate__(self):
device = str(self.batch.device) if self.batch is not None else None
batch_bytes = pickle.dumps(self.batch.numpy()) if self.batch is not None else None
return batch_bytes, self.non_tensor_batch, self.meta_info, deviceIn def __setstate__(self, data):
# Handle both old and new format for backward compatibility
if len(data) == 4:
batch_bytes, non_tensor_batch, meta_info, device = data
else:
batch_bytes, non_tensor_batch, meta_info = data
device = None
if batch_bytes is not None:
numpy_dict = pickle.loads(batch_bytes)
self.batch = numpy_dict_to_tensor_dict(numpy_dict)
if device and device != "cpu":
self.batch = self.batch.to(device)
else:
self.batch = None
self.non_tensor_batch = non_tensor_batch
self.meta_info = meta_info |
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of
__getstate__will raise anAttributeErrorifself.batchisNone, as it would attempt to call.numpy()onNone.DataProtoobjects can be initialized withbatch=None, so this case must be handled to prevent crashes during serialization.