Skip to content

Commit b75cbba

Browse files
committed
rebase + simplify
Signed-off-by: Bill Nell <[email protected]>
1 parent 41ab065 commit b75cbba

File tree

1 file changed

+41
-92
lines changed

1 file changed

+41
-92
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 41 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -59,45 +59,30 @@ def residual_slice_shape_fake(residual: torch.Tensor, rank: int) -> int:
5959
return slices[rank].shape[0]
6060

6161

62-
def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool) -> Callable:
63-
64-
def match_gemm_rs_ag_gemm(
62+
def match_gemm_rs_ag_gemm(
6563
residual: torch.Tensor,
6664
gemm_1_weights: torch.Tensor,
6765
gemm_1_activations: torch.Tensor,
6866
rms_norm_weights: torch.Tensor,
6967
gemm_2_weights: torch.Tensor,
70-
) -> Tuple[torch.Tensor, torch.Tensor]:
71-
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
72-
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm)
73-
74-
# It would be nice to do this instead of having two separate patterns.
75-
all_reduce = tensor_model_parallel_all_reduce(mm_1)
76-
#if custom_ar:
77-
# all_reduce = torch.ops.vllm.outplace_all_reduce.default(
78-
# mm_1, tp_group_name)
79-
#else:
80-
# all_reduce = torch.ops.higher_order.auto_functionalized(
81-
# torch.ops.vllm.inplace_all_reduce.default,
82-
# tensor=mm_1,
83-
# group_name=tp_group_name)
84-
# all_reduce = all_reduce[1]
85-
86-
norm_res = torch.ops.higher_order.auto_functionalized(
87-
torch.ops._C.fused_add_rms_norm.default,
88-
input=all_reduce,
89-
residual=residual,
90-
weight=rms_norm_weights,
91-
epsilon=1e-05)
92-
normalized = norm_res[1]
93-
new_residual = norm_res[2]
94-
95-
gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0])
96-
mm_2 = torch.ops.aten.mm.default(normalized, gemm_2_w_perm)
97-
98-
return mm_2, new_residual
99-
100-
return match_gemm_rs_ag_gemm
68+
) -> Tuple[torch.Tensor, torch.Tensor]:
69+
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
70+
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm)
71+
all_reduce = tensor_model_parallel_all_reduce(mm_1)
72+
73+
norm_res = torch.ops.higher_order.auto_functionalized(
74+
torch.ops._C.fused_add_rms_norm.default,
75+
input=all_reduce,
76+
residual=residual,
77+
weight=rms_norm_weights,
78+
epsilon=1e-05)
79+
normalized = norm_res[1]
80+
new_residual = norm_res[2]
81+
82+
gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0])
83+
mm_2 = torch.ops.aten.mm.default(normalized, gemm_2_w_perm)
84+
85+
return mm_2, new_residual
10186

10287

10388
def get_gemm_rs_ag_gemm(use_flux: bool, max_m: int, gemm_1_type: torch.dtype,
@@ -253,40 +238,26 @@ def gemm_rs_ag_gemm_fake(
253238
return getattr(torch.ops.vllm, name).default
254239

255240

256-
def get_match_final(tp_group_name: str, use_custom_ar: bool) -> Callable:
257-
258-
def match_final(
241+
def match_final(
259242
my_residual: torch.Tensor,
260243
gemm_1_weights: torch.Tensor,
261244
gemm_1_activations: torch.Tensor,
262245
rms_norm_weights: torch.Tensor,
263-
) -> torch.Tensor:
264-
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
265-
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm)
266-
267-
# TODO: it would be nice to be able to use the official api directly.
268-
all_reduce = tensor_model_parallel_all_reduce(mm_1)
269-
#if use_custom_ar:
270-
# all_reduce = torch.ops.vllm.outplace_all_reduce.default(
271-
# mm_1, tp_group_name)
272-
#else:
273-
# all_reduce = torch.ops.higher_order.auto_functionalized(
274-
# torch.ops.vllm.inplace_all_reduce.default,
275-
# tensor=mm_1,
276-
# group_name=tp_group_name)
277-
# all_reduce = all_reduce[1]
278-
279-
norm_res = torch.ops.higher_order.auto_functionalized(
280-
torch.ops._C.fused_add_rms_norm.default,
281-
input=all_reduce,
282-
residual=my_residual,
283-
weight=rms_norm_weights,
284-
epsilon=1e-05)
285-
normalized = norm_res[1]
286-
287-
return normalized
288-
289-
return match_final
246+
) -> torch.Tensor:
247+
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
248+
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm)
249+
250+
all_reduce = tensor_model_parallel_all_reduce(mm_1)
251+
252+
norm_res = torch.ops.higher_order.auto_functionalized(
253+
torch.ops._C.fused_add_rms_norm.default,
254+
input=all_reduce,
255+
residual=my_residual,
256+
weight=rms_norm_weights,
257+
epsilon=1e-05)
258+
normalized = norm_res[1]
259+
260+
return normalized
290261

291262

292263
# Register this as a custom op since all reduce cannot be torch.compiled yet.
@@ -362,39 +333,17 @@ def __init__(self, config: CompilationConfig):
362333
inputs = [resid, x, w, resid_w, x2]
363334
final_inputs = [x, w, resid, resid_w]
364335

365-
# register multiple patterns for all group names.
366-
world_size = get_tensor_model_parallel_world_size()
367-
group_names = [f"tp:{rank}" for rank in range(world_size)]
368-
369-
m = get_match_gemm_rs_ag_gemm(group_names[0], False)
370336
register_replacement(
371-
m,
372-
m,
337+
match_gemm_rs_ag_gemm,
338+
match_gemm_rs_ag_gemm,
373339
inputs,
374340
fwd_only, [self.gemm_rs_ag_gemm_pattern],
375341
extra_check=lambda m: self.record_match(m))
376342

377-
for group_name in group_names:
378-
for m in [
379-
#get_match_gemm_rs_ag_gemm(group_name, False),
380-
#get_match_gemm_rs_ag_gemm(group_name, True)
381-
]:
382-
register_replacement(
383-
m,
384-
m,
385-
inputs,
386-
fwd_only, [self.gemm_rs_ag_gemm_pattern],
387-
extra_check=lambda m: self.record_match(m))
388-
389-
for m in [
390-
get_match_final(group_name, False),
391-
#get_match_final(group_name, True)
392-
]:
393-
torch._inductor.pattern_matcher._seen_patterns.clear()
394-
register_replacement(m,
395-
torch.ops.vllm.gemm_ag_final,
396-
final_inputs, fwd_only,
397-
[self.final_pattern])
343+
register_replacement(match_final
344+
torch.ops.vllm.gemm_ag_final,
345+
final_inputs, fwd_only,
346+
[self.final_pattern])
398347

399348
def record_match(self, match: Match) -> bool:
400349
# Hijack the extra_check to record the match and

0 commit comments

Comments
 (0)