@@ -71,21 +71,28 @@ def forward(self, text_tokens, **kwargs):
7171 return self .decoder (model_input , passed_x = model_input )[0 ]
7272
7373
74- # class LongNet(Module):
75- # def __init__(self):
76- # super().__init__()
77-
78- # self.model = LongNet(
79- # num_tokens = 16000, # number of tokens
80- # dim = (512, 256), # transformer model dimension (512 for coarsest, 256 for fine in this example)
81- # max_seq_len = (1024, 4), # sequence length for global and then local. this can be more than 2
82- # depth = (6, 4), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
83- # dim_head = 64, # dimension per head
84- # heads = 8, # number of attention heads
85- # flash_attn = True # use flash attention
86- # )
87-
88- # def forward(self, text_tokens, temperature: int = None, filter_thres: int = None, **kwargs):
89- # sampled = self.model.generate(temperature=temperature, filter_thres=filter_thres)
90- # return sampled
74+ class DilatedLongNet (Module ):
75+ def __init__ (self ):
76+ super ().__init__ ()
77+
78+ self .model = LongNet (
79+ num_tokens = 16000 , # number of tokens
80+ dim = (512 , 256 ), # transformer model dimension (512 for coarsest, 256 for fine in this example)
81+ max_seq_len = (1024 , 4 ), # sequence length for global and then local. this can be more than 2
82+ depth = (6 , 4 ), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
83+ dim_head = 64 , # dimension per head
84+ heads = 8 , # number of attention heads
85+ flash_attn = True , # use flash attention
86+ dilation_rate = 1 , # dilation rate for DilatedAttention
87+ segment_size = 0 , # segment size for DilatedAttention
88+ casual = False , # whether to use causal attention for DilatedAttention
89+ use_xpos = False , # whether to use absolute positional embeddings for DilatedAttention
90+ use_rel_pos_bias = False , # whether to use relative positional bias for DilatedAttention
91+ distributed = False # whether to distribute attention for DilatedAttention
92+ )
93+
94+ def forward (self , text_tokens , temperature : int = None , filter_thres : int = None , ** kwargs ):
95+ sampled = self .model .generate (temperature = temperature , filter_thres = filter_thres )
96+ return sampled
97+
9198
0 commit comments