@@ -233,14 +233,20 @@ def __init__(self, model, strategy):
233233
234234 ss = strategy .split (' ' )
235235 DEVICE = ss [0 ]
236+ self .ffn_fp8 = False
236237 if ss [1 ] == 'fp16' :
237238 DTYPE = torch .half
238239 elif ss [1 ] == 'fp32' :
239240 DTYPE = torch .float32
240241 elif ss [1 ] == 'bf16' :
241242 DTYPE = torch .bfloat16
243+ elif ss [1 ] == 'fp8' :
244+ DTYPE = torch .half
245+ print ("FP8 only support with modern CUDA devices, and it only quantize FFN weights" )
246+ print ("The other weights are still in FP16\n \n " )
247+ self .ffn_fp8 = True
242248 else :
243- assert False , "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
249+ assert False , "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16/fp8 "
244250
245251 temp_z = torch .load (args .MODEL_NAME + '.pth' , map_location = 'cpu' , mmap = True )
246252
@@ -257,14 +263,20 @@ def __init__(self, model, strategy):
257263 args .n_layer = max (args .n_layer , layer_id + 1 )
258264 tensor = temp_z [k ]
259265 if 'key.weight' in k or 'value.weight' in k or 'receptance.weight' in k or 'output.weight' in k or 'head.weight' in k :
260- tensor = tensor .t (). contiguous ()
266+ tensor = tensor .t ()
261267 tensor = tensor .squeeze ()
262268 if k .endswith ('att.r_k' ):
263269 tensor = tensor .flatten ()
264- self .z [k ] = tensor .to (DEVICE ).to (DTYPE )
270+
271+ if ('ffn.key.weight' in k or 'ffn.value.weight' in k ) and self .ffn_fp8 :
272+ # convert to fp8
273+ tensor , reciprocal = to_float8 (tensor )
274+ self .z [k ] = tensor .to (DEVICE )
275+ self .z [k + '_scale' ] = reciprocal .to (DEVICE )
276+ else :
277+ self .z [k ] = tensor .contiguous ().to (DEVICE ).to (DTYPE )
265278 del temp_z [k ]
266- if keys .index (k ) % 5 == 0 :
267- torch .cuda .empty_cache ()
279+
268280
269281 self .n_embd = args .n_embd
270282 self .n_layer = args .n_layer
@@ -315,7 +327,10 @@ def forward_one(self, idx:int, state:List[torch.Tensor]):
315327
316328 xx = F .layer_norm (x , (self .n_embd ,), weight = z [bbb + 'ln2.weight' ], bias = z [bbb + 'ln2.bias' ])
317329
318- xx , state [i * 3 + 2 ] = RWKV_x070_CMix_one (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'value.weight' ])
330+ if not self .ffn_fp8 :
331+ xx , state [i * 3 + 2 ] = RWKV_x070_CMix_one (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'value.weight' ])
332+ else :
333+ xx , state [i * 3 + 2 ] = RWKV_x070_CMix_one_fp8 (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'key.weight_scale' ], z [ffn + 'value.weight' ], z [ffn + 'value.weight_scale' ])
319334 x = x + xx
320335
321336 # if math.isnan(torch.min(x).item()): print(idx, i)
@@ -347,8 +362,10 @@ def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool=
347362 x = x + xx
348363
349364 xx = F .layer_norm (x , (self .n_embd ,), weight = z [bbb + 'ln2.weight' ], bias = z [bbb + 'ln2.bias' ])
350-
351- xx , state [i * 3 + 2 ] = RWKV_x070_CMix_seq (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'value.weight' ])
365+ if not self .ffn_fp8 :
366+ xx , state [i * 3 + 2 ] = RWKV_x070_CMix_seq (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'value.weight' ])
367+ else :
368+ xx , state [i * 3 + 2 ] = RWKV_x070_CMix_seq_fp8 (xx , state [i * 3 + 2 ], z [ffn + 'x_k' ], z [ffn + 'key.weight' ], z [ffn + 'key.weight_scale' ], z [ffn + 'value.weight' ], z [ffn + 'value.weight_scale' ])
352369 x = x + xx
353370
354371 if not full_output : x = x [- 1 ,:]
@@ -442,21 +459,63 @@ def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x
442459 return (xx * g ) @ O_ , x [- 1 ,:], state , v_first
443460
444461 ########################################################################################################
462+ @MyStatic
463+ def to_float8 (x ):
464+ dtype_min = torch .tensor (- 448 ).to (x .device )
465+ dtype_max = torch .tensor (448 ).to (x .device )
466+ scale = dtype_max / x .abs ().max ().clamp (min = 1e-12 )
467+ x_scl_sat = (x * scale ).clamp (min = dtype_min , max = dtype_max )
468+ return x_scl_sat .to (torch .float8_e4m3fn ), scale .float ().reciprocal ()
469+
470+ @MyStatic
471+ def to_float8_e5m2 (x ):
472+ dtype_min = torch .tensor (- 57344 ).to (x .device )
473+ dtype_max = torch .tensor (57344 ).to (x .device )
474+ abs_max = x .abs ().max ().clamp (min = 1e-12 )
475+ scale = dtype_max / abs_max
476+ x_scl_sat = (x * scale ).clamp (min = dtype_min , max = dtype_max )
477+ return x_scl_sat .to (torch .float8_e5m2 ), scale .float ().reciprocal ()
445478
446479 @MyStatic
447480 def RWKV_x070_CMix_one (x , x_prev , x_k , K_ , V_ ):
448481 xx = x_prev - x
449- k = x + xx * x_k
482+ k = torch . addcmul ( x , xx , x_k )
450483 k = torch .relu (k @ K_ ) ** 2
451484 return k @ V_ , x
452485
486+ @MyStatic
487+ def RWKV_x070_CMix_one_fp8 (x , x_prev , x_k , K_ , K_scale , V_ , V_scale ):
488+ xx = x_prev - x
489+ k = torch .addcmul (x , xx , x_k )
490+ k_fp8 , k_scale_1 = to_float8 (k )
491+ k_fp8 = k_fp8 .unsqueeze (0 )
492+ k_new = torch ._scaled_mm (k_fp8 , K_ , k_scale_1 , K_scale , out_dtype = x .dtype )
493+ k = torch .relu (k_new ) ** 2
494+ k_fp8 , k_scale_2 = to_float8 (k )
495+ output1 = torch ._scaled_mm (k_fp8 , V_ , k_scale_2 , V_scale , out_dtype = x .dtype )
496+ return output1 .squeeze (0 ), x
497+
498+
453499 @MyStatic
454500 def RWKV_x070_CMix_seq (x , x_prev , x_k , K_ , V_ ):
455501 xx = torch .cat ((x_prev .unsqueeze (0 ), x [:- 1 ,:])) - x
456- k = x + xx * x_k
502+ k = torch . addcmul ( x , xx , x_k )
457503 k = torch .relu (k @ K_ ) ** 2
458504 return k @ V_ , x [- 1 ,:]
459505
506+ @MyStatic
507+ def RWKV_x070_CMix_seq_fp8 (x , x_prev , x_k , K_ , K_scale , V_ , V_scale ):
508+ xx = torch .cat ((x_prev .unsqueeze (0 ), x [:- 1 ,:])) - x
509+ k = torch .addcmul (x , xx , x_k )
510+ k_fp8 , k_scale_1 = to_float8 (k )
511+ k_new = torch ._scaled_mm (k_fp8 , K_ , k_scale_1 , K_scale , out_dtype = x .dtype )
512+ k = torch .relu (k_new ) ** 2
513+ k_fp8 , k_scale_2 = to_float8 (k )
514+ output1 = torch ._scaled_mm (k_fp8 , V_ , k_scale_2 , V_scale , out_dtype = x .dtype )
515+ return output1 , x [- 1 ,:]
516+
517+
518+
460519########################################################################################################
461520
462521class RWKV (MyModule ):
0 commit comments