-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[ray] refactor: Accelerate Tensor serialization by converting to np.ndarray #3425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to optimize serialization performance by switching from torch.save to using numpy and pickle. While the benchmarks show a significant improvement in serialization speed, the current implementation introduces a critical issue where serialization will fail if a DataProto object does not contain a tensor batch (self.batch is None). Furthermore, the new approach loses the device information of the tensors, causing all deserialized tensors to be on the CPU. This is a functional regression from the previous implementation and can lead to performance issues or errors in a multi-device environment. My review includes critical fixes for the crash and a detailed suggestion to preserve device information.
verl/protocol.py
Outdated
| def __getstate__(self): | ||
| import io | ||
|
|
||
| buffer = io.BytesIO() | ||
| if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: | ||
| batch_to_save = self.batch.contiguous().consolidate() | ||
| else: | ||
| batch_to_save = self.batch | ||
| torch.save(batch_to_save, buffer) | ||
| buffer_bytes = buffer.getvalue() | ||
| return buffer_bytes, self.non_tensor_batch, self.meta_info | ||
| return pickle.dumps(self.batch.numpy()), self.non_tensor_batch, self.meta_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of __getstate__ will raise an AttributeError if self.batch is None, as it would attempt to call .numpy() on None. DataProto objects can be initialized with batch=None, so this case must be handled to prevent crashes during serialization.
| def __getstate__(self): | |
| import io | |
| buffer = io.BytesIO() | |
| if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: | |
| batch_to_save = self.batch.contiguous().consolidate() | |
| else: | |
| batch_to_save = self.batch | |
| torch.save(batch_to_save, buffer) | |
| buffer_bytes = buffer.getvalue() | |
| return buffer_bytes, self.non_tensor_batch, self.meta_info | |
| return pickle.dumps(self.batch.numpy()), self.non_tensor_batch, self.meta_info | |
| def __getstate__(self): | |
| batch_bytes = pickle.dumps(self.batch.numpy()) if self.batch is not None else None | |
| return batch_bytes, self.non_tensor_batch, self.meta_info | |
| def __setstate__(self, data): | ||
| import io | ||
|
|
||
| batch_deserialized_bytes, non_tensor_batch, meta_info = data | ||
| batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) | ||
| batch = torch.load( | ||
| batch_deserialized, | ||
| weights_only=False, | ||
| map_location="cpu" if not get_torch_device().is_available() else None, | ||
| ) | ||
| self.batch = batch | ||
| batch_deserialized_bytes = pickle.loads(batch_deserialized_bytes) | ||
|
|
||
| self.batch = numpy_dict_to_tensor_dict(batch_deserialized_bytes) | ||
| self.non_tensor_batch = non_tensor_batch | ||
| self.meta_info = meta_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation of __setstate__ does not correctly handle deserialization when the original self.batch was None. Assuming __getstate__ is fixed to handle the None case, batch_deserialized_bytes could be None, which would cause pickle.loads(None) to raise a TypeError. The logic should check for None before attempting to deserialize.
def __setstate__(self, data):
batch_deserialized_bytes, non_tensor_batch, meta_info = data
if batch_deserialized_bytes is not None:
numpy_dict = pickle.loads(batch_deserialized_bytes)
self.batch = numpy_dict_to_tensor_dict(numpy_dict)
else:
self.batch = None
self.non_tensor_batch = non_tensor_batch
self.meta_info = meta_info| def __getstate__(self): | ||
| import io | ||
|
|
||
| buffer = io.BytesIO() | ||
| if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: | ||
| batch_to_save = self.batch.contiguous().consolidate() | ||
| else: | ||
| batch_to_save = self.batch | ||
| torch.save(batch_to_save, buffer) | ||
| buffer_bytes = buffer.getvalue() | ||
| return buffer_bytes, self.non_tensor_batch, self.meta_info | ||
| return pickle.dumps(self.batch.numpy()), self.non_tensor_batch, self.meta_info | ||
|
|
||
| def __setstate__(self, data): | ||
| import io | ||
|
|
||
| batch_deserialized_bytes, non_tensor_batch, meta_info = data | ||
| batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) | ||
| batch = torch.load( | ||
| batch_deserialized, | ||
| weights_only=False, | ||
| map_location="cpu" if not get_torch_device().is_available() else None, | ||
| ) | ||
| self.batch = batch | ||
| batch_deserialized_bytes = pickle.loads(batch_deserialized_bytes) | ||
|
|
||
| self.batch = numpy_dict_to_tensor_dict(batch_deserialized_bytes) | ||
| self.non_tensor_batch = non_tensor_batch | ||
| self.meta_info = meta_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new serialization mechanism loses the device information of the tensors in self.batch. The previous implementation using torch.load had a map_location argument that provided control over the device on which tensors were loaded. The new implementation always deserializes tensors to the CPU because torch.from_numpy creates CPU tensors. This is a functional regression that could lead to performance degradation from unnecessary device transfers (e.g., GPU -> CPU -> GPU) or errors if downstream code expects tensors on a specific device.
To address this, you should store the device in __getstate__ and use it in __setstate__ to restore the tensors to their original device.
Here is an example of how you could modify both methods to preserve device information (this also includes the fix for the None batch case):
In __getstate__:
def __getstate__(self):
device = str(self.batch.device) if self.batch is not None else None
batch_bytes = pickle.dumps(self.batch.numpy()) if self.batch is not None else None
return batch_bytes, self.non_tensor_batch, self.meta_info, deviceIn __setstate__:
def __setstate__(self, data):
# Handle both old and new format for backward compatibility
if len(data) == 4:
batch_bytes, non_tensor_batch, meta_info, device = data
else:
batch_bytes, non_tensor_batch, meta_info = data
device = None
if batch_bytes is not None:
numpy_dict = pickle.loads(batch_bytes)
self.batch = numpy_dict_to_tensor_dict(numpy_dict)
if device and device != "cpu":
self.batch = self.batch.to(device)
else:
self.batch = None
self.non_tensor_batch = non_tensor_batch
self.meta_info = meta_info|
Could you create a test case for nested tensor? Thanks! https://docs.pytorch.org/docs/stable/nested.html |
|
Could you add an environment variable called VERL_DATAPROTO_SERIALIZATION_METHOD=numpy to toggle this? Some old data proto saved in old format won't be able to load if we directly modify the code |
|
I think this PR is ready for review, and the failed test cases should be unrelated to this PR :). cc @vermouth1992 |
|
When a device is specified, TensorDict automatically transfers tensors to that device. The current serialization process preserves the TensorDict's device information, which is then utilized during deserialization to restore the tensors to their original device. >> import torch
>>> from tensordict import TensorDict
>>> x=torch.rand((2,3))
>>> TensorDict({'x':x})
TensorDict(
fields={
x: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> TensorDict({'x':x}, device="npu")
TensorDict(
fields={
x: Tensor(shape=torch.Size([2, 3]), device=npu:0, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=npu:0,
is_shared=False)
>>> TensorDict({'x':x}, device="npu:0")
TensorDict(
fields={
x: Tensor(shape=torch.Size([2, 3]), device=npu:0, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=npu:0,
is_shared=False) |
…darray (volcengine#3425) ### What does this PR do? For a data size of 6400x20480, the average serialization duration was reduced from 3.32s to 1.32s following this optimization, resulting in a ~151% improvement. ``` # tensor average serialize:2.58s deserialize:0.74s total:3.32s TaskRunner pid=1904793) baymax debug serialize time=2.5947s (TaskRunner pid=1904793) baymax debug serialize time=2.593357s (TaskRunner pid=1904793) baymax debug serialize time=2.580081s (TaskRunner pid=1904793) baymax debug serialize time=2.582321s (WorkerDict pid=1905183) baymax debug deserialize time=0.475745s (WorkerDict pid=1905184) baymax debug deserialize time=0.538223s (WorkerDict pid=1905181) baymax debug deserialize time=0.609146s (WorkerDict pid=1905182) baymax debug deserialize time=0.61064s (WorkerDict pid=1905189) baymax debug deserialize time=0.597746s (WorkerDict pid=1905185) baymax debug deserialize time=0.530353s (WorkerDict pid=1905180) baymax debug deserialize time=0.811555s (WorkerDict pid=1905194) baymax debug deserialize time=0.513646s (WorkerDict pid=1905193) baymax debug deserialize time=0.962868s (WorkerDict pid=1905179) baymax debug deserialize time=0.929226s (WorkerDict pid=1905186) baymax debug deserialize time=0.701976s (WorkerDict pid=1905191) baymax debug deserialize time=0.867236s (WorkerDict pid=1905192) baymax debug deserialize time=0.858472s (WorkerDict pid=1905187) baymax debug deserialize time=1.045251s (WorkerDict pid=1905188) baymax debug deserialize time=0.960867s (WorkerDict pid=1905190) baymax debug deserialize time=1.010673s # numpy average serialize:0.000617s deserialize:1.32s total:1.32s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.00016s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000117s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000158s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000182s �[36m(WorkerDict pid=1730035)�[0m baymax debug deserialize time=0.867232s �[36m(WorkerDict pid=1730036)�[0m baymax debug deserialize time=0.97372s �[36m(WorkerDict pid=1730028)�[0m baymax debug deserialize time=1.08627s �[36m(WorkerDict pid=1730034)�[0m baymax debug deserialize time=1.187599s �[36m(WorkerDict pid=1730037)�[0m baymax debug deserialize time=1.165926s �[36m(WorkerDict pid=1730025)�[0m baymax debug deserialize time=1.281101s �[36m(WorkerDict pid=1730029)�[0m baymax debug deserialize time=1.359834s �[36m(WorkerDict pid=1730027)�[0m baymax debug deserialize time=1.281978s �[36m(WorkerDict pid=1730030)�[0m baymax debug deserialize time=1.329298s �[36m(WorkerDict pid=1730026)�[0m baymax debug deserialize time=1.475415s �[36m(WorkerDict pid=1730031)�[0m baymax debug deserialize time=1.422345s �[36m(WorkerDict pid=1730033)�[0m baymax debug deserialize time=1.378894s �[36m(WorkerDict pid=1730039)�[0m baymax debug deserialize time=1.368721s �[36m(WorkerDict pid=1730040)�[0m baymax debug deserialize time=1.601587s �[36m(WorkerDict pid=1730042)�[0m baymax debug deserialize time=1.768378s �[36m(WorkerDict pid=1730038)�[0m baymax debug deserialize time=1.765994s ``` ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Huazhong <[email protected]>
…darray (volcengine#3425) ### What does this PR do? For a data size of 6400x20480, the average serialization duration was reduced from 3.32s to 1.32s following this optimization, resulting in a ~151% improvement. ``` # tensor average serialize:2.58s deserialize:0.74s total:3.32s TaskRunner pid=1904793) baymax debug serialize time=2.5947s (TaskRunner pid=1904793) baymax debug serialize time=2.593357s (TaskRunner pid=1904793) baymax debug serialize time=2.580081s (TaskRunner pid=1904793) baymax debug serialize time=2.582321s (WorkerDict pid=1905183) baymax debug deserialize time=0.475745s (WorkerDict pid=1905184) baymax debug deserialize time=0.538223s (WorkerDict pid=1905181) baymax debug deserialize time=0.609146s (WorkerDict pid=1905182) baymax debug deserialize time=0.61064s (WorkerDict pid=1905189) baymax debug deserialize time=0.597746s (WorkerDict pid=1905185) baymax debug deserialize time=0.530353s (WorkerDict pid=1905180) baymax debug deserialize time=0.811555s (WorkerDict pid=1905194) baymax debug deserialize time=0.513646s (WorkerDict pid=1905193) baymax debug deserialize time=0.962868s (WorkerDict pid=1905179) baymax debug deserialize time=0.929226s (WorkerDict pid=1905186) baymax debug deserialize time=0.701976s (WorkerDict pid=1905191) baymax debug deserialize time=0.867236s (WorkerDict pid=1905192) baymax debug deserialize time=0.858472s (WorkerDict pid=1905187) baymax debug deserialize time=1.045251s (WorkerDict pid=1905188) baymax debug deserialize time=0.960867s (WorkerDict pid=1905190) baymax debug deserialize time=1.010673s # numpy average serialize:0.000617s deserialize:1.32s total:1.32s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.00016s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000117s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000158s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000182s �[36m(WorkerDict pid=1730035)�[0m baymax debug deserialize time=0.867232s �[36m(WorkerDict pid=1730036)�[0m baymax debug deserialize time=0.97372s �[36m(WorkerDict pid=1730028)�[0m baymax debug deserialize time=1.08627s �[36m(WorkerDict pid=1730034)�[0m baymax debug deserialize time=1.187599s �[36m(WorkerDict pid=1730037)�[0m baymax debug deserialize time=1.165926s �[36m(WorkerDict pid=1730025)�[0m baymax debug deserialize time=1.281101s �[36m(WorkerDict pid=1730029)�[0m baymax debug deserialize time=1.359834s �[36m(WorkerDict pid=1730027)�[0m baymax debug deserialize time=1.281978s �[36m(WorkerDict pid=1730030)�[0m baymax debug deserialize time=1.329298s �[36m(WorkerDict pid=1730026)�[0m baymax debug deserialize time=1.475415s �[36m(WorkerDict pid=1730031)�[0m baymax debug deserialize time=1.422345s �[36m(WorkerDict pid=1730033)�[0m baymax debug deserialize time=1.378894s �[36m(WorkerDict pid=1730039)�[0m baymax debug deserialize time=1.368721s �[36m(WorkerDict pid=1730040)�[0m baymax debug deserialize time=1.601587s �[36m(WorkerDict pid=1730042)�[0m baymax debug deserialize time=1.768378s �[36m(WorkerDict pid=1730038)�[0m baymax debug deserialize time=1.765994s ``` ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Huazhong <[email protected]>
…darray (volcengine#3425) ### What does this PR do? For a data size of 6400x20480, the average serialization duration was reduced from 3.32s to 1.32s following this optimization, resulting in a ~151% improvement. ``` # tensor average serialize:2.58s deserialize:0.74s total:3.32s TaskRunner pid=1904793) baymax debug serialize time=2.5947s (TaskRunner pid=1904793) baymax debug serialize time=2.593357s (TaskRunner pid=1904793) baymax debug serialize time=2.580081s (TaskRunner pid=1904793) baymax debug serialize time=2.582321s (WorkerDict pid=1905183) baymax debug deserialize time=0.475745s (WorkerDict pid=1905184) baymax debug deserialize time=0.538223s (WorkerDict pid=1905181) baymax debug deserialize time=0.609146s (WorkerDict pid=1905182) baymax debug deserialize time=0.61064s (WorkerDict pid=1905189) baymax debug deserialize time=0.597746s (WorkerDict pid=1905185) baymax debug deserialize time=0.530353s (WorkerDict pid=1905180) baymax debug deserialize time=0.811555s (WorkerDict pid=1905194) baymax debug deserialize time=0.513646s (WorkerDict pid=1905193) baymax debug deserialize time=0.962868s (WorkerDict pid=1905179) baymax debug deserialize time=0.929226s (WorkerDict pid=1905186) baymax debug deserialize time=0.701976s (WorkerDict pid=1905191) baymax debug deserialize time=0.867236s (WorkerDict pid=1905192) baymax debug deserialize time=0.858472s (WorkerDict pid=1905187) baymax debug deserialize time=1.045251s (WorkerDict pid=1905188) baymax debug deserialize time=0.960867s (WorkerDict pid=1905190) baymax debug deserialize time=1.010673s # numpy average serialize:0.000617s deserialize:1.32s total:1.32s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.00016s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000117s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000158s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000182s �[36m(WorkerDict pid=1730035)�[0m baymax debug deserialize time=0.867232s �[36m(WorkerDict pid=1730036)�[0m baymax debug deserialize time=0.97372s �[36m(WorkerDict pid=1730028)�[0m baymax debug deserialize time=1.08627s �[36m(WorkerDict pid=1730034)�[0m baymax debug deserialize time=1.187599s �[36m(WorkerDict pid=1730037)�[0m baymax debug deserialize time=1.165926s �[36m(WorkerDict pid=1730025)�[0m baymax debug deserialize time=1.281101s �[36m(WorkerDict pid=1730029)�[0m baymax debug deserialize time=1.359834s �[36m(WorkerDict pid=1730027)�[0m baymax debug deserialize time=1.281978s �[36m(WorkerDict pid=1730030)�[0m baymax debug deserialize time=1.329298s �[36m(WorkerDict pid=1730026)�[0m baymax debug deserialize time=1.475415s �[36m(WorkerDict pid=1730031)�[0m baymax debug deserialize time=1.422345s �[36m(WorkerDict pid=1730033)�[0m baymax debug deserialize time=1.378894s �[36m(WorkerDict pid=1730039)�[0m baymax debug deserialize time=1.368721s �[36m(WorkerDict pid=1730040)�[0m baymax debug deserialize time=1.601587s �[36m(WorkerDict pid=1730042)�[0m baymax debug deserialize time=1.768378s �[36m(WorkerDict pid=1730038)�[0m baymax debug deserialize time=1.765994s ``` ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Huazhong <[email protected]>
…darray (volcengine#3425) ### What does this PR do? For a data size of 6400x20480, the average serialization duration was reduced from 3.32s to 1.32s following this optimization, resulting in a ~151% improvement. ``` # tensor average serialize:2.58s deserialize:0.74s total:3.32s TaskRunner pid=1904793) baymax debug serialize time=2.5947s (TaskRunner pid=1904793) baymax debug serialize time=2.593357s (TaskRunner pid=1904793) baymax debug serialize time=2.580081s (TaskRunner pid=1904793) baymax debug serialize time=2.582321s (WorkerDict pid=1905183) baymax debug deserialize time=0.475745s (WorkerDict pid=1905184) baymax debug deserialize time=0.538223s (WorkerDict pid=1905181) baymax debug deserialize time=0.609146s (WorkerDict pid=1905182) baymax debug deserialize time=0.61064s (WorkerDict pid=1905189) baymax debug deserialize time=0.597746s (WorkerDict pid=1905185) baymax debug deserialize time=0.530353s (WorkerDict pid=1905180) baymax debug deserialize time=0.811555s (WorkerDict pid=1905194) baymax debug deserialize time=0.513646s (WorkerDict pid=1905193) baymax debug deserialize time=0.962868s (WorkerDict pid=1905179) baymax debug deserialize time=0.929226s (WorkerDict pid=1905186) baymax debug deserialize time=0.701976s (WorkerDict pid=1905191) baymax debug deserialize time=0.867236s (WorkerDict pid=1905192) baymax debug deserialize time=0.858472s (WorkerDict pid=1905187) baymax debug deserialize time=1.045251s (WorkerDict pid=1905188) baymax debug deserialize time=0.960867s (WorkerDict pid=1905190) baymax debug deserialize time=1.010673s # numpy average serialize:0.000617s deserialize:1.32s total:1.32s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.00016s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000117s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000158s �[36m(TaskRunner pid=1729638)�[0m baymax debug serialize time=0.000182s �[36m(WorkerDict pid=1730035)�[0m baymax debug deserialize time=0.867232s �[36m(WorkerDict pid=1730036)�[0m baymax debug deserialize time=0.97372s �[36m(WorkerDict pid=1730028)�[0m baymax debug deserialize time=1.08627s �[36m(WorkerDict pid=1730034)�[0m baymax debug deserialize time=1.187599s �[36m(WorkerDict pid=1730037)�[0m baymax debug deserialize time=1.165926s �[36m(WorkerDict pid=1730025)�[0m baymax debug deserialize time=1.281101s �[36m(WorkerDict pid=1730029)�[0m baymax debug deserialize time=1.359834s �[36m(WorkerDict pid=1730027)�[0m baymax debug deserialize time=1.281978s �[36m(WorkerDict pid=1730030)�[0m baymax debug deserialize time=1.329298s �[36m(WorkerDict pid=1730026)�[0m baymax debug deserialize time=1.475415s �[36m(WorkerDict pid=1730031)�[0m baymax debug deserialize time=1.422345s �[36m(WorkerDict pid=1730033)�[0m baymax debug deserialize time=1.378894s �[36m(WorkerDict pid=1730039)�[0m baymax debug deserialize time=1.368721s �[36m(WorkerDict pid=1730040)�[0m baymax debug deserialize time=1.601587s �[36m(WorkerDict pid=1730042)�[0m baymax debug deserialize time=1.768378s �[36m(WorkerDict pid=1730038)�[0m baymax debug deserialize time=1.765994s ``` ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: Huazhong <[email protected]>
What does this PR do?
For a data size of 6400x20480, the average serialization duration was reduced from 3.32s to 1.32s following this optimization, resulting in a ~151% improvement.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)