@@ -67,7 +67,6 @@ def get_mask(self, i, j):
6767 return torch .ones ((i , j ), device = device , dtype = torch .bool ).triu (j - i + 2 )
6868
6969
70- # Forward function
7170 def forward (self , x ):
7271 # Get batch size, sequence length and model dimension
7372 batch_size , seq_len , _ = x .shape
@@ -86,17 +85,17 @@ def forward(self, x):
8685
8786 # Apply offset and segment for this head
8887 x_ = x [:, offset ::self .dilation_rate , :]
89- x_ = x_ .contiguous ().view (batch_size , - 1 , self .segment_size , self .d_model )
90-
88+ x_ = x_ .contiguous ().view (batch_size , 1 , - 1 , self .segment_size , self .d_model ) # Add an extra dimension for the number of heads
9189
92-
90+ # Process each segment separately
9391 elements_attns = []
94- for idx in range (x_ .shape [1 ]):
95- element = x_ [:, idx , :, :].to (dtype )
92+ for idx in range (x_ .shape [2 ]):
93+ element = x_ [:, :, idx , :, :].to (dtype )
9694 element_attn = attention (element , element , element )
9795 elements_attns .append (element_attn )
9896
99- attn_output = torch .cat (elements_attns , dim = 1 )
97+ attn_output = torch .cat (elements_attns , dim = 2 )
98+
10099
101100
102101 #option2
0 commit comments