Skip to content

Frozen dataclass as a batch #21577

@mazurkin

Description

@mazurkin

Description & Motivation

So I have a frozen immutable dataclass as a batch (this is my way to control complexity), and with bf16-true precision I am having this error in the HalfPrecision which calls apply_to_collection without allow_frozen parameter:

[rank0]: ValueError: A frozen dataclass was passed to `apply_to_collection` but this is not allowed.

I would like to suggest modification of the _apply_to_collection_slow, so in the is_dataclass_instance(data): you could check for the to() method in the dataclass with hasattr(data, 'to') so you could call just data.to(dtype=dtype) and let the user to handle the situation with the conversion themselves.

So I could add to my batch the new method:

import dataclasses
import typing

@dataclasses.dataclass(frozen=True)
class MyBatch:
    
    input: torch.Tensor

    def to(self, device: torch.device, dtype: torch.dtype, non_blocking: bool = False, dataset_idx: int = 0) -> typing.Self:
        # I am handling conversion myself
        return MyBatch(
            input=input.to(device=device, dtype=dtype, non_blocking=non_blocking),
        )

This feature also allow to get rid of the transfer_batch_to_device() callback, as the batch allows to move to the other device itself.

Pitch

The feature allows to provide more flexibility and keep the code neat.

Alternatives

  • make dataclass mutable (unfrozen) - it is possible, but makes the code fragile
  • convert dataclass to dict and back - two additional operations

Additional context

No response

cc @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementneeds triageWaiting to be triaged by maintainers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions