Skip to content

Commit 24a5d53

Browse files
committed
Adding the circulant graph queued variable ring algorithm for Bcast.
This algorithm achieves better performance than existing algorithms for both small and large message sizes. The algorithms is based on the circulant graph abstraction and Jesper Larsson Traff's recent paper: https://dl.acm.org/doi/full/10.1145/3735139. It creates communication schedules around various rings in the circulant graph, then repeats the schedule to pipeline message chunks. We introduce a FIFO queue for overlapping sends and receives across communication rounds, which particularly benefits small messages. In the graph below, we show the algorithm's performance for a fixed chunk size (256k) and queue length (24) for various scales on ANL Aurora (N, PPN). The baseline for this graph is the best-performing algorithm currently in MPICH, so all speedups represent improvements over all algorithms currently in the library. We note that the performance drops around our selected chunk size (256k). By tuning the chunk size near this message size, it is possible to achieve a speedup across all message sizes for all scales.
1 parent 7fcdc20 commit 24a5d53

File tree

7 files changed

+413
-1
lines changed

7 files changed

+413
-1
lines changed

src/mpi/coll/bcast/Makefile.mk

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mpi_core_sources += \
1313
src/mpi/coll/bcast/bcast_intra_binomial.c \
1414
src/mpi/coll/bcast/bcast_intra_scatter_recursive_doubling_allgather.c \
1515
src/mpi/coll/bcast/bcast_intra_scatter_ring_allgather.c \
16+
src/mpi/coll/bcast/bcast_intra_circ_qvring.c \
1617
src/mpi/coll/bcast/bcast_intra_smp.c \
1718
src/mpi/coll/bcast/bcast_intra_tree.c \
1819
src/mpi/coll/bcast/bcast_intra_pipelined_tree.c \
Lines changed: 364 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,364 @@
1+
#include "mpiimpl.h"
2+
3+
/* Algorithm: Circulant graph queued variable ring bcast
4+
* This algorithm is based on the paper by Jesper Larsson Traff:
5+
* https://dl.acm.org/doi/full/10.1145/3735139, with additional optimizations.
6+
* It is optimal for both small and large message sizes.
7+
*/
8+
9+
struct sched_args_t {
10+
int* skips;
11+
int* send_sched;
12+
int* next;
13+
int* prev;
14+
int* extra;
15+
int tree_depth;
16+
int comm_size;
17+
};
18+
19+
struct queue_tracker_t {
20+
MPIR_Request* req;
21+
int chunk_id;
22+
int need_wait;
23+
};
24+
25+
static int all_blocks(int r, int r_, int s, int e, int k, int* buffer, struct sched_args_t* args);
26+
static void gen_rsched(int r, int* buffer, struct sched_args_t* args);
27+
static void gen_ssched(int r, struct sched_args_t* args);
28+
29+
static int get_baseblock(int r, struct sched_args_t* args);
30+
31+
int MPIR_Bcast_intra_circ_qvring(void *buffer,
32+
MPI_Aint count,
33+
MPI_Datatype datatype,
34+
int root, MPIR_Comm* comm,
35+
const int chunk_size, const int q_len, int coll_attr)
36+
{
37+
int mpi_errno = MPI_SUCCESS;
38+
39+
int comm_size = comm->local_size;
40+
int rank = comm->rank;
41+
42+
if (comm_size < 2) {
43+
goto fn_exit;
44+
}
45+
46+
int depth = 1;
47+
while (0x1<<depth < comm_size) {
48+
depth++;
49+
}
50+
51+
/* Circulant Graph Queued Variable Ring Bcast:
52+
* This algorithm uses the circulant graph abstraction to create communication
53+
* rings of every "n" processes, where "n" values are called "skips".
54+
* We use a roughly doubling pattern to generate the skips, creating log2(p)
55+
* communication rounds, making the algorithm latency-optimal for small messages.
56+
*
57+
* The algorithm works by generating a schedule of sends and receives based on the
58+
* number of processes (see paper for algorithm). The schedule is then executed to
59+
* complete the broadcast. The message size can be broken into chunks and pipelined,
60+
* repeating the send schedule for each chunk, to optimize for bandwidth/large messages.
61+
*
62+
* This implementation also includes a queued send optimization, where "q_len" sends
63+
* can be outstanding simultaneously. Bcast is a one->many operation, so the root can
64+
* begin the algorithm with "q_len" non-blocking sends, and each receiving process then
65+
* does the same. New sends are introduced when previous sends complete using a FIFO
66+
* queue. This optimization further improves small-message performance.
67+
*/
68+
69+
MPIR_CHKLMEM_DECL();
70+
int* skips; int* recv_sched; int* send_sched; int* next; int* prev; int* extra;
71+
MPIR_CHKLMEM_MALLOC(skips, (depth + 1) * sizeof(int));
72+
MPIR_CHKLMEM_MALLOC(recv_sched, depth * sizeof(int));
73+
MPIR_CHKLMEM_MALLOC(send_sched, depth * sizeof(int));
74+
MPIR_CHKLMEM_MALLOC(next, (depth + 2) * sizeof(int));
75+
MPIR_CHKLMEM_MALLOC(prev, (depth + 2) * sizeof(int));
76+
MPIR_CHKLMEM_MALLOC(extra, depth * sizeof(int));
77+
78+
// Precalculate skips (roughly doubling)
79+
skips[depth] = comm_size;
80+
for (int i = depth - 1; i >= 0; i--) {
81+
skips[i] = (skips[i+1] / 2) + (skips[i+1] & 0x1);
82+
}
83+
84+
// Generate send and receive schedules
85+
struct sched_args_t args = {
86+
skips, send_sched, next + 1, prev + 1, extra,
87+
depth, comm_size
88+
};
89+
gen_rsched(rank, recv_sched, &args);
90+
gen_ssched(rank, &args);
91+
92+
// Datatype Handling:
93+
MPI_Aint type_size;
94+
int is_contig;
95+
int buf_size;
96+
97+
MPIR_Datatype_get_size_macro(datatype, type_size);
98+
buf_size = count * type_size;
99+
100+
if (buf_size == 0) goto dealloc;
101+
102+
if (HANDLE_IS_BUILTIN(datatype))
103+
is_contig = 1;
104+
else {
105+
MPIR_Datatype_is_contig(datatype, &is_contig);
106+
}
107+
void* tmp_buf;
108+
if (is_contig) {
109+
tmp_buf = buffer;
110+
} else {
111+
MPIR_CHKLMEM_MALLOC(tmp_buf, buf_size);
112+
if (rank == root) {
113+
mpi_errno = MPIR_Localcopy(buffer, count, datatype, tmp_buf, buf_size, MPIR_BYTE_INTERNAL);
114+
MPIR_ERR_CHECK(mpi_errno);
115+
}
116+
}
117+
118+
// Handle pipeline chunks
119+
int n_chunk;
120+
int last_msg_size;
121+
122+
if (chunk_size == 0) {
123+
n_chunk = 1;
124+
last_msg_size = buf_size;
125+
} else {
126+
n_chunk = (buf_size / chunk_size) + (buf_size % chunk_size != 0);
127+
last_msg_size = (buf_size % chunk_size == 0)
128+
? chunk_size
129+
: buf_size % chunk_size;
130+
}
131+
132+
char* can_send;
133+
MPIR_CHKLMEM_MALLOC(can_send, n_chunk * sizeof(char));
134+
for (int i = 0; i < n_chunk; i++) {
135+
can_send[i] = (rank == root);
136+
}
137+
138+
// Run schedule
139+
int x = (((depth - ((n_chunk - 1) % depth)) % depth) + depth) % depth;
140+
int offset = -x;
141+
142+
int tru_ql = q_len;
143+
if (tru_ql < 1) tru_ql = 1;
144+
145+
struct queue_tracker_t* requests;
146+
MPIR_CHKLMEM_MALLOC(requests, tru_ql * sizeof(struct queue_tracker_t));
147+
for (int i = 0; i < tru_ql; i++) {
148+
requests[i].need_wait = 0;
149+
}
150+
151+
int q_head = 0;
152+
int q_tail = 0;
153+
int q_used = 0;
154+
for (int i = x; i < n_chunk - 1 + depth + x; i++) {
155+
int k = i % depth;
156+
157+
if (send_sched[k] + offset >= 0) {
158+
int peer = (rank + skips[k]) % comm_size;
159+
if (peer) {
160+
int send_block = send_sched[k] + offset;
161+
if (send_block >= n_chunk) send_block = n_chunk - 1;
162+
int msg_size = (send_block != n_chunk - 1) ? chunk_size : last_msg_size;
163+
164+
if (can_send[send_block] == 0) {
165+
for (int j = 0; j < tru_ql; j++) {
166+
if (requests[j].chunk_id == send_block) {
167+
mpi_errno = MPIC_Wait(requests[j].req);
168+
MPIR_ERR_CHECK(mpi_errno);
169+
requests[j].need_wait = 0;
170+
MPIR_Request_free(requests[j].req);
171+
172+
can_send[send_block] = 1;
173+
break;
174+
}
175+
}
176+
}
177+
178+
mpi_errno = MPIC_Isend(((char*) tmp_buf) + (chunk_size * send_block), msg_size, MPIR_BYTE_INTERNAL, peer, MPIR_BCAST_TAG, comm, &(requests[q_head].req), coll_attr);
179+
MPIR_ERR_CHECK(mpi_errno);
180+
requests[q_head].chunk_id = -1;
181+
requests[q_head].need_wait = 1;
182+
183+
q_head = (q_head + 1) % tru_ql;
184+
q_used = 1;
185+
}
186+
}
187+
188+
if (q_used && q_head == q_tail) {
189+
if (requests[q_tail].need_wait) {
190+
mpi_errno = MPIC_Wait(requests[q_tail].req);
191+
MPIR_ERR_CHECK(mpi_errno);
192+
requests[q_tail].need_wait = 0;
193+
MPIR_Request_free(requests[q_tail].req);
194+
}
195+
196+
if (requests[q_tail].chunk_id != -1) {
197+
can_send[requests[q_tail].chunk_id] = 1;
198+
}
199+
200+
q_tail = (q_tail + 1) % tru_ql;
201+
}
202+
203+
if (recv_sched[k] + offset >= 0 && (rank != root)) {
204+
int peer = (rank - skips[k] + comm_size) % comm_size;
205+
206+
int recv_block = recv_sched[k] + offset;
207+
if (recv_block >= n_chunk) recv_block = n_chunk - 1;
208+
int msg_size = (recv_block != n_chunk - 1) ? chunk_size : last_msg_size;
209+
210+
mpi_errno = MPIC_Irecv(((char*) tmp_buf) + (chunk_size * recv_block), msg_size, MPIR_BYTE_INTERNAL, peer, MPIR_BCAST_TAG, comm, &(requests[q_head].req));
211+
MPIR_ERR_CHECK(mpi_errno);
212+
requests[q_head].chunk_id = recv_block;
213+
requests[q_head].need_wait = 1;
214+
215+
q_head = (q_head + 1) % tru_ql;
216+
q_used = 1;
217+
}
218+
219+
if (q_used && q_head == q_tail) {
220+
if (requests[q_tail].need_wait) {
221+
mpi_errno = MPIC_Wait(requests[q_tail].req);
222+
MPIR_ERR_CHECK(mpi_errno);
223+
requests[q_tail].need_wait = 0;
224+
MPIR_Request_free(requests[q_tail].req);
225+
}
226+
227+
if (requests[q_tail].chunk_id != -1) {
228+
can_send[requests[q_tail].chunk_id] = 1;
229+
}
230+
231+
q_tail = (q_tail + 1) % tru_ql;
232+
}
233+
234+
if (k == depth - 1) {
235+
offset += depth;
236+
}
237+
}
238+
239+
for (int i = 0; i < tru_ql; i++) {
240+
if (requests[i].need_wait) {
241+
mpi_errno = MPIC_Wait(requests[i].req);
242+
MPIR_ERR_CHECK(mpi_errno);
243+
MPIR_Request_free(requests[i].req);
244+
}
245+
}
246+
247+
if (!is_contig) {
248+
mpi_errno = MPIR_Localcopy(tmp_buf, buf_size, MPIR_BYTE_INTERNAL, buffer, count, datatype);
249+
MPIR_ERR_CHECK(mpi_errno);
250+
}
251+
252+
dealloc:
253+
MPIR_CHKLMEM_FREEALL();
254+
255+
fn_exit:
256+
return mpi_errno;
257+
fn_fail:
258+
printf("saude :(\n");
259+
goto fn_exit;
260+
}
261+
262+
//////// HELPER FUNCTIONS ////////
263+
static int all_blocks(int r, int r_, int s, int e, int k, int* buffer, struct sched_args_t* args) {
264+
while (e != -1) {
265+
if ((r_ + args->skips[e] <= r - args->skips[k])
266+
&& (r_ + args->skips[e] < s)) {
267+
if (r_ + args->skips[e] <= r - args->skips[k+1]) {
268+
k = all_blocks(r, r_ + args->skips[e], s, e, k, buffer, args);
269+
}
270+
if (r_ > r - args->skips[k+1]) {
271+
return k;
272+
}
273+
s = r_ + args->skips[e];
274+
buffer[k] = e;
275+
k += 1;
276+
args->next[args->prev[e]] = args->next[e];
277+
args->prev[args->next[e]] = args->prev[e];
278+
}
279+
e = args->next[e];
280+
}
281+
return k;
282+
}
283+
284+
static void gen_rsched(int r, int* buffer, struct sched_args_t* args) {
285+
for (int i = 0; i <= args->tree_depth; i++) {
286+
args->next[i] = i - 1;
287+
args->prev[i] = i + 1;
288+
}
289+
args->prev[args->tree_depth] = -1;
290+
args->next[-1] = args->tree_depth;
291+
args->prev[-1] = 0;
292+
293+
int b = get_baseblock(r, args);
294+
295+
args->next[args->prev[b]] = args->next[b];
296+
args->prev[args->next[b]] = args->prev[b];
297+
298+
all_blocks(args->comm_size + r, 0, args->comm_size * 2, args->tree_depth, 0, buffer, args);
299+
300+
for (int i = 0; i < args->tree_depth; i++) {
301+
if (buffer[i] == args->tree_depth) {
302+
buffer[i] = b;
303+
} else {
304+
buffer[i] = buffer[i] - args->tree_depth;
305+
}
306+
}
307+
}
308+
309+
static void gen_ssched(int r, struct sched_args_t* args) {
310+
if (r == 0) {
311+
for (int i = 0; i < args->tree_depth; i++) {
312+
args->send_sched[i] = i;
313+
}
314+
return;
315+
}
316+
317+
int b = get_baseblock(r, args);
318+
319+
int r_ = r;
320+
int c = b;
321+
int e = args->comm_size;
322+
for (int i = args->tree_depth - 1; i > 0; i--) {
323+
if (r_ < args->skips[i]) {
324+
if ((r_ + args->skips[i] < e)
325+
|| (e < args->skips[i-1])
326+
|| ((i == 1)
327+
&& (b > 0))) {
328+
args->send_sched[i] = c;
329+
} else {
330+
gen_rsched((r + args->skips[i]) % args->comm_size, args->extra, args);
331+
args->send_sched[i] = args->extra[i];
332+
}
333+
if (e > args->skips[i]) {
334+
e = args->skips[i];
335+
}
336+
} else {
337+
c = i - args->tree_depth;
338+
e = e - args->skips[i];
339+
if ((r_ > args->skips[i])
340+
|| (r_ <= e)
341+
|| (i == 1)
342+
|| (e < args->skips[i-1])) {
343+
args->send_sched[i] = c;
344+
} else {
345+
gen_rsched((r + args->skips[i]) % args->comm_size, args->extra, args);
346+
args->send_sched[i] = args->extra[i];
347+
}
348+
r_ -= args->skips[i];
349+
}
350+
}
351+
args->send_sched[0] = b - args->tree_depth;
352+
}
353+
354+
static int get_baseblock(int r, struct sched_args_t* args) {
355+
int r_ = 0;
356+
for (int i = args->tree_depth - 1; i >= 0; i--) {
357+
if (r_ + args->skips[i] == r) {
358+
return i;
359+
} else if (r_ + args->skips[i] < r) {
360+
r_ += args->skips[i];
361+
}
362+
}
363+
return args->tree_depth;
364+
}

src/mpi/coll/coll_algorithms.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ bcast-intra:
6666
binomial
6767
scatter_recursive_doubling_allgather
6868
scatter_ring_allgather
69+
circ_qvring
70+
extra_params: chunk_size, q_len
71+
cvar_params: CIRC_CHUNK_SIZE, CIRC_Q_LEN
6972
smp
7073
restrictions: parent-comm
7174
tree

0 commit comments

Comments
 (0)