1919# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
2020_PARTITION_SIZE = 512
2121
22- from timber .models .timber_attention .attention1_block_gpu import paged_timber_attention
23-
22+ from timber .models .timber_attention .attention1_block_gpu import (
23+ paged_timber_attention ,
24+ timber_attention
25+ )
26+ from vllm .transformers_utils import config as vllm_transformers_config
27+ from timber .utils import get_bench
2428BENCHMARK_ITERATION = 0
2529
2630class PagedAttention (nn .Module ):
@@ -44,6 +48,7 @@ def __init__(
4448 num_kv_heads : Optional [int ] = None ,
4549 alibi_slopes : Optional [List [float ]] = None ,
4650 sliding_window : Optional [int ] = None ,
51+ layer_index : Optional [int ] = None ,
4752 ) -> None :
4853 super ().__init__ ()
4954 self .num_heads = num_heads
@@ -61,6 +66,8 @@ def __init__(
6166 if self .head_size not in _SUPPORTED_HEAD_SIZES :
6267 raise ValueError (f"head_size ({ self .head_size } ) is not supported. "
6368 f"Supported head sizes: { _SUPPORTED_HEAD_SIZES } ." )
69+
70+ self .layer_index = layer_index
6471
6572 def forward (
6673 self ,
@@ -106,88 +113,160 @@ def forward(
106113 input_metadata .slot_mapping .flatten (),
107114 input_metadata .kv_cache_dtype ,
108115 )
116+
117+ hip_k = int (os .environ .get ('HIP_K' , '1024' ))
109118
110119 if input_metadata .is_prompt :
111120 # Prompt run.
112- if self .num_kv_heads != self .num_heads :
113- # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
114- # project the key and value tensors to the desired number of
115- # heads.
116- # TODO(woosuk): Use MQA/GQA kernels for higher performance.
117- query = query .view (query .shape [0 ], self .num_kv_heads ,
118- self .num_queries_per_kv , query .shape [- 1 ])
119- key = key [:, :,
120- None , :].expand (key .shape [0 ], self .num_kv_heads ,
121- self .num_queries_per_kv ,
122- key .shape [- 1 ])
123- value = value [:, :, None , :].expand (value .shape [0 ],
124- self .num_kv_heads ,
125- self .num_queries_per_kv ,
126- value .shape [- 1 ])
127- # normal attention
128- if (key_cache is None or value_cache is None
129- or input_metadata .block_tables .numel () == 0 ):
130- # Set attention bias if not provided. This typically happens at
131- # the very attention layer of every iteration.
132- # FIXME(woosuk): This is a hack.
133- if input_metadata .attn_bias is None :
121+ BENCHMARK_PROMPT_ATTENTION = os .environ .get ('BENCHMARK_PAGED_ATTENTION' , '0' ) == '1'
122+ backend = os .environ .get ('PROMPT_ATTENTION_BACKEND' , 'vllm' )
123+ is_normal_attention = (key_cache is None ) or (value_cache is None ) or (input_metadata .block_tables .numel () == 0 )
124+ if backend == 'vllm' :
125+ if self .num_kv_heads != self .num_heads :
126+ # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
127+ # project the key and value tensors to the desired number of
128+ # heads.
129+ # TODO(woosuk): Use MQA/GQA kernels for higher performance.
130+ query = query .view (
131+ query .shape [0 ],
132+ self .num_kv_heads ,
133+ self .num_queries_per_kv ,
134+ query .shape [- 1 ],
135+ )
136+ key = key [:, :, None , :]\
137+ .expand (
138+ key .shape [0 ],
139+ self .num_kv_heads ,
140+ self .num_queries_per_kv ,
141+ key .shape [- 1 ]
142+ )
143+ value = value [:, :, None , :]\
144+ .expand (
145+ value .shape [0 ],
146+ self .num_kv_heads ,
147+ self .num_queries_per_kv ,
148+ value .shape [- 1 ]
149+ )
150+ # normal attention
151+ if is_normal_attention :
152+ # Set attention bias if not provided. This typically happens at
153+ # the very attention layer of every iteration.
154+ # FIXME(woosuk): This is a hack.
155+ if input_metadata .attn_bias is None :
156+ if self .alibi_slopes is None :
157+ attn_bias = BlockDiagonalCausalMask .from_seqlens (
158+ [seq_len ] * batch_size )
159+ if self .sliding_window is not None :
160+ attn_bias = attn_bias .make_local_attention (
161+ self .sliding_window )
162+ input_metadata .attn_bias = attn_bias
163+ else :
164+ input_metadata .attn_bias = _make_alibi_bias (
165+ self .alibi_slopes , self .num_kv_heads , batch_size ,
166+ seq_len , query .dtype )
167+
168+ # TODO(woosuk): Too many view operations. Let's try to reduce
169+ # them in the future for code readability.
134170 if self .alibi_slopes is None :
135- attn_bias = BlockDiagonalCausalMask .from_seqlens (
136- [seq_len ] * batch_size )
137- if self .sliding_window is not None :
138- attn_bias = attn_bias .make_local_attention (
139- self .sliding_window )
140- input_metadata .attn_bias = attn_bias
171+ query = query .unsqueeze (0 )
172+ key = key .unsqueeze (0 )
173+ value = value .unsqueeze (0 )
141174 else :
142- input_metadata . attn_bias = _make_alibi_bias (
143- self . alibi_slopes , self . num_kv_heads , batch_size ,
144- seq_len , query . dtype )
175+ query = query . unflatten ( 0 , ( batch_size , seq_len ))
176+ key = key . unflatten ( 0 , ( batch_size , seq_len ))
177+ value = value . unflatten ( 0 , ( batch_size , seq_len ) )
145178
146- # TODO(woosuk): Too many view operations. Let's try to reduce
147- # them in the future for code readability.
148- if self .alibi_slopes is None :
149- query = query .unsqueeze (0 )
150- key = key .unsqueeze (0 )
151- value = value .unsqueeze (0 )
179+ if BENCHMARK_PROMPT_ATTENTION :
180+ start = torch .cuda .Event (enable_timing = True )
181+ end = torch .cuda .Event (enable_timing = True )
182+ start .record ()
183+
184+ out = xops .memory_efficient_attention_forward (
185+ query ,
186+ key ,
187+ value ,
188+ attn_bias = input_metadata .attn_bias ,
189+ p = 0.0 ,
190+ scale = self .scale ,
191+ op = xops .fmha .MemoryEfficientAttentionFlashAttentionOp [0 ] if
192+ (is_hip ()) else None ,
193+ )
194+ output = out .view_as (query )
195+
196+ if BENCHMARK_PROMPT_ATTENTION :
197+ end .record ()
198+ torch .cuda .synchronize ()
199+ print (backend , start .elapsed_time (end ), output .shape , end = '\n ' )
152200 else :
153- query = query .unflatten (0 , (batch_size , seq_len ))
154- key = key .unflatten (0 , (batch_size , seq_len ))
155- value = value .unflatten (0 , (batch_size , seq_len ))
156-
157- out = xops .memory_efficient_attention_forward (
158- query ,
159- key ,
160- value ,
161- attn_bias = input_metadata .attn_bias ,
162- p = 0.0 ,
163- scale = self .scale ,
164- op = xops .fmha .MemoryEfficientAttentionFlashAttentionOp [0 ] if
165- (is_hip ()) else None ,
201+ # prefix-enabled attention
202+ output = torch .empty_like (query )
203+ context_attention_fwd (
204+ query ,
205+ key ,
206+ value ,
207+ output ,
208+ key_cache ,
209+ value_cache ,
210+ input_metadata .block_tables , # [BS, max_block_per_request]
211+ input_metadata .start_loc ,
212+ input_metadata .prompt_lens ,
213+ input_metadata .context_lens ,
214+ input_metadata .max_seq_len ,
215+ getattr (self , "alibi_slopes" , None ),
216+ )
217+ elif backend == 'timber' :
218+ # timber support MQA/GQA
219+ warnings .warn ('prompt attention backend is timber' )
220+
221+ TDST , H , HID = query .shape
222+ TSRC , H_KV , _HID = key .shape
223+ assert key .shape [:- 1 ] == value .shape [:- 1 ]
224+ assert HID == _HID
225+
226+ query = query .permute (1 , 0 , 2 )
227+ key = key .permute (1 , 0 , 2 )
228+ value = value .permute (1 , 0 , 2 )
229+
230+ if BENCHMARK_PROMPT_ATTENTION :
231+ start = torch .cuda .Event (enable_timing = True )
232+ end = torch .cuda .Event (enable_timing = True )
233+ start .record ()
234+
235+ assert input_metadata .attn_bias is None
236+ assert self .alibi_slopes is None
237+
238+ output , _ = timber_attention (
239+ q = query * self .scale ,
240+ k = key ,
241+ v = value ,
242+ attention_mask = None ,
243+ mask_k = hip_k ,
244+ block_size_q = 32 ,
245+ block_size_k = 2 ,
166246 )
167- output = out .view_as (query )
247+
248+ output = output .permute (1 , 0 , 2 )
249+ output = output .view (
250+ 1 ,
251+ TDST ,
252+ H ,
253+ HID ,
254+ ).contiguous ()
255+
256+ if BENCHMARK_PROMPT_ATTENTION :
257+ end .record ()
258+ torch .cuda .synchronize ()
259+ print (backend , start .elapsed_time (end ), output .shape , end = '\n ' )
168260 else :
169- # prefix-enabled attention
170- output = torch .empty_like (query )
171- context_attention_fwd (
172- query ,
173- key ,
174- value ,
175- output ,
176- key_cache ,
177- value_cache ,
178- input_metadata .block_tables , # [BS, max_block_per_request]
179- input_metadata .start_loc ,
180- input_metadata .prompt_lens ,
181- input_metadata .context_lens ,
182- input_metadata .max_seq_len ,
183- getattr (self , "alibi_slopes" , None ),
184- )
185-
261+ raise Exception (backend )
186262 else :
187263 # Decoding run.
188264 BENCHMARK_PAGED_ATTENTION = os .environ .get ('BENCHMARK_PAGED_ATTENTION' , '0' ) == '1'
265+
266+ # print(f'[{os.getpid()}, {self.layer_index}] query_size: {query.shape}, block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')
267+
189268 if BENCHMARK_PAGED_ATTENTION :
190- warnings .warn (f'query_size: { query .shape } , block_table: { input_metadata .block_tables .shape } [{ input_metadata .max_context_len } /{ input_metadata .max_seq_len } ]' )
269+ warnings .warn (f'query_size: { query .shape } ( { query . dtype } ) , block_table: { input_metadata .block_tables .shape } [{ input_metadata .max_context_len } /{ input_metadata .max_seq_len } ]' )
191270 torch .cuda .synchronize ()
192271 start = torch .cuda .Event (enable_timing = True )
193272 end = torch .cuda .Event (enable_timing = True )
@@ -203,9 +282,9 @@ def forward(
203282 self .num_kv_heads ,
204283 self .scale ,
205284 self .alibi_slopes ,
206- )
285+ )
207286 elif backend == 'timber' :
208- warnings .warn ('backend is timber' )
287+ warnings .warn ('paged attention backend is timber' )
209288
210289 output , _ = paged_timber_attention (
211290 q = query ,
@@ -216,9 +295,9 @@ def forward(
216295 context_lens = input_metadata .context_lens ,
217296 max_context_len = input_metadata .max_context_len ,
218297 attention_mask = None ,
219- mask_k = 1024 ,
298+ mask_k = hip_k ,
299+ block_size_q = 32 ,
220300 block_size_k = 2 ,
221- block_size_q = 16
222301 )
223302
224303 N_H , _ , HID = output .shape
@@ -243,11 +322,12 @@ def forward(
243322 "alibi_slopes" : self .alibi_slopes ,
244323 "output" : output ,
245324 }, 'cache/llama/vllmout.pth' )
325+ print ('saved cache/llama/vllmout.pth' )
246326
247327 if BENCHMARK_PAGED_ATTENTION :
248328 end .record ()
249329 torch .cuda .synchronize ()
250- print (start .elapsed_time (end ))
330+ print (f'( { backend } ) { start .elapsed_time (end )} ' , end = ' \r ' )
251331
252332 # Reshape the output tensor.
253333 return output .view (batch_size , seq_len , hidden_size )
0 commit comments