-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[BACKEND] Implement BF16x3 trick #7592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f34500a to
591d437
Compare
|
On MI300, TF32 is |
8622a54 to
8132a2e
Compare
8132a2e to
056783a
Compare
a6a28f2 to
ee9c90e
Compare
ee9c90e to
993098f
Compare
|
Thanks for the review @antiagainst, all feedback addressed now. |
0cd4704 to
4be3f28
Compare
antiagainst
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; need @ThomasRaoux to take a look too.
7676d22 to
840f482
Compare
(cherry picked from commit f2dcc4e71280f76dddead1390a0459f7d3f93a8f)
35b0f99 to
9228dac
Compare
ThomasRaoux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
few more minor comments
| c.getLoc(), rewriter.getF32FloatAttr(0))); | ||
| }; | ||
|
|
||
| template <InputPrecision precision> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't need to be a template arg, make it function argument
| maxNumImpreciseAcc); | ||
| }; | ||
|
|
||
| auto replaceNansWithZeros(Value value, PatternRewriter &rewriter) -> Value { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| auto replaceNansWithZeros(Value value, PatternRewriter &rewriter) -> Value { | |
| Value replaceNansWithZeros(Value value, PatternRewriter &rewriter) { |
| return rewriter.create<arith::SelectOp>(value.getLoc(), nans, zero, value); | ||
| }; | ||
|
|
||
| auto getBF16Count(triton::InputPrecision precision) -> unsigned { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| auto getBF16Count(triton::InputPrecision precision) -> unsigned { | |
| unsigned getBF16Count(triton::InputPrecision precision) { |
|
Feedback addressed, thanks again @ThomasRaoux |
ThomasRaoux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
**Update:** I have found that for better perf, we need to use 3-6 BF16 dot products but not more. My findings are at: https://gist.github.com/plotfi/72554bd410ea55d8ae67b501c69b2766 The short version is that the Triton Bench tutorial matmul with F32 benefits by 60-70% using 3 BF16 dots or 10-15% using 6 BF16 dots. I think this is sufficient to move forward as a replacement for MI350s TF32 and is in line with what hipblas does: https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330 There is a similar implementation in XLA as well: https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152 -------- Implements emulation of a 32-bit floating point dot operation using 3 BF16s. This is based on https://arxiv.org/abs/1904.06376 and works because the mantisa of 3 BF16s add up to the mantisa of a fp32. Storing 1 fp32 in 3 bf16s: ```python def BF16(v): return v.to(torch.bfloat16) def FP32(v): return v.to(torch.float32) def BF16x3(v): b0 = BF16(original) b1 = BF16(original - FP32(b0)) b2 = BF16(original - FP32(b0) - FP32(b1)) return (b0, b1, b2) original = torch.rand(1, 1, dtype=torch.float32) bf16x3 = BF16x3(original) ``` Emulating multiplication of two fp32s: ```python def mul_bf16x3(a, b, c): a0, a1, a2 = BF16x3(a) b0, b1, b2 = BF16x3(b) c = c + (a0 * b0) # low low c = c + (a1 * b0) # mid low c = c + (a0 * b1) # low mid c = c + (a1 * b1) # mid mid c = c + (a0 * b2) # low hi c = c + (a2 * b0) # hi low c = c + (a1 * b2) # mid hi c = c + (a2 * b1) # hi mid c = c + (a2 * b2) # hi hi return c a = torch.rand(1, 1, dtype=torch.float32) b = torch.rand(1, 1, dtype=torch.float32) c = torch.zeros(1, 1, dtype=torch.float32) # accumulator result = mul_bf16x3(a, b, c) ``` The emulation using BF16x3 is used when invoking tl.dot with input precision 'BF16x3'. This pass is implemented in a GPU agnostic manner, but it is needed support for MI350's lack of TF32 support. This part is a work in progress but will be based on this patch.
Summary: This is a cherry-pick of triton-lang/triton#7592 This the D86786661 pick, but for beta [BACKEND] Implement BF16x3 trick (#7592) **Update:** I have found that for better perf, we need to use 3-6 BF16 dot products but not more. My findings are at: https://gist.github.com/plotfi/72554bd410ea55d8ae67b501c69b2766 The short version is that the Triton Bench tutorial matmul with F32 benefits by 60-70% using 3 BF16 dots or 10-15% using 6 BF16 dots. I think this is sufficient to move forward as a replacement for MI350s TF32 and is in line with what hipblas does: https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330 There is a similar implementation in XLA as well: https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152 -------- Implements emulation of a 32-bit floating point dot operation using 3 BF16s. This is based on https://arxiv.org/abs/1904.06376 and works because the mantisa of 3 BF16s add up to the mantisa of a fp32. Storing 1 fp32 in 3 bf16s: ```python def BF16(v): return v.to(torch.bfloat16) def FP32(v): return v.to(torch.float32) def BF16x3(v): b0 = BF16(original) b1 = BF16(original - FP32(b0)) b2 = BF16(original - FP32(b0) - FP32(b1)) return (b0, b1, b2) original = torch.rand(1, 1, dtype=torch.float32) bf16x3 = BF16x3(original) ``` Emulating multiplication of two fp32s: ```python def mul_bf16x3(a, b, c): a0, a1, a2 = BF16x3(a) b0, b1, b2 = BF16x3(b) c = c + (a0 * b0) # low low c = c + (a1 * b0) # mid low c = c + (a0 * b1) # low mid c = c + (a1 * b1) # mid mid c = c + (a0 * b2) # low hi c = c + (a2 * b0) # hi low c = c + (a1 * b2) # mid hi c = c + (a2 * b1) # hi mid c = c + (a2 * b2) # hi hi return c a = torch.rand(1, 1, dtype=torch.float32) b = torch.rand(1, 1, dtype=torch.float32) c = torch.zeros(1, 1, dtype=torch.float32) # accumulator result = mul_bf16x3(a, b, c) ``` The emulation using BF16x3 is used when invoking tl.dot with input precision 'BF16x3'. This pass is implemented in a GPU agnostic manner, but it is needed support for MI350's lack of TF32 support. This part is a work in progress but will be based on this patch. Reviewed By: danzimm, NikhilAPatel Differential Revision: D86786661
Summary: Pull Request resolved: #667 This is a cherry-pick of triton-lang/triton#7592 This the D86786661 pick, but for beta [BACKEND] Implement BF16x3 trick (#7592) **Update:** I have found that for better perf, we need to use 3-6 BF16 dot products but not more. My findings are at: https://gist.github.com/plotfi/72554bd410ea55d8ae67b501c69b2766 The short version is that the Triton Bench tutorial matmul with F32 benefits by 60-70% using 3 BF16 dots or 10-15% using 6 BF16 dots. I think this is sufficient to move forward as a replacement for MI350s TF32 and is in line with what hipblas does: https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330 There is a similar implementation in XLA as well: https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152 -------- Implements emulation of a 32-bit floating point dot operation using 3 BF16s. This is based on https://arxiv.org/abs/1904.06376 and works because the mantisa of 3 BF16s add up to the mantisa of a fp32. Storing 1 fp32 in 3 bf16s: ```python def BF16(v): return v.to(torch.bfloat16) def FP32(v): return v.to(torch.float32) def BF16x3(v): b0 = BF16(original) b1 = BF16(original - FP32(b0)) b2 = BF16(original - FP32(b0) - FP32(b1)) return (b0, b1, b2) original = torch.rand(1, 1, dtype=torch.float32) bf16x3 = BF16x3(original) ``` Emulating multiplication of two fp32s: ```python def mul_bf16x3(a, b, c): a0, a1, a2 = BF16x3(a) b0, b1, b2 = BF16x3(b) c = c + (a0 * b0) # low low c = c + (a1 * b0) # mid low c = c + (a0 * b1) # low mid c = c + (a1 * b1) # mid mid c = c + (a0 * b2) # low hi c = c + (a2 * b0) # hi low c = c + (a1 * b2) # mid hi c = c + (a2 * b1) # hi mid c = c + (a2 * b2) # hi hi return c a = torch.rand(1, 1, dtype=torch.float32) b = torch.rand(1, 1, dtype=torch.float32) c = torch.zeros(1, 1, dtype=torch.float32) # accumulator result = mul_bf16x3(a, b, c) ``` The emulation using BF16x3 is used when invoking tl.dot with input precision 'BF16x3'. This pass is implemented in a GPU agnostic manner, but it is needed support for MI350's lack of TF32 support. This part is a work in progress but will be based on this patch. Reviewed By: agron911, danzimm, NikhilAPatel, jananisriram, Sibylau Differential Revision: D86786661 fbshipit-source-id: e254b1bdf61a400e1b55f9dc9d35a541460b1571
Update: I have found that for better perf, we need to use 3-6 BF16 dot products but not more. My findings are at:
https://gist.github.com/plotfi/72554bd410ea55d8ae67b501c69b2766
The short version is that the Triton Bench tutorial matmul with F32 benefits by 60-70% using 3 BF16 dots or 10-15% using 6 BF16 dots.
I think this is sufficient to move forward as a replacement for MI350s TF32 and is in line with what hipblas does:
https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipblaslt/tensilelite/Tensile/Components/LocalRead.py#L288-L330
There is a similar implementation in XLA as well: https://github.com/openxla/xla/blob/e33f93fb7220d408811afdc926cf10baaf49c64e/xla/backends/gpu/codegen/triton/dot_algorithms.cc#L152
Implements emulation of a 32-bit floating point dot operation using 3 BF16s. This is based on https://arxiv.org/abs/1904.06376 and works because the mantisa of 3 BF16s add up to the mantisa of a fp32.
Storing 1 fp32 in 3 bf16s:
Emulating multiplication of two fp32s:
The emulation using BF16x3 is used when invoking tl.dot with input precision 'BF16x3'. This pass is implemented in a GPU agnostic manner, but it is needed support for MI350's lack of TF32 support. This part is a work in progress but will be based on this patch.
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD.Select one of the following.
/testforlittests/python/testfor end-to-end tests