Skip to content

Conversation

@xylcbd
Copy link
Contributor

@xylcbd xylcbd commented Jul 9, 2025

What does this PR do?

FIX: '_io.BytesIO' object has no attribute 'startswith'

#1976

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

  1. download the test dataset.
huggingface-cli download --repo-type dataset xylcbd/pgdp5k_mini
  1. convert data to parquet format
import argparse
import os

import datasets

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_dir", default="~/data/pgdp5k_mini")
    args = parser.parse_args()

    data_source = "xylcbd/pgdp5k_mini"
    dataset = datasets.load_dataset(data_source)
    train_dataset = dataset["train"]
    train_dataset.to_parquet(os.path.join(args.local_dir, "train.parquet"))
  1. test dataset loading:
import os

import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn

model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
tokenizer = hf_tokenizer(model_path)
processor = hf_processor(model_path)
config = OmegaConf.create(
    {
        "prompt_key": "prompt",
        "max_prompt_length": 1024,
        "filter_overlong_prompts": True,
        "filter_overlong_prompts_workers": 2,
    }
)
dataset = RLHFDataset(
    data_files=os.path.expanduser("~/data/pgdp5k_mini/train.parquet"),
    tokenizer=tokenizer,
    config=config,
    processor=processor,
)

dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, drop_last=True, collate_fn=collate_fn)

a = next(iter(dataloader))

from verl import DataProto

tensors = {}
non_tensors = {}

for key, val in a.items():
    if isinstance(val, torch.Tensor):
        tensors[key] = val
    else:
        non_tensors[key] = val

data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)

assert "multi_modal_data" in data_proto.non_tensor_batch, data_proto
assert "multi_modal_inputs" in data_proto.non_tensor_batch, data_proto

data = dataset[0]["input_ids"]
output = tokenizer.batch_decode([data])[0]
print(f"type: type{output}")
print(f"\n\noutput: {output}")
  1. Error reported before repair (no error reported after repair)
AttributeError: '_io.BytesIO' object has no attribute 'startswith'

API and Usage Example

No change for the API.

Design & Code Changes

No change for the design.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

xylcbd added 2 commits July 9, 2025 11:49
'_io.BytesIO' object has no attribute 'startswith'
…-"bytes"-in-image

[data] fix: fix bug of '_io.BytesIO' object has no attribute 'startswith'
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request fixes an AttributeError by changing how image bytes are handled in process_image. The code now directly returns a PIL Image object instead of storing BytesIO in the image dictionary. It's crucial to verify that this change doesn't break compatibility with the rest of the codebase, especially the fetch_image function.

if "bytes" in image:
assert "image" not in image, "Cannot have both `bytes` and `image`"
image["image"] = BytesIO(image["bytes"])
return Image.open(BytesIO(image["bytes"]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The original code assigns BytesIO(image["bytes"]) to image["image"], but this change directly returns Image.open(BytesIO(image["bytes"])). Ensure that the subsequent fetch_image function is compatible with this change, as it might have been expecting a dictionary with an 'image' key.

return Image.open(BytesIO(image["bytes"]))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix is okay

@xylcbd
Copy link
Contributor Author

xylcbd commented Jul 15, 2025

@vermouth1992 Any review comments? plz

@warmsnow-sh
Copy link

I think perhaps it should be wrapped as a PIL object instead of being returned directly. This way, subsequent processing, such as resize, won't be skipped.

@xylcbd
Copy link
Contributor Author

xylcbd commented Jul 21, 2025

Yes, there is a problem with returning it directly. I'll make some changes.

@Maxwell-Jia
Copy link
Contributor

This PR should be merged, the current process_image has logical issues and affects #2398. @wuxibin89

@wuxibin89 wuxibin89 merged commit 31ac4dc into volcengine:main Aug 8, 2025
46 of 53 checks passed
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Aug 11, 2025
…ith' (volcengine#2430)

### What does this PR do?

FIX: '_io.BytesIO' object has no attribute 'startswith'

volcengine#1976

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test
1. download the test dataset.
```
huggingface-cli download --repo-type dataset xylcbd/pgdp5k_mini
```
2. convert data to parquet format
```
import argparse
import os

import datasets

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_dir", default="~/data/pgdp5k_mini")
    args = parser.parse_args()

    data_source = "xylcbd/pgdp5k_mini"
    dataset = datasets.load_dataset(data_source)
    train_dataset = dataset["train"]
    train_dataset.to_parquet(os.path.join(args.local_dir, "train.parquet"))
```
3. test dataset loading:
```
import os

import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn

model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
tokenizer = hf_tokenizer(model_path)
processor = hf_processor(model_path)
config = OmegaConf.create(
    {
        "prompt_key": "prompt",
        "max_prompt_length": 1024,
        "filter_overlong_prompts": True,
        "filter_overlong_prompts_workers": 2,
    }
)
dataset = RLHFDataset(
    data_files=os.path.expanduser("~/data/pgdp5k_mini/train.parquet"),
    tokenizer=tokenizer,
    config=config,
    processor=processor,
)

dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, drop_last=True, collate_fn=collate_fn)

a = next(iter(dataloader))

from verl import DataProto

tensors = {}
non_tensors = {}

for key, val in a.items():
    if isinstance(val, torch.Tensor):
        tensors[key] = val
    else:
        non_tensors[key] = val

data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)

assert "multi_modal_data" in data_proto.non_tensor_batch, data_proto
assert "multi_modal_inputs" in data_proto.non_tensor_batch, data_proto

data = dataset[0]["input_ids"]
output = tokenizer.batch_decode([data])[0]
print(f"type: type{output}")
print(f"\n\noutput: {output}")
```
4. Error reported before repair (no error reported after repair)
```
AttributeError: '_io.BytesIO' object has no attribute 'startswith'
```
### API and Usage Example

No change for the API.

### Design & Code Changes

No change for the design.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
ChangyiYang pushed a commit to SwordFaith/verl that referenced this pull request Aug 16, 2025
…ith' (volcengine#2430)

### What does this PR do?

FIX: '_io.BytesIO' object has no attribute 'startswith'

volcengine#1976

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test
1. download the test dataset.
```
huggingface-cli download --repo-type dataset xylcbd/pgdp5k_mini
```
2. convert data to parquet format
```
import argparse
import os

import datasets

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_dir", default="~/data/pgdp5k_mini")
    args = parser.parse_args()

    data_source = "xylcbd/pgdp5k_mini"
    dataset = datasets.load_dataset(data_source)
    train_dataset = dataset["train"]
    train_dataset.to_parquet(os.path.join(args.local_dir, "train.parquet"))
```
3. test dataset loading:
```
import os

import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn

model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
tokenizer = hf_tokenizer(model_path)
processor = hf_processor(model_path)
config = OmegaConf.create(
    {
        "prompt_key": "prompt",
        "max_prompt_length": 1024,
        "filter_overlong_prompts": True,
        "filter_overlong_prompts_workers": 2,
    }
)
dataset = RLHFDataset(
    data_files=os.path.expanduser("~/data/pgdp5k_mini/train.parquet"),
    tokenizer=tokenizer,
    config=config,
    processor=processor,
)

dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, drop_last=True, collate_fn=collate_fn)

a = next(iter(dataloader))

from verl import DataProto

tensors = {}
non_tensors = {}

for key, val in a.items():
    if isinstance(val, torch.Tensor):
        tensors[key] = val
    else:
        non_tensors[key] = val

data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)

assert "multi_modal_data" in data_proto.non_tensor_batch, data_proto
assert "multi_modal_inputs" in data_proto.non_tensor_batch, data_proto

data = dataset[0]["input_ids"]
output = tokenizer.batch_decode([data])[0]
print(f"type: type{output}")
print(f"\n\noutput: {output}")
```
4. Error reported before repair (no error reported after repair)
```
AttributeError: '_io.BytesIO' object has no attribute 'startswith'
```
### API and Usage Example

No change for the API.

### Design & Code Changes

No change for the design.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
whatadayG pushed a commit to whatadayG/verl that referenced this pull request Sep 5, 2025
…ith' (volcengine#2430)

### What does this PR do?

FIX: '_io.BytesIO' object has no attribute 'startswith'

volcengine#1976

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test
1. download the test dataset.
```
huggingface-cli download --repo-type dataset xylcbd/pgdp5k_mini
```
2. convert data to parquet format
```
import argparse
import os

import datasets

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_dir", default="~/data/pgdp5k_mini")
    args = parser.parse_args()

    data_source = "xylcbd/pgdp5k_mini"
    dataset = datasets.load_dataset(data_source)
    train_dataset = dataset["train"]
    train_dataset.to_parquet(os.path.join(args.local_dir, "train.parquet"))
```
3. test dataset loading:
```
import os

import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn

model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
tokenizer = hf_tokenizer(model_path)
processor = hf_processor(model_path)
config = OmegaConf.create(
    {
        "prompt_key": "prompt",
        "max_prompt_length": 1024,
        "filter_overlong_prompts": True,
        "filter_overlong_prompts_workers": 2,
    }
)
dataset = RLHFDataset(
    data_files=os.path.expanduser("~/data/pgdp5k_mini/train.parquet"),
    tokenizer=tokenizer,
    config=config,
    processor=processor,
)

dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, drop_last=True, collate_fn=collate_fn)

a = next(iter(dataloader))

from verl import DataProto

tensors = {}
non_tensors = {}

for key, val in a.items():
    if isinstance(val, torch.Tensor):
        tensors[key] = val
    else:
        non_tensors[key] = val

data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)

assert "multi_modal_data" in data_proto.non_tensor_batch, data_proto
assert "multi_modal_inputs" in data_proto.non_tensor_batch, data_proto

data = dataset[0]["input_ids"]
output = tokenizer.batch_decode([data])[0]
print(f"type: type{output}")
print(f"\n\noutput: {output}")
```
4. Error reported before repair (no error reported after repair)
```
AttributeError: '_io.BytesIO' object has no attribute 'startswith'
```
### API and Usage Example

No change for the API.

### Design & Code Changes

No change for the design.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
WncFht pushed a commit to WncFht/verl that referenced this pull request Oct 10, 2025
…ith' (volcengine#2430)

### What does this PR do?

FIX: '_io.BytesIO' object has no attribute 'startswith'

volcengine#1976

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test
1. download the test dataset.
```
huggingface-cli download --repo-type dataset xylcbd/pgdp5k_mini
```
2. convert data to parquet format
```
import argparse
import os

import datasets

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_dir", default="~/data/pgdp5k_mini")
    args = parser.parse_args()

    data_source = "xylcbd/pgdp5k_mini"
    dataset = datasets.load_dataset(data_source)
    train_dataset = dataset["train"]
    train_dataset.to_parquet(os.path.join(args.local_dir, "train.parquet"))
```
3. test dataset loading:
```
import os

import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn

model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
tokenizer = hf_tokenizer(model_path)
processor = hf_processor(model_path)
config = OmegaConf.create(
    {
        "prompt_key": "prompt",
        "max_prompt_length": 1024,
        "filter_overlong_prompts": True,
        "filter_overlong_prompts_workers": 2,
    }
)
dataset = RLHFDataset(
    data_files=os.path.expanduser("~/data/pgdp5k_mini/train.parquet"),
    tokenizer=tokenizer,
    config=config,
    processor=processor,
)

dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, drop_last=True, collate_fn=collate_fn)

a = next(iter(dataloader))

from verl import DataProto

tensors = {}
non_tensors = {}

for key, val in a.items():
    if isinstance(val, torch.Tensor):
        tensors[key] = val
    else:
        non_tensors[key] = val

data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)

assert "multi_modal_data" in data_proto.non_tensor_batch, data_proto
assert "multi_modal_inputs" in data_proto.non_tensor_batch, data_proto

data = dataset[0]["input_ids"]
output = tokenizer.batch_decode([data])[0]
print(f"type: type{output}")
print(f"\n\noutput: {output}")
```
4. Error reported before repair (no error reported after repair)
```
AttributeError: '_io.BytesIO' object has no attribute 'startswith'
```
### API and Usage Example

No change for the API.

### Design & Code Changes

No change for the design.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants