Skip to content

[Feature Request] Add Liger CE Loss #2692

@pbontrager

Description

@pbontrager

Add a new loss in the cross_entropy_loss.py file that inherits from SFT loss but calls the Liger fused_linear_cross_entropy loss. It will need to handle if the input is a DTensor and convert it before calling the liger loss.

Edge Case: if the model output is a tied embedding and TP sharded (DTensor). Then either we'll have to unshard and then reshard the weight every step, or throw an error for that case. (This assumes that liger losses don't work with sharded weights)

A good validation of this feature would be to see if this loss even further improves the numbers here over compiled linear cross entropy loss.

Metadata

Metadata

Assignees

No one assigned

    Labels

    community help wantedWe would love the community's help completing this issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions