@@ -233,14 +233,25 @@ def __init__(self, model, strategy):
233233
234234 ss = strategy .split (' ' )
235235 DEVICE = ss [0 ]
236+ self .ffn_fp8 = False
236237 if ss [1 ] == 'fp16' :
237238 DTYPE = torch .half
238239 elif ss [1 ] == 'fp32' :
239240 DTYPE = torch .float32
240241 elif ss [1 ] == 'bf16' :
241242 DTYPE = torch .bfloat16
243+ elif ss [1 ] == 'fp8' :
244+ DTYPE = torch .half
245+ # Check pytorch support for fp8
246+ if hasattr (torch , 'float8_e4m3fn' ) and hasattr (torch , 'float8_e5m2' ) \
247+ and hasattr (torch , '_scaled_mm' ):
248+ print ("FP8 only support with modern CUDA devices, and it only quantize FFN weights" )
249+ print ("The other weights are still in FP16\n \n " )
250+ self .ffn_fp8 = True
251+ else :
252+ print ('Your pytorch version does not support fp8, please update to 2.6 or later' )
242253 else :
243- assert False , "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
254+ assert False , "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16/fp8 "
244255
245256 temp_z = torch .load (args .MODEL_NAME + '.pth' , map_location = 'cpu' , mmap = True )
246257
@@ -257,14 +268,20 @@ def __init__(self, model, strategy):
257268 args .n_layer = max (args .n_layer , layer_id + 1 )
258269 tensor = temp_z [k ]
259270 if 'key.weight' in k or 'value.weight' in k or 'receptance.weight' in k or 'output.weight' in k or 'head.weight' in k :
260- tensor = tensor .t (). contiguous ()
271+ tensor = tensor .t ()
261272 tensor = tensor .squeeze ()
262273 if k .endswith ('att.r_k' ):
263274 tensor = tensor .flatten ()
264- self .z [k ] = tensor .to (DEVICE ).to (DTYPE )
275+
276+ if ('ffn.key.weight' in k or 'ffn.value.weight' in k ) and self .ffn_fp8 :
277+ # convert to fp8
278+ tensor , reciprocal = to_float8 (tensor )
279+ self .z [k ] = tensor .to (DEVICE )
280+ self .z [k + '_scale' ] = reciprocal .to (DEVICE )
281+ else :
282+ self .z [k ] = tensor .contiguous ().to (DEVICE ).to (DTYPE )
265283 del temp_z [k ]
266- if keys .index (k ) % 5 == 0 :
267- torch .cuda .empty_cache ()
284+
268285
269286 self .n_embd = args .n_embd
270287 self .n_layer = args .n_layer
@@ -315,7 +332,10 @@ def forward_one(self, idx:int, state:List[torch.Tensor]):
315332
316333 xx = F .layer_norm (x , (self .n_embd ,), weight = z [bbb + 'ln2.weight' ], bias = z [bbb + 'ln2.bias' ])
317334
318- xx , state [i * 3 + 2 ] = RWKV_x070_CMix_one (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'value.weight' ])
335+ if not self .ffn_fp8 :
336+ xx , state [i * 3 + 2 ] = RWKV_x070_CMix_one (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'value.weight' ])
337+ else :
338+ xx , state [i * 3 + 2 ] = RWKV_x070_CMix_one_fp8 (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'key.weight_scale' ], z [ffn + 'value.weight' ], z [ffn + 'value.weight_scale' ])
319339 x = x + xx
320340
321341 # if math.isnan(torch.min(x).item()): print(idx, i)
@@ -347,8 +367,10 @@ def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool=
347367 x = x + xx
348368
349369 xx = F .layer_norm (x , (self .n_embd ,), weight = z [bbb + 'ln2.weight' ], bias = z [bbb + 'ln2.bias' ])
350-
351- xx , state [i * 3 + 2 ] = RWKV_x070_CMix_seq (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'value.weight' ])
370+ if not self .ffn_fp8 :
371+ xx , state [i * 3 + 2 ] = RWKV_x070_CMix_seq (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'value.weight' ])
372+ else :
373+ xx , state [i * 3 + 2 ] = RWKV_x070_CMix_seq_fp8 (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'key.weight_scale' ], z [ffn + 'value.weight' ], z [ffn + 'value.weight_scale' ])
352374 x = x + xx
353375
354376 if not full_output : x = x [- 1 ,:]
@@ -442,21 +464,63 @@ def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x
442464 return (xx * g ) @ O_ , x [- 1 ,:], state , v_first
443465
444466 ########################################################################################################
467+ @MyStatic
468+ def to_float8 (x ):
469+ dtype_min = torch .tensor (- 448 ).to (x .device )
470+ dtype_max = torch .tensor (448 ).to (x .device )
471+ scale = dtype_max / x .abs ().max ().clamp (min = 1e-12 )
472+ x_scl_sat = (x * scale ).clamp (min = dtype_min , max = dtype_max )
473+ return x_scl_sat .to (torch .float8_e4m3fn ), scale .float ().reciprocal ()
474+
475+ @MyStatic
476+ def to_float8_e5m2 (x ):
477+ dtype_min = torch .tensor (- 57344 ).to (x .device )
478+ dtype_max = torch .tensor (57344 ).to (x .device )
479+ abs_max = x .abs ().max ().clamp (min = 1e-12 )
480+ scale = dtype_max / abs_max
481+ x_scl_sat = (x * scale ).clamp (min = dtype_min , max = dtype_max )
482+ return x_scl_sat .to (torch .float8_e5m2 ), scale .float ().reciprocal ()
445483
446484 @MyStatic
447485 def RWKV_x070_CMix_one (x , x_prev , x_k , K_ , V_ ):
448486 xx = x_prev - x
449- k = x + xx * x_k
487+ k = torch . addcmul ( x , xx , x_k )
450488 k = torch .relu (k @ K_ ) ** 2
451489 return k @ V_ , x
452490
491+ @MyStatic
492+ def RWKV_x070_CMix_one_fp8 (x , x_prev , x_k , K_ , K_scale , V_ , V_scale ):
493+ xx = x_prev - x
494+ k = torch .addcmul (x , xx , x_k )
495+ k_fp8 , k_scale_1 = to_float8 (k )
496+ k_fp8 = k_fp8 .unsqueeze (0 )
497+ k_new = torch ._scaled_mm (k_fp8 , K_ , k_scale_1 , K_scale , out_dtype = x .dtype )
498+ k = torch .relu (k_new ) ** 2
499+ k_fp8 , k_scale_2 = to_float8 (k )
500+ output1 = torch ._scaled_mm (k_fp8 , V_ , k_scale_2 , V_scale , out_dtype = x .dtype )
501+ return output1 .squeeze (0 ), x
502+
503+
453504 @MyStatic
454505 def RWKV_x070_CMix_seq (x , x_prev , x_k , K_ , V_ ):
455506 xx = torch .cat ((x_prev .unsqueeze (0 ), x [:- 1 ,:])) - x
456- k = x + xx * x_k
507+ k = torch . addcmul ( x , xx , x_k )
457508 k = torch .relu (k @ K_ ) ** 2
458509 return k @ V_ , x [- 1 ,:]
459510
511+ @MyStatic
512+ def RWKV_x070_CMix_seq_fp8 (x , x_prev , x_k , K_ , K_scale , V_ , V_scale ):
513+ xx = torch .cat ((x_prev .unsqueeze (0 ), x [:- 1 ,:])) - x
514+ k = torch .addcmul (x , xx , x_k )
515+ k_fp8 , k_scale_1 = to_float8 (k )
516+ k_new = torch ._scaled_mm (k_fp8 , K_ , k_scale_1 , K_scale , out_dtype = x .dtype )
517+ k = torch .relu (k_new ) ** 2
518+ k_fp8 , k_scale_2 = to_float8 (k )
519+ output1 = torch ._scaled_mm (k_fp8 , V_ , k_scale_2 , V_scale , out_dtype = x .dtype )
520+ return output1 , x [- 1 ,:]
521+
522+
523+
460524########################################################################################################
461525
462526class RWKV (MyModule ):
0 commit comments