|
217 | 217 | gaudi_mistral_rmsnorm_forward, |
218 | 218 | gaudi_mixtral_block_dynamic_moe_forward, |
219 | 219 | gaudi_mixtral_block_sparse_moe_forward, |
| 220 | + gaudi_mixtral_block_moe_forward, |
220 | 221 | gaudi_mixtral_rmsnorm_forward, |
221 | 222 | gaudi_opt_attention_forward, |
222 | 223 | gaudi_opt_decoder_forward, |
@@ -557,15 +558,15 @@ def adapt_transformers_to_gaudi(): |
557 | 558 | transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention |
558 | 559 | transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM |
559 | 560 | transformers.models.mixtral.modeling_mixtral.MixtralModel = GaudiMixtralModel |
560 | | - # We need this workaround until moe op in hpu is supporting fp8 |
561 | | - if os.environ.get("QUANT_CONFIG"): |
562 | | - transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = ( |
563 | | - gaudi_mixtral_block_sparse_moe_forward |
564 | | - ) |
565 | | - else: |
566 | | - transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = ( |
567 | | - gaudi_mixtral_block_dynamic_moe_forward |
568 | | - ) |
| 561 | + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.sparse_moe_forward = ( |
| 562 | + gaudi_mixtral_block_sparse_moe_forward |
| 563 | + ) |
| 564 | + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.dynamic_moe_forward = ( |
| 565 | + gaudi_mixtral_block_dynamic_moe_forward |
| 566 | + ) |
| 567 | + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = ( |
| 568 | + gaudi_mixtral_block_moe_forward |
| 569 | + ) |
569 | 570 | transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer |
570 | 571 | transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward |
571 | 572 | transformers.models.mixtral.configuration_mixtral.MixtralConfig = MixtralConfig |
|
0 commit comments