Skip to content

Commit 02b4cac

Browse files
authored
Add BigDL Llama worker for batching on decoding (vllm-project#4)
* Init * refine
1 parent e36fc39 commit 02b4cac

File tree

8 files changed

+248
-96
lines changed

8 files changed

+248
-96
lines changed

tests/under_models/send_mock_request.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ async def step_async(self) -> List[RequestOutput]:
4343
blocks_to_swap_out={},
4444
blocks_to_copy={},
4545
)
46+
print(output)
4647

4748
# TODO: change this to real one
4849
return RequestOutput(request_id=request_id, prompt="", prompt_token_ids=[1, 3087, 8970, 338, 263], outputs=[], finished=False)
@@ -109,7 +110,7 @@ async def _run_workers_async(
109110
@pytest.mark.asyncio
110111
async def test_model_execution():
111112
# Let's build an engine_args
112-
engine_args = AsyncEngineArgs(model='/models/vicuna-7b/', tokenizer='/models/vicuna-7b/', tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='auto', seed=0, max_model_len=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, block_size=16, swap_space=16, gpu_memory_utilization=0.9, max_num_batched_tokens=None, max_num_seqs=256, disable_log_stats=False, revision=None, tokenizer_revision=None, quantization=None, engine_use_ray=False, disable_log_requests=True, max_log_len=None)
113+
engine_args = AsyncEngineArgs(model='/models/vicuna-7b/', tokenizer='/models/vicuna-7b/', tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='dummy', dtype='auto', seed=0, max_model_len=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, block_size=16, swap_space=16, gpu_memory_utilization=0.9, max_num_batched_tokens=None, max_num_seqs=256, disable_log_stats=False, revision=None, tokenizer_revision=None, quantization=None, engine_use_ray=False, disable_log_requests=True, max_log_len=None)
113114
# Start the engine
114115
engine = AsyncLLMEngine.from_engine_args(engine_args)
115116

vllm/core/scheduler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ def __init__(
6868

6969
# Instantiate the scheduling policy.
7070
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
71-
# Create the block space manager.
72-
self.block_manager = BlockSpaceManager(
73-
block_size=self.cache_config.block_size,
74-
num_gpu_blocks=self.cache_config.num_gpu_blocks,
75-
num_cpu_blocks=self.cache_config.num_cpu_blocks,
76-
sliding_window=self.cache_config.sliding_window)
71+
# # Create the block space manager.
72+
# self.block_manager = BlockSpaceManager(
73+
# block_size=self.cache_config.block_size,
74+
# num_gpu_blocks=self.cache_config.num_gpu_blocks,
75+
# num_cpu_blocks=self.cache_config.num_cpu_blocks,
76+
# sliding_window=self.cache_config.sliding_window)
7777

7878
# TODO(zhuohan): Use deque instead of list for better performance.
7979
# Sequence groups in the WAITING state.

vllm/engine/llm_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def __init__(
109109
else:
110110
self._init_workers(distributed_init_method)
111111

112-
# Profile the memory usage and initialize the cache.
113-
self._init_cache()
112+
# # Profile the memory usage and initialize the cache.
113+
# self._init_cache()
114114

115115
# Create the scheduler.
116116
self.scheduler = Scheduler(scheduler_config, cache_config)

vllm/model_executor/input_metadata.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,48 +25,48 @@ def __init__(
2525
seq_groups: List[Tuple[List[int], SamplingParams]],
2626
seq_data: Dict[int, SequenceData],
2727
prompt_lens: List[int],
28-
slot_mapping: torch.Tensor,
28+
# slot_mapping: torch.Tensor,
2929
context_lens: torch.Tensor,
3030
max_context_len: int,
31-
block_tables: torch.Tensor,
31+
# block_tables: torch.Tensor,
3232
sliding_window: Optional[int] = None,
3333
) -> None:
3434
self.seq_groups = seq_groups
3535
self.seq_data = seq_data
3636
self.prompt_lens = prompt_lens
37-
self.slot_mapping = slot_mapping
37+
# self.slot_mapping = slot_mapping
3838
self.context_lens = context_lens
3939
self.max_context_len = max_context_len
40-
self.block_tables = block_tables
40+
# self.block_tables = block_tables
4141

4242
self.to_cache = None
43-
if sliding_window is not None:
44-
# We need to keep the positions of sliding windows within
45-
# the key / value tables, this is helpful to know which
46-
# elements we need to cache and where
47-
to_cache, start_idx = [], 0
48-
for prompt_len in self.prompt_lens:
49-
to_cache.extend(
50-
range(
51-
start_idx + max(0, prompt_len - sliding_window),
52-
start_idx + prompt_len,
53-
))
54-
start_idx += prompt_len
55-
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
56-
self.to_cache = torch.tensor(to_cache,
57-
dtype=torch.int32,
58-
device=self.slot_mapping.device)
43+
# if sliding_window is not None:
44+
# # We need to keep the positions of sliding windows within
45+
# # the key / value tables, this is helpful to know which
46+
# # elements we need to cache and where
47+
# to_cache, start_idx = [], 0
48+
# for prompt_len in self.prompt_lens:
49+
# to_cache.extend(
50+
# range(
51+
# start_idx + max(0, prompt_len - sliding_window),
52+
# start_idx + prompt_len,
53+
# ))
54+
# start_idx += prompt_len
55+
# to_cache.extend(range(start_idx, slot_mapping.shape[0]))
56+
# self.to_cache = torch.tensor(to_cache,
57+
# dtype=torch.int32,
58+
# device=self.slot_mapping.device)
5959

6060
self.num_prompts = len(prompt_lens)
6161
self.num_prompt_tokens = sum(prompt_lens)
6262
self.num_generation_tokens = context_lens.shape[0]
63-
self.num_valid_tokens = slot_mapping.shape[0]
64-
if block_tables.numel() > 0:
65-
self.max_num_blocks_per_seq = block_tables.shape[1]
66-
else:
67-
self.max_num_blocks_per_seq = 0
68-
assert block_tables.shape[0] == self.num_generation_tokens
69-
assert context_lens.shape[0] == self.num_generation_tokens
63+
# self.num_valid_tokens = slot_mapping.shape[0]
64+
# if block_tables.numel() > 0:
65+
# self.max_num_blocks_per_seq = block_tables.shape[1]
66+
# else:
67+
# self.max_num_blocks_per_seq = 0
68+
# assert block_tables.shape[0] == self.num_generation_tokens
69+
# assert context_lens.shape[0] == self.num_generation_tokens
7070

7171
# Set during the execution of the first attention op.
7272
self.attn_bias: List[AttentionBias] = []

vllm/model_executor/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"GPTJForCausalLM": GPTJForCausalLM,
2424
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
2525
"InternLMForCausalLM": InternLMForCausalLM,
26-
"LlamaForCausalLM": LlamaForCausalLM,
26+
"LlamaForCausalLM": BigDLLlamaForCausalLM,
2727
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
2828
"MistralForCausalLM": MistralForCausalLM,
2929
"MPTForCausalLM": MPTForCausalLM,

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
1010
from vllm.model_executor.models.internlm import InternLMForCausalLM
1111
from vllm.model_executor.models.llama import LlamaForCausalLM
12+
from vllm.model_executor.models.bigdl_llama import BigDLLlamaForCausalLM
1213
from vllm.model_executor.models.mpt import MPTForCausalLM
1314
from vllm.model_executor.models.opt import OPTForCausalLM
1415
from vllm.model_executor.models.qwen import QWenLMHeadModel
@@ -26,6 +27,7 @@
2627
"GPTNeoXForCausalLM",
2728
"InternLMForCausalLM",
2829
"LlamaForCausalLM",
30+
"BigDLLlamaForCausalLM",
2931
"MPTForCausalLM",
3032
"OPTForCausalLM",
3133
"QWenLMHeadModel",
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import torch
2+
from torch import nn
3+
4+
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, LlamaConfig
5+
from typing import Optional, Tuple, List, Type, Dict
6+
7+
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
8+
get_tokenizer)
9+
from vllm.model_executor.quantization_utils import QuantizationConfig
10+
from vllm.sequence import SamplerOutput, SequenceOutputs
11+
import math
12+
13+
import pdb
14+
15+
from transformers.generation.logits_process import (
16+
LogitsProcessorList,
17+
RepetitionPenaltyLogitsProcessor,
18+
TemperatureLogitsWarper,
19+
TopKLogitsWarper,
20+
TopPLogitsWarper,
21+
)
22+
23+
def prepare_logits_processor(
24+
temperature: float, repetition_penalty: float, top_p: float, top_k: int
25+
) -> LogitsProcessorList:
26+
processor_list = LogitsProcessorList()
27+
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
28+
if temperature >= 1e-5 and temperature != 1.0:
29+
processor_list.append(TemperatureLogitsWarper(temperature))
30+
# if repetition_penalty > 1.0:
31+
# processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
32+
if 1e-8 <= top_p < 1.0:
33+
processor_list.append(TopPLogitsWarper(top_p))
34+
if top_k > 0:
35+
processor_list.append(TopKLogitsWarper(top_k))
36+
return processor_list
37+
38+
class BigDLLlamaForCausalLM(nn.Module):
39+
def __init__(
40+
self,
41+
config: LlamaConfig,
42+
quant_config: Optional[QuantizationConfig] = None,
43+
):
44+
super().__init__()
45+
# pdb.set_trace()
46+
self.config = config
47+
self.model = AutoModelForCausalLM.from_pretrained(config._name_or_path)
48+
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
49+
50+
def decode(self, generated_ids: List[int]) -> str:
51+
return self.tokenizer.decode(
52+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
53+
)
54+
55+
def forward(
56+
self, seq_group_meta_data_lists, kv_cache: Optional = None
57+
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
58+
kv_cache_0 = self.model.config.num_hidden_layers
59+
kv_cache_1 = 2
60+
bigdl_kv_cache = [[torch.Tensor() for _ in range(kv_cache_1)] for _ in range(kv_cache_0)]
61+
seq_len = len(seq_group_meta_data_lists)
62+
for i in range(seq_len):
63+
if kv_cache.get(i) is None:
64+
kv_cache[i] = bigdl_kv_cache[:]
65+
66+
bigdl_input_ids = []
67+
bigdl_position_ids = []
68+
cur_seq_ids = []
69+
bigdl_sampling_params = {}
70+
71+
all_decoding = True
72+
for seq_group_meta_data in seq_group_meta_data_lists:
73+
req_id = seq_group_meta_data.request_id
74+
all_decoding = all_decoding and (not seq_group_meta_data.is_prompt)
75+
seq_ids = list(seq_group_meta_data.seq_data.keys())
76+
seq_id = seq_ids[0]
77+
print(seq_id)
78+
cur_seq_ids.append(seq_id)
79+
seq_data = seq_group_meta_data.seq_data[seq_id]
80+
81+
cur_seq_input_ids = seq_data.get_token_ids()
82+
bigdl_input_ids.append(cur_seq_input_ids)
83+
84+
bigdl_sampling_params[seq_id] = seq_group_meta_data.sampling_params
85+
86+
context_len = seq_data.get_len()
87+
bigdl_position_ids.append(range(context_len))
88+
if all_decoding:
89+
for seq_group_meta_data in seq_group_meta_data_lists:
90+
for i in range(kv_cache_0):
91+
for j in range(kv_cache_1):
92+
bigdl_kv_cache[i][j] = torch.cat((bigdl_kv_cache[i][j], kv_cache[seq_id][i][j]), dim=0)
93+
94+
bigdl_input_ids = torch.tensor(bigdl_input_ids, device="cuda")
95+
bigdl_position_ids = torch.tensor(bigdl_position_ids, device="cuda")
96+
if all_decoding:
97+
kwargs = {
98+
"input_ids": bigdl_input_ids,
99+
"position_ids": bigdl_position_ids,
100+
"past_key_values": bigdl_kv_cache,
101+
"use_cache": True,
102+
"return_dict": True,
103+
}
104+
else:
105+
kwargs = {
106+
"input_ids": bigdl_input_ids,
107+
"position_ids": bigdl_position_ids,
108+
"past_key_values": None,
109+
"use_cache": True,
110+
"return_dict": True,
111+
}
112+
# kwargs["position_ids"] = position_ids
113+
outputs = self.model.forward(**kwargs)
114+
index = 0
115+
bigdl_output = []
116+
for seq_id in cur_seq_ids:
117+
cur_sampling_params = bigdl_sampling_params[seq_id]
118+
logits_processor = prepare_logits_processor(
119+
cur_sampling_params.temperature, 1,
120+
cur_sampling_params.top_p, cur_sampling_params.top_k
121+
)
122+
123+
last_token_logits = logits_processor(None, outputs.logits[index:index+1, -1, :])[0]
124+
probs = torch.softmax(last_token_logits, dim=-1)
125+
indices = torch.multinomial(probs, num_samples=2)
126+
tokens = [int(token) for token in indices.tolist()]
127+
128+
logprobs = math.log(probs[tokens[0]])
129+
seq_output = SequenceOutputs(
130+
parent_seq_id = seq_id,
131+
output_token = tokens[0],
132+
logprobs = {tokens[0]: logprobs}
133+
)
134+
bigdl_output.append([seq_output])
135+
136+
for i in range(kv_cache_0):
137+
for j in range(kv_cache_1):
138+
kv_cache[seq_id][i][j] = outputs.past_key_values[i][j][index].unsqueeze(0)
139+
index = index + 1
140+
return bigdl_output

0 commit comments

Comments
 (0)