@@ -59,45 +59,30 @@ def residual_slice_shape_fake(residual: torch.Tensor, rank: int) -> int:
5959 return slices [rank ].shape [0 ]
6060
6161
62- def get_match_gemm_rs_ag_gemm (tp_group_name : str , custom_ar : bool ) -> Callable :
63-
64- def match_gemm_rs_ag_gemm (
62+ def match_gemm_rs_ag_gemm (
6563 residual : torch .Tensor ,
6664 gemm_1_weights : torch .Tensor ,
6765 gemm_1_activations : torch .Tensor ,
6866 rms_norm_weights : torch .Tensor ,
6967 gemm_2_weights : torch .Tensor ,
70- ) -> Tuple [torch .Tensor , torch .Tensor ]:
71- gemm_1_w_perm = torch .ops .aten .permute .default (gemm_1_weights , [1 , 0 ])
72- mm_1 = torch .ops .aten .mm .default (gemm_1_activations , gemm_1_w_perm )
73-
74- # It would be nice to do this instead of having two separate patterns.
75- all_reduce = tensor_model_parallel_all_reduce (mm_1 )
76- #if custom_ar:
77- # all_reduce = torch.ops.vllm.outplace_all_reduce.default(
78- # mm_1, tp_group_name)
79- #else:
80- # all_reduce = torch.ops.higher_order.auto_functionalized(
81- # torch.ops.vllm.inplace_all_reduce.default,
82- # tensor=mm_1,
83- # group_name=tp_group_name)
84- # all_reduce = all_reduce[1]
85-
86- norm_res = torch .ops .higher_order .auto_functionalized (
87- torch .ops ._C .fused_add_rms_norm .default ,
88- input = all_reduce ,
89- residual = residual ,
90- weight = rms_norm_weights ,
91- epsilon = 1e-05 )
92- normalized = norm_res [1 ]
93- new_residual = norm_res [2 ]
94-
95- gemm_2_w_perm = torch .ops .aten .permute .default (gemm_2_weights , [1 , 0 ])
96- mm_2 = torch .ops .aten .mm .default (normalized , gemm_2_w_perm )
97-
98- return mm_2 , new_residual
99-
100- return match_gemm_rs_ag_gemm
68+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
69+ gemm_1_w_perm = torch .ops .aten .permute .default (gemm_1_weights , [1 , 0 ])
70+ mm_1 = torch .ops .aten .mm .default (gemm_1_activations , gemm_1_w_perm )
71+ all_reduce = tensor_model_parallel_all_reduce (mm_1 )
72+
73+ norm_res = torch .ops .higher_order .auto_functionalized (
74+ torch .ops ._C .fused_add_rms_norm .default ,
75+ input = all_reduce ,
76+ residual = residual ,
77+ weight = rms_norm_weights ,
78+ epsilon = 1e-05 )
79+ normalized = norm_res [1 ]
80+ new_residual = norm_res [2 ]
81+
82+ gemm_2_w_perm = torch .ops .aten .permute .default (gemm_2_weights , [1 , 0 ])
83+ mm_2 = torch .ops .aten .mm .default (normalized , gemm_2_w_perm )
84+
85+ return mm_2 , new_residual
10186
10287
10388def get_gemm_rs_ag_gemm (use_flux : bool , max_m : int , gemm_1_type : torch .dtype ,
@@ -253,40 +238,26 @@ def gemm_rs_ag_gemm_fake(
253238 return getattr (torch .ops .vllm , name ).default
254239
255240
256- def get_match_final (tp_group_name : str , use_custom_ar : bool ) -> Callable :
257-
258- def match_final (
241+ def match_final (
259242 my_residual : torch .Tensor ,
260243 gemm_1_weights : torch .Tensor ,
261244 gemm_1_activations : torch .Tensor ,
262245 rms_norm_weights : torch .Tensor ,
263- ) -> torch .Tensor :
264- gemm_1_w_perm = torch .ops .aten .permute .default (gemm_1_weights , [1 , 0 ])
265- mm_1 = torch .ops .aten .mm .default (gemm_1_activations , gemm_1_w_perm )
266-
267- # TODO: it would be nice to be able to use the official api directly.
268- all_reduce = tensor_model_parallel_all_reduce (mm_1 )
269- #if use_custom_ar:
270- # all_reduce = torch.ops.vllm.outplace_all_reduce.default(
271- # mm_1, tp_group_name)
272- #else:
273- # all_reduce = torch.ops.higher_order.auto_functionalized(
274- # torch.ops.vllm.inplace_all_reduce.default,
275- # tensor=mm_1,
276- # group_name=tp_group_name)
277- # all_reduce = all_reduce[1]
278-
279- norm_res = torch .ops .higher_order .auto_functionalized (
280- torch .ops ._C .fused_add_rms_norm .default ,
281- input = all_reduce ,
282- residual = my_residual ,
283- weight = rms_norm_weights ,
284- epsilon = 1e-05 )
285- normalized = norm_res [1 ]
286-
287- return normalized
288-
289- return match_final
246+ ) -> torch .Tensor :
247+ gemm_1_w_perm = torch .ops .aten .permute .default (gemm_1_weights , [1 , 0 ])
248+ mm_1 = torch .ops .aten .mm .default (gemm_1_activations , gemm_1_w_perm )
249+
250+ all_reduce = tensor_model_parallel_all_reduce (mm_1 )
251+
252+ norm_res = torch .ops .higher_order .auto_functionalized (
253+ torch .ops ._C .fused_add_rms_norm .default ,
254+ input = all_reduce ,
255+ residual = my_residual ,
256+ weight = rms_norm_weights ,
257+ epsilon = 1e-05 )
258+ normalized = norm_res [1 ]
259+
260+ return normalized
290261
291262
292263# Register this as a custom op since all reduce cannot be torch.compiled yet.
@@ -362,39 +333,17 @@ def __init__(self, config: CompilationConfig):
362333 inputs = [resid , x , w , resid_w , x2 ]
363334 final_inputs = [x , w , resid , resid_w ]
364335
365- # register multiple patterns for all group names.
366- world_size = get_tensor_model_parallel_world_size ()
367- group_names = [f"tp:{ rank } " for rank in range (world_size )]
368-
369- m = get_match_gemm_rs_ag_gemm (group_names [0 ], False )
370336 register_replacement (
371- m ,
372- m ,
337+ match_gemm_rs_ag_gemm ,
338+ match_gemm_rs_ag_gemm ,
373339 inputs ,
374340 fwd_only , [self .gemm_rs_ag_gemm_pattern ],
375341 extra_check = lambda m : self .record_match (m ))
376342
377- for group_name in group_names :
378- for m in [
379- #get_match_gemm_rs_ag_gemm(group_name, False),
380- #get_match_gemm_rs_ag_gemm(group_name, True)
381- ]:
382- register_replacement (
383- m ,
384- m ,
385- inputs ,
386- fwd_only , [self .gemm_rs_ag_gemm_pattern ],
387- extra_check = lambda m : self .record_match (m ))
388-
389- for m in [
390- get_match_final (group_name , False ),
391- #get_match_final(group_name, True)
392- ]:
393- torch ._inductor .pattern_matcher ._seen_patterns .clear ()
394- register_replacement (m ,
395- torch .ops .vllm .gemm_ag_final ,
396- final_inputs , fwd_only ,
397- [self .final_pattern ])
343+ register_replacement (match_final
344+ torch .ops .vllm .gemm_ag_final ,
345+ final_inputs , fwd_only ,
346+ [self .final_pattern ])
398347
399348 def record_match (self , match : Match ) -> bool :
400349 # Hijack the extra_check to record the match and
0 commit comments