Skip to content

Commit 0cd1aa7

Browse files
lkk12014402zzhang37
authored andcommitted
Refactor mixtral moe block. (huggingface#1635)
1 parent ed26bad commit 0cd1aa7

5 files changed

Lines changed: 25 additions & 11 deletions

File tree

examples/trl/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ $ pip install -U -r requirements.txt
4343
--use_flash_attention
4444
```
4545
46-
2. Supervised fine-tuning of the mistralai/Mixtral-8x7B-v0.1 on 4 cards:
46+
2. Supervised fine-tuning of the mistralai/Mixtral-8x7B-Instruct-v0.1 on 4 cards:
4747
4848
```
4949
DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 python ../gaudi_spawn.py --world_size 4 --use_deepspeed sft.py \
50-
--model_name_or_path mistralai/Mixtral-8x7B-v0.1 \
50+
--model_name_or_path mistralai/Mixtral-8x7B-Instruct-v0.1 \
5151
--dataset_name "philschmid/dolly-15k-oai-style" \
5252
--subset 'data/' \
5353
--streaming False \

optimum/habana/transformers/modeling_utils.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@
217217
gaudi_mistral_rmsnorm_forward,
218218
gaudi_mixtral_block_dynamic_moe_forward,
219219
gaudi_mixtral_block_sparse_moe_forward,
220+
gaudi_mixtral_block_moe_forward,
220221
gaudi_mixtral_rmsnorm_forward,
221222
gaudi_opt_attention_forward,
222223
gaudi_opt_decoder_forward,
@@ -557,15 +558,15 @@ def adapt_transformers_to_gaudi():
557558
transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention
558559
transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM
559560
transformers.models.mixtral.modeling_mixtral.MixtralModel = GaudiMixtralModel
560-
# We need this workaround until moe op in hpu is supporting fp8
561-
if os.environ.get("QUANT_CONFIG"):
562-
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = (
563-
gaudi_mixtral_block_sparse_moe_forward
564-
)
565-
else:
566-
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = (
567-
gaudi_mixtral_block_dynamic_moe_forward
568-
)
561+
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.sparse_moe_forward = (
562+
gaudi_mixtral_block_sparse_moe_forward
563+
)
564+
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.dynamic_moe_forward = (
565+
gaudi_mixtral_block_dynamic_moe_forward
566+
)
567+
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = (
568+
gaudi_mixtral_block_moe_forward
569+
)
569570
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer
570571
transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward
571572
transformers.models.mixtral.configuration_mixtral.MixtralConfig = MixtralConfig

optimum/habana/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@
180180
MixtralConfig,
181181
gaudi_mixtral_block_dynamic_moe_forward,
182182
gaudi_mixtral_block_sparse_moe_forward,
183+
gaudi_mixtral_block_moe_forward,
183184
gaudi_mixtral_rmsnorm_forward,
184185
)
185186
from .mllama import (

optimum/habana/transformers/models/mixtral/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
GaudiMixtralModel,
77
gaudi_mixtral_block_dynamic_moe_forward,
88
gaudi_mixtral_block_sparse_moe_forward,
9+
gaudi_mixtral_block_moe_forward,
910
gaudi_mixtral_rmsnorm_forward,
1011
)

optimum/habana/transformers/models/mixtral/modeling_mixtral.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
"""PyTorch Mixtral model."""
2222

23+
import os
24+
2325
import contextlib
2426
import math
2527
from typing import List, Optional, Tuple, Union
@@ -357,6 +359,15 @@ def forward(
357359
return attn_output, attn_weights, past_key_value
358360

359361

362+
def gaudi_mixtral_block_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
363+
364+
# We need this workaround until moe op in hpu is supporting fp8
365+
if not self.training and not os.environ.get("QUANT_CONFIG"):
366+
return self.dynamic_moe_forward(hidden_states)
367+
368+
return self.sparse_moe_forward(hidden_states)
369+
370+
360371
def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
361372
"""
362373
Copied from MixtralSparseMoeBlock.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py

0 commit comments

Comments
 (0)