-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Description & Motivation
So I have a frozen immutable dataclass as a batch (this is my way to control complexity), and with bf16-true precision I am having this error in the HalfPrecision which calls apply_to_collection without allow_frozen parameter:
[rank0]: ValueError: A frozen dataclass was passed to `apply_to_collection` but this is not allowed.
I would like to suggest modification of the _apply_to_collection_slow, so in the is_dataclass_instance(data): you could check for the to() method in the dataclass with hasattr(data, 'to') so you could call just data.to(dtype=dtype) and let the user to handle the situation with the conversion themselves.
So I could add to my batch the new method:
import dataclasses
import typing
@dataclasses.dataclass(frozen=True)
class MyBatch:
input: torch.Tensor
def to(self, device: torch.device, dtype: torch.dtype, non_blocking: bool = False, dataset_idx: int = 0) -> typing.Self:
# I am handling conversion myself
return MyBatch(
input=input.to(device=device, dtype=dtype, non_blocking=non_blocking),
)This feature also allow to get rid of the transfer_batch_to_device() callback, as the batch allows to move to the other device itself.
Pitch
The feature allows to provide more flexibility and keep the code neat.
Alternatives
- make dataclass mutable (unfrozen) - it is possible, but makes the code fragile
- convert dataclass to dict and back - two additional operations
Additional context
No response
cc @lantiga