Skip to content

Commit 3522dfb

Browse files
vermouth1992gemini-code-assist[bot]
authored andcommitted
[misc] feat: prototype deprecate DataProto and replace with Tensordict: part 1 (volcengine#2733)
### What does this PR do? - Add TensorDict utilities and tests to cover the current DataProto functionalities. - Add nested tensor example to remove padding throughout the system - Add image example - Upgrade tensordict to v0.10 ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent d3f156d commit 3522dfb

File tree

10 files changed

+776
-16
lines changed

10 files changed

+776
-16
lines changed

requirements-npu.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ peft>=0.15.2
1010
pyarrow>=15.0.0
1111
pybind11
1212
pylatexenc
13-
tensordict>=0.8.0,<=0.9.1,!=0.9.0
13+
tensordict>=0.8.0,<=0.10.0,!=0.9.0
1414
transformers==4.52.4
1515
ray==2.46.0
1616
wandb

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pybind11
1414
pylatexenc
1515
pre-commit
1616
ray[default]
17-
tensordict>=0.8.0,<=0.9.1,!=0.9.0
17+
tensordict>=0.8.0,<=0.10.0,!=0.9.0
1818
torchdata
1919
transformers
2020
# vllm==0.8.4

requirements_sglang.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pyarrow>=19.0.0
1212
pybind11
1313
pylatexenc
1414
ray[default]>=2.10
15-
tensordict>=0.8.0,<=0.9.1,!=0.9.0
15+
tensordict>=0.8.0,<=0.10.0,!=0.9.0
1616
torchdata
1717
torchvision
1818
transformers

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"pylatexenc",
3838
"ray[default]>=2.41.0",
3939
"torchdata",
40-
"tensordict>=0.8.0,<=0.9.1,!=0.9.0",
40+
"tensordict>=0.8.0,<=0.10.0,!=0.9.0",
4141
"transformers",
4242
"wandb",
4343
"packaging>=20.0",
@@ -49,9 +49,9 @@
4949
GEO_REQUIRES = ["mathruler", "torchvision", "qwen_vl_utils"]
5050
GPU_REQUIRES = ["liger-kernel", "flash-attn"]
5151
MATH_REQUIRES = ["math-verify"] # Add math-verify as an optional dependency
52-
VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.9.1,!=0.9.0", "vllm>=0.7.3,<=0.9.1"]
52+
VLLM_REQUIRES = ["tensordict>=0.8.0,<=0.10.0,!=0.9.0", "vllm>=0.7.3,<=0.9.1"]
5353
SGLANG_REQUIRES = [
54-
"tensordict>=0.8.0,<=0.9.1,!=0.9.0",
54+
"tensordict>=0.8.0,<=0.10.0,!=0.9.0",
5555
"sglang[srt,openai]==0.4.10.post2",
5656
"torch==2.7.1",
5757
]

tests/special_sanity/validate_structure.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ def main() -> None:
8686
parser.add_argument(
8787
"--allow-files",
8888
nargs="*",
89-
default=["tests/test_protocol_on_cpu.py", "tests/test_base_config_on_cpu.py"],
89+
default=[
90+
"tests/test_protocol_on_cpu.py",
91+
"tests/test_base_config_on_cpu.py",
92+
"tests/test_protocol_v2_on_cpu.py",
93+
],
9094
help="Extra top-level test folders that are exempt from the rule",
9195
)
9296
args = parser.parse_args()

tests/test_protocol_on_cpu.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
import numpy as np
1818
import pytest
19+
import tensordict
1920
import torch
21+
from packaging.version import parse as parse_version
2022
from tensordict import TensorDict
2123

2224
from verl import DataProto
@@ -598,3 +600,17 @@ def test_dataproto_chunk_after_index():
598600
selected = data[torch_int_mask]
599601
assert isinstance(selected.batch.batch_size, torch.Size)
600602
assert all(isinstance(d, int) for d in selected.batch.batch_size)
603+
604+
605+
@pytest.mark.skipif(
606+
parse_version(tensordict.__version__) < parse_version("0.10"), reason="requires at least tensordict 0.10"
607+
)
608+
def test_to_tensordict():
609+
obs = torch.tensor([1, 2, 3, 4, 5, 6])
610+
labels = ["a", "b", "c", "d", "e", "f"]
611+
data = DataProto.from_dict(tensors={"obs": obs}, non_tensors={"labels": labels}, meta_info={"name": "abdce"})
612+
output = data.to_tensordict()
613+
614+
assert torch.all(torch.eq(output["obs"], obs)).item()
615+
assert output["labels"] == labels
616+
assert output["name"] == "abdce"

0 commit comments

Comments
 (0)