Skip to content

Commit 8c2846b

Browse files
committed
Support FP8 quantize
1 parent 7c941c9 commit 8c2846b

File tree

1 file changed

+69
-10
lines changed

1 file changed

+69
-10
lines changed

rwkv_pip_package/src/rwkv/model.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,20 @@ def __init__(self, model, strategy):
233233

234234
ss = strategy.split(' ')
235235
DEVICE = ss[0]
236+
self.ffn_fp8 = False
236237
if ss[1] == 'fp16':
237238
DTYPE = torch.half
238239
elif ss[1] == 'fp32':
239240
DTYPE = torch.float32
240241
elif ss[1] == 'bf16':
241242
DTYPE = torch.bfloat16
243+
elif ss[1] == 'fp8':
244+
DTYPE = torch.half
245+
print("FP8 only support with modern CUDA devices, and it only quantize FFN weights")
246+
print("The other weights are still in FP16\n\n")
247+
self.ffn_fp8 = True
242248
else:
243-
assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
249+
assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16/fp8"
244250

245251
temp_z = torch.load(args.MODEL_NAME + '.pth', map_location='cpu', mmap=True)
246252

@@ -257,14 +263,20 @@ def __init__(self, model, strategy):
257263
args.n_layer = max(args.n_layer, layer_id+1)
258264
tensor = temp_z[k]
259265
if 'key.weight' in k or 'value.weight' in k or 'receptance.weight' in k or 'output.weight' in k or 'head.weight' in k:
260-
tensor = tensor.t().contiguous()
266+
tensor = tensor.t()
261267
tensor = tensor.squeeze()
262268
if k.endswith('att.r_k'):
263269
tensor = tensor.flatten()
264-
self.z[k] = tensor.to(DEVICE).to(DTYPE)
270+
271+
if ('ffn.key.weight' in k or 'ffn.value.weight' in k) and self.ffn_fp8:
272+
# convert to fp8
273+
tensor, reciprocal = to_float8(tensor)
274+
self.z[k] = tensor.to(DEVICE)
275+
self.z[k+'_scale'] = reciprocal.to(DEVICE)
276+
else:
277+
self.z[k] = tensor.contiguous().to(DEVICE).to(DTYPE)
265278
del temp_z[k]
266-
if keys.index(k) % 5 == 0:
267-
torch.cuda.empty_cache()
279+
268280

269281
self.n_embd = args.n_embd
270282
self.n_layer = args.n_layer
@@ -315,7 +327,10 @@ def forward_one(self, idx:int, state:List[torch.Tensor]):
315327

316328
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
317329

318-
xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
330+
if not self.ffn_fp8:
331+
xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
332+
else:
333+
xx, state[i*3+2] = RWKV_x070_CMix_one_fp8(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'key.weight_scale'], z[ffn+'value.weight'], z[ffn+'value.weight_scale'])
319334
x = x + xx
320335

321336
# if math.isnan(torch.min(x).item()): print(idx, i)
@@ -347,8 +362,10 @@ def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool=
347362
x = x + xx
348363

349364
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
350-
351-
xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
365+
if not self.ffn_fp8:
366+
xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
367+
else:
368+
xx, state[i*3+2] = RWKV_x070_CMix_seq_fp8(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'key.weight_scale'], z[ffn+'value.weight'], z[ffn+'value.weight_scale'])
352369
x = x + xx
353370

354371
if not full_output: x = x[-1,:]
@@ -442,21 +459,63 @@ def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x
442459
return (xx * g) @ O_, x[-1,:], state, v_first
443460

444461
########################################################################################################
462+
@MyStatic
463+
def to_float8(x):
464+
dtype_min = torch.tensor(-448).to(x.device)
465+
dtype_max = torch.tensor(448).to(x.device)
466+
scale = dtype_max / x.abs().max().clamp(min=1e-12)
467+
x_scl_sat = (x * scale).clamp(min=dtype_min, max=dtype_max)
468+
return x_scl_sat.to(torch.float8_e4m3fn), scale.float().reciprocal()
469+
470+
@MyStatic
471+
def to_float8_e5m2(x):
472+
dtype_min = torch.tensor(-57344).to(x.device)
473+
dtype_max = torch.tensor(57344).to(x.device)
474+
abs_max = x.abs().max().clamp(min=1e-12)
475+
scale = dtype_max / abs_max
476+
x_scl_sat = (x * scale).clamp(min=dtype_min, max=dtype_max)
477+
return x_scl_sat.to(torch.float8_e5m2), scale.float().reciprocal()
445478

446479
@MyStatic
447480
def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_):
448481
xx = x_prev - x
449-
k = x + xx * x_k
482+
k = torch.addcmul(x, xx, x_k)
450483
k = torch.relu(k @ K_) ** 2
451484
return k @ V_, x
452485

486+
@MyStatic
487+
def RWKV_x070_CMix_one_fp8(x, x_prev, x_k, K_, K_scale, V_, V_scale):
488+
xx = x_prev - x
489+
k = torch.addcmul(x, xx, x_k)
490+
k_fp8, k_scale_1 = to_float8(k)
491+
k_fp8 = k_fp8.unsqueeze(0)
492+
k_new= torch._scaled_mm(k_fp8, K_, k_scale_1, K_scale, out_dtype=x.dtype)
493+
k = torch.relu(k_new) ** 2
494+
k_fp8, k_scale_2 = to_float8(k)
495+
output1 = torch._scaled_mm(k_fp8, V_, k_scale_2, V_scale, out_dtype=x.dtype)
496+
return output1.squeeze(0), x
497+
498+
453499
@MyStatic
454500
def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_):
455501
xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
456-
k = x + xx * x_k
502+
k = torch.addcmul(x, xx, x_k)
457503
k = torch.relu(k @ K_) ** 2
458504
return k @ V_, x[-1,:]
459505

506+
@MyStatic
507+
def RWKV_x070_CMix_seq_fp8(x, x_prev, x_k, K_, K_scale, V_, V_scale):
508+
xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
509+
k = torch.addcmul(x, xx, x_k)
510+
k_fp8, k_scale_1 = to_float8(k)
511+
k_new = torch._scaled_mm(k_fp8, K_, k_scale_1, K_scale, out_dtype=x.dtype)
512+
k = torch.relu(k_new) ** 2
513+
k_fp8, k_scale_2 = to_float8(k)
514+
output1 = torch._scaled_mm(k_fp8, V_, k_scale_2, V_scale, out_dtype=x.dtype)
515+
return output1, x[-1,:]
516+
517+
518+
460519
########################################################################################################
461520

462521
class RWKV(MyModule):

0 commit comments

Comments
 (0)