@@ -30,44 +30,78 @@ int main(int argc, char **argv) {
3030 ImageParam A (type_of<float >(), 2 );
3131 ImageParam B (type_of<float >(), 2 );
3232
33- Var x (" x" ), xi (" xi" ), xo (" xo" ), y (" y" ), yo (" yo" ), yi (" yi" ), yii (" yii" ), xii (" xii" );
34- Func matrix_mul (" matrix_mul" );
35-
33+ Var x (" x" ), y (" y" );
3634 RDom k (0 , matrix_size);
37- RVar ki;
35+
36+ Func matrix_mul (" matrix_mul" );
3837
3938 matrix_mul (x, y) += A (k, y) * B (x, k);
4039
4140 Func out;
4241 out (x, y) = matrix_mul (x, y);
4342
44- Var xy;
43+ // Now the schedule. Single-threaded, it hits 155 GFlops on Skylake-X
44+ // i9-9960x with AVX-512 (80% of peak), and 87 GFlops with AVX2 (90% of
45+ // peak).
46+ //
47+ // Using 16 threads (and no hyperthreading), hits 2080 GFlops (67% of peak)
48+ // and 1310 GFLops (85% of peak) respectively.
4549
46- out.tile (x, y, xi, yi, 24 , 32 )
47- .fuse (x, y, xy)
48- .parallel (xy)
49- .split (yi, yi, yii, 4 )
50- .vectorize (xi, 8 )
50+ const int vec = target.natural_vector_size <float >();
51+
52+ // Size the inner loop tiles to fit into the number of registers available
53+ // on the target, using either 12 accumulator registers or 24.
54+ const int inner_tile_x = 3 * vec;
55+ const int inner_tile_y = (target.has_feature (Target::AVX512) || target.arch != Target::X86) ? 8 : 4 ;
56+
57+ // The shape of the outer tiling
58+ const int tile_y = matrix_size / 4 ;
59+ const int tile_k = matrix_size / 16 ;
60+
61+ Var xy (" xy" ), xi (" xi" ), yi (" yi" ), yii (" yii" );
62+
63+ out.tile (x, y, xi, yi, inner_tile_x, tile_y)
64+ .split (yi, yi, yii, inner_tile_y)
65+ .vectorize (xi, vec)
5166 .unroll (xi)
52- .unroll (yii);
67+ .unroll (yii)
68+ .fuse (x, y, xy)
69+ .parallel (xy);
70+
71+ RVar ko (" ko" ), ki (" ki" );
72+ Var z (" z" );
73+ matrix_mul.update ().split (k, ko, ki, tile_k);
74+
75+ // Factor the reduction so that we can do outer blocking over the reduction
76+ // dimension.
77+ Func intm = matrix_mul.update ().rfactor (ko, z);
5378
54- matrix_mul.compute_at (out, yi)
55- .vectorize (x, 8 )
79+ intm.compute_at (matrix_mul, y)
80+ .vectorize (x, vec)
81+ .unroll (x)
5682 .unroll (y);
5783
58- matrix_mul .update (0 )
59- .reorder (x, y, k )
60- .vectorize (x, 8 )
84+ intm .update (0 )
85+ .reorder (x, y, ki )
86+ .vectorize (x, vec )
6187 .unroll (x)
62- .unroll (y)
63- .unroll (k, 2 );
88+ .unroll (y);
89+
90+ matrix_mul.compute_at (out, xy)
91+ .vectorize (x, vec)
92+ .unroll (x);
93+
94+ matrix_mul.update ()
95+ .split (y, y, yi, inner_tile_y)
96+ .reorder (x, yi, y, ko)
97+ .vectorize (x, vec)
98+ .unroll (x)
99+ .unroll (yi);
64100
65101 out
66102 .bound (x, 0 , matrix_size)
67103 .bound (y, 0 , matrix_size);
68104
69- out.compile_jit ();
70-
71105 Buffer<float > mat_A (matrix_size, matrix_size);
72106 Buffer<float > mat_B (matrix_size, matrix_size);
73107 Buffer<float > output (matrix_size, matrix_size);
0 commit comments