Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions audio_tokenizer/configuration_audio_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
class AudioVAEconfig(PretrainedConfig):
def __init__(
self,
sample_rate: int=16000,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The sample_rate parameter is introduced with a default value of 16000. While this is a common sample rate, it might be beneficial to explicitly document the expected range or common values for this parameter, especially if the model is intended to support various audio types (speech, sound, music) as suggested by the PR title.

enc_kwargs: dict = None,
semantic_module_kwargs: dict = None,
dec_kwargs: dict = None,
Expand All @@ -18,6 +19,7 @@ def __init__(
patch_size=-1,
**kwargs
):
self.sample_rate = sample_rate
self.enc_kwargs = enc_kwargs
self.semantic_module_kwargs = semantic_module_kwargs
self.dec_kwargs = dec_kwargs
Expand Down
4 changes: 4 additions & 0 deletions audio_tokenizer/modeling_audio_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, config: AudioVAEconfig):
input_dim=config.enc_kwargs['input_dim'],
hop_size=config.enc_kwargs.get('hop_size', 320),
latent_dim=config.enc_kwargs['latent_dim'],
patch_size=config.patch_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The patch_size is directly passed from config.patch_size to the Encoder constructor. It's good that this parameter is now configurable. Ensure that the patch_size value is validated within the AudioVAEconfig or Encoder to prevent unexpected behavior, such as non-positive values or values that might lead to inefficient patching.

)

if config.semantic_module_kwargs is not None:
Expand All @@ -33,6 +34,7 @@ def __init__(self, config: AudioVAEconfig):
output_dim=config.dec_kwargs['output_dim'],
latent_dim=config.dec_kwargs['latent_dim'],
semantic_model=semantic_model,
patch_size=config.patch_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the encoder, patch_size is passed to the Decoder. Consider adding validation for patch_size to ensure it's a positive integer, as negative or zero values could lead to issues in the upsampling logic.

)

self.post_init()
Expand Down Expand Up @@ -66,6 +68,8 @@ def encode_latent(self, waveform, waveform_length):
- frame_num (torch.Tensor): The number of frames for each audio, shape (B,).
"""
frame_num = torch.ceil(waveform_length/self.config.enc_kwargs['input_dim']).to(torch.int32)
if self.config.patch_size != -1:
frame_num = torch.ceil(frame_num/self.config.patch_size)
Comment on lines +71 to +72

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The frame_num calculation is updated to account for patch_size. The use of torch.ceil is appropriate for ensuring all frames are covered. However, it might be clearer to cast the result to torch.int64 instead of torch.int32 to avoid potential overflow issues with very long audio inputs, although int32 is likely sufficient for typical audio lengths.

            frame_num = torch.ceil(frame_num/self.config.patch_size).to(torch.int64)

h, y = self.encoder(waveform)
h = h.transpose(1, 2) # [B, d, T]

Expand Down
37 changes: 36 additions & 1 deletion audio_tokenizer/vae_modules.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch.nn as nn
import torch.nn.functional as F
from transformers import Qwen2Model, Qwen2Config
import torch

from .istft import ISTFTHead


class Encoder(nn.Module):
def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64):
def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64, patch_size=-1):
super().__init__()
config = Qwen2Config.from_dict(config_dict=encoder_args)
self.encoder = Qwen2Model(config)
Expand All @@ -17,6 +18,12 @@ def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64):
self.fc2 = nn.Linear(config.hidden_size, config.hidden_size)
self.fc3 = nn.Linear(config.hidden_size, latent_dim*2)
self.norm = nn.LayerNorm(config.hidden_size)
self.patch_size = patch_size
if patch_size != -1:
config.num_hidden_layers = 4

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The config.num_hidden_layers is hardcoded to 4 when patch_size is enabled. This might limit the flexibility of the aggregator model. Consider making this configurable through encoder_args or config to allow for more fine-grained control over the aggregator's architecture.

self.aggregator = Qwen2Model(config)
self.cls_embed = nn.Parameter(torch.rand(1, 1, config.hidden_size))
self.cls_embed.data.normal_(0, 0.02)

def get_frames(self, x):
num_frames_total = (x.size(-1) + self.hop_size - 1) // self.hop_size # 向上取整的帧数
Expand All @@ -27,6 +34,17 @@ def get_frames(self, x):
frames = waveform.unfold(dimension=-1, size=self.input_dim, step=self.hop_size) # [B, T, d]
return frames

def pad_patch_insert_cls(self, x):
bsz, _, dim = x.size()
num_frame = x.size(1)
r = num_frame % self.patch_size
pad_num = self.patch_size-r if r else 0
x = F.pad(x, (0, 0, 0, pad_num), value=0.0) # 帧数对齐到patch_size倍数
x = x.reshape(-1, self.patch_size, dim)
x = torch.cat((x, self.cls_embed.expand(x.size(0), -1, -1)), dim=1) # 每个patch后插入一个cls
x = x.reshape(bsz, -1, dim)
return x

def forward(self, waveform):
x = self.get_frames(waveform)

Expand All @@ -35,6 +53,15 @@ def forward(self, waveform):
x = self.encoder(inputs_embeds=x)
x = x.last_hidden_state

# downsample
if self.patch_size != -1:
x = self.pad_patch_insert_cls(x)
x = self.aggregator(inputs_embeds=x)
x = x.last_hidden_state
bsz, _, dim = x.size()
x = x.reshape(-1, self.patch_size+1, dim)
x = x[:, -1:, :].reshape(bsz, -1, dim)

x = self.fc3(x)
return x, waveform.unsqueeze(1)

Expand All @@ -59,6 +86,9 @@ def __init__(self, decoder_args, output_dim=320, latent_dim=64, semantic_model=N
self.hop_length = output_dim
self.head = ISTFTHead(dim=config.hidden_size, n_fft=self.hop_length * 4, hop_length=self.hop_length, padding="same")
self.patch_size = patch_size
if self.patch_size != -1:
self.upsampling = nn.Upsample(scale_factor=patch_size, mode='linear')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nn.Upsample module is initialized with mode='linear'. For audio data, linear interpolation might not always be the best choice, especially if preserving specific audio characteristics is crucial. Consider if other interpolation modes like nearest or bicubic (if applicable for the data dimensions) might be more suitable, or if this should be a configurable parameter.



def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=False):
x = self.fc1(x)
Expand All @@ -73,6 +103,9 @@ def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=Fa
else:
unified_emb = None

if self.patch_size != -1:
x = self.upsampling(x.transpose(1, 2)).transpose(1, 2)

x = self.decoder(inputs_embeds=x)
x = x.last_hidden_state

Expand All @@ -82,6 +115,8 @@ def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=Fa

def low_level_reconstruct(self, x):
x = self.fc1(x)
if self.patch_size != -1:
x = self.upsampling(x.transpose(1, 2)).transpose(1, 2)
x = self.decoder(inputs_embeds=x)
x = x.last_hidden_state

Expand Down
8 changes: 6 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@

from audio_tokenizer.modeling_audio_vae import AudioVAE

model = AudioVAE.from_pretrained('inclusionAI/MingTok-Audio')
model = AudioVAE.from_pretrained('inclusionAI/Ming-omni-tts-tokenizer-12Hz')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The model name has been updated to inclusionAI/Ming-omni-tts-tokenizer-12Hz. It's good to see the test script reflecting the new model. Ensure that this model is publicly available or accessible in the environment where the tests will run.

model = model.cuda()
model.eval()
model = model.to(torch.bfloat16)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The model is cast to torch.bfloat16. This is a good practice for optimizing memory and computation with compatible hardware. Ensure that the target environment supports bfloat16 operations, as not all GPUs or PyTorch versions do.


sample_rate = model.config.sample_rate
waveform, sr = torchaudio.load('data/1089-134686-0000.flac', backend='soundfile')
if sr != sample_rate:
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)(waveform)
Comment on lines +13 to +14

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torchaudio.transforms.Resample is used to match the waveform's sample rate with the model's expected sample_rate. This is a critical step for ensuring correct input to the model. It's good that this is handled explicitly.

sample = {'waveform': waveform.cuda(), 'waveform_length': torch.tensor([waveform.size(-1)]).cuda()}

with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
latent, frame_num = model.encode_latent(**sample)
output_waveform = model.decode(latent)

torchaudio.save('./1089-134686-0000_reconstruct.wav', output_waveform.cpu()[0], sample_rate=16000)
torchaudio.save('./1089-134686-0000_reconstruct.wav', output_waveform.cpu()[0], sample_rate=sample_rate)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output sample_rate for saving the reconstructed audio is now dynamically set from model.config.sample_rate. This is an improvement over the hardcoded 16000 as it ensures consistency with the model's configuration.

Loading