-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_torch.py
More file actions
34 lines (24 loc) · 922 Bytes
/
test_torch.py
File metadata and controls
34 lines (24 loc) · 922 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
os.environ['KERAS_BACKEND'] = 'torch'
from rwkv_kernel.torch_rwkv_kernel import RWKVKernelOperator
import numpy as np
import torch
seq_len = 512
head_size = 64
channels = 512
bz = 4
inputs_r = np.ones(shape=(bz,seq_len,channels))
inputs_r = torch.Tensor(inputs_r,dtype=torch.float32)
inputs_k = np.ones(shape=(bz,seq_len,channels))
inputs_k = torch.Tensor(inputs_k,dtype=torch.float32)
inputs_v = np.ones(shape=(bz,seq_len,channels))
inputs_v = torch.Tensor(inputs_v,dtype=torch.float32)
inputs_w = np.zeros(shape=(bz,seq_len,channels)) - 4
inputs_w = torch.Tensor(inputs_w,dtype=torch.float32)
inputs_u = np.ones(shape=(channels,))
inputs_u = torch.Tensor(inputs_u,dtype=torch.float32)
inputs_s = np.ones(shape=(bz,seq_len,channels))
inputs_s = torch.Tensor(inputs_s,dtype=torch.float32)
kernel = RWKVKernelOperator(head_size,seq_len)
y = kernel(inputs_r,inputs_k,inputs_v,inputs_w,inputs_u)
print(y)