Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 74 additions & 10 deletions rwkv_pip_package/src/rwkv/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,25 @@ def __init__(self, model, strategy):

ss = strategy.split(' ')
DEVICE = ss[0]
self.ffn_fp8 = False
if ss[1] == 'fp16':
DTYPE = torch.half
elif ss[1] == 'fp32':
DTYPE = torch.float32
elif ss[1] == 'bf16':
DTYPE = torch.bfloat16
elif ss[1] == 'fp8':
DTYPE = torch.half
# Check pytorch support for fp8
if hasattr(torch, 'float8_e4m3fn') and hasattr(torch, 'float8_e5m2') \
and hasattr(torch, '_scaled_mm'):
print("FP8 only support with modern CUDA devices, and it only quantize FFN weights")
print("The other weights are still in FP16\n\n")
self.ffn_fp8 = True
else:
print('Your pytorch version does not support fp8, please update to 2.6 or later')
else:
assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
assert False, "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16/fp8"

temp_z = torch.load(args.MODEL_NAME + '.pth', map_location='cpu', mmap=True)

Expand All @@ -257,14 +268,20 @@ def __init__(self, model, strategy):
args.n_layer = max(args.n_layer, layer_id+1)
tensor = temp_z[k]
if 'key.weight' in k or 'value.weight' in k or 'receptance.weight' in k or 'output.weight' in k or 'head.weight' in k:
tensor = tensor.t().contiguous()
tensor = tensor.t()
tensor = tensor.squeeze()
if k.endswith('att.r_k'):
tensor = tensor.flatten()
self.z[k] = tensor.to(DEVICE).to(DTYPE)

if ('ffn.key.weight' in k or 'ffn.value.weight' in k) and self.ffn_fp8:
# convert to fp8
tensor, reciprocal = to_float8(tensor)
self.z[k] = tensor.to(DEVICE)
self.z[k+'_scale'] = reciprocal.to(DEVICE)
else:
self.z[k] = tensor.contiguous().to(DEVICE).to(DTYPE)
del temp_z[k]
if keys.index(k) % 5 == 0:
torch.cuda.empty_cache()


self.n_embd = args.n_embd
self.n_layer = args.n_layer
Expand Down Expand Up @@ -315,7 +332,10 @@ def forward_one(self, idx:int, state:List[torch.Tensor]):

xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])

xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
if not self.ffn_fp8:
xx, state[i*3+2] = RWKV_x070_CMix_one(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
else:
xx, state[i*3+2] = RWKV_x070_CMix_one_fp8(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'key.weight_scale'], z[ffn+'value.weight'], z[ffn+'value.weight_scale'])
x = x + xx

# if math.isnan(torch.min(x).item()): print(idx, i)
Expand Down Expand Up @@ -347,8 +367,10 @@ def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool=
x = x + xx

xx = F.layer_norm(x, (self.n_embd,), weight=z[bbb+'ln2.weight'], bias=z[bbb+'ln2.bias'])

xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
if not self.ffn_fp8:
xx, state[i*3+2] = RWKV_x070_CMix_seq(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'value.weight'])
else:
xx, state[i*3+2] = RWKV_x070_CMix_seq_fp8(xx, state[i*3+2], z[ffn+'x_k'], z[ffn+'key.weight'], z[ffn+'key.weight_scale'], z[ffn+'value.weight'], z[ffn+'value.weight_scale'])
x = x + xx

if not full_output: x = x[-1,:]
Expand Down Expand Up @@ -442,21 +464,63 @@ def RWKV_x070_TMix_seq(layer_id: int, H:int, N:int, x, x_prev, v_first, state, x
return (xx * g) @ O_, x[-1,:], state, v_first

########################################################################################################
@MyStatic
def to_float8(x):
dtype_min = torch.tensor(-448).to(x.device)
dtype_max = torch.tensor(448).to(x.device)
scale = dtype_max / x.abs().max().clamp(min=1e-12)
x_scl_sat = (x * scale).clamp(min=dtype_min, max=dtype_max)
return x_scl_sat.to(torch.float8_e4m3fn), scale.float().reciprocal()

@MyStatic
def to_float8_e5m2(x):
dtype_min = torch.tensor(-57344).to(x.device)
dtype_max = torch.tensor(57344).to(x.device)
abs_max = x.abs().max().clamp(min=1e-12)
scale = dtype_max / abs_max
x_scl_sat = (x * scale).clamp(min=dtype_min, max=dtype_max)
return x_scl_sat.to(torch.float8_e5m2), scale.float().reciprocal()

@MyStatic
def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_):
xx = x_prev - x
k = x + xx * x_k
k = torch.addcmul(x, xx, x_k)
k = torch.relu(k @ K_) ** 2
return k @ V_, x

@MyStatic
def RWKV_x070_CMix_one_fp8(x, x_prev, x_k, K_, K_scale, V_, V_scale):
xx = x_prev - x
k = torch.addcmul(x, xx, x_k)
k_fp8, k_scale_1 = to_float8(k)
k_fp8 = k_fp8.unsqueeze(0)
k_new= torch._scaled_mm(k_fp8, K_, k_scale_1, K_scale, out_dtype=x.dtype)
k = torch.relu(k_new) ** 2
k_fp8, k_scale_2 = to_float8(k)
output1 = torch._scaled_mm(k_fp8, V_, k_scale_2, V_scale, out_dtype=x.dtype)
return output1.squeeze(0), x


@MyStatic
def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_):
xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
k = x + xx * x_k
k = torch.addcmul(x, xx, x_k)
k = torch.relu(k @ K_) ** 2
return k @ V_, x[-1,:]

@MyStatic
def RWKV_x070_CMix_seq_fp8(x, x_prev, x_k, K_, K_scale, V_, V_scale):
xx = torch.cat((x_prev.unsqueeze(0), x[:-1,:])) - x
k = torch.addcmul(x, xx, x_k)
k_fp8, k_scale_1 = to_float8(k)
k_new = torch._scaled_mm(k_fp8, K_, k_scale_1, K_scale, out_dtype=x.dtype)
k = torch.relu(k_new) ** 2
k_fp8, k_scale_2 = to_float8(k)
output1 = torch._scaled_mm(k_fp8, V_, k_scale_2, V_scale, out_dtype=x.dtype)
return output1, x[-1,:]



########################################################################################################

class RWKV(MyModule):
Expand Down