diff --git a/rwkv_pip_package/src/rwkv/model.py b/rwkv_pip_package/src/rwkv/model.py index 83e6192d..cb937f72 100644 --- a/rwkv_pip_package/src/rwkv/model.py +++ b/rwkv_pip_package/src/rwkv/model.py @@ -233,14 +233,25 @@ def __init__(self, model, strategy): ss = strategy.split(' ') DEVICE = ss[0] + self.ffn_fp8 = False if ss[1] == 'fp16': DTYPE = torch.half elif ss[1] == 'fp32': DTYPE = torch.float32 elif ss[1] == 'bf16': DTYPE = torch.bfloat16 + elif ss[1] == 'fp8': + DTYPE = torch.half + # Check pytorch support for fp8 + if hasattr(torch, 'float8_e4m3fn') and hasattr(torch, 'float8_e5m2') \ + and hasattr(torch, '_scaled_mm'): + print("FP8 only support with modern CUDA devices, and it only quantize FFN weights") + print("The other weights are still in FP16\n\n") + self.ffn_fp8 = True + else: + print('Your pytorch version does not support fp8, please update to 2.6 or later') else: - assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16" + assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16/fp8" temp_z = torch.load(args.MODEL_NAME + '.pth', map_location='cpu', mmap=True) @@ -257,14 +268,20 @@ def __init__(self, model, strategy): args.n_layer = max(args.n_layer, layer_id+1) tensor = temp_z[k] if 'key.weight' in k or 'value.weight' in k or 'receptance.weight' in k or 'output.weight' in k or 'head.weight' in k: - tensor = tensor.t().contiguous() + tensor = tensor.t() tensor = tensor.squeeze() if k.endswith('att.r_k'): tensor = tensor.flatten() - self.z[k] = tensor.to(DEVICE).to(DTYPE) + + if ('ffn.key.weight' in k or 'ffn.value.weight' in k) and self.ffn_fp8: + # convert to fp8 + tensor, reciprocal = to_float8(tensor) + self.z[k] = tensor.to(DEVICE) + self.z[k+'_scale'] = reciprocal.to(DEVICE) + else: + self.z[k] = tensor.contiguous().to(DEVICE).to(DTYPE) del temp_z[k] - if keys.index(k) % 5 == 0: - torch.cuda.empty_cache() + self.n_embd = args.n_embd self.n_layer = args.n_layer @@ -315,7 +332,10 @@ def forward_one(self, idx:int, state:List[torch.Tensor]): xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias']) - xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight']) + if not self.ffn_fp8: + xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight']) + else: + xx, state[i*3+2] = RWKV_x070_CMix_one_fp8(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'key.weight_scale'], z[ffn+'value.weight'], z[ffn+'value.weight_scale']) x = x + xx # if math.isnan(torch.min(x).item()): print(idx, i) @@ -347,8 +367,10 @@ def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool= x = x + xx xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias']) - - xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight']) + if not self.ffn_fp8: + xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight']) + else: + xx, state[i*3+2] = RWKV_x070_CMix_seq_fp8(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'key.weight_scale'], z[ffn+'value.weight'], z[ffn+'value.weight_scale']) x = x + xx if not full_output: x = x[-1,:] @@ -442,21 +464,63 @@ def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x return (xx * g) @ O_, x[-1,:], state, v_first ######################################################################################################## + @MyStatic + def to_float8(x): + dtype_min = torch.tensor(-448).to(x.device) + dtype_max = torch.tensor(448).to(x.device) + scale = dtype_max / x.abs().max().clamp(min=1e-12) + x_scl_sat = (x * scale).clamp(min=dtype_min, max=dtype_max) + return x_scl_sat.to(torch.float8_e4m3fn), scale.float().reciprocal() + + @MyStatic + def to_float8_e5m2(x): + dtype_min = torch.tensor(-57344).to(x.device) + dtype_max = torch.tensor(57344).to(x.device) + abs_max = x.abs().max().clamp(min=1e-12) + scale = dtype_max / abs_max + x_scl_sat = (x * scale).clamp(min=dtype_min, max=dtype_max) + return x_scl_sat.to(torch.float8_e5m2), scale.float().reciprocal() @MyStatic def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_): xx = x_prev - x - k = x + xx * x_k + k = torch.addcmul(x, xx, x_k) k = torch.relu(k @ K_) ** 2 return k @ V_, x + @MyStatic + def RWKV_x070_CMix_one_fp8(x, x_prev, x_k, K_, K_scale, V_, V_scale): + xx = x_prev - x + k = torch.addcmul(x, xx, x_k) + k_fp8, k_scale_1 = to_float8(k) + k_fp8 = k_fp8.unsqueeze(0) + k_new= torch._scaled_mm(k_fp8, K_, k_scale_1, K_scale, out_dtype=x.dtype) + k = torch.relu(k_new) ** 2 + k_fp8, k_scale_2 = to_float8(k) + output1 = torch._scaled_mm(k_fp8, V_, k_scale_2, V_scale, out_dtype=x.dtype) + return output1.squeeze(0), x + + @MyStatic def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_): xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x - k = x + xx * x_k + k = torch.addcmul(x, xx, x_k) k = torch.relu(k @ K_) ** 2 return k @ V_, x[-1,:] + @MyStatic + def RWKV_x070_CMix_seq_fp8(x, x_prev, x_k, K_, K_scale, V_, V_scale): + xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x + k = torch.addcmul(x, xx, x_k) + k_fp8, k_scale_1 = to_float8(k) + k_new = torch._scaled_mm(k_fp8, K_, k_scale_1, K_scale, out_dtype=x.dtype) + k = torch.relu(k_new) ** 2 + k_fp8, k_scale_2 = to_float8(k) + output1 = torch._scaled_mm(k_fp8, V_, k_scale_2, V_scale, out_dtype=x.dtype) + return output1, x[-1,:] + + + ######################################################################################################## class RWKV(MyModule):