Skip to content

Commit c72e361

Browse files
author
Kye
committed
longnet class
1 parent 4d3ada3 commit c72e361

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

LongNet/model.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,28 @@ def forward(self, text_tokens, **kwargs):
7171
return self.decoder(model_input, passed_x=model_input)[0]
7272

7373

74-
# class LongNet(Module):
75-
# def __init__(self):
76-
# super().__init__()
77-
78-
# self.model = LongNet(
79-
# num_tokens = 16000, # number of tokens
80-
# dim = (512, 256), # transformer model dimension (512 for coarsest, 256 for fine in this example)
81-
# max_seq_len = (1024, 4), # sequence length for global and then local. this can be more than 2
82-
# depth = (6, 4), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
83-
# dim_head = 64, # dimension per head
84-
# heads = 8, # number of attention heads
85-
# flash_attn = True # use flash attention
86-
# )
87-
88-
# def forward(self, text_tokens, temperature: int = None, filter_thres: int = None, **kwargs):
89-
# sampled = self.model.generate(temperature=temperature, filter_thres=filter_thres)
90-
# return sampled
74+
class DilatedLongNet(Module):
75+
def __init__(self):
76+
super().__init__()
77+
78+
self.model = LongNet(
79+
num_tokens = 16000, # number of tokens
80+
dim = (512, 256), # transformer model dimension (512 for coarsest, 256 for fine in this example)
81+
max_seq_len = (1024, 4), # sequence length for global and then local. this can be more than 2
82+
depth = (6, 4), # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
83+
dim_head = 64, # dimension per head
84+
heads = 8, # number of attention heads
85+
flash_attn = True, # use flash attention
86+
dilation_rate = 1, # dilation rate for DilatedAttention
87+
segment_size = 0, # segment size for DilatedAttention
88+
casual = False, # whether to use causal attention for DilatedAttention
89+
use_xpos = False, # whether to use absolute positional embeddings for DilatedAttention
90+
use_rel_pos_bias = False, # whether to use relative positional bias for DilatedAttention
91+
distributed = False # whether to distribute attention for DilatedAttention
92+
)
93+
94+
def forward(self, text_tokens, temperature: int = None, filter_thres: int = None, **kwargs):
95+
sampled = self.model.generate(temperature=temperature, filter_thres=filter_thres)
96+
return sampled
97+
9198

0 commit comments

Comments
 (0)