Skip to content

Commit ea023f0

Browse files
xhaihaoZjq9409Wei-Lin-Intel
committed
Add support for Baichuan2
Below is an example for baichuan-inc/Baichuan2-7B-Chat: python3 run_generation.py \ --model_name_or_path baichuan-inc/Baichuan2-7B-Chat \ --bf16 --trim_logits --batch_size 1 \ --max_input_tokens 1024 --max_new_tokens 512 \ --use_kv_cache --use_hpu_graphs --use_flash_attention \ --reuse_cache \ --no-ignore_eos Below is an example for baichuan-inc/Baichuan2-13B-Chat: python3 run_generation.py \ --model_name_or_path baichuan-inc/Baichuan2-13B-Chat \ --bf16 --trim_logits --batch_size 1 \ --max_input_tokens 1024 --max_new_tokens 512 \ --use_kv_cache --use_hpu_graphs --bucket_size 256 \ --bucket_internal --reuse_cache \ --no-ignore_eos Co-authored-by: Jianqian Zhou <jianqian.zhou@intel.com> Co-authored-by: Wei Lin <wei2.lin@intel.com> Signed-off-by: Haihao Xiang <haihao.xiang@intel.com>
1 parent d6914e9 commit ea023f0

12 files changed

Lines changed: 1969 additions & 10 deletions

File tree

examples/language-modeling/run_clm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,11 @@ def main():
472472

473473
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
474474
# on a small vocab and want a smaller embedding size, remove this test.
475-
embedding_size = model.get_input_embeddings().weight.shape[0]
476-
if len(tokenizer) > embedding_size:
477-
model.resize_token_embeddings(len(tokenizer))
475+
# We need to skip this test for baichuan pretrain
476+
if config.model_type not in ("baichuan"):
477+
embedding_size = model.get_input_embeddings().weight.shape[0]
478+
if len(tokenizer) > embedding_size:
479+
model.resize_token_embeddings(len(tokenizer))
478480

479481
# Preprocessing the datasets.
480482
# First we tokenize all the texts.

examples/text-generation/run_lm_eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,14 @@ def __init__(self, tokenizer, model, args, options):
111111
"gptj",
112112
"starcoder2",
113113
"gemma",
114+
"baichuan",
114115
]:
115116
self.model_inputs.update(
116117
{
117118
"reuse_cache": self.options.reuse_cache,
118119
}
119120
)
120-
if self.model.config.model_type in ["llama", "mistral", "qwen2", "falcon", "starcoder2", "gemma"]:
121+
if self.model.config.model_type in ["llama", "mistral", "qwen2", "falcon", "starcoder2", "gemma", "baichuan"]:
121122
if self.model.config.model_type != "falcon":
122123
self.model_inputs.update(
123124
{

optimum/habana/transformers/generation/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
"paligemma",
113113
"idefics2",
114114
"mllama",
115+
"baichuan",
115116
]
116117

117118

@@ -1081,8 +1082,9 @@ def generate(
10811082
"qwen2_moe",
10821083
"gemma",
10831084
"gemma2",
1085+
"baichuan",
10841086
]
1085-
), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2 and starcoder2 at the moment"
1087+
), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2 and baichuan at the moment"
10861088
if not generation_config.bucket_internal:
10871089
assert (
10881090
generation_config.bucket_size <= 0
@@ -1288,8 +1290,12 @@ def generate(
12881290
"gemma",
12891291
"gemma2",
12901292
"qwen2_moe",
1293+
"baichuan",
12911294
]:
1292-
if self.config.max_position_embeddings < calculated_max_length:
1295+
if (
1296+
hasattr(self.config, "max_position_embeddings")
1297+
and self.config.max_position_embeddings < calculated_max_length
1298+
):
12931299
unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length)
12941300

12951301
# 8. determine generation mode

optimum/habana/transformers/modeling_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
)
2929
from .models import (
3030
GAUDI_WHISPER_ATTENTION_CLASSES,
31+
BaichuanConfig,
32+
BaichuanForCausalLM,
33+
BaichuanTokenizer,
3134
DeciLMConfig,
3235
DeciLMForCausalLM,
3336
Gaudi2Idefics2ImageProcessor,
@@ -676,3 +679,8 @@ def adapt_transformers_to_gaudi():
676679
transformers.models.xglm.modeling_xglm.XGLMModel.forward = gaudi_xglm_model_forward
677680
transformers.models.xglm.modeling_xglm.XGLMAttention.forward = gaudi_xglm_attention_forward
678681
transformers.models.xglm.modeling_xglm.XGLMDecoderLayer.forward = gaudi_xglm_decoder_layer_forward
682+
683+
# Optimization for Baichuan2 on Gaudi
684+
transformers.AutoConfig.register("baichuan", BaichuanConfig)
685+
transformers.AutoTokenizer.register(BaichuanConfig, slow_tokenizer_class=BaichuanTokenizer)
686+
transformers.AutoModelForCausalLM.register(BaichuanConfig, BaichuanForCausalLM)

optimum/habana/transformers/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
from .albert import gaudi_albert_forward
2+
from .baichuan import (
3+
BaichuanConfig,
4+
BaichuanForCausalLM,
5+
BaichuanTokenizer,
6+
)
27
from .bart import (
38
gaudi_BartAttention_forward,
49
gaudi_BartDecoder_forward,
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .configuration_baichuan import BaichuanConfig
2+
from .modeling_baichuan import (
3+
BaichuanForCausalLM,
4+
)
5+
from .tokenization_baichuan import BaichuanTokenizer
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2023 Baichuan Inc. All Rights Reserved.
2+
3+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4+
#
5+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6+
# and OPT implementations in this library. It has been modified from its
7+
# original forms to accommodate minor architectural differences compared
8+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9+
#
10+
# Licensed under the Apache License, Version 2.0 (the "License");
11+
# you may not use this file except in compliance with the License.
12+
# You may obtain a copy of the License at
13+
#
14+
# http://www.apache.org/licenses/LICENSE-2.0
15+
#
16+
# Unless required by applicable law or agreed to in writing, software
17+
# distributed under the License is distributed on an "AS IS" BASIS,
18+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19+
# See the License for the specific language governing permissions and
20+
# limitations under the License.
21+
22+
"""
23+
Adapted from the following sources:
24+
https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/configuration_baichuan.py
25+
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/configuration_baichuan.py
26+
"""
27+
28+
import sys
29+
30+
from transformers.configuration_utils import PretrainedConfig
31+
32+
33+
class BaichuanConfig(PretrainedConfig):
34+
model_type = "baichuan"
35+
keys_to_ignore_at_inference = ["past_key_values"]
36+
37+
def __init__(
38+
self,
39+
vocab_size=125696,
40+
hidden_size=4096,
41+
intermediate_size=11008,
42+
num_hidden_layers=32,
43+
num_attention_heads=32,
44+
hidden_act="silu",
45+
max_position_embeddings=sys.maxsize,
46+
model_max_length=4096,
47+
initializer_range=0.02,
48+
rms_norm_eps=1e-6,
49+
use_cache=True,
50+
pad_token_id=0,
51+
bos_token_id=1,
52+
eos_token_id=2,
53+
tie_word_embeddings=False,
54+
gradient_checkpointing=False,
55+
z_loss_weight=0,
56+
**kwargs,
57+
):
58+
self.vocab_size = vocab_size
59+
# 13B config doesn't have max_position_embeddings
60+
if max_position_embeddings < sys.maxsize:
61+
self.max_position_embeddings = max_position_embeddings
62+
self.model_max_length = model_max_length
63+
self.hidden_size = hidden_size
64+
self.intermediate_size = intermediate_size
65+
self.num_hidden_layers = num_hidden_layers
66+
self.num_attention_heads = num_attention_heads
67+
self.hidden_act = hidden_act
68+
self.initializer_range = initializer_range
69+
self.rms_norm_eps = rms_norm_eps
70+
self.use_cache = use_cache
71+
self.z_loss_weight = z_loss_weight
72+
self.gradient_checkpointing = (gradient_checkpointing,)
73+
super().__init__(
74+
pad_token_id=pad_token_id,
75+
bos_token_id=bos_token_id,
76+
eos_token_id=eos_token_id,
77+
tie_word_embeddings=tie_word_embeddings,
78+
**kwargs,
79+
)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2023 Baichuan Inc. All Rights Reserved.
2+
3+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
4+
#
5+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6+
# and OPT implementations in this library. It has been modified from its
7+
# original forms to accommodate minor architectural differences compared
8+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9+
#
10+
# Licensed under the Apache License, Version 2.0 (the "License");
11+
# you may not use this file except in compliance with the License.
12+
# You may obtain a copy of the License at
13+
#
14+
# http://www.apache.org/licenses/LICENSE-2.0
15+
#
16+
# Unless required by applicable law or agreed to in writing, software
17+
# distributed under the License is distributed on an "AS IS" BASIS,
18+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19+
# See the License for the specific language governing permissions and
20+
# limitations under the License.
21+
22+
"""
23+
Adapted from the following sources:
24+
https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/main/generation_utils.py
25+
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/generation_utils.py
26+
"""
27+
28+
from queue import Queue
29+
from typing import List
30+
31+
import torch
32+
33+
34+
def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int = 0):
35+
def _parse_messages(messages, split_role="user"):
36+
system, rounds = "", []
37+
round = []
38+
for i, message in enumerate(messages):
39+
if message["role"] == "system":
40+
assert i == 0
41+
system = message["content"]
42+
continue
43+
if message["role"] == split_role and round:
44+
rounds.append(round)
45+
round = []
46+
round.append(message)
47+
if round:
48+
rounds.append(round)
49+
return system, rounds
50+
51+
max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
52+
max_input_tokens = model.config.model_max_length - max_new_tokens
53+
system, rounds = _parse_messages(messages, split_role="user")
54+
system_tokens = tokenizer.encode(system)
55+
max_history_tokens = max_input_tokens - len(system_tokens)
56+
57+
history_tokens = []
58+
for round in rounds[::-1]:
59+
round_tokens = []
60+
for message in round:
61+
if message["role"] == "user":
62+
round_tokens.append(model.generation_config.user_token_id)
63+
else:
64+
round_tokens.append(model.generation_config.assistant_token_id)
65+
round_tokens.extend(tokenizer.encode(message["content"]))
66+
if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
67+
history_tokens = round_tokens + history_tokens # concat left
68+
if len(history_tokens) < max_history_tokens:
69+
continue
70+
break
71+
72+
input_tokens = system_tokens + history_tokens
73+
if messages[-1]["role"] != "assistant":
74+
input_tokens.append(model.generation_config.assistant_token_id)
75+
input_tokens = input_tokens[-max_input_tokens:] # truncate left
76+
return torch.LongTensor([input_tokens]).to(model.device)
77+
78+
79+
class TextIterStreamer:
80+
def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
81+
self.tokenizer = tokenizer
82+
self.skip_prompt = skip_prompt
83+
self.skip_special_tokens = skip_special_tokens
84+
self.tokens = []
85+
self.text_queue = Queue()
86+
self.next_tokens_are_prompt = True
87+
88+
def put(self, value):
89+
if self.skip_prompt and self.next_tokens_are_prompt:
90+
self.next_tokens_are_prompt = False
91+
else:
92+
if len(value.shape) > 1:
93+
value = value[0]
94+
self.tokens.extend(value.tolist())
95+
self.text_queue.put(self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
96+
97+
def end(self):
98+
self.text_queue.put(None)
99+
100+
def __iter__(self):
101+
return self
102+
103+
def __next__(self):
104+
value = self.text_queue.get()
105+
if value is None:
106+
raise StopIteration()
107+
else:
108+
return value

0 commit comments

Comments
 (0)