|
1 | | -import os |
2 | 1 | import pathlib |
3 | 2 |
|
4 | | -import librosa |
| 3 | +import numpy as np |
5 | 4 | import torch |
6 | 5 | import torch.nn.functional as F |
7 | 6 | import yaml |
8 | | -import numpy as np |
9 | 7 | from librosa.filters import mel as librosa_mel_fn |
| 8 | + |
10 | 9 | from basics.base_vocoder import BaseVocoder |
11 | 10 | from modules.vocoders.registry import register_vocoder |
12 | 11 | from utils.hparams import hparams |
@@ -35,83 +34,6 @@ def load_model(model_path: pathlib.Path, device='cpu'): |
35 | 34 | return model, args |
36 | 35 |
|
37 | 36 |
|
38 | | -class Audio2Mel(torch.nn.Module): |
39 | | - def __init__( |
40 | | - self, |
41 | | - hop_length, |
42 | | - sampling_rate, |
43 | | - n_mel_channels, |
44 | | - win_length, |
45 | | - n_fft=None, |
46 | | - mel_fmin=0, |
47 | | - mel_fmax=None, |
48 | | - clamp=1e-5 |
49 | | - ): |
50 | | - super().__init__() |
51 | | - n_fft = win_length if n_fft is None else n_fft |
52 | | - self.hann_window = {} |
53 | | - mel_basis = librosa_mel_fn( |
54 | | - sr=sampling_rate, |
55 | | - n_fft=n_fft, |
56 | | - n_mels=n_mel_channels, |
57 | | - fmin=mel_fmin, |
58 | | - fmax=mel_fmax) |
59 | | - mel_basis = torch.from_numpy(mel_basis).float() |
60 | | - self.register_buffer("mel_basis", mel_basis) |
61 | | - self.n_fft = n_fft |
62 | | - self.hop_length = hop_length |
63 | | - self.win_length = win_length |
64 | | - self.sampling_rate = sampling_rate |
65 | | - self.n_mel_channels = n_mel_channels |
66 | | - self.clamp = clamp |
67 | | - |
68 | | - def forward(self, audio, keyshift=0, speed=1): |
69 | | - ''' |
70 | | - audio: B x C x T |
71 | | - log_mel_spec: B x T_ x C x n_mel |
72 | | - ''' |
73 | | - factor = 2 ** (keyshift / 12) |
74 | | - n_fft_new = int(np.round(self.n_fft * factor)) |
75 | | - win_length_new = int(np.round(self.win_length * factor)) |
76 | | - hop_length_new = int(np.round(self.hop_length * speed)) |
77 | | - |
78 | | - keyshift_key = str(keyshift) + '_' + str(audio.device) |
79 | | - if keyshift_key not in self.hann_window: |
80 | | - self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) |
81 | | - |
82 | | - B, C, T = audio.shape |
83 | | - audio = audio.reshape(B * C, T) |
84 | | - fft = torch.stft( |
85 | | - audio, |
86 | | - n_fft=n_fft_new, |
87 | | - hop_length=hop_length_new, |
88 | | - win_length=win_length_new, |
89 | | - window=self.hann_window[keyshift_key], |
90 | | - center=True, |
91 | | - return_complex=False) |
92 | | - real_part, imag_part = fft.unbind(-1) |
93 | | - magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) |
94 | | - |
95 | | - if keyshift != 0: |
96 | | - size = self.n_fft // 2 + 1 |
97 | | - resize = magnitude.size(1) |
98 | | - if resize < size: |
99 | | - magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) |
100 | | - magnitude = magnitude[:, :size, :] * self.win_length / win_length_new |
101 | | - |
102 | | - mel_output = torch.matmul(self.mel_basis, magnitude) |
103 | | - log_mel_spec = torch.log10(torch.clamp(mel_output, min=self.clamp)) |
104 | | - |
105 | | - # log_mel_spec: B x C, M, T |
106 | | - T_ = log_mel_spec.shape[-1] |
107 | | - log_mel_spec = log_mel_spec.reshape(B, C, self.n_mel_channels, T_) |
108 | | - log_mel_spec = log_mel_spec.permute(0, 3, 1, 2) |
109 | | - |
110 | | - # print('og_mel_spec:', log_mel_spec.shape) |
111 | | - log_mel_spec = log_mel_spec.squeeze(2) # mono |
112 | | - return log_mel_spec |
113 | | - |
114 | | - |
115 | 37 | @register_vocoder |
116 | 38 | class DDSP(BaseVocoder): |
117 | 39 | def __init__(self, device='cpu'): |
@@ -149,8 +71,15 @@ def spec2wav_torch(self, mel, f0): # mel: [B, T, bins] f0: [B, T] |
149 | 71 | print('Mismatch parameters: hparams[\'fmax\']=', hparams['fmax'], '!=', self.args.data.mel_fmax, |
150 | 72 | '(vocoder)') |
151 | 73 | with torch.no_grad(): |
152 | | - f0 = f0.unsqueeze(-1) |
153 | | - signal, _, (s_h, s_n) = self.model(mel.to(self.device), f0.to(self.device)) |
| 74 | + mel = mel.to(self.device) |
| 75 | + mel_base = hparams.get('mel_base', 10) |
| 76 | + if mel_base != 'e': |
| 77 | + assert mel_base in [10, '10'], "mel_base must be 'e', '10' or 10." |
| 78 | + else: |
| 79 | + # log mel to log10 mel |
| 80 | + mel = 0.434294 * mel |
| 81 | + f0 = f0.unsqueeze(-1).to(self.device) |
| 82 | + signal, _, (s_h, s_n) = self.model(mel, f0) |
154 | 83 | signal = signal.view(-1) |
155 | 84 | return signal |
156 | 85 |
|
@@ -178,38 +107,14 @@ def spec2wav(self, mel, f0): |
178 | 107 | '(vocoder)') |
179 | 108 | with torch.no_grad(): |
180 | 109 | mel = torch.FloatTensor(mel).unsqueeze(0).to(self.device) |
| 110 | + mel_base = hparams.get('mel_base', 10) |
| 111 | + if mel_base != 'e': |
| 112 | + assert mel_base in [10, '10'], "mel_base must be 'e', '10' or 10." |
| 113 | + else: |
| 114 | + # log mel to log10 mel |
| 115 | + mel = 0.434294 * mel |
181 | 116 | f0 = torch.FloatTensor(f0).unsqueeze(0).unsqueeze(-1).to(self.device) |
182 | | - signal, _, (s_h, s_n) = self.model(mel.to(self.device), f0.to(self.device)) |
| 117 | + signal, _, (s_h, s_n) = self.model(mel, f0) |
183 | 118 | signal = signal.view(-1) |
184 | 119 | wav_out = signal.cpu().numpy() |
185 | 120 | return wav_out |
186 | | - |
187 | | - @staticmethod |
188 | | - def wav2spec(inp_path, keyshift=0, speed=1, device=None): |
189 | | - if device is None: |
190 | | - device = 'cuda' if torch.cuda.is_available() else 'cpu' |
191 | | - sampling_rate = hparams['audio_sample_rate'] |
192 | | - n_mel_channels = hparams['audio_num_mel_bins'] |
193 | | - n_fft = hparams['fft_size'] |
194 | | - win_length = hparams['win_size'] |
195 | | - hop_length = hparams['hop_size'] |
196 | | - mel_fmin = hparams['fmin'] |
197 | | - mel_fmax = hparams['fmax'] |
198 | | - |
199 | | - # load input |
200 | | - x, _ = librosa.load(inp_path, sr=sampling_rate) |
201 | | - x_t = torch.from_numpy(x).float().to(device) |
202 | | - x_t = x_t.unsqueeze(0).unsqueeze(0) # (T,) --> (1, 1, T) |
203 | | - |
204 | | - # mel analysis |
205 | | - mel_extractor = Audio2Mel( |
206 | | - hop_length=hop_length, |
207 | | - sampling_rate=sampling_rate, |
208 | | - n_mel_channels=n_mel_channels, |
209 | | - win_length=win_length, |
210 | | - n_fft=n_fft, |
211 | | - mel_fmin=mel_fmin, |
212 | | - mel_fmax=mel_fmax).to(device) |
213 | | - |
214 | | - mel = mel_extractor(x_t, keyshift=keyshift, speed=speed) |
215 | | - return x, mel.squeeze(0).cpu().numpy() |
0 commit comments