|
10 | 10 |
|
11 | 11 | # Architecture -> (module, class). |
12 | 12 | _MODELS = { |
13 | | - "AquilaModel": ("llama", "LlamaForCausalLM"), |
14 | | - "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 |
| 13 | + "AquilaModel": ("aquila", "AquilaForCausalLM"), |
| 14 | + "AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2 |
15 | 15 | "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b |
16 | 16 | "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b |
17 | 17 | "BloomForCausalLM": ("bloom", "BloomForCausalLM"), |
|
24 | 24 | "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), |
25 | 25 | "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), |
26 | 26 | "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), |
27 | | - "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), |
| 27 | + "InternLMForCausalLM": ("internlm", "InternLMForCausalLM"), |
28 | 28 | "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), |
29 | 29 | "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), |
30 | 30 | # For decapoda-research/llama-* |
31 | 31 | "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), |
32 | | - "MistralForCausalLM": ("llama", "LlamaForCausalLM"), |
| 32 | + "MistralForCausalLM": ("mistral", "MistralForCausalLM"), |
33 | 33 | "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), |
34 | 34 | "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), |
35 | 35 | # transformers's mpt class has lower case |
|
41 | 41 | "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), |
42 | 42 | "RWForCausalLM": ("falcon", "FalconForCausalLM"), |
43 | 43 | "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), |
| 44 | + "YiForCausalLM": ("yi", "YiForCausalLM") |
44 | 45 | } |
45 | 46 |
|
46 | 47 | # Models not supported by ROCm. |
|
0 commit comments