Skip to content

Commit 9bcb9b7

Browse files
abadamsmcourteaux
authored andcommitted
Reschedule the matrix multiply performance app (#8418)
Someone was using this as a reference expert schedule, but it was stale and a bit simplistic for large matrices. I rescheduled it to get a better fraction of peak. This also now demonstrates how to use rfactor to block an sgemm over the k axis.
1 parent a8966e9 commit 9bcb9b7

1 file changed

Lines changed: 54 additions & 20 deletions

File tree

test/performance/matrix_multiplication.cpp

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,44 +30,78 @@ int main(int argc, char **argv) {
3030
ImageParam A(type_of<float>(), 2);
3131
ImageParam B(type_of<float>(), 2);
3232

33-
Var x("x"), xi("xi"), xo("xo"), y("y"), yo("yo"), yi("yi"), yii("yii"), xii("xii");
34-
Func matrix_mul("matrix_mul");
35-
33+
Var x("x"), y("y");
3634
RDom k(0, matrix_size);
37-
RVar ki;
35+
36+
Func matrix_mul("matrix_mul");
3837

3938
matrix_mul(x, y) += A(k, y) * B(x, k);
4039

4140
Func out;
4241
out(x, y) = matrix_mul(x, y);
4342

44-
Var xy;
43+
// Now the schedule. Single-threaded, it hits 155 GFlops on Skylake-X
44+
// i9-9960x with AVX-512 (80% of peak), and 87 GFlops with AVX2 (90% of
45+
// peak).
46+
//
47+
// Using 16 threads (and no hyperthreading), hits 2080 GFlops (67% of peak)
48+
// and 1310 GFLops (85% of peak) respectively.
4549

46-
out.tile(x, y, xi, yi, 24, 32)
47-
.fuse(x, y, xy)
48-
.parallel(xy)
49-
.split(yi, yi, yii, 4)
50-
.vectorize(xi, 8)
50+
const int vec = target.natural_vector_size<float>();
51+
52+
// Size the inner loop tiles to fit into the number of registers available
53+
// on the target, using either 12 accumulator registers or 24.
54+
const int inner_tile_x = 3 * vec;
55+
const int inner_tile_y = (target.has_feature(Target::AVX512) || target.arch != Target::X86) ? 8 : 4;
56+
57+
// The shape of the outer tiling
58+
const int tile_y = matrix_size / 4;
59+
const int tile_k = matrix_size / 16;
60+
61+
Var xy("xy"), xi("xi"), yi("yi"), yii("yii");
62+
63+
out.tile(x, y, xi, yi, inner_tile_x, tile_y)
64+
.split(yi, yi, yii, inner_tile_y)
65+
.vectorize(xi, vec)
5166
.unroll(xi)
52-
.unroll(yii);
67+
.unroll(yii)
68+
.fuse(x, y, xy)
69+
.parallel(xy);
70+
71+
RVar ko("ko"), ki("ki");
72+
Var z("z");
73+
matrix_mul.update().split(k, ko, ki, tile_k);
74+
75+
// Factor the reduction so that we can do outer blocking over the reduction
76+
// dimension.
77+
Func intm = matrix_mul.update().rfactor(ko, z);
5378

54-
matrix_mul.compute_at(out, yi)
55-
.vectorize(x, 8)
79+
intm.compute_at(matrix_mul, y)
80+
.vectorize(x, vec)
81+
.unroll(x)
5682
.unroll(y);
5783

58-
matrix_mul.update(0)
59-
.reorder(x, y, k)
60-
.vectorize(x, 8)
84+
intm.update(0)
85+
.reorder(x, y, ki)
86+
.vectorize(x, vec)
6187
.unroll(x)
62-
.unroll(y)
63-
.unroll(k, 2);
88+
.unroll(y);
89+
90+
matrix_mul.compute_at(out, xy)
91+
.vectorize(x, vec)
92+
.unroll(x);
93+
94+
matrix_mul.update()
95+
.split(y, y, yi, inner_tile_y)
96+
.reorder(x, yi, y, ko)
97+
.vectorize(x, vec)
98+
.unroll(x)
99+
.unroll(yi);
64100

65101
out
66102
.bound(x, 0, matrix_size)
67103
.bound(y, 0, matrix_size);
68104

69-
out.compile_jit();
70-
71105
Buffer<float> mat_A(matrix_size, matrix_size);
72106
Buffer<float> mat_B(matrix_size, matrix_size);
73107
Buffer<float> output(matrix_size, matrix_size);

0 commit comments

Comments
 (0)