Skip to content

Commit 4c4bf74

Browse files
mgoinshreyankg
authored andcommitted
[Model] Support Grok1 (vllm-project#13795)
Signed-off-by: mgoin <[email protected]>
1 parent 335729c commit 4c4bf74

File tree

11 files changed

+634
-17
lines changed

11 files changed

+634
-17
lines changed

docs/source/models/supported_models.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,11 @@ See [this page](#generative-models) for more information on how to use generativ
286286
* `parasail-ai/GritLM-7B-vllm`.
287287
* ✅︎
288288
* ✅︎
289+
- * `Grok1ModelForCausalLM`
290+
* Grok1
291+
* `hpcai-tech/grok-1`.
292+
* ✅︎
293+
* ✅︎
289294
- * `InternLMForCausalLM`
290295
* InternLM
291296
* `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def check_available_online(
130130
"GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"),
131131
"GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"),
132132
"GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"),
133+
"Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1",
134+
trust_remote_code=True),
133135
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
134136
trust_remote_code=True),
135137
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
10401040
w2: torch.Tensor,
10411041
topk_weights: torch.Tensor,
10421042
topk_ids: torch.Tensor,
1043+
activation: str = "silu",
10431044
use_fp8_w8a8: bool = False,
10441045
use_int8_w8a16: bool = False,
10451046
use_int4_w4a16: bool = False,
@@ -1053,9 +1054,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
10531054
a2_scale: Optional[torch.Tensor] = None,
10541055
block_shape: Optional[List[int]] = None) -> None:
10551056
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
1056-
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
1057-
global_num_experts, expert_map, w1_scale, w2_scale,
1058-
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
1057+
activation, use_fp8_w8a8, use_int8_w8a16,
1058+
use_int4_w4a16, global_num_experts, expert_map,
1059+
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
1060+
block_shape)
10591061

10601062

10611063
def inplace_fused_experts_fake(
@@ -1064,6 +1066,7 @@ def inplace_fused_experts_fake(
10641066
w2: torch.Tensor,
10651067
topk_weights: torch.Tensor,
10661068
topk_ids: torch.Tensor,
1069+
activation: str = "silu",
10671070
use_fp8_w8a8: bool = False,
10681071
use_int8_w8a16: bool = False,
10691072
use_int4_w4a16: bool = False,
@@ -1093,6 +1096,7 @@ def outplace_fused_experts(
10931096
w2: torch.Tensor,
10941097
topk_weights: torch.Tensor,
10951098
topk_ids: torch.Tensor,
1099+
activation: str = "silu",
10961100
use_fp8_w8a8: bool = False,
10971101
use_int8_w8a16: bool = False,
10981102
use_int4_w4a16: bool = False,
@@ -1106,7 +1110,7 @@ def outplace_fused_experts(
11061110
a2_scale: Optional[torch.Tensor] = None,
11071111
block_shape: Optional[List[int]] = None) -> torch.Tensor:
11081112
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
1109-
False, use_fp8_w8a8, use_int8_w8a16,
1113+
False, activation, use_fp8_w8a8, use_int8_w8a16,
11101114
use_int4_w4a16, global_num_experts, expert_map,
11111115
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
11121116
a2_scale, block_shape)
@@ -1118,6 +1122,7 @@ def outplace_fused_experts_fake(
11181122
w2: torch.Tensor,
11191123
topk_weights: torch.Tensor,
11201124
topk_ids: torch.Tensor,
1125+
activation: str = "silu",
11211126
use_fp8_w8a8: bool = False,
11221127
use_int8_w8a16: bool = False,
11231128
use_int4_w4a16: bool = False,
@@ -1147,6 +1152,7 @@ def fused_experts(hidden_states: torch.Tensor,
11471152
topk_weights: torch.Tensor,
11481153
topk_ids: torch.Tensor,
11491154
inplace: bool = False,
1155+
activation: str = "silu",
11501156
use_fp8_w8a8: bool = False,
11511157
use_int8_w8a16: bool = False,
11521158
use_int4_w4a16: bool = False,
@@ -1162,15 +1168,17 @@ def fused_experts(hidden_states: torch.Tensor,
11621168

11631169
if inplace:
11641170
torch.ops.vllm.inplace_fused_experts(
1165-
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
1166-
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map,
1167-
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
1171+
hidden_states, w1, w2, topk_weights, topk_ids, activation,
1172+
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
1173+
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
1174+
block_shape)
11681175
return hidden_states
11691176
else:
11701177
return torch.ops.vllm.outplace_fused_experts(
1171-
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
1172-
use_int8_w8a16, use_int4_w4a16, global_num_experts, expert_map,
1173-
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
1178+
hidden_states, w1, w2, topk_weights, topk_ids, activation,
1179+
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
1180+
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
1181+
block_shape)
11741182

11751183

11761184
def fused_experts_impl(hidden_states: torch.Tensor,
@@ -1179,6 +1187,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
11791187
topk_weights: torch.Tensor,
11801188
topk_ids: torch.Tensor,
11811189
inplace: bool = False,
1190+
activation: str = "silu",
11821191
use_fp8_w8a8: bool = False,
11831192
use_int8_w8a16: bool = False,
11841193
use_int4_w4a16: bool = False,
@@ -1303,8 +1312,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
13031312
use_int4_w4a16=use_int4_w4a16,
13041313
block_shape=block_shape)
13051314

1306-
torch.ops._C.silu_and_mul(intermediate_cache2,
1307-
intermediate_cache1.view(-1, N))
1315+
if activation == "silu":
1316+
torch.ops._C.silu_and_mul(intermediate_cache2,
1317+
intermediate_cache1.view(-1, N))
1318+
elif activation == "gelu":
1319+
torch.ops._C.gelu_and_mul(intermediate_cache2,
1320+
intermediate_cache1.view(-1, N))
1321+
else:
1322+
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
13081323

13091324
invoke_fused_moe_kernel(intermediate_cache2,
13101325
w2,
@@ -1339,6 +1354,7 @@ def fused_moe(
13391354
topk: int,
13401355
renormalize: bool,
13411356
inplace: bool = False,
1357+
activation: str = "silu",
13421358
use_grouped_topk: bool = False,
13431359
num_expert_group: Optional[int] = None,
13441360
topk_group: Optional[int] = None,
@@ -1370,6 +1386,8 @@ def fused_moe(
13701386
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
13711387
- inplace (bool): If True, perform the operation in-place.
13721388
Defaults to False.
1389+
- activation (str): The activation function to apply after the first
1390+
MoE layer.
13731391
- num_expert_group: Optional[int]: additional parameter for grouped_topk
13741392
- topk_group: Optional[int]: additional parameter for grouped_topk
13751393
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
@@ -1420,6 +1438,7 @@ def fused_moe(
14201438
topk_weights,
14211439
topk_ids,
14221440
inplace=inplace,
1441+
activation=activation,
14231442
use_fp8_w8a8=use_fp8_w8a8,
14241443
use_int8_w8a16=use_int8_w8a16,
14251444
use_int4_w4a16=use_int4_w4a16,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def apply(
120120
expert_map: Optional[torch.Tensor] = None,
121121
custom_routing_function: Optional[Callable] = None,
122122
scoring_func: str = "softmax",
123-
e_score_correction_bias: Optional[torch.Tensor] = None
123+
e_score_correction_bias: Optional[torch.Tensor] = None,
124+
activation: str = "silu",
124125
) -> torch.Tensor:
125126
return self.forward(x=x,
126127
layer=layer,
@@ -134,7 +135,8 @@ def apply(
134135
expert_map=expert_map,
135136
custom_routing_function=custom_routing_function,
136137
scoring_func=scoring_func,
137-
e_score_correction_bias=e_score_correction_bias)
138+
e_score_correction_bias=e_score_correction_bias,
139+
activation=activation)
138140

139141
def forward_cuda(
140142
self,
@@ -150,7 +152,8 @@ def forward_cuda(
150152
expert_map: Optional[torch.Tensor] = None,
151153
custom_routing_function: Optional[Callable] = None,
152154
scoring_func: str = "softmax",
153-
e_score_correction_bias: Optional[torch.Tensor] = None
155+
e_score_correction_bias: Optional[torch.Tensor] = None,
156+
activation: str = "silu",
154157
) -> torch.Tensor:
155158
topk_weights, topk_ids = FusedMoE.select_experts(
156159
hidden_states=x,
@@ -170,6 +173,7 @@ def forward_cuda(
170173
topk_weights=topk_weights,
171174
topk_ids=topk_ids,
172175
inplace=True,
176+
activation=activation,
173177
global_num_experts=global_num_experts,
174178
expert_map=expert_map)
175179

@@ -186,9 +190,11 @@ def forward_cpu(
186190
global_num_experts: int = -1,
187191
expert_map: Optional[torch.Tensor] = None,
188192
custom_routing_function: Optional[Callable] = None,
193+
activation: str = "silu",
189194
**kwargs,
190195
):
191196
assert custom_routing_function is None
197+
assert activation == "silu", f"{activation} is not supported."
192198
return layer.ipex_fusion(
193199
x,
194200
use_grouped_topk,
@@ -213,7 +219,8 @@ def forward_tpu(
213219
expert_map: Optional[torch.Tensor] = None,
214220
custom_routing_function: Optional[Callable] = None,
215221
scoring_func: str = "softmax",
216-
e_score_correction_bias: Optional[torch.Tensor] = None
222+
e_score_correction_bias: Optional[torch.Tensor] = None,
223+
activation: str = "silu",
217224
) -> torch.Tensor:
218225
assert not use_grouped_topk
219226
assert num_expert_group is None
@@ -225,6 +232,7 @@ def forward_tpu(
225232
if e_score_correction_bias is not None:
226233
raise NotImplementedError(
227234
"Expert score correction bias is not supported for TPU.")
235+
assert activation == "silu", f"{activation} is not supported for TPU."
228236
return fused_moe_pallas(hidden_states=x,
229237
w1=layer.w13_weight,
230238
w2=layer.w2_weight,
@@ -277,6 +285,7 @@ def __init__(
277285
custom_routing_function: Optional[Callable] = None,
278286
scoring_func: str = "softmax",
279287
e_score_correction_bias: Optional[torch.Tensor] = None,
288+
activation: str = "silu",
280289
):
281290
super().__init__()
282291

@@ -305,6 +314,7 @@ def __init__(
305314
self.custom_routing_function = custom_routing_function
306315
self.scoring_func = scoring_func
307316
self.e_score_correction_bias = e_score_correction_bias
317+
self.activation = activation
308318
self.expert_map = None
309319

310320
if self.ep_size > 1:
@@ -653,7 +663,9 @@ def forward(self, hidden_states: torch.Tensor,
653663
num_expert_group=self.num_expert_group,
654664
custom_routing_function=self.custom_routing_function,
655665
scoring_func=self.scoring_func,
656-
e_score_correction_bias=self.e_score_correction_bias)
666+
e_score_correction_bias=self.e_score_correction_bias,
667+
activation=self.activation,
668+
)
657669

658670
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
659671
# Default set to False. (May have to add shared expert outputs.)

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,9 @@ def apply(
469469
custom_routing_function: Optional[Callable] = None,
470470
scoring_func: str = "softmax",
471471
e_score_correction_bias: Optional[torch.Tensor] = None,
472+
activation: str = "silu",
472473
) -> torch.Tensor:
474+
assert activation == "silu", "Only SiLU activation is supported."
473475
if expert_map is not None:
474476
raise NotImplementedError(
475477
"Expert Parallelism is not supported for "

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def apply(
219219
custom_routing_function: Optional[Callable] = None,
220220
scoring_func: str = "softmax",
221221
e_score_correction_bias: Optional[torch.Tensor] = None,
222+
activation: str = "silu",
222223
) -> torch.Tensor:
223224
from vllm.model_executor.layers.fused_moe import fused_experts
224225

@@ -240,6 +241,7 @@ def apply(
240241
topk_weights=topk_weights,
241242
topk_ids=topk_ids,
242243
inplace=True,
244+
activation=activation,
243245
use_fp8_w8a8=True,
244246
global_num_experts=global_num_experts,
245247
expert_map=expert_map,
@@ -550,7 +552,9 @@ def apply(
550552
custom_routing_function: Optional[Callable] = None,
551553
scoring_func: str = "softmax",
552554
e_score_correction_bias: Optional[torch.Tensor] = None,
555+
activation: str = "silu",
553556
) -> torch.Tensor:
557+
assert activation == "silu", "Only SiLU activation is supported."
554558
if expert_map is not None:
555559
raise NotImplementedError(
556560
"Expert Parallelism is not supported for "

vllm/model_executor/layers/quantization/experts_int8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def apply(
113113
custom_routing_function: Optional[Callable] = None,
114114
scoring_func: str = "softmax",
115115
e_score_correction_bias: Optional[torch.Tensor] = None,
116+
activation: str = "silu",
116117
) -> torch.Tensor:
117118
from vllm.model_executor.layers.fused_moe import fused_experts
118119

@@ -134,6 +135,7 @@ def apply(
134135
topk_weights=topk_weights,
135136
topk_ids=topk_ids,
136137
inplace=True,
138+
activation=activation,
137139
use_int8_w8a16=True,
138140
global_num_experts=global_num_experts,
139141
expert_map=expert_map,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,7 @@ def apply(
675675
custom_routing_function: Optional[Callable] = None,
676676
scoring_func: str = "softmax",
677677
e_score_correction_bias: Optional[torch.Tensor] = None,
678+
activation: str = "silu",
678679
) -> torch.Tensor:
679680
from vllm.model_executor.layers.fused_moe import fused_experts
680681

@@ -698,6 +699,7 @@ def apply(
698699
topk_weights=topk_weights,
699700
topk_ids=topk_ids,
700701
inplace=True,
702+
activation=activation,
701703
use_fp8_w8a8=True,
702704
global_num_experts=global_num_experts,
703705
expert_map=expert_map,

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,10 @@ def apply(
590590
custom_routing_function: Optional[Callable] = None,
591591
scoring_func: str = "softmax",
592592
e_score_correction_bias: Optional[torch.Tensor] = None,
593+
activation: str = "silu",
593594
) -> torch.Tensor:
595+
assert activation == "silu", "Only SiLU activation is supported."
596+
594597
# The input must currently be float16
595598
orig_dtype = x.dtype
596599
x = x.half()

0 commit comments

Comments
 (0)