diff --git a/audio_tokenizer/configuration_audio_vae.py b/audio_tokenizer/configuration_audio_vae.py index 37f541e..72b1f05 100644 --- a/audio_tokenizer/configuration_audio_vae.py +++ b/audio_tokenizer/configuration_audio_vae.py @@ -4,6 +4,7 @@ class AudioVAEconfig(PretrainedConfig): def __init__( self, + sample_rate: int=16000, enc_kwargs: dict = None, semantic_module_kwargs: dict = None, dec_kwargs: dict = None, @@ -18,6 +19,7 @@ def __init__( patch_size=-1, **kwargs ): + self.sample_rate = sample_rate self.enc_kwargs = enc_kwargs self.semantic_module_kwargs = semantic_module_kwargs self.dec_kwargs = dec_kwargs diff --git a/audio_tokenizer/modeling_audio_vae.py b/audio_tokenizer/modeling_audio_vae.py index 932c17f..7ea7a45 100644 --- a/audio_tokenizer/modeling_audio_vae.py +++ b/audio_tokenizer/modeling_audio_vae.py @@ -18,6 +18,7 @@ def __init__(self, config: AudioVAEconfig): input_dim=config.enc_kwargs['input_dim'], hop_size=config.enc_kwargs.get('hop_size', 320), latent_dim=config.enc_kwargs['latent_dim'], + patch_size=config.patch_size ) if config.semantic_module_kwargs is not None: @@ -33,6 +34,7 @@ def __init__(self, config: AudioVAEconfig): output_dim=config.dec_kwargs['output_dim'], latent_dim=config.dec_kwargs['latent_dim'], semantic_model=semantic_model, + patch_size=config.patch_size ) self.post_init() @@ -66,6 +68,8 @@ def encode_latent(self, waveform, waveform_length): - frame_num (torch.Tensor): The number of frames for each audio, shape (B,). """ frame_num = torch.ceil(waveform_length/self.config.enc_kwargs['input_dim']).to(torch.int32) + if self.config.patch_size != -1: + frame_num = torch.ceil(frame_num/self.config.patch_size) h, y = self.encoder(waveform) h = h.transpose(1, 2) # [B, d, T] diff --git a/audio_tokenizer/vae_modules.py b/audio_tokenizer/vae_modules.py index 5e38b69..6fdb889 100644 --- a/audio_tokenizer/vae_modules.py +++ b/audio_tokenizer/vae_modules.py @@ -1,12 +1,13 @@ import torch.nn as nn import torch.nn.functional as F from transformers import Qwen2Model, Qwen2Config +import torch from .istft import ISTFTHead class Encoder(nn.Module): - def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64): + def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64, patch_size=-1): super().__init__() config = Qwen2Config.from_dict(config_dict=encoder_args) self.encoder = Qwen2Model(config) @@ -17,6 +18,12 @@ def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64): self.fc2 = nn.Linear(config.hidden_size, config.hidden_size) self.fc3 = nn.Linear(config.hidden_size, latent_dim*2) self.norm = nn.LayerNorm(config.hidden_size) + self.patch_size = patch_size + if patch_size != -1: + config.num_hidden_layers = 4 + self.aggregator = Qwen2Model(config) + self.cls_embed = nn.Parameter(torch.rand(1, 1, config.hidden_size)) + self.cls_embed.data.normal_(0, 0.02) def get_frames(self, x): num_frames_total = (x.size(-1) + self.hop_size - 1) // self.hop_size # 向上取整的帧数 @@ -27,6 +34,17 @@ def get_frames(self, x): frames = waveform.unfold(dimension=-1, size=self.input_dim, step=self.hop_size) # [B, T, d] return frames + def pad_patch_insert_cls(self, x): + bsz, _, dim = x.size() + num_frame = x.size(1) + r = num_frame % self.patch_size + pad_num = self.patch_size-r if r else 0 + x = F.pad(x, (0, 0, 0, pad_num), value=0.0) # 帧数对齐到patch_size倍数 + x = x.reshape(-1, self.patch_size, dim) + x = torch.cat((x, self.cls_embed.expand(x.size(0), -1, -1)), dim=1) # 每个patch后插入一个cls + x = x.reshape(bsz, -1, dim) + return x + def forward(self, waveform): x = self.get_frames(waveform) @@ -35,6 +53,15 @@ def forward(self, waveform): x = self.encoder(inputs_embeds=x) x = x.last_hidden_state + # downsample + if self.patch_size != -1: + x = self.pad_patch_insert_cls(x) + x = self.aggregator(inputs_embeds=x) + x = x.last_hidden_state + bsz, _, dim = x.size() + x = x.reshape(-1, self.patch_size+1, dim) + x = x[:, -1:, :].reshape(bsz, -1, dim) + x = self.fc3(x) return x, waveform.unsqueeze(1) @@ -59,6 +86,9 @@ def __init__(self, decoder_args, output_dim=320, latent_dim=64, semantic_model=N self.hop_length = output_dim self.head = ISTFTHead(dim=config.hidden_size, n_fft=self.hop_length * 4, hop_length=self.hop_length, padding="same") self.patch_size = patch_size + if self.patch_size != -1: + self.upsampling = nn.Upsample(scale_factor=patch_size, mode='linear') + def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=False): x = self.fc1(x) @@ -73,6 +103,9 @@ def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=Fa else: unified_emb = None + if self.patch_size != -1: + x = self.upsampling(x.transpose(1, 2)).transpose(1, 2) + x = self.decoder(inputs_embeds=x) x = x.last_hidden_state @@ -82,6 +115,8 @@ def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=Fa def low_level_reconstruct(self, x): x = self.fc1(x) + if self.patch_size != -1: + x = self.upsampling(x.transpose(1, 2)).transpose(1, 2) x = self.decoder(inputs_embeds=x) x = x.last_hidden_state diff --git a/test.py b/test.py index 8cbe649..11f589e 100644 --- a/test.py +++ b/test.py @@ -3,11 +3,15 @@ from audio_tokenizer.modeling_audio_vae import AudioVAE -model = AudioVAE.from_pretrained('inclusionAI/MingTok-Audio') +model = AudioVAE.from_pretrained('inclusionAI/Ming-omni-tts-tokenizer-12Hz') model = model.cuda() model.eval() +model = model.to(torch.bfloat16) +sample_rate = model.config.sample_rate waveform, sr = torchaudio.load('data/1089-134686-0000.flac', backend='soundfile') +if sr != sample_rate: + waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)(waveform) sample = {'waveform': waveform.cuda(), 'waveform_length': torch.tensor([waveform.size(-1)]).cuda()} with torch.no_grad(): @@ -15,4 +19,4 @@ latent, frame_num = model.encode_latent(**sample) output_waveform = model.decode(latent) -torchaudio.save('./1089-134686-0000_reconstruct.wav', output_waveform.cpu()[0], sample_rate=16000) \ No newline at end of file +torchaudio.save('./1089-134686-0000_reconstruct.wav', output_waveform.cpu()[0], sample_rate=sample_rate)