Skip to content

Corrupted outputs with Marlin int4 kernels as parallelization increases #332

@dacorvo

Description

@dacorvo

When using MarlinInt4WeightQBitsTensor and its associated optimized gemm kernel, there are issues with the weight/scales/zero-point readback as soon as parallelization increases.

The consequence is that output features higher than 128 are corrupted when a sufficient amount of inputs are parallelized.

Test to reproduce the issue here:

@pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False)

Metadata

Metadata

Assignees

No one assigned

    Labels

    StalebugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions