Skip to content

Commit e03bedb

Browse files
Varun Sundar RabindranathLucasWilkinson
authored andcommitted
add reduce kernel grouping
1 parent b54298c commit e03bedb

File tree

1 file changed

+91
-4
lines changed

1 file changed

+91
-4
lines changed

vllm/profiler/visualize_layerwise_profile.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,36 @@ def is_mem_op(op_name: str):
172172
def is_vocab_embedding_op(op_name: str):
173173
return "vocabparallelembed" in op_name.lower()
174174

175+
# nccl ops
176+
def is_nccl_op(op_name: str):
177+
return "nccl" in op_name.lower()
178+
179+
def is_nccl_all_reduce(op_name: str):
180+
return is_nccl_op(op_name) and \
181+
("all_reduce" in op_name.lower() or \
182+
"allreduce" in op_name.lower())
183+
184+
def is_nccl_gather(op_name: str):
185+
return is_nccl_op(op_name) and \
186+
"gather" in op_name.lower()
187+
188+
def is_nccl_broadcast(op_name: str):
189+
return is_nccl_op(op_name) and \
190+
"broadcast" in op_name.lower()
191+
192+
# Reduce ops types
193+
def is_cross_device_reduce_1stage(op_name: str):
194+
return "cross_device_reduce_1stage" in op_name
195+
196+
def is_cross_device_reduce_2stage(op_name: str):
197+
return "cross_device_reduce_2stage" in op_name
198+
199+
def is_custom_ar_all_reduce_unreg(op_name: str):
200+
return "_C_custom_ar::all_reduce_unreg" in op_name
201+
202+
def is_reduce_kernel(op_name: str):
203+
return "reduce_kernel" in op_name
204+
175205
headers = list(trace_df)
176206
ops = copy.deepcopy(headers)
177207

@@ -196,6 +226,33 @@ def is_vocab_embedding_op(op_name: str):
196226
elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops))
197227
ops = list(filter(lambda x: x not in elementwise_ops, ops))
198228

229+
nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops))
230+
ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops))
231+
232+
nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops))
233+
ops = list(filter(lambda x: x not in nccl_gather_ops, ops))
234+
235+
nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops))
236+
ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops))
237+
238+
nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops))
239+
ops = list(filter(lambda x: x not in nccl_other_ops, ops))
240+
241+
cross_device_reduce_1stage_ops = list(
242+
filter(lambda x: is_cross_device_reduce_1stage(x), ops))
243+
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))
244+
245+
cross_device_reduce_2stage_ops = list(
246+
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
247+
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
248+
249+
custom_ar_all_reduce_unreg_ops = list(
250+
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops))
251+
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops))
252+
253+
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
254+
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
255+
199256
if len(attention_ops):
200257
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
201258
if len(quant_ops):
@@ -213,10 +270,40 @@ def is_vocab_embedding_op(op_name: str):
213270
trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum",
214271
axis=1)
215272

216-
trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops +
217-
vocab_embed_ops + mem_ops + elementwise_ops,
218-
axis=1,
219-
inplace=True)
273+
if len(nccl_all_reduce_ops):
274+
trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg(
275+
"sum", axis=1)
276+
if len(nccl_gather_ops):
277+
trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum",
278+
axis=1)
279+
if len(nccl_broadcast_ops):
280+
trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg(
281+
"sum", axis=1)
282+
if len(nccl_other_ops):
283+
trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum",
284+
axis=1)
285+
286+
if len(cross_device_reduce_1stage_ops):
287+
trace_df['cross_device_reduce_1stage_ops'] = trace_df[
288+
cross_device_reduce_1stage_ops].agg("sum", axis=1)
289+
if len(cross_device_reduce_2stage_ops):
290+
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
291+
cross_device_reduce_2stage_ops].agg("sum", axis=1)
292+
if len(custom_ar_all_reduce_unreg_ops):
293+
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[
294+
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1)
295+
if len(reduce_kernel_ops):
296+
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
297+
axis=1)
298+
299+
trace_df.drop(
300+
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
301+
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
302+
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
303+
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
304+
reduce_kernel_ops,
305+
axis=1,
306+
inplace=True)
220307
return trace_df
221308

222309

0 commit comments

Comments
 (0)