Skip to content

Commit 8132a2e

Browse files
committed
[BACKEND] Implement BF16x3 trick
Implements emulation of a 32-bit floating point dot operation using 3 BF16s. This is based on https://arxiv.org/abs/1904.06376 and works because the mantisa of 3 BF16s add up to the mantisa of a fp32. Storing 1 fp32 in 3 bf16s: ``` def BF16(v): return v.to(torch.bfloat16) def FP32(v): return v.to(torch.float32) def BF16x3(v): b0 = BF16(original) b1 = BF16(original - FP32(b0)) b2 = BF16(original - FP32(b0) - FP32(b1)) return (b0, b1, b2) original = torch.rand(1, 1, dtype=torch.float32) bf16x3 = BF16x3(original) ``` Emulating multiplication of two fp32s: ``` def mul_bf16x3(a, b, c): a0, a1, a2 = BF16x3(a) b0, b1, b2 = BF16x3(b) c = c + (a0 * b0) # low low c = c + (a1 * b0) # mid low c = c + (a0 * b1) # low mid c = c + (a1 * b1) # mid mid c = c + (a0 * b2) # low hi c = c + (a2 * b0) # hi low c = c + (a1 * b2) # mid hi c = c + (a2 * b1) # hi mid c = c + (a2 * b2) # hi hi return c a = torch.rand(1, 1, dtype=torch.float32) b = torch.rand(1, 1, dtype=torch.float32) c = torch.zeros(1, 1, dtype=torch.float32) # accumulator result = mul_bf16x3(a, b, c) ``` The emulation using BF16x3 is used when invoking tl.dot with input precision 'BF16x3'. This pass is implemented in a GPU agnostic manner, but it is needed support for MI350's lack of TF32 support. This part is a work in progress but will be based on this patch.
1 parent a7dab71 commit 8132a2e

File tree

12 files changed

+391
-5
lines changed

12 files changed

+391
-5
lines changed

03-matrix-multiplication.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
7+
8+
def is_cuda():
9+
return triton.runtime.driver.active.get_current_target().backend == "cuda"
10+
11+
def get_cuda_autotune_config():
12+
return [
13+
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,num_warps=2),
14+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,num_warps=2),
15+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4),
16+
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4),
17+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4),
18+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4),
19+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,num_warps=4),
20+
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,num_warps=8),
21+
]
22+
23+
def get_hip_autotune_config():
24+
sizes = [
25+
{'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
26+
{'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
27+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
28+
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
29+
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
30+
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
31+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4},
32+
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
33+
]
34+
return [triton.Config(s | {'matrix_instr_nonkdim': 16}, num_warps=8, num_stages=2) for s in sizes]
35+
36+
def get_autotune_config():
37+
if is_cuda():
38+
return get_cuda_autotune_config()
39+
else:
40+
return get_hip_autotune_config()
41+
42+
@triton.autotune(
43+
configs=get_autotune_config(),
44+
key=['M', 'N', 'K'],
45+
)
46+
@triton.jit
47+
def matmul_kernel(
48+
a_ptr, b_ptr, c_ptr,
49+
M, N, K,
50+
stride_am, stride_ak,
51+
stride_bk, stride_bn,
52+
stride_cm, stride_cn,
53+
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
54+
GROUP_SIZE_M: tl.constexpr,
55+
PRECISION: tl.constexpr
56+
):
57+
pid = tl.program_id(axis=0)
58+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
59+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
60+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
61+
group_id = pid // num_pid_in_group
62+
first_pid_m = group_id * GROUP_SIZE_M
63+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
64+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
65+
pid_n = (pid % num_pid_in_group) // group_size_m
66+
67+
tl.assume(pid_m >= 0)
68+
tl.assume(pid_n >= 0)
69+
tl.assume(stride_am > 0)
70+
tl.assume(stride_ak > 0)
71+
tl.assume(stride_bn > 0)
72+
tl.assume(stride_bk > 0)
73+
tl.assume(stride_cm > 0)
74+
tl.assume(stride_cn > 0)
75+
76+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
77+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
78+
offs_k = tl.arange(0, BLOCK_SIZE_K)
79+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
80+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
81+
82+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
83+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
84+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
85+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
86+
accumulator += tl.dot(a, b, input_precision=PRECISION)
87+
# accumulator = tl.dot(a, b, accumulator)
88+
# Advance the ptrs to the next K block.
89+
a_ptrs += BLOCK_SIZE_K * stride_ak
90+
b_ptrs += BLOCK_SIZE_K * stride_bk
91+
c = accumulator.to(tl.float32)
92+
93+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
94+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
95+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
96+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
97+
tl.store(c_ptrs, c, mask=c_mask)
98+
99+
100+
def matmul(a, b, precision="ieee"):
101+
M, K = a.shape
102+
K, N = b.shape
103+
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
104+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
105+
matmul_kernel[grid](
106+
a, b, c,
107+
M, N, K,
108+
a.stride(0), a.stride(1),
109+
b.stride(0), b.stride(1),
110+
c.stride(0), c.stride(1),
111+
PRECISION=precision
112+
)
113+
return c
114+
115+
116+
precisions = ["ieee", "bf16", "bf16x3", "bf16x6", "bf16x9"]
117+
torch.manual_seed(0)
118+
119+
for precision in precisions:
120+
a = torch.rand((512, 512), device=DEVICE, dtype=torch.float32) - 0.5
121+
b = torch.rand((512, 512), device=DEVICE, dtype=torch.float32) - 0.5
122+
triton_output = matmul(a, b, precision=precision)
123+
torch_output = torch.matmul(a, b)
124+
#print(f"triton_output_with_fp32_inputs={triton_output}")
125+
#print(f"torch_output_with_fp32_inputs={torch_output}")
126+
127+
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):
128+
print(f'✅ Triton and Torch match for input_precision={precision}')
129+
else:
130+
print(f'❌ Triton and Torch differ for input_precision={precision}')
131+
132+
ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS'
133+
134+
configs = []
135+
configs.append(
136+
triton.testing.Benchmark(
137+
x_names=["M", "N", "K"],
138+
x_vals=[128 * i for i in range(2, 33)],
139+
line_arg="provider",
140+
line_vals=[ref_lib.lower(), "triton-ieee", "triton-bf16", "triton-bf16x3", "triton-bf16x6", "triton-bf16x9"],
141+
line_names=[ref_lib, "Triton-IEEE", "Triton-BF16", "Triton-BF16x3", "Triton-BF16x6", "Triton-BF16x9"],
142+
styles=[("green", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-"), ("blue", "-")],
143+
ylabel="TFLOPS",
144+
plot_name="matmul-performance-f32",
145+
args={},
146+
))
147+
148+
@triton.testing.perf_report(configs)
149+
def benchmark(M, N, K, provider):
150+
a = torch.randn((M, K), device=DEVICE, dtype=torch.float32)
151+
b = torch.randn((K, N), device=DEVICE, dtype=torch.float32)
152+
quantiles = [0.5, 0.2, 0.8]
153+
if provider == ref_lib.lower():
154+
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
155+
if provider.startswith('triton-'):
156+
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, provider.removeprefix('triton-')), quantiles=quantiles)
157+
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
158+
return perf(ms), perf(max_ms), perf(min_ms)
159+
160+
benchmark.run(show_plots=False, print_data=True)

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,11 @@ def TT_InputPrecisionAttr : I32EnumAttr<
129129
[
130130
I32EnumAttrCase<"TF32", 0, "tf32">,
131131
I32EnumAttrCase<"TF32x3", 1, "tf32x3">,
132-
I32EnumAttrCase<"IEEE", 2, "ieee">
132+
I32EnumAttrCase<"IEEE", 2, "ieee">,
133+
I32EnumAttrCase<"BF16", 3, "bf16">,
134+
I32EnumAttrCase<"BF16x3", 4, "bf16x3">,
135+
I32EnumAttrCase<"BF16x6", 5, "bf16x6">,
136+
I32EnumAttrCase<"BF16x9", 6, "bf16x9">
133137
]>{
134138
let cppNamespace = "::mlir::triton";
135139
}

include/triton/Dialect/TritonGPU/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,17 @@ def TritonGPUF32DotTC : Pass<"tritongpu-F32DotTC", "mlir::ModuleOp"> {
188188
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
189189
}
190190

191+
def TritonGPUBF16DotTC : Pass<"tritongpu-BF16DotTC", "mlir::ModuleOp"> {
192+
let summary = "3xBF16 trick";
193+
194+
let description = [{
195+
Decompose fp32 `DotOp` instructions into BF16 operations.
196+
See https://arxiv.org/abs/1904.06376
197+
}];
198+
199+
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
200+
}
201+
191202
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
192203
let summary = "prefetch";
193204

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
#include "mlir/IR/PatternMatch.h"
2+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3+
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
4+
#include "llvm/Support/raw_ostream.h"
5+
6+
namespace mlir {
7+
namespace triton {
8+
namespace gpu {
9+
10+
#define GEN_PASS_DEF_TRITONGPUBF16DOTTC
11+
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
12+
13+
namespace {
14+
15+
// Implement 3xBF16 https://arxiv.org/abs/1904.06376
16+
class BF16x3 : public OpRewritePattern<DotOp> {
17+
public:
18+
using OpRewritePattern::OpRewritePattern;
19+
20+
LogicalResult matchAndRewrite(DotOp dotOp,
21+
PatternRewriter &rewriter) const override {
22+
switch (dotOp.getInputPrecision()) {
23+
case InputPrecision::BF16:
24+
case InputPrecision::BF16x3:
25+
case InputPrecision::BF16x6:
26+
case InputPrecision::BF16x9:
27+
break;
28+
default:
29+
return failure();
30+
}
31+
32+
auto isF32 = [](Value operand) {
33+
return cast<RankedTensorType>(operand.getType()).getElementType().isF32();
34+
};
35+
if (!isF32(dotOp.getA()) || !isF32(dotOp.getB())) {
36+
return failure();
37+
}
38+
39+
// Aux functions
40+
auto f32ToBF16 = [&](Value value) -> Value {
41+
auto fp32Type = cast<RankedTensorType>(value.getType());
42+
auto bf16Type =
43+
RankedTensorType::get(fp32Type.getShape(), rewriter.getBF16Type(), fp32Type.getEncoding());
44+
return rewriter.create<arith::TruncFOp>(dotOp.getLoc(), bf16Type, value)
45+
.getResult();
46+
};
47+
auto bf16ToF32 = [&](Value value) -> Value {
48+
auto bf16Type = cast<RankedTensorType>(value.getType());
49+
auto fp32Type =
50+
RankedTensorType::get(bf16Type.getShape(), rewriter.getF32Type(), bf16Type.getEncoding());
51+
return rewriter.create<arith::ExtFOp>(dotOp.getLoc(), fp32Type, value)
52+
.getResult();
53+
};
54+
auto zeroLike = [&](Value c) -> Value {
55+
return rewriter.create<SplatOp>(
56+
dotOp->getLoc(), c.getType(),
57+
rewriter.create<arith::ConstantOp>(dotOp->getLoc(),
58+
rewriter.getF32FloatAttr(0)));
59+
};
60+
auto add = [&](Value a, Value b) -> Value {
61+
return rewriter.create<arith::AddFOp>(dotOp.getLoc(), a, b);
62+
};
63+
auto sub = [&](Value a, Value b) -> Value {
64+
return rewriter.create<arith::SubFOp>(dotOp.getLoc(), a, b);
65+
};
66+
auto dot = [&](Value a, Value b, Value c) -> Value {
67+
return rewriter.create<DotOp>(dotOp->getLoc(), c.getType(), a, b, c,
68+
InputPrecision::BF16,
69+
dotOp.getMaxNumImpreciseAcc());
70+
};
71+
auto replaceNansWithZeros = [&](Value value) -> Value {
72+
auto nans = rewriter.create<arith::CmpFOp>(
73+
dotOp->getLoc(), arith::CmpFPredicate::UNO, value, value);
74+
auto zero = zeroLike(value);
75+
return rewriter.create<arith::SelectOp>(dotOp->getLoc(), nans, zero,
76+
value);
77+
};
78+
79+
auto SplitF32 = [&](Value input, unsigned N) -> std::vector<Value> {
80+
std::vector<Value> split_inputs;
81+
split_inputs.reserve(N);
82+
for (int i = 0; i < N; ++i) {
83+
Value input_as_bf16 = f32ToBF16(input);
84+
if (i != N - 1) {
85+
Value input_as_f32 = bf16ToF32(input_as_bf16);
86+
input = rewriter.create<arith::SubFOp>(dotOp->getLoc(), input,
87+
input_as_f32);
88+
}
89+
split_inputs.push_back(input_as_bf16);
90+
}
91+
return split_inputs;
92+
};
93+
94+
const int hi = 0;
95+
const int med = 1;
96+
const int lo = 2;
97+
98+
const unsigned N = 3;
99+
auto lhs_parts = SplitF32(dotOp.getA(), N);
100+
auto rhs_parts = SplitF32(dotOp.getB(), N);
101+
102+
auto result = zeroLike(dotOp.getC());
103+
104+
if (dotOp.getInputPrecision() == InputPrecision::BF16x9) {
105+
result = dot(lhs_parts[lo], rhs_parts[lo], result);
106+
result = dot(lhs_parts[med], rhs_parts[lo], result);
107+
result = dot(lhs_parts[lo], rhs_parts[med], result);
108+
109+
result = dot(lhs_parts[med], rhs_parts[med], result);
110+
111+
result = dot(lhs_parts[lo], rhs_parts[hi], result);
112+
result = dot(lhs_parts[hi], rhs_parts[lo], result);
113+
114+
result = dot(lhs_parts[med], rhs_parts[hi], result);
115+
result = dot(lhs_parts[hi], rhs_parts[med], result);
116+
117+
} else if (dotOp.getInputPrecision() == InputPrecision::BF16x6) {
118+
result = dot(lhs_parts[med], rhs_parts[med], result);
119+
120+
result = dot(lhs_parts[lo], rhs_parts[hi], result);
121+
result = dot(lhs_parts[hi], rhs_parts[lo], result);
122+
123+
result = dot(lhs_parts[med], rhs_parts[hi], result);
124+
result = dot(lhs_parts[hi], rhs_parts[med], result);
125+
126+
} else if (dotOp.getInputPrecision() == InputPrecision::BF16x3) {
127+
result = dot(lhs_parts[med], rhs_parts[hi], result);
128+
result = dot(lhs_parts[hi], rhs_parts[med], result);
129+
}
130+
131+
result = replaceNansWithZeros(result);
132+
result = dot(lhs_parts[hi], rhs_parts[hi], result);
133+
result = add(result, dotOp.getC());
134+
135+
rewriter.replaceOp(dotOp, result);
136+
return success();
137+
}
138+
};
139+
140+
} // anonymous namespace
141+
142+
struct BF16DotTCPass : public impl::TritonGPUBF16DotTCBase<BF16DotTCPass> {
143+
void runOnOperation() override {
144+
MLIRContext *context = &getContext();
145+
ModuleOp m = getOperation();
146+
147+
RewritePatternSet decomposePatterns(context);
148+
decomposePatterns.add<BF16x3>(context);
149+
if (applyPatternsGreedily(m, std::move(decomposePatterns)).failed()) {
150+
signalPassFailure();
151+
}
152+
}
153+
};
154+
155+
} // namespace gpu
156+
} // namespace triton
157+
} // namespace mlir

lib/Dialect/TritonGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_triton_library(TritonGPUTransforms
22
AccelerateMatmul.cpp
33
Coalesce.cpp
44
F32DotTC.cpp
5+
BF16DotTC.cpp
56
FuseNestedLoops.cpp
67
CombineTensorSelectAndIf.cpp
78
DecomposeScaledBlocked.cpp

python/src/ir.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ void init_triton_ir(py::module &&m) {
308308
.value("TF32", InputPrecision::TF32)
309309
.value("TF32x3", InputPrecision::TF32x3)
310310
.value("IEEE", InputPrecision::IEEE)
311+
.value("BF16", InputPrecision::BF16)
312+
.value("BF16x3", InputPrecision::BF16x3)
313+
.value("BF16x6", InputPrecision::BF16x6)
314+
.value("BF16x9", InputPrecision::BF16x9)
311315
.export_values();
312316

313317
py::enum_<ScaleDotElemType>(m, "ScaleDotElemTypeTY", py::module_local())

python/src/passes.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ void init_triton_passes_ttgpuir(py::module &&m) {
7171
ADD_PASS_WRAPPER_0("add_accelerate_matmul", createTritonGPUAccelerateMatmul);
7272
ADD_PASS_WRAPPER_0("add_reorder_instructions",
7373
createTritonGPUReorderInstructions);
74+
ADD_PASS_WRAPPER_0("add_bf16_dot_tc", createTritonGPUBF16DotTC);
7475
ADD_PASS_WRAPPER_0("add_f32_dot_tc", createTritonGPUF32DotTC);
7576
ADD_PASS_OPTION_WRAPPER_1("add_optimize_dot_operands",
7677
createTritonGPUOptimizeDotOperands, bool);

python/test/unit/language/test_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3083,7 +3083,7 @@ def get_test_dot_base_cases():
30833083
return [(*shape, 4, False, False, epilogue, input_precision, in_dtype, out_dtype, 1, None)
30843084
for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)]
30853085
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot']
3086-
for input_precision in ['tf32', 'tf32x3', 'ieee']
3086+
for input_precision in ['tf32', 'tf32x3', 'ieee', 'bf16', 'bf16x3']
30873087
for in_dtype, out_dtype in [('float16', 'float16'), ('float16',
30883088
'float32'), ('float32',
30893089
'float32'), ('float64', 'float64')]
@@ -3237,7 +3237,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
32373237
pytest.skip(f"{in_dtype} only supported on CDNA4 and gfx12")
32383238
if in_dtype in ("float8e5b16", "float8e4b8") and not is_hip_cdna3():
32393239
pytest.skip(f"{in_dtype} only supported on CDNA3")
3240-
if not ((input_precision == "ieee") or (input_precision == "tf32" and is_hip_cdna3())):
3240+
if not ((input_precision == "bf16x3") or (input_precision == "ieee") or
3241+
(input_precision == "tf32" and is_hip_cdna3())):
32413242
pytest.skip(f"{input_precision} not supported on HIP")
32423243
if kpack == 2 and in_dtype == 'int8' and K < 64:
32433244
pytest.skip("kpack too large for K")

0 commit comments

Comments
 (0)