11import torch .nn as nn
22import torch .nn .functional as F
33from transformers import Qwen2Model , Qwen2Config
4+ import torch
45
56from .istft import ISTFTHead
67
78
89class Encoder (nn .Module ):
9- def __init__ (self , encoder_args , input_dim = 320 , hop_size = 320 , latent_dim = 64 ):
10+ def __init__ (self , encoder_args , input_dim = 320 , hop_size = 320 , latent_dim = 64 , patch_size = - 1 ):
1011 super ().__init__ ()
1112 config = Qwen2Config .from_dict (config_dict = encoder_args )
1213 self .encoder = Qwen2Model (config )
@@ -17,6 +18,12 @@ def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64):
1718 self .fc2 = nn .Linear (config .hidden_size , config .hidden_size )
1819 self .fc3 = nn .Linear (config .hidden_size , latent_dim * 2 )
1920 self .norm = nn .LayerNorm (config .hidden_size )
21+ self .patch_size = patch_size
22+ if patch_size != - 1 :
23+ config .num_hidden_layers = 4
24+ self .aggregator = Qwen2Model (config )
25+ self .cls_embed = nn .Parameter (torch .rand (1 , 1 , config .hidden_size ))
26+ self .cls_embed .data .normal_ (0 , 0.02 )
2027
2128 def get_frames (self , x ):
2229 num_frames_total = (x .size (- 1 ) + self .hop_size - 1 ) // self .hop_size # 向上取整的帧数
@@ -27,6 +34,17 @@ def get_frames(self, x):
2734 frames = waveform .unfold (dimension = - 1 , size = self .input_dim , step = self .hop_size ) # [B, T, d]
2835 return frames
2936
37+ def pad_patch_insert_cls (self , x ):
38+ bsz , _ , dim = x .size ()
39+ num_frame = x .size (1 )
40+ r = num_frame % self .patch_size
41+ pad_num = self .patch_size - r if r else 0
42+ x = F .pad (x , (0 , 0 , 0 , pad_num ), value = 0.0 ) # 帧数对齐到patch_size倍数
43+ x = x .reshape (- 1 , self .patch_size , dim )
44+ x = torch .cat ((x , self .cls_embed .expand (x .size (0 ), - 1 , - 1 )), dim = 1 ) # 每个patch后插入一个cls
45+ x = x .reshape (bsz , - 1 , dim )
46+ return x
47+
3048 def forward (self , waveform ):
3149 x = self .get_frames (waveform )
3250
@@ -35,6 +53,15 @@ def forward(self, waveform):
3553 x = self .encoder (inputs_embeds = x )
3654 x = x .last_hidden_state
3755
56+ # downsample
57+ if self .patch_size != - 1 :
58+ x = self .pad_patch_insert_cls (x )
59+ x = self .aggregator (inputs_embeds = x )
60+ x = x .last_hidden_state
61+ bsz , _ , dim = x .size ()
62+ x = x .reshape (- 1 , self .patch_size + 1 , dim )
63+ x = x [:, - 1 :, :].reshape (bsz , - 1 , dim )
64+
3865 x = self .fc3 (x )
3966 return x , waveform .unsqueeze (1 )
4067
@@ -59,6 +86,9 @@ def __init__(self, decoder_args, output_dim=320, latent_dim=64, semantic_model=N
5986 self .hop_length = output_dim
6087 self .head = ISTFTHead (dim = config .hidden_size , n_fft = self .hop_length * 4 , hop_length = self .hop_length , padding = "same" )
6188 self .patch_size = patch_size
89+ if self .patch_size != - 1 :
90+ self .upsampling = nn .Upsample (scale_factor = patch_size , mode = 'linear' )
91+
6292
6393 def forward (self , x , only_semantic_emb = False , past_key_values = None , use_cache = False ):
6494 x = self .fc1 (x )
@@ -73,6 +103,9 @@ def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=Fa
73103 else :
74104 unified_emb = None
75105
106+ if self .patch_size != - 1 :
107+ x = self .upsampling (x .transpose (1 , 2 )).transpose (1 , 2 )
108+
76109 x = self .decoder (inputs_embeds = x )
77110 x = x .last_hidden_state
78111
@@ -82,6 +115,8 @@ def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=Fa
82115
83116 def low_level_reconstruct (self , x ):
84117 x = self .fc1 (x )
118+ if self .patch_size != - 1 :
119+ x = self .upsampling (x .transpose (1 , 2 )).transpose (1 , 2 )
85120 x = self .decoder (inputs_embeds = x )
86121 x = x .last_hidden_state
87122
0 commit comments