@@ -3,7 +3,9 @@ use std::{
33 sync:: { Arc , Mutex , MutexGuard } ,
44} ;
55
6- use candle_core:: { DType , Device , Result , Tensor } ;
6+ use candle_core:: {
7+ from_storage_no_op, DType , Device , MetalStorage , Result , Shape , Storage , Tensor ,
8+ } ;
79use mistralrs_paged_attn:: { copy_blocks, swap_blocks} ;
810
911use super :: config:: ModelConfigLike ;
@@ -70,30 +72,88 @@ impl CacheEngine {
7072 . take ( model_config. num_layers ( ) )
7173 . map ( |x| x. as_ref ( ) . unwrap_or ( device) )
7274 {
73- let key_blocks = unsafe {
74- Tensor :: empty (
75- (
76- cache_config. num_gpu_blocks ,
77- key_block_shape. 0 ,
78- key_block_shape. 1 ,
79- key_block_shape. 2 ,
80- key_block_shape. 3 ,
81- ) ,
82- dtype,
83- device,
84- ) ?
75+ let key_blocks = if let Device :: Metal ( dev) = & device {
76+ #[ cfg( feature = "metal" ) ]
77+ {
78+ let elem_count = cache_config. num_gpu_blocks
79+ * key_block_shape. 0
80+ * key_block_shape. 1
81+ * key_block_shape. 2
82+ * key_block_shape. 3 ;
83+ let buffer = dev. new_buffer_private ( elem_count, dtype, "k_cache" ) ?;
84+ let storage =
85+ Storage :: Metal ( MetalStorage :: new ( buffer, dev. clone ( ) , elem_count, dtype) ) ;
86+ from_storage_no_op (
87+ storage,
88+ Shape :: from_dims ( & [
89+ cache_config. num_gpu_blocks ,
90+ key_block_shape. 0 ,
91+ key_block_shape. 1 ,
92+ key_block_shape. 2 ,
93+ key_block_shape. 3 ,
94+ ] ) ,
95+ false ,
96+ )
97+ }
98+
99+ #[ cfg( not( feature = "metal" ) ) ]
100+ {
101+ unreachable ! ( )
102+ }
103+ } else {
104+ unsafe {
105+ Tensor :: empty (
106+ (
107+ cache_config. num_gpu_blocks ,
108+ key_block_shape. 0 ,
109+ key_block_shape. 1 ,
110+ key_block_shape. 2 ,
111+ key_block_shape. 3 ,
112+ ) ,
113+ dtype,
114+ device,
115+ ) ?
116+ }
85117 } ;
86- let value_blocks = unsafe {
87- Tensor :: empty (
88- (
89- cache_config. num_gpu_blocks ,
90- value_block_shape. 0 ,
91- value_block_shape. 1 ,
92- value_block_shape. 2 ,
93- ) ,
94- dtype,
95- device,
96- ) ?
118+ let value_blocks = if let Device :: Metal ( dev) = & device {
119+ #[ cfg( feature = "metal" ) ]
120+ {
121+ let elem_count = cache_config. num_gpu_blocks
122+ * value_block_shape. 0
123+ * value_block_shape. 1
124+ * value_block_shape. 2 ;
125+ let buffer = dev. new_buffer_private ( elem_count, dtype, "v_cache" ) ?;
126+ let storage =
127+ Storage :: Metal ( MetalStorage :: new ( buffer, dev. clone ( ) , elem_count, dtype) ) ;
128+ from_storage_no_op (
129+ storage,
130+ Shape :: from_dims ( & [
131+ cache_config. num_gpu_blocks ,
132+ value_block_shape. 0 ,
133+ value_block_shape. 1 ,
134+ value_block_shape. 2 ,
135+ ] ) ,
136+ false ,
137+ )
138+ }
139+
140+ #[ cfg( not( feature = "metal" ) ) ]
141+ {
142+ unreachable ! ( )
143+ }
144+ } else {
145+ unsafe {
146+ Tensor :: empty (
147+ (
148+ cache_config. num_gpu_blocks ,
149+ value_block_shape. 0 ,
150+ value_block_shape. 1 ,
151+ value_block_shape. 2 ,
152+ ) ,
153+ dtype,
154+ device,
155+ ) ?
156+ }
97157 } ;
98158 gpu_cache. push ( ( key_blocks, value_blocks) ) ;
99159 }
0 commit comments