@@ -172,6 +172,36 @@ def is_mem_op(op_name: str):
172172 def is_vocab_embedding_op (op_name : str ):
173173 return "vocabparallelembed" in op_name .lower ()
174174
175+ # nccl ops
176+ def is_nccl_op (op_name : str ):
177+ return "nccl" in op_name .lower ()
178+
179+ def is_nccl_all_reduce (op_name : str ):
180+ return is_nccl_op (op_name ) and \
181+ ("all_reduce" in op_name .lower () or \
182+ "allreduce" in op_name .lower ())
183+
184+ def is_nccl_gather (op_name : str ):
185+ return is_nccl_op (op_name ) and \
186+ "gather" in op_name .lower ()
187+
188+ def is_nccl_broadcast (op_name : str ):
189+ return is_nccl_op (op_name ) and \
190+ "broadcast" in op_name .lower ()
191+
192+ # Reduce ops types
193+ def is_cross_device_reduce_1stage (op_name : str ):
194+ return "cross_device_reduce_1stage" in op_name
195+
196+ def is_cross_device_reduce_2stage (op_name : str ):
197+ return "cross_device_reduce_2stage" in op_name
198+
199+ def is_custom_ar_all_reduce_unreg (op_name : str ):
200+ return "_C_custom_ar::all_reduce_unreg" in op_name
201+
202+ def is_reduce_kernel (op_name : str ):
203+ return "reduce_kernel" in op_name
204+
175205 headers = list (trace_df )
176206 ops = copy .deepcopy (headers )
177207
@@ -196,6 +226,33 @@ def is_vocab_embedding_op(op_name: str):
196226 elementwise_ops = list (filter (lambda x : is_elementwise_op (x ), ops ))
197227 ops = list (filter (lambda x : x not in elementwise_ops , ops ))
198228
229+ nccl_all_reduce_ops = list (filter (lambda x : is_nccl_all_reduce (x ), ops ))
230+ ops = list (filter (lambda x : x not in nccl_all_reduce_ops , ops ))
231+
232+ nccl_gather_ops = list (filter (lambda x : is_nccl_gather (x ), ops ))
233+ ops = list (filter (lambda x : x not in nccl_gather_ops , ops ))
234+
235+ nccl_broadcast_ops = list (filter (lambda x : is_nccl_broadcast (x ), ops ))
236+ ops = list (filter (lambda x : x not in nccl_broadcast_ops , ops ))
237+
238+ nccl_other_ops = list (filter (lambda x : is_nccl_op (x ), ops ))
239+ ops = list (filter (lambda x : x not in nccl_other_ops , ops ))
240+
241+ cross_device_reduce_1stage_ops = list (
242+ filter (lambda x : is_cross_device_reduce_1stage (x ), ops ))
243+ ops = list (filter (lambda x : x not in cross_device_reduce_1stage_ops , ops ))
244+
245+ cross_device_reduce_2stage_ops = list (
246+ filter (lambda x : is_cross_device_reduce_2stage (x ), ops ))
247+ ops = list (filter (lambda x : x not in cross_device_reduce_2stage_ops , ops ))
248+
249+ custom_ar_all_reduce_unreg_ops = list (
250+ filter (lambda x : is_custom_ar_all_reduce_unreg (x ), ops ))
251+ ops = list (filter (lambda x : x not in custom_ar_all_reduce_unreg_ops , ops ))
252+
253+ reduce_kernel_ops = list (filter (lambda x : is_reduce_kernel (x ), ops ))
254+ ops = list (filter (lambda x : x not in reduce_kernel_ops , ops ))
255+
199256 if len (attention_ops ):
200257 trace_df ['attention' ] = trace_df [attention_ops ].agg ("sum" , axis = 1 )
201258 if len (quant_ops ):
@@ -213,10 +270,40 @@ def is_vocab_embedding_op(op_name: str):
213270 trace_df ['elementwise_ops' ] = trace_df [elementwise_ops ].agg ("sum" ,
214271 axis = 1 )
215272
216- trace_df .drop (attention_ops + quant_ops + gemm_ops + rms_norm_ops +
217- vocab_embed_ops + mem_ops + elementwise_ops ,
218- axis = 1 ,
219- inplace = True )
273+ if len (nccl_all_reduce_ops ):
274+ trace_df ['nccl_all_reduce_ops' ] = trace_df [nccl_all_reduce_ops ].agg (
275+ "sum" , axis = 1 )
276+ if len (nccl_gather_ops ):
277+ trace_df ['nccl_gather_ops' ] = trace_df [nccl_gather_ops ].agg ("sum" ,
278+ axis = 1 )
279+ if len (nccl_broadcast_ops ):
280+ trace_df ['nccl_broadcast_ops' ] = trace_df [nccl_broadcast_ops ].agg (
281+ "sum" , axis = 1 )
282+ if len (nccl_other_ops ):
283+ trace_df ['nccl_other_ops' ] = trace_df [nccl_other_ops ].agg ("sum" ,
284+ axis = 1 )
285+
286+ if len (cross_device_reduce_1stage_ops ):
287+ trace_df ['cross_device_reduce_1stage_ops' ] = trace_df [
288+ cross_device_reduce_1stage_ops ].agg ("sum" , axis = 1 )
289+ if len (cross_device_reduce_2stage_ops ):
290+ trace_df ['cross_device_reduce_2stage_ops' ] = trace_df [
291+ cross_device_reduce_2stage_ops ].agg ("sum" , axis = 1 )
292+ if len (custom_ar_all_reduce_unreg_ops ):
293+ trace_df ['custom_ar_all_reduce_unreg_ops' ] = trace_df [
294+ custom_ar_all_reduce_unreg_ops ].agg ("sum" , axis = 1 )
295+ if len (reduce_kernel_ops ):
296+ trace_df ['reduce_kernel_ops' ] = trace_df [reduce_kernel_ops ].agg ("sum" ,
297+ axis = 1 )
298+
299+ trace_df .drop (
300+ attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
301+ mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
302+ nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
303+ cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
304+ reduce_kernel_ops ,
305+ axis = 1 ,
306+ inplace = True )
220307 return trace_df
221308
222309
0 commit comments