When using MarlinInt4WeightQBitsTensor and its associated optimized gemm kernel, there are issues with the weight/scales/zero-point readback as soon as parallelization increases.
The consequence is that output features higher than 128 are corrupted when a sufficient amount of inputs are parallelized.
Test to reproduce the issue here:
|
@pytest.mark.xfail(reason="Bug in Marlin kernel", strict=False) |