Skip to content

Commit b629642

Browse files
committed
Support FP8 quantize
1 parent 7c941c9 commit b629642

File tree

1 file changed

+74
-10
lines changed

1 file changed

+74
-10
lines changed

rwkv_pip_package/src/rwkv/model.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,25 @@ def __init__(self, model, strategy):
233233

234234
ss = strategy.split(' ')
235235
DEVICE = ss[0]
236+
self.ffn_fp8 = False
236237
if ss[1] == 'fp16':
237238
DTYPE = torch.half
238239
elif ss[1] == 'fp32':
239240
DTYPE = torch.float32
240241
elif ss[1] == 'bf16':
241242
DTYPE = torch.bfloat16
243+
elif ss[1] == 'fp8':
244+
DTYPE = torch.half
245+
# Check pytorch support for fp8
246+
if hasattr(torch, 'float8_e4m3fn') and hasattr(torch, 'float8_e5m2') \
247+
and hasattr(torch, '_scaled_mm'):
248+
print("FP8 only support with modern CUDA devices, and it only quantize FFN weights")
249+
print("The other weights are still in FP16\n\n")
250+
self.ffn_fp8 = True
251+
else:
252+
print('Your pytorch version does not support fp8, please update to 2.6 or later')
242253
else:
243-
assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
254+
assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16/fp8"
244255

245256
temp_z = torch.load(args.MODEL_NAME + '.pth', map_location='cpu', mmap=True)
246257

@@ -257,14 +268,20 @@ def __init__(self, model, strategy):
257268
args.n_layer = max(args.n_layer, layer_id+1)
258269
tensor = temp_z[k]
259270
if 'key.weight' in k or 'value.weight' in k or 'receptance.weight' in k or 'output.weight' in k or 'head.weight' in k:
260-
tensor = tensor.t().contiguous()
271+
tensor = tensor.t()
261272
tensor = tensor.squeeze()
262273
if k.endswith('att.r_k'):
263274
tensor = tensor.flatten()
264-
self.z[k] = tensor.to(DEVICE).to(DTYPE)
275+
276+
if ('ffn.key.weight' in k or 'ffn.value.weight' in k) and self.ffn_fp8:
277+
# convert to fp8
278+
tensor, reciprocal = to_float8(tensor)
279+
self.z[k] = tensor.to(DEVICE)
280+
self.z[k+'_scale'] = reciprocal.to(DEVICE)
281+
else:
282+
self.z[k] = tensor.contiguous().to(DEVICE).to(DTYPE)
265283
del temp_z[k]
266-
if keys.index(k) % 5 == 0:
267-
torch.cuda.empty_cache()
284+
268285

269286
self.n_embd = args.n_embd
270287
self.n_layer = args.n_layer
@@ -315,7 +332,10 @@ def forward_one(self, idx:int, state:List[torch.Tensor]):
315332

316333
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
317334

318-
xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
335+
if not self.ffn_fp8:
336+
xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
337+
else:
338+
xx, state[i*3+2] = RWKV_x070_CMix_one_fp8(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'key.weight_scale'], z[ffn+'value.weight'], z[ffn+'value.weight_scale'])
319339
x = x + xx
320340

321341
# if math.isnan(torch.min(x).item()): print(idx, i)
@@ -347,8 +367,10 @@ def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool=
347367
x = x + xx
348368

349369
xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])
350-
351-
xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
370+
if not self.ffn_fp8:
371+
xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
372+
else:
373+
xx, state[i*3+2] = RWKV_x070_CMix_seq_fp8(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'key.weight_scale'], z[ffn+'value.weight'], z[ffn+'value.weight_scale'])
352374
x = x + xx
353375

354376
if not full_output: x = x[-1,:]
@@ -442,21 +464,63 @@ def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x
442464
return (xx * g) @ O_, x[-1,:], state, v_first
443465

444466
########################################################################################################
467+
@MyStatic
468+
def to_float8(x):
469+
dtype_min = torch.tensor(-448).to(x.device)
470+
dtype_max = torch.tensor(448).to(x.device)
471+
scale = dtype_max / x.abs().max().clamp(min=1e-12)
472+
x_scl_sat = (x * scale).clamp(min=dtype_min, max=dtype_max)
473+
return x_scl_sat.to(torch.float8_e4m3fn), scale.float().reciprocal()
474+
475+
@MyStatic
476+
def to_float8_e5m2(x):
477+
dtype_min = torch.tensor(-57344).to(x.device)
478+
dtype_max = torch.tensor(57344).to(x.device)
479+
abs_max = x.abs().max().clamp(min=1e-12)
480+
scale = dtype_max / abs_max
481+
x_scl_sat = (x * scale).clamp(min=dtype_min, max=dtype_max)
482+
return x_scl_sat.to(torch.float8_e5m2), scale.float().reciprocal()
445483

446484
@MyStatic
447485
def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_):
448486
xx = x_prev - x
449-
k = x + xx * x_k
487+
k = torch.addcmul(x, xx, x_k)
450488
k = torch.relu(k @ K_) ** 2
451489
return k @ V_, x
452490

491+
@MyStatic
492+
def RWKV_x070_CMix_one_fp8(x, x_prev, x_k, K_, K_scale, V_, V_scale):
493+
xx = x_prev - x
494+
k = torch.addcmul(x, xx, x_k)
495+
k_fp8, k_scale_1 = to_float8(k)
496+
k_fp8 = k_fp8.unsqueeze(0)
497+
k_new= torch._scaled_mm(k_fp8, K_, k_scale_1, K_scale, out_dtype=x.dtype)
498+
k = torch.relu(k_new) ** 2
499+
k_fp8, k_scale_2 = to_float8(k)
500+
output1 = torch._scaled_mm(k_fp8, V_, k_scale_2, V_scale, out_dtype=x.dtype)
501+
return output1.squeeze(0), x
502+
503+
453504
@MyStatic
454505
def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_):
455506
xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
456-
k = x + xx * x_k
507+
k = torch.addcmul(x, xx, x_k)
457508
k = torch.relu(k @ K_) ** 2
458509
return k @ V_, x[-1,:]
459510

511+
@MyStatic
512+
def RWKV_x070_CMix_seq_fp8(x, x_prev, x_k, K_, K_scale, V_, V_scale):
513+
xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
514+
k = torch.addcmul(x, xx, x_k)
515+
k_fp8, k_scale_1 = to_float8(k)
516+
k_new = torch._scaled_mm(k_fp8, K_, k_scale_1, K_scale, out_dtype=x.dtype)
517+
k = torch.relu(k_new) ** 2
518+
k_fp8, k_scale_2 = to_float8(k)
519+
output1 = torch._scaled_mm(k_fp8, V_, k_scale_2, V_scale, out_dtype=x.dtype)
520+
return output1, x[-1,:]
521+
522+
523+
460524
########################################################################################################
461525

462526
class RWKV(MyModule):

0 commit comments

Comments
 (0)