Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit d9d746e

Browse files
Merge pull request vllm-project#1 from DeepAuto-AI/geon-dev
merge code
2 parents 1c03585 + 7f2a7d8 commit d9d746e

File tree

5 files changed

+209
-83
lines changed

5 files changed

+209
-83
lines changed

vllm/model_executor/layers/attention.py

Lines changed: 156 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
2020
_PARTITION_SIZE = 512
2121

22-
from timber.models.timber_attention.attention1_block_gpu import paged_timber_attention
23-
22+
from timber.models.timber_attention.attention1_block_gpu import (
23+
paged_timber_attention,
24+
timber_attention
25+
)
26+
from vllm.transformers_utils import config as vllm_transformers_config
27+
from timber.utils import get_bench
2428
BENCHMARK_ITERATION = 0
2529

2630
class PagedAttention(nn.Module):
@@ -44,6 +48,7 @@ def __init__(
4448
num_kv_heads: Optional[int] = None,
4549
alibi_slopes: Optional[List[float]] = None,
4650
sliding_window: Optional[int] = None,
51+
layer_index: Optional[int] = None,
4752
) -> None:
4853
super().__init__()
4954
self.num_heads = num_heads
@@ -61,6 +66,8 @@ def __init__(
6166
if self.head_size not in _SUPPORTED_HEAD_SIZES:
6267
raise ValueError(f"head_size ({self.head_size}) is not supported. "
6368
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
69+
70+
self.layer_index = layer_index
6471

6572
def forward(
6673
self,
@@ -106,88 +113,160 @@ def forward(
106113
input_metadata.slot_mapping.flatten(),
107114
input_metadata.kv_cache_dtype,
108115
)
116+
117+
hip_k = int(os.environ.get('HIP_K', '1024'))
109118

110119
if input_metadata.is_prompt:
111120
# Prompt run.
112-
if self.num_kv_heads != self.num_heads:
113-
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
114-
# project the key and value tensors to the desired number of
115-
# heads.
116-
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
117-
query = query.view(query.shape[0], self.num_kv_heads,
118-
self.num_queries_per_kv, query.shape[-1])
119-
key = key[:, :,
120-
None, :].expand(key.shape[0], self.num_kv_heads,
121-
self.num_queries_per_kv,
122-
key.shape[-1])
123-
value = value[:, :, None, :].expand(value.shape[0],
124-
self.num_kv_heads,
125-
self.num_queries_per_kv,
126-
value.shape[-1])
127-
# normal attention
128-
if (key_cache is None or value_cache is None
129-
or input_metadata.block_tables.numel() == 0):
130-
# Set attention bias if not provided. This typically happens at
131-
# the very attention layer of every iteration.
132-
# FIXME(woosuk): This is a hack.
133-
if input_metadata.attn_bias is None:
121+
BENCHMARK_PROMPT_ATTENTION = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'
122+
backend = os.environ.get('PROMPT_ATTENTION_BACKEND', 'vllm')
123+
is_normal_attention = (key_cache is None) or (value_cache is None) or (input_metadata.block_tables.numel() == 0)
124+
if backend == 'vllm':
125+
if self.num_kv_heads != self.num_heads:
126+
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
127+
# project the key and value tensors to the desired number of
128+
# heads.
129+
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
130+
query = query.view(
131+
query.shape[0],
132+
self.num_kv_heads,
133+
self.num_queries_per_kv,
134+
query.shape[-1],
135+
)
136+
key = key[:, :, None, :]\
137+
.expand(
138+
key.shape[0],
139+
self.num_kv_heads,
140+
self.num_queries_per_kv,
141+
key.shape[-1]
142+
)
143+
value = value[:, :, None, :]\
144+
.expand(
145+
value.shape[0],
146+
self.num_kv_heads,
147+
self.num_queries_per_kv,
148+
value.shape[-1]
149+
)
150+
# normal attention
151+
if is_normal_attention:
152+
# Set attention bias if not provided. This typically happens at
153+
# the very attention layer of every iteration.
154+
# FIXME(woosuk): This is a hack.
155+
if input_metadata.attn_bias is None:
156+
if self.alibi_slopes is None:
157+
attn_bias = BlockDiagonalCausalMask.from_seqlens(
158+
[seq_len] * batch_size)
159+
if self.sliding_window is not None:
160+
attn_bias = attn_bias.make_local_attention(
161+
self.sliding_window)
162+
input_metadata.attn_bias = attn_bias
163+
else:
164+
input_metadata.attn_bias = _make_alibi_bias(
165+
self.alibi_slopes, self.num_kv_heads, batch_size,
166+
seq_len, query.dtype)
167+
168+
# TODO(woosuk): Too many view operations. Let's try to reduce
169+
# them in the future for code readability.
134170
if self.alibi_slopes is None:
135-
attn_bias = BlockDiagonalCausalMask.from_seqlens(
136-
[seq_len] * batch_size)
137-
if self.sliding_window is not None:
138-
attn_bias = attn_bias.make_local_attention(
139-
self.sliding_window)
140-
input_metadata.attn_bias = attn_bias
171+
query = query.unsqueeze(0)
172+
key = key.unsqueeze(0)
173+
value = value.unsqueeze(0)
141174
else:
142-
input_metadata.attn_bias = _make_alibi_bias(
143-
self.alibi_slopes, self.num_kv_heads, batch_size,
144-
seq_len, query.dtype)
175+
query = query.unflatten(0, (batch_size, seq_len))
176+
key = key.unflatten(0, (batch_size, seq_len))
177+
value = value.unflatten(0, (batch_size, seq_len))
145178

146-
# TODO(woosuk): Too many view operations. Let's try to reduce
147-
# them in the future for code readability.
148-
if self.alibi_slopes is None:
149-
query = query.unsqueeze(0)
150-
key = key.unsqueeze(0)
151-
value = value.unsqueeze(0)
179+
if BENCHMARK_PROMPT_ATTENTION:
180+
start = torch.cuda.Event(enable_timing=True)
181+
end = torch.cuda.Event(enable_timing=True)
182+
start.record()
183+
184+
out = xops.memory_efficient_attention_forward(
185+
query,
186+
key,
187+
value,
188+
attn_bias=input_metadata.attn_bias,
189+
p=0.0,
190+
scale=self.scale,
191+
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
192+
(is_hip()) else None,
193+
)
194+
output = out.view_as(query)
195+
196+
if BENCHMARK_PROMPT_ATTENTION:
197+
end.record()
198+
torch.cuda.synchronize()
199+
print(backend, start.elapsed_time(end), output.shape, end='\n')
152200
else:
153-
query = query.unflatten(0, (batch_size, seq_len))
154-
key = key.unflatten(0, (batch_size, seq_len))
155-
value = value.unflatten(0, (batch_size, seq_len))
156-
157-
out = xops.memory_efficient_attention_forward(
158-
query,
159-
key,
160-
value,
161-
attn_bias=input_metadata.attn_bias,
162-
p=0.0,
163-
scale=self.scale,
164-
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
165-
(is_hip()) else None,
201+
# prefix-enabled attention
202+
output = torch.empty_like(query)
203+
context_attention_fwd(
204+
query,
205+
key,
206+
value,
207+
output,
208+
key_cache,
209+
value_cache,
210+
input_metadata.block_tables, # [BS, max_block_per_request]
211+
input_metadata.start_loc,
212+
input_metadata.prompt_lens,
213+
input_metadata.context_lens,
214+
input_metadata.max_seq_len,
215+
getattr(self, "alibi_slopes", None),
216+
)
217+
elif backend == 'timber':
218+
# timber support MQA/GQA
219+
warnings.warn('prompt attention backend is timber')
220+
221+
TDST, H, HID = query.shape
222+
TSRC, H_KV, _HID = key.shape
223+
assert key.shape[:-1] == value.shape[:-1]
224+
assert HID == _HID
225+
226+
query = query.permute(1, 0, 2)
227+
key = key.permute(1, 0, 2)
228+
value = value.permute(1, 0, 2)
229+
230+
if BENCHMARK_PROMPT_ATTENTION:
231+
start = torch.cuda.Event(enable_timing=True)
232+
end = torch.cuda.Event(enable_timing=True)
233+
start.record()
234+
235+
assert input_metadata.attn_bias is None
236+
assert self.alibi_slopes is None
237+
238+
output, _ = timber_attention(
239+
q=query * self.scale,
240+
k=key,
241+
v=value,
242+
attention_mask=None,
243+
mask_k=hip_k,
244+
block_size_q=32,
245+
block_size_k=2,
166246
)
167-
output = out.view_as(query)
247+
248+
output = output.permute(1, 0, 2)
249+
output = output.view(
250+
1,
251+
TDST,
252+
H,
253+
HID,
254+
).contiguous()
255+
256+
if BENCHMARK_PROMPT_ATTENTION:
257+
end.record()
258+
torch.cuda.synchronize()
259+
print(backend, start.elapsed_time(end), output.shape, end='\n')
168260
else:
169-
# prefix-enabled attention
170-
output = torch.empty_like(query)
171-
context_attention_fwd(
172-
query,
173-
key,
174-
value,
175-
output,
176-
key_cache,
177-
value_cache,
178-
input_metadata.block_tables, # [BS, max_block_per_request]
179-
input_metadata.start_loc,
180-
input_metadata.prompt_lens,
181-
input_metadata.context_lens,
182-
input_metadata.max_seq_len,
183-
getattr(self, "alibi_slopes", None),
184-
)
185-
261+
raise Exception(backend)
186262
else:
187263
# Decoding run.
188264
BENCHMARK_PAGED_ATTENTION = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'
265+
266+
# print(f'[{os.getpid()}, {self.layer_index}] query_size: {query.shape}, block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')
267+
189268
if BENCHMARK_PAGED_ATTENTION:
190-
warnings.warn(f'query_size: {query.shape}, block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')
269+
warnings.warn(f'query_size: {query.shape}({query.dtype}), block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')
191270
torch.cuda.synchronize()
192271
start = torch.cuda.Event(enable_timing=True)
193272
end = torch.cuda.Event(enable_timing=True)
@@ -203,9 +282,9 @@ def forward(
203282
self.num_kv_heads,
204283
self.scale,
205284
self.alibi_slopes,
206-
)
285+
)
207286
elif backend == 'timber':
208-
warnings.warn('backend is timber')
287+
warnings.warn('paged attention backend is timber')
209288

210289
output, _ = paged_timber_attention(
211290
q=query,
@@ -216,9 +295,9 @@ def forward(
216295
context_lens=input_metadata.context_lens,
217296
max_context_len=input_metadata.max_context_len,
218297
attention_mask=None,
219-
mask_k=1024,
298+
mask_k=hip_k,
299+
block_size_q=32,
220300
block_size_k=2,
221-
block_size_q=16
222301
)
223302

224303
N_H, _, HID = output.shape
@@ -243,11 +322,12 @@ def forward(
243322
"alibi_slopes": self.alibi_slopes,
244323
"output": output,
245324
}, 'cache/llama/vllmout.pth')
325+
print('saved cache/llama/vllmout.pth')
246326

247327
if BENCHMARK_PAGED_ATTENTION:
248328
end.record()
249329
torch.cuda.synchronize()
250-
print(start.elapsed_time(end))
330+
print(f'({backend}) {start.elapsed_time(end)}', end='\r')
251331

252332
# Reshape the output tensor.
253333
return output.view(batch_size, seq_len, hidden_size)

vllm/model_executor/layers/rotary_embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
306306
# Get n-d rotational scaling corrected for extrapolation
307307
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
308308
low, high, self.rotary_dim // 2,
309-
dtype=torch.float)) * self.extrapolation_factor
309+
dtype=torch.float,
310+
device=pos_freqs.device
311+
)) * self.extrapolation_factor
310312
inv_freq = inv_freq_interpolation * (
311313
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
312314
return inv_freq

vllm/model_executor/models/llama.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,15 @@ def load_weights(self,
342342
continue
343343
for (param_name, weight_name, shard_id) in stacked_params_mapping:
344344
if weight_name not in name:
345+
# print('vllm.load_weight: ignore', weight_name)
345346
continue
346347
name = name.replace(weight_name, param_name)
347348
# Skip loading extra bias for GPTQ models.
348349
if name.endswith(".bias") and name not in params_dict:
349350
continue
351+
if name not in params_dict:
352+
print('vllm.load_weight: ignore', name)
353+
continue
350354
param = params_dict[name]
351355
weight_loader = param.weight_loader
352356
weight_loader(param, loaded_weight, shard_id)
@@ -355,6 +359,9 @@ def load_weights(self,
355359
# Skip loading extra bias for GPTQ models.
356360
if name.endswith(".bias") and name not in params_dict:
357361
continue
362+
if name not in params_dict:
363+
print('vllm.load_weight: ignore', name)
364+
continue
358365
param = params_dict[name]
359366
weight_loader = getattr(param, "weight_loader",
360367
default_weight_loader)

vllm/transformers_utils/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
}
1717

1818
# NOTE: For benchmarking
19-
FORCE_SIGNLE_LAYER = False
19+
FORCE_SIGNLE_LAYER = 0
2020

2121
def get_config(
2222
model: str,
@@ -44,7 +44,8 @@ def get_config(
4444
config = config_class.from_pretrained(model, revision=revision)
4545

4646
# NOTE: DEBUG
47-
if FORCE_SIGNLE_LAYER:
48-
config.num_hidden_layers = 1
47+
if FORCE_SIGNLE_LAYER > 0:
48+
assert isinstance(FORCE_SIGNLE_LAYER, int)
49+
config.num_hidden_layers = FORCE_SIGNLE_LAYER
4950

5051
return config

0 commit comments

Comments
 (0)