@@ -234,10 +234,11 @@ def __init__(self, args, layer_id):
234234 self .ffn = RWKV_ChannelMix (args , layer_id )
235235
236236 if args .tiny_att_dim > 0 and self .layer_id == args .tiny_att_layer :
237- self .head_q = nn .Linear (args .n_embd , args .tiny_att_dim , bias = False )
238- self .head_k = nn .Linear (args .n_embd , args .tiny_att_dim , bias = False )
239- self .head_v = nn .Linear (args .n_embd , args .n_embd , bias = False )
240- self .register_buffer ("head_mask" , torch .tril (torch .ones (args .ctx_len , args .ctx_len )))
237+ self .tiny_ln = nn .LayerNorm (args .n_embd )
238+ self .tiny_q = nn .Linear (args .n_embd , args .tiny_att_dim , bias = False )
239+ self .tiny_k = nn .Linear (args .n_embd , args .tiny_att_dim , bias = False )
240+ self .tiny_v = nn .Linear (args .n_embd , args .n_embd , bias = False )
241+ self .register_buffer ("tiny_mask" , torch .tril (torch .ones (args .ctx_len , args .ctx_len )))
241242
242243 def forward (self , x , x_emb = None ):
243244 args = self .args
@@ -255,11 +256,12 @@ def forward(self, x, x_emb=None):
255256 x = x + self .ffn (self .ln2 (x ))
256257
257258 if args .tiny_att_dim > 0 and self .layer_id == args .tiny_att_layer :
258- q = self .head_q (x )[:, :T , :]
259- k = self .head_k (x )[:, :T , :]
260- c = (q @ k .transpose (- 2 , - 1 )) * (1.0 / args .tiny_att_downscale )
261- c = c .masked_fill (self .head_mask [:T , :T ] == 0 , 0 )
262- x = x + c @ self .head_v (x_emb )
259+ xx = self .tiny_ln (x )
260+ q = self .tiny_q (xx )[:, :T , :]
261+ k = self .tiny_k (xx )[:, :T , :]
262+ c = (q @ k .transpose (- 2 , - 1 )) * (args .tiny_att_dim ** (- 0.5 ))
263+ c = c .masked_fill (self .tiny_mask [:T , :T ] == 0 , 0 )
264+ x = x + c @ self .tiny_v (x_emb )
263265 return x
264266
265267
0 commit comments