-
Notifications
You must be signed in to change notification settings - Fork 8
44.1kHz acoustic tokenizer supports speech & sound & music #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ def __init__(self, config: AudioVAEconfig): | |
| input_dim=config.enc_kwargs['input_dim'], | ||
| hop_size=config.enc_kwargs.get('hop_size', 320), | ||
| latent_dim=config.enc_kwargs['latent_dim'], | ||
| patch_size=config.patch_size | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| ) | ||
|
|
||
| if config.semantic_module_kwargs is not None: | ||
|
|
@@ -33,6 +34,7 @@ def __init__(self, config: AudioVAEconfig): | |
| output_dim=config.dec_kwargs['output_dim'], | ||
| latent_dim=config.dec_kwargs['latent_dim'], | ||
| semantic_model=semantic_model, | ||
| patch_size=config.patch_size | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ) | ||
|
|
||
| self.post_init() | ||
|
|
@@ -66,6 +68,8 @@ def encode_latent(self, waveform, waveform_length): | |
| - frame_num (torch.Tensor): The number of frames for each audio, shape (B,). | ||
| """ | ||
| frame_num = torch.ceil(waveform_length/self.config.enc_kwargs['input_dim']).to(torch.int32) | ||
| if self.config.patch_size != -1: | ||
| frame_num = torch.ceil(frame_num/self.config.patch_size) | ||
|
Comment on lines
+71
to
+72
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The frame_num = torch.ceil(frame_num/self.config.patch_size).to(torch.int64) |
||
| h, y = self.encoder(waveform) | ||
| h = h.transpose(1, 2) # [B, d, T] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,13 @@ | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from transformers import Qwen2Model, Qwen2Config | ||
| import torch | ||
|
|
||
| from .istft import ISTFTHead | ||
|
|
||
|
|
||
| class Encoder(nn.Module): | ||
| def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64): | ||
| def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64, patch_size=-1): | ||
| super().__init__() | ||
| config = Qwen2Config.from_dict(config_dict=encoder_args) | ||
| self.encoder = Qwen2Model(config) | ||
|
|
@@ -17,6 +18,12 @@ def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64): | |
| self.fc2 = nn.Linear(config.hidden_size, config.hidden_size) | ||
| self.fc3 = nn.Linear(config.hidden_size, latent_dim*2) | ||
| self.norm = nn.LayerNorm(config.hidden_size) | ||
| self.patch_size = patch_size | ||
| if patch_size != -1: | ||
| config.num_hidden_layers = 4 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| self.aggregator = Qwen2Model(config) | ||
| self.cls_embed = nn.Parameter(torch.rand(1, 1, config.hidden_size)) | ||
| self.cls_embed.data.normal_(0, 0.02) | ||
|
|
||
| def get_frames(self, x): | ||
| num_frames_total = (x.size(-1) + self.hop_size - 1) // self.hop_size # 向上取整的帧数 | ||
|
|
@@ -27,6 +34,17 @@ def get_frames(self, x): | |
| frames = waveform.unfold(dimension=-1, size=self.input_dim, step=self.hop_size) # [B, T, d] | ||
| return frames | ||
|
|
||
| def pad_patch_insert_cls(self, x): | ||
| bsz, _, dim = x.size() | ||
| num_frame = x.size(1) | ||
| r = num_frame % self.patch_size | ||
| pad_num = self.patch_size-r if r else 0 | ||
| x = F.pad(x, (0, 0, 0, pad_num), value=0.0) # 帧数对齐到patch_size倍数 | ||
| x = x.reshape(-1, self.patch_size, dim) | ||
| x = torch.cat((x, self.cls_embed.expand(x.size(0), -1, -1)), dim=1) # 每个patch后插入一个cls | ||
| x = x.reshape(bsz, -1, dim) | ||
| return x | ||
|
|
||
| def forward(self, waveform): | ||
| x = self.get_frames(waveform) | ||
|
|
||
|
|
@@ -35,6 +53,15 @@ def forward(self, waveform): | |
| x = self.encoder(inputs_embeds=x) | ||
| x = x.last_hidden_state | ||
|
|
||
| # downsample | ||
| if self.patch_size != -1: | ||
| x = self.pad_patch_insert_cls(x) | ||
| x = self.aggregator(inputs_embeds=x) | ||
| x = x.last_hidden_state | ||
| bsz, _, dim = x.size() | ||
| x = x.reshape(-1, self.patch_size+1, dim) | ||
| x = x[:, -1:, :].reshape(bsz, -1, dim) | ||
|
|
||
| x = self.fc3(x) | ||
| return x, waveform.unsqueeze(1) | ||
|
|
||
|
|
@@ -59,6 +86,9 @@ def __init__(self, decoder_args, output_dim=320, latent_dim=64, semantic_model=N | |
| self.hop_length = output_dim | ||
| self.head = ISTFTHead(dim=config.hidden_size, n_fft=self.hop_length * 4, hop_length=self.hop_length, padding="same") | ||
| self.patch_size = patch_size | ||
| if self.patch_size != -1: | ||
| self.upsampling = nn.Upsample(scale_factor=patch_size, mode='linear') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
|
|
||
| def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=False): | ||
| x = self.fc1(x) | ||
|
|
@@ -73,6 +103,9 @@ def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=Fa | |
| else: | ||
| unified_emb = None | ||
|
|
||
| if self.patch_size != -1: | ||
| x = self.upsampling(x.transpose(1, 2)).transpose(1, 2) | ||
|
|
||
| x = self.decoder(inputs_embeds=x) | ||
| x = x.last_hidden_state | ||
|
|
||
|
|
@@ -82,6 +115,8 @@ def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=Fa | |
|
|
||
| def low_level_reconstruct(self, x): | ||
| x = self.fc1(x) | ||
| if self.patch_size != -1: | ||
| x = self.upsampling(x.transpose(1, 2)).transpose(1, 2) | ||
| x = self.decoder(inputs_embeds=x) | ||
| x = x.last_hidden_state | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,16 +3,20 @@ | |
|
|
||
| from audio_tokenizer.modeling_audio_vae import AudioVAE | ||
|
|
||
| model = AudioVAE.from_pretrained('inclusionAI/MingTok-Audio') | ||
| model = AudioVAE.from_pretrained('inclusionAI/Ming-omni-tts-tokenizer-12Hz') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| model = model.cuda() | ||
| model.eval() | ||
| model = model.to(torch.bfloat16) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| sample_rate = model.config.sample_rate | ||
| waveform, sr = torchaudio.load('data/1089-134686-0000.flac', backend='soundfile') | ||
| if sr != sample_rate: | ||
| waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)(waveform) | ||
|
Comment on lines
+13
to
+14
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| sample = {'waveform': waveform.cuda(), 'waveform_length': torch.tensor([waveform.size(-1)]).cuda()} | ||
|
|
||
| with torch.no_grad(): | ||
| with torch.autocast(device_type='cuda', dtype=torch.bfloat16): | ||
| latent, frame_num = model.encode_latent(**sample) | ||
| output_waveform = model.decode(latent) | ||
|
|
||
| torchaudio.save('./1089-134686-0000_reconstruct.wav', output_waveform.cpu()[0], sample_rate=16000) | ||
| torchaudio.save('./1089-134686-0000_reconstruct.wav', output_waveform.cpu()[0], sample_rate=sample_rate) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
sample_rateparameter is introduced with a default value of16000. While this is a common sample rate, it might be beneficial to explicitly document the expected range or common values for this parameter, especially if the model is intended to support various audio types (speech, sound, music) as suggested by the PR title.