Skip to content
3 changes: 2 additions & 1 deletion verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def collate_fn(data_list: list[dict]) -> dict:
tensors[key] = torch.stack(val, dim=0)

for key, val in non_tensors.items():
non_tensors[key] = np.array(val, dtype=object)
non_tensors[key] = np.empty(len(val), dtype=object)
non_tensors[key][:] = val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation uses np.empty followed by assignment. A more concise and potentially performant approach is to use np.fromiter to create the 1D object array directly, as collate_fn is on a critical performance path for data loading.

non_tensors[key] = np.fromiter(val, dtype=object, count=len(val))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!
I've updated the code to use np.fromiter(val, dtype=object, count=len(val)) as recommended.


return {**tensors, **non_tensors}

Expand Down