Skip to content

Conversation

@plotfi
Copy link
Contributor

@plotfi plotfi commented Jul 22, 2025

Update: I have found that for better perf, we need to use 3-6 BF16 dot products but not more. My findings are at:

https://gist.github.com/plotfi/72554bd410ea55d8ae67b501c69b2766

The short version is that the Triton Bench tutorial matmul with F32 benefits by 60-70% using 3 BF16 dots or 10-15% using 6 BF16 dots.

I think this is sufficient to move forward as a replacement for MI350s TF32 and is in line with what hipblas does:

https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330

There is a similar implementation in XLA as well: https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152


Implements emulation of a 32-bit floating point dot operation using 3 BF16s. This is based on https://arxiv.org/abs/1904.06376 and works because the mantisa of 3 BF16s add up to the mantisa of a fp32.

Storing 1 fp32 in 3 bf16s:

def BF16(v):
    return v.to(torch.bfloat16)
def FP32(v):
    return v.to(torch.float32)

def BF16x3(v):
    b0 = BF16(original)
    b1 = BF16(original - FP32(b0))
    b2 = BF16(original - FP32(b0) - FP32(b1))
    return (b0, b1, b2)

original = torch.rand(1, 1, dtype=torch.float32)
bf16x3 = BF16x3(original)

Emulating multiplication of two fp32s:

def mul_bf16x3(a, b, c):
    a0, a1, a2 = BF16x3(a)
    b0, b1, b2 = BF16x3(b)
    c = c + (a0 * b0) # low low
    c = c + (a1 * b0) # mid low
    c = c + (a0 * b1) # low mid
    c = c + (a1 * b1) # mid mid
    c = c + (a0 * b2) # low hi
    c = c + (a2 * b0) # hi low
    c = c + (a1 * b2) # mid hi
    c = c + (a2 * b1) # hi mid
    c = c + (a2 * b2) # hi hi
    return c

a = torch.rand(1, 1, dtype=torch.float32)
b = torch.rand(1, 1, dtype=torch.float32)
c = torch.zeros(1, 1, dtype=torch.float32) # accumulator

result = mul_bf16x3(a, b, c)

The emulation using BF16x3 is used when invoking tl.dot with input precision 'BF16x3'. This pass is implemented in a GPU agnostic manner, but it is needed support for MI350's lack of TF32 support. This part is a work in progress but will be based on this patch.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /python/test for end-to-end tests

@plotfi plotfi marked this pull request as draft July 22, 2025 08:43
@plotfi plotfi force-pushed the plotfi-bf16x3-dot branch 4 times, most recently from f34500a to 591d437 Compare July 22, 2025 19:55
@scxiao
Copy link
Contributor

scxiao commented Aug 16, 2025

On MI300, TF32 is The XF32 instructions take 32-bit floats but round the mantissa to 10 bits in order to perform reducedprecision multiplication, so I think we only needs the b0 and b1 (can cover 10 bits of mantissa), so 4 instead of 9 BF16 dots are needed.

@plotfi plotfi force-pushed the plotfi-bf16x3-dot branch 6 times, most recently from 8622a54 to 8132a2e Compare September 26, 2025 23:05
@plotfi plotfi force-pushed the plotfi-bf16x3-dot branch from 8132a2e to 056783a Compare October 2, 2025 00:22
@plotfi plotfi marked this pull request as ready for review October 2, 2025 01:27
@plotfi plotfi force-pushed the plotfi-bf16x3-dot branch from a6a28f2 to ee9c90e Compare October 10, 2025 16:53
@plotfi plotfi requested a review from antiagainst October 10, 2025 16:53
@plotfi plotfi force-pushed the plotfi-bf16x3-dot branch from ee9c90e to 993098f Compare October 10, 2025 16:56
@plotfi
Copy link
Contributor Author

plotfi commented Oct 10, 2025

Thanks for the review @antiagainst, all feedback addressed now.

@plotfi plotfi force-pushed the plotfi-bf16x3-dot branch 5 times, most recently from 0cd4704 to 4be3f28 Compare October 16, 2025 22:44
Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; need @ThomasRaoux to take a look too.

@plotfi plotfi force-pushed the plotfi-bf16x3-dot branch from 7676d22 to 840f482 Compare October 18, 2025 03:59
@plotfi plotfi force-pushed the plotfi-bf16x3-dot branch from 35b0f99 to 9228dac Compare October 18, 2025 04:31
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few more minor comments

c.getLoc(), rewriter.getF32FloatAttr(0)));
};

template <InputPrecision precision>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't need to be a template arg, make it function argument

maxNumImpreciseAcc);
};

auto replaceNansWithZeros(Value value, PatternRewriter &rewriter) -> Value {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
auto replaceNansWithZeros(Value value, PatternRewriter &rewriter) -> Value {
Value replaceNansWithZeros(Value value, PatternRewriter &rewriter) {

return rewriter.create<arith::SelectOp>(value.getLoc(), nans, zero, value);
};

auto getBF16Count(triton::InputPrecision precision) -> unsigned {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
auto getBF16Count(triton::InputPrecision precision) -> unsigned {
unsigned getBF16Count(triton::InputPrecision precision) {

@plotfi
Copy link
Contributor Author

plotfi commented Oct 18, 2025

Feedback addressed, thanks again @ThomasRaoux

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@antiagainst antiagainst enabled auto-merge (squash) October 19, 2025 03:48
@antiagainst antiagainst merged commit 33e7dc2 into triton-lang:main Oct 19, 2025
9 checks passed
masahi pushed a commit to masahi/triton that referenced this pull request Oct 24, 2025
**Update:** I have found that for better perf, we need to use 3-6 BF16
dot products but not more. My findings are at:

https://gist.github.com/plotfi/72554bd410ea55d8ae67b501c69b2766

The short version is that the Triton Bench tutorial matmul with F32
benefits by 60-70% using 3 BF16 dots or 10-15% using 6 BF16 dots.

I think this is sufficient to move forward as a replacement for MI350s
TF32 and is in line with what hipblas does:


https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330

There is a similar implementation in XLA as well:
https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152


--------

Implements emulation of a 32-bit floating point dot operation using 3
BF16s. This is based on https://arxiv.org/abs/1904.06376 and works
because the mantisa of 3 BF16s add up to the mantisa of a fp32.

Storing 1 fp32 in 3 bf16s:

```python
def BF16(v):
    return v.to(torch.bfloat16)
def FP32(v):
    return v.to(torch.float32)

def BF16x3(v):
    b0 = BF16(original)
    b1 = BF16(original - FP32(b0))
    b2 = BF16(original - FP32(b0) - FP32(b1))
    return (b0, b1, b2)

original = torch.rand(1, 1, dtype=torch.float32)
bf16x3 = BF16x3(original)
```

Emulating multiplication of two fp32s:

```python
def mul_bf16x3(a, b, c):
    a0, a1, a2 = BF16x3(a)
    b0, b1, b2 = BF16x3(b)
    c = c + (a0 * b0) # low low
    c = c + (a1 * b0) # mid low
    c = c + (a0 * b1) # low mid
    c = c + (a1 * b1) # mid mid
    c = c + (a0 * b2) # low hi
    c = c + (a2 * b0) # hi low
    c = c + (a1 * b2) # mid hi
    c = c + (a2 * b1) # hi mid
    c = c + (a2 * b2) # hi hi
    return c

a = torch.rand(1, 1, dtype=torch.float32)
b = torch.rand(1, 1, dtype=torch.float32)
c = torch.zeros(1, 1, dtype=torch.float32) # accumulator

result = mul_bf16x3(a, b, c)
```

The emulation using BF16x3 is used when invoking tl.dot with input
precision 'BF16x3'. This pass is implemented in a GPU agnostic manner,
but it is needed support for MI350's lack of TF32 support. This part is
a work in progress but will be based on this patch.
plotfi added a commit to plotfi/triton-facebookexperimental that referenced this pull request Nov 14, 2025
Summary:
This is a cherry-pick of triton-lang/triton#7592

This the D86786661 pick, but for beta

[BACKEND] Implement BF16x3 trick (#7592)

**Update:** I have found that for better perf, we need to use 3-6 BF16
dot products but not more. My findings are at:

https://gist.github.com/plotfi/72554bd410ea55d8ae67b501c69b2766

The short version is that the Triton Bench tutorial matmul with F32
benefits by 60-70% using 3 BF16 dots or 10-15% using 6 BF16 dots.

I think this is sufficient to move forward as a replacement for MI350s
TF32 and is in line with what hipblas does:


https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330

There is a similar implementation in XLA as well:
https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152


--------

Implements emulation of a 32-bit floating point dot operation using 3
BF16s. This is based on https://arxiv.org/abs/1904.06376 and works
because the mantisa of 3 BF16s add up to the mantisa of a fp32.

Storing 1 fp32 in 3 bf16s:

```python
def BF16(v):
    return v.to(torch.bfloat16)
def FP32(v):
    return v.to(torch.float32)

def BF16x3(v):
    b0 = BF16(original)
    b1 = BF16(original - FP32(b0))
    b2 = BF16(original - FP32(b0) - FP32(b1))
    return (b0, b1, b2)

original = torch.rand(1, 1, dtype=torch.float32)
bf16x3 = BF16x3(original)
```

Emulating multiplication of two fp32s:

```python
def mul_bf16x3(a, b, c):
    a0, a1, a2 = BF16x3(a)
    b0, b1, b2 = BF16x3(b)
    c = c + (a0 * b0) # low low
    c = c + (a1 * b0) # mid low
    c = c + (a0 * b1) # low mid
    c = c + (a1 * b1) # mid mid
    c = c + (a0 * b2) # low hi
    c = c + (a2 * b0) # hi low
    c = c + (a1 * b2) # mid hi
    c = c + (a2 * b1) # hi mid
    c = c + (a2 * b2) # hi hi
    return c

a = torch.rand(1, 1, dtype=torch.float32)
b = torch.rand(1, 1, dtype=torch.float32)
c = torch.zeros(1, 1, dtype=torch.float32) # accumulator

result = mul_bf16x3(a, b, c)
```

The emulation using BF16x3 is used when invoking tl.dot with input
precision 'BF16x3'. This pass is implemented in a GPU agnostic manner,
but it is needed support for MI350's lack of TF32 support. This part is
a work in progress but will be based on this patch.

Reviewed By: danzimm, NikhilAPatel

Differential Revision: D86786661
meta-codesync bot pushed a commit to facebookexperimental/triton that referenced this pull request Nov 15, 2025
Summary:
Pull Request resolved: #667

This is a cherry-pick of triton-lang/triton#7592

This the D86786661 pick, but for beta

[BACKEND] Implement BF16x3 trick (#7592)

**Update:** I have found that for better perf, we need to use 3-6 BF16
dot products but not more. My findings are at:

https://gist.github.com/plotfi/72554bd410ea55d8ae67b501c69b2766

The short version is that the Triton Bench tutorial matmul with F32
benefits by 60-70% using 3 BF16 dots or 10-15% using 6 BF16 dots.

I think this is sufficient to move forward as a replacement for MI350s
TF32 and is in line with what hipblas does:

https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330

There is a similar implementation in XLA as well:
https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152

--------

Implements emulation of a 32-bit floating point dot operation using 3
BF16s. This is based on https://arxiv.org/abs/1904.06376 and works
because the mantisa of 3 BF16s add up to the mantisa of a fp32.

Storing 1 fp32 in 3 bf16s:

```python
def BF16(v):
    return v.to(torch.bfloat16)
def FP32(v):
    return v.to(torch.float32)

def BF16x3(v):
    b0 = BF16(original)
    b1 = BF16(original - FP32(b0))
    b2 = BF16(original - FP32(b0) - FP32(b1))
    return (b0, b1, b2)

original = torch.rand(1, 1, dtype=torch.float32)
bf16x3 = BF16x3(original)
```

Emulating multiplication of two fp32s:

```python
def mul_bf16x3(a, b, c):
    a0, a1, a2 = BF16x3(a)
    b0, b1, b2 = BF16x3(b)
    c = c + (a0 * b0) # low low
    c = c + (a1 * b0) # mid low
    c = c + (a0 * b1) # low mid
    c = c + (a1 * b1) # mid mid
    c = c + (a0 * b2) # low hi
    c = c + (a2 * b0) # hi low
    c = c + (a1 * b2) # mid hi
    c = c + (a2 * b1) # hi mid
    c = c + (a2 * b2) # hi hi
    return c

a = torch.rand(1, 1, dtype=torch.float32)
b = torch.rand(1, 1, dtype=torch.float32)
c = torch.zeros(1, 1, dtype=torch.float32) # accumulator

result = mul_bf16x3(a, b, c)
```

The emulation using BF16x3 is used when invoking tl.dot with input
precision 'BF16x3'. This pass is implemented in a GPU agnostic manner,
but it is needed support for MI350's lack of TF32 support. This part is
a work in progress but will be based on this patch.

Reviewed By: agron911, danzimm, NikhilAPatel, jananisriram, Sibylau

Differential Revision: D86786661

fbshipit-source-id: e254b1bdf61a400e1b55f9dc9d35a541460b1571
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants