Skip to content

Commit 22d0e4d

Browse files
authored
Merge pull request #4 from yongjie-lv/main
44.1kHz acoustic tokenizer supports speech & sound & music
2 parents 6bcc224 + 6f64176 commit 22d0e4d

File tree

4 files changed

+48
-3
lines changed

4 files changed

+48
-3
lines changed

audio_tokenizer/configuration_audio_vae.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
class AudioVAEconfig(PretrainedConfig):
55
def __init__(
66
self,
7+
sample_rate: int=16000,
78
enc_kwargs: dict = None,
89
semantic_module_kwargs: dict = None,
910
dec_kwargs: dict = None,
@@ -18,6 +19,7 @@ def __init__(
1819
patch_size=-1,
1920
**kwargs
2021
):
22+
self.sample_rate = sample_rate
2123
self.enc_kwargs = enc_kwargs
2224
self.semantic_module_kwargs = semantic_module_kwargs
2325
self.dec_kwargs = dec_kwargs

audio_tokenizer/modeling_audio_vae.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, config: AudioVAEconfig):
1818
input_dim=config.enc_kwargs['input_dim'],
1919
hop_size=config.enc_kwargs.get('hop_size', 320),
2020
latent_dim=config.enc_kwargs['latent_dim'],
21+
patch_size=config.patch_size
2122
)
2223

2324
if config.semantic_module_kwargs is not None:
@@ -33,6 +34,7 @@ def __init__(self, config: AudioVAEconfig):
3334
output_dim=config.dec_kwargs['output_dim'],
3435
latent_dim=config.dec_kwargs['latent_dim'],
3536
semantic_model=semantic_model,
37+
patch_size=config.patch_size
3638
)
3739

3840
self.post_init()
@@ -66,6 +68,8 @@ def encode_latent(self, waveform, waveform_length):
6668
- frame_num (torch.Tensor): The number of frames for each audio, shape (B,).
6769
"""
6870
frame_num = torch.ceil(waveform_length/self.config.enc_kwargs['input_dim']).to(torch.int32)
71+
if self.config.patch_size != -1:
72+
frame_num = torch.ceil(frame_num/self.config.patch_size)
6973
h, y = self.encoder(waveform)
7074
h = h.transpose(1, 2) # [B, d, T]
7175

audio_tokenizer/vae_modules.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import torch.nn as nn
22
import torch.nn.functional as F
33
from transformers import Qwen2Model, Qwen2Config
4+
import torch
45

56
from .istft import ISTFTHead
67

78

89
class Encoder(nn.Module):
9-
def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64):
10+
def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64, patch_size=-1):
1011
super().__init__()
1112
config = Qwen2Config.from_dict(config_dict=encoder_args)
1213
self.encoder = Qwen2Model(config)
@@ -17,6 +18,12 @@ def __init__(self, encoder_args, input_dim=320, hop_size=320, latent_dim=64):
1718
self.fc2 = nn.Linear(config.hidden_size, config.hidden_size)
1819
self.fc3 = nn.Linear(config.hidden_size, latent_dim*2)
1920
self.norm = nn.LayerNorm(config.hidden_size)
21+
self.patch_size = patch_size
22+
if patch_size != -1:
23+
config.num_hidden_layers = 4
24+
self.aggregator = Qwen2Model(config)
25+
self.cls_embed = nn.Parameter(torch.rand(1, 1, config.hidden_size))
26+
self.cls_embed.data.normal_(0, 0.02)
2027

2128
def get_frames(self, x):
2229
num_frames_total = (x.size(-1) + self.hop_size - 1) // self.hop_size # 向上取整的帧数
@@ -27,6 +34,17 @@ def get_frames(self, x):
2734
frames = waveform.unfold(dimension=-1, size=self.input_dim, step=self.hop_size) # [B, T, d]
2835
return frames
2936

37+
def pad_patch_insert_cls(self, x):
38+
bsz, _, dim = x.size()
39+
num_frame = x.size(1)
40+
r = num_frame % self.patch_size
41+
pad_num = self.patch_size-r if r else 0
42+
x = F.pad(x, (0, 0, 0, pad_num), value=0.0) # 帧数对齐到patch_size倍数
43+
x = x.reshape(-1, self.patch_size, dim)
44+
x = torch.cat((x, self.cls_embed.expand(x.size(0), -1, -1)), dim=1) # 每个patch后插入一个cls
45+
x = x.reshape(bsz, -1, dim)
46+
return x
47+
3048
def forward(self, waveform):
3149
x = self.get_frames(waveform)
3250

@@ -35,6 +53,15 @@ def forward(self, waveform):
3553
x = self.encoder(inputs_embeds=x)
3654
x = x.last_hidden_state
3755

56+
# downsample
57+
if self.patch_size != -1:
58+
x = self.pad_patch_insert_cls(x)
59+
x = self.aggregator(inputs_embeds=x)
60+
x = x.last_hidden_state
61+
bsz, _, dim = x.size()
62+
x = x.reshape(-1, self.patch_size+1, dim)
63+
x = x[:, -1:, :].reshape(bsz, -1, dim)
64+
3865
x = self.fc3(x)
3966
return x, waveform.unsqueeze(1)
4067

@@ -59,6 +86,9 @@ def __init__(self, decoder_args, output_dim=320, latent_dim=64, semantic_model=N
5986
self.hop_length = output_dim
6087
self.head = ISTFTHead(dim=config.hidden_size, n_fft=self.hop_length * 4, hop_length=self.hop_length, padding="same")
6188
self.patch_size = patch_size
89+
if self.patch_size != -1:
90+
self.upsampling = nn.Upsample(scale_factor=patch_size, mode='linear')
91+
6292

6393
def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=False):
6494
x = self.fc1(x)
@@ -73,6 +103,9 @@ def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=Fa
73103
else:
74104
unified_emb = None
75105

106+
if self.patch_size != -1:
107+
x = self.upsampling(x.transpose(1, 2)).transpose(1, 2)
108+
76109
x = self.decoder(inputs_embeds=x)
77110
x = x.last_hidden_state
78111

@@ -82,6 +115,8 @@ def forward(self, x, only_semantic_emb=False, past_key_values=None, use_cache=Fa
82115

83116
def low_level_reconstruct(self, x):
84117
x = self.fc1(x)
118+
if self.patch_size != -1:
119+
x = self.upsampling(x.transpose(1, 2)).transpose(1, 2)
85120
x = self.decoder(inputs_embeds=x)
86121
x = x.last_hidden_state
87122

test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@
33

44
from audio_tokenizer.modeling_audio_vae import AudioVAE
55

6-
model = AudioVAE.from_pretrained('inclusionAI/MingTok-Audio')
6+
model = AudioVAE.from_pretrained('inclusionAI/Ming-omni-tts-tokenizer-12Hz')
77
model = model.cuda()
88
model.eval()
9+
model = model.to(torch.bfloat16)
910

11+
sample_rate = model.config.sample_rate
1012
waveform, sr = torchaudio.load('data/1089-134686-0000.flac', backend='soundfile')
13+
if sr != sample_rate:
14+
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)(waveform)
1115
sample = {'waveform': waveform.cuda(), 'waveform_length': torch.tensor([waveform.size(-1)]).cuda()}
1216

1317
with torch.no_grad():
1418
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
1519
latent, frame_num = model.encode_latent(**sample)
1620
output_waveform = model.decode(latent)
1721

18-
torchaudio.save('./1089-134686-0000_reconstruct.wav', output_waveform.cpu()[0], sample_rate=16000)
22+
torchaudio.save('./1089-134686-0000_reconstruct.wav', output_waveform.cpu()[0], sample_rate=sample_rate)

0 commit comments

Comments
 (0)