44.1kHz acoustic tokenizer supports speech & sound & music#4
44.1kHz acoustic tokenizer supports speech & sound & music#4yongjie-lv merged 1 commit intoinclusionAI:mainfrom
Conversation
Summary of ChangesHello @yongjie-lv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the acoustic tokenizer by introducing a configurable sample rate and a novel patch-based processing mechanism. These changes aim to improve the model's adaptability and performance across diverse audio inputs, including speech, sound, and music, by allowing it to handle different sampling frequencies and process audio more efficiently through patching and aggregation. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces support for a sample_rate parameter in the AudioVAEconfig and integrates patch_size functionality into the encoder and decoder modules. This allows for more flexible audio processing, including downsampling and upsampling based on the patch_size. The test.py script has also been updated to reflect these changes, including sample rate resampling and using the new model. Overall, the changes enhance the model's capabilities for handling different audio resolutions and introduce a new aggregation mechanism in the encoder.
| class AudioVAEconfig(PretrainedConfig): | ||
| def __init__( | ||
| self, | ||
| sample_rate: int=16000, |
There was a problem hiding this comment.
The sample_rate parameter is introduced with a default value of 16000. While this is a common sample rate, it might be beneficial to explicitly document the expected range or common values for this parameter, especially if the model is intended to support various audio types (speech, sound, music) as suggested by the PR title.
| input_dim=config.enc_kwargs['input_dim'], | ||
| hop_size=config.enc_kwargs.get('hop_size', 320), | ||
| latent_dim=config.enc_kwargs['latent_dim'], | ||
| patch_size=config.patch_size |
There was a problem hiding this comment.
The patch_size is directly passed from config.patch_size to the Encoder constructor. It's good that this parameter is now configurable. Ensure that the patch_size value is validated within the AudioVAEconfig or Encoder to prevent unexpected behavior, such as non-positive values or values that might lead to inefficient patching.
| output_dim=config.dec_kwargs['output_dim'], | ||
| latent_dim=config.dec_kwargs['latent_dim'], | ||
| semantic_model=semantic_model, | ||
| patch_size=config.patch_size |
| if self.config.patch_size != -1: | ||
| frame_num = torch.ceil(frame_num/self.config.patch_size) |
There was a problem hiding this comment.
The frame_num calculation is updated to account for patch_size. The use of torch.ceil is appropriate for ensuring all frames are covered. However, it might be clearer to cast the result to torch.int64 instead of torch.int32 to avoid potential overflow issues with very long audio inputs, although int32 is likely sufficient for typical audio lengths.
frame_num = torch.ceil(frame_num/self.config.patch_size).to(torch.int64)| self.norm = nn.LayerNorm(config.hidden_size) | ||
| self.patch_size = patch_size | ||
| if patch_size != -1: | ||
| config.num_hidden_layers = 4 |
There was a problem hiding this comment.
| self.head = ISTFTHead(dim=config.hidden_size, n_fft=self.hop_length * 4, hop_length=self.hop_length, padding="same") | ||
| self.patch_size = patch_size | ||
| if self.patch_size != -1: | ||
| self.upsampling = nn.Upsample(scale_factor=patch_size, mode='linear') |
There was a problem hiding this comment.
The nn.Upsample module is initialized with mode='linear'. For audio data, linear interpolation might not always be the best choice, especially if preserving specific audio characteristics is crucial. Consider if other interpolation modes like nearest or bicubic (if applicable for the data dimensions) might be more suitable, or if this should be a configurable parameter.
| from audio_tokenizer.modeling_audio_vae import AudioVAE | ||
|
|
||
| model = AudioVAE.from_pretrained('inclusionAI/MingTok-Audio') | ||
| model = AudioVAE.from_pretrained('inclusionAI/Ming-omni-tts-tokenizer-12Hz') |
| model = AudioVAE.from_pretrained('inclusionAI/Ming-omni-tts-tokenizer-12Hz') | ||
| model = model.cuda() | ||
| model.eval() | ||
| model = model.to(torch.bfloat16) |
| if sr != sample_rate: | ||
| waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)(waveform) |
| output_waveform = model.decode(latent) | ||
|
|
||
| torchaudio.save('./1089-134686-0000_reconstruct.wav', output_waveform.cpu()[0], sample_rate=16000) No newline at end of file | ||
| torchaudio.save('./1089-134686-0000_reconstruct.wav', output_waveform.cpu()[0], sample_rate=sample_rate) |
No description provided.