Skip to content

Commit 28563a1

Browse files
yqzhishenyxlllc
andauthored
Support using e as the base number of the log mel-spectrogram (#175)
* Add `mel_base` and set `e` as default value * Export dynamic `mel_base` * Restore old configs * Refactor mel extraction * add mel-base check for ddsp vocoder --------- Co-authored-by: yxlllc <llc1995@sina.com>
1 parent 8066e72 commit 28563a1

12 files changed

Lines changed: 95 additions & 186 deletions

File tree

augmentation/spec_stretch.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from copy import deepcopy
22

3+
import librosa
34
import numpy as np
45
import torch
56

67
from basics.base_augmentation import BaseAugmentation, require_same_keys
78
from basics.base_pe import BasePE
89
from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST
910
from modules.fastspeech.tts_modules import LengthRegulator
10-
from modules.vocoders.registry import VOCODERS
11-
from utils.binarizer_utils import get_mel2ph_torch
11+
from utils.binarizer_utils import get_mel_torch, get_mel2ph_torch
1212
from utils.hparams import hparams
1313
from utils.infer_utils import resample_align_curve
1414

@@ -27,14 +27,13 @@ def __init__(self, data_dirs: list, augmentation_args: dict, pe: BasePE = None):
2727
@require_same_keys
2828
def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None) -> dict:
2929
aug_item = deepcopy(item)
30-
if hparams['vocoder'] in VOCODERS:
31-
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(
32-
aug_item['wav_fn'], keyshift=key_shift, speed=speed
33-
)
34-
else:
35-
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(
36-
aug_item['wav_fn'], keyshift=key_shift, speed=speed
37-
)
30+
waveform, _ = librosa.load(aug_item['wav_fn'], sr=hparams['audio_sample_rate'], mono=True)
31+
mel = get_mel_torch(
32+
waveform, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'],
33+
hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'],
34+
fmin=hparams['fmin'], fmax=hparams['fmax'], mel_base=hparams['mel_base'],
35+
keyshift=key_shift, speed=speed, device=self.device
36+
)
3837

3938
aug_item['mel'] = mel
4039

@@ -48,7 +47,7 @@ def process_item(self, item: dict, key_shift=0., speed=1., replace_spk_id=None)
4847
).cpu().numpy()
4948

5049
f0, _ = self.pe.get_pitch(
51-
wav, samplerate=hparams['audio_sample_rate'], length=aug_item['length'],
50+
waveform, samplerate=hparams['audio_sample_rate'], length=aug_item['length'],
5251
hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'],
5352
speed=speed, interp_uv=True
5453
)

basics/base_vocoder.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,3 @@ def spec2wav(self, mel, **kwargs):
2121
"""
2222

2323
raise NotImplementedError()
24-
25-
@staticmethod
26-
def wav2spec(wav_fn):
27-
"""
28-
29-
:param wav_fn: str
30-
:return: wav, mel: [T, 80]
31-
"""
32-
raise NotImplementedError()

configs/acoustic.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ binarizer_cls: preprocessing.acoustic_binarizer.AcousticBinarizer
4747
dictionary: dictionaries/opencpop-extension.txt
4848
spec_min: [-5]
4949
spec_max: [0]
50-
mel_vmin: -6. #-6.
50+
mel_vmin: -6.
5151
mel_vmax: 1.5
52+
mel_base: '10'
5253
energy_smooth_width: 0.12
5354
breathiness_smooth_width: 0.12
5455
voicing_smooth_width: 0.12

deployment/exporters/acoustic_exporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def export_attachments(self, path: Path):
148148
dsconfig['num_mel_bins'] = hparams['audio_num_mel_bins']
149149
dsconfig['mel_fmin'] = hparams['fmin']
150150
dsconfig['mel_fmax'] = hparams['fmax'] if hparams['fmax'] is not None else hparams['audio_sample_rate'] / 2
151-
dsconfig['mel_base'] = '10'
151+
dsconfig['mel_base'] = str(hparams.get('mel_base', '10'))
152152
dsconfig['mel_scale'] = 'slaney'
153153
config_path = path / 'dsconfig.yaml'
154154
with open(config_path, 'w', encoding='utf8') as fw:

deployment/exporters/nsf_hifigan_exporter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def build_model(self) -> nn.Module:
3333
config_path = self.model_path.with_name('config.json')
3434
with open(config_path, 'r', encoding='utf8') as f:
3535
config = json.load(f)
36-
model = NSFHiFiGANONNX(config).eval().to(self.device)
36+
model = NSFHiFiGANONNX(config, mel_base=hparams.get('mel_base', '10')).eval().to(self.device)
3737
load_ckpt(model.generator, str(self.model_path),
3838
prefix_in_ckpt=None, key_in_ckpt='generator',
3939
strict=True, device=self.device)
@@ -67,7 +67,7 @@ def export_attachments(self, path: Path):
6767
'num_mel_bins': hparams['audio_num_mel_bins'],
6868
'mel_fmin': hparams['fmin'],
6969
'mel_fmax': hparams['fmax'] if hparams['fmax'] is not None else hparams['audio_sample_rate'] / 2,
70-
'mel_base': '10',
70+
'mel_base': str(hparams.get('mel_base', '10')),
7171
'mel_scale': 'slaney',
7272
}, fw, sort_keys=False)
7373
print(f'| export configs => {config_path} **PLEASE EDIT BEFORE USE**')

deployment/modules/nsf_hifigan.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66

77
# noinspection SpellCheckingInspection
88
class NSFHiFiGANONNX(torch.nn.Module):
9-
def __init__(self, attrs: dict):
9+
def __init__(self, attrs: dict, mel_base='e'):
1010
super().__init__()
11+
self.mel_base = str(mel_base)
12+
assert self.mel_base in ['e', '10'], "mel_base must be 'e', '10' or 10."
1113
self.generator = Generator(AttrDict(attrs))
1214

1315
def forward(self, mel: torch.Tensor, f0: torch.Tensor):
14-
mel = mel.transpose(1, 2) * 2.30259
16+
mel = mel.transpose(1, 2)
17+
if self.mel_base != 'e':
18+
# log10 to log mel
19+
mel = mel * 2.30259
1520
wav = self.generator(mel, f0)
1621
return wav.squeeze(1)

inference/val_nsf_hifigan.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
22
import sys
33

4+
import librosa
45
import numpy as np
56
import resampy
67
import torch
78
import torchcrepe
89
import tqdm
910

10-
from utils.binarizer_utils import get_pitch_parselmouth
11+
from utils.binarizer_utils import get_pitch_parselmouth, get_mel_torch
1112
from modules.vocoders.nsf_hifigan import NsfHifiGAN
1213
from utils.infer_utils import save_wav
1314
from utils.hparams import set_hparams, hparams
@@ -60,7 +61,14 @@ def get_pitch(wav_data, mel, hparams, threshold=0.3):
6061
for filename in tqdm.tqdm(os.listdir(in_path)):
6162
if not filename.endswith('.wav'):
6263
continue
63-
wav, mel = vocoder.wav2spec(os.path.join(in_path, filename))
64+
wav, _ = librosa.load(os.path.join(in_path, filename), sr=hparams['audio_sample_rate'], mono=True)
65+
mel = get_mel_torch(
66+
wav, hparams['audio_sample_rate'], num_mel_bins=hparams['audio_num_mel_bins'],
67+
hop_size=hparams['hop_size'], win_size=hparams['win_size'], fft_size=hparams['fft_size'],
68+
fmin=hparams['fmin'], fmax=hparams['fmax'], mel_base=hparams['mel_base'],
69+
device=device
70+
)
71+
6472
f0, _ = get_pitch_parselmouth(
6573
wav, samplerate=hparams['audio_sample_rate'], length=len(mel),
6674
hop_size=hparams['hop_size']

modules/nsf_hifigan/nvSTFT.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,10 @@
44
import torch
55
import torch.utils.data
66
import numpy as np
7-
import librosa
87
from librosa.filters import mel as librosa_mel_fn
98
import torch.nn.functional as F
109

1110

12-
def load_wav_to_torch(full_path, target_sr=None):
13-
data, sr = librosa.load(full_path, sr=target_sr, mono=True)
14-
return torch.from_numpy(data), sr
15-
16-
1711
def dynamic_range_compression(x, C=1, clip_val=1e-5):
1812
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
1913

@@ -96,8 +90,3 @@ def get_mel(self, y, keyshift=0, speed=1, center=False):
9690
spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
9791

9892
return spec
99-
100-
def __call__(self, audiopath):
101-
audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
102-
spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
103-
return spect

modules/vocoders/ddsp.py

Lines changed: 18 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import os
21
import pathlib
32

4-
import librosa
3+
import numpy as np
54
import torch
65
import torch.nn.functional as F
76
import yaml
8-
import numpy as np
97
from librosa.filters import mel as librosa_mel_fn
8+
109
from basics.base_vocoder import BaseVocoder
1110
from modules.vocoders.registry import register_vocoder
1211
from utils.hparams import hparams
@@ -35,83 +34,6 @@ def load_model(model_path: pathlib.Path, device='cpu'):
3534
return model, args
3635

3736

38-
class Audio2Mel(torch.nn.Module):
39-
def __init__(
40-
self,
41-
hop_length,
42-
sampling_rate,
43-
n_mel_channels,
44-
win_length,
45-
n_fft=None,
46-
mel_fmin=0,
47-
mel_fmax=None,
48-
clamp=1e-5
49-
):
50-
super().__init__()
51-
n_fft = win_length if n_fft is None else n_fft
52-
self.hann_window = {}
53-
mel_basis = librosa_mel_fn(
54-
sr=sampling_rate,
55-
n_fft=n_fft,
56-
n_mels=n_mel_channels,
57-
fmin=mel_fmin,
58-
fmax=mel_fmax)
59-
mel_basis = torch.from_numpy(mel_basis).float()
60-
self.register_buffer("mel_basis", mel_basis)
61-
self.n_fft = n_fft
62-
self.hop_length = hop_length
63-
self.win_length = win_length
64-
self.sampling_rate = sampling_rate
65-
self.n_mel_channels = n_mel_channels
66-
self.clamp = clamp
67-
68-
def forward(self, audio, keyshift=0, speed=1):
69-
'''
70-
audio: B x C x T
71-
log_mel_spec: B x T_ x C x n_mel
72-
'''
73-
factor = 2 ** (keyshift / 12)
74-
n_fft_new = int(np.round(self.n_fft * factor))
75-
win_length_new = int(np.round(self.win_length * factor))
76-
hop_length_new = int(np.round(self.hop_length * speed))
77-
78-
keyshift_key = str(keyshift) + '_' + str(audio.device)
79-
if keyshift_key not in self.hann_window:
80-
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
81-
82-
B, C, T = audio.shape
83-
audio = audio.reshape(B * C, T)
84-
fft = torch.stft(
85-
audio,
86-
n_fft=n_fft_new,
87-
hop_length=hop_length_new,
88-
win_length=win_length_new,
89-
window=self.hann_window[keyshift_key],
90-
center=True,
91-
return_complex=False)
92-
real_part, imag_part = fft.unbind(-1)
93-
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
94-
95-
if keyshift != 0:
96-
size = self.n_fft // 2 + 1
97-
resize = magnitude.size(1)
98-
if resize < size:
99-
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
100-
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
101-
102-
mel_output = torch.matmul(self.mel_basis, magnitude)
103-
log_mel_spec = torch.log10(torch.clamp(mel_output, min=self.clamp))
104-
105-
# log_mel_spec: B x C, M, T
106-
T_ = log_mel_spec.shape[-1]
107-
log_mel_spec = log_mel_spec.reshape(B, C, self.n_mel_channels, T_)
108-
log_mel_spec = log_mel_spec.permute(0, 3, 1, 2)
109-
110-
# print('og_mel_spec:', log_mel_spec.shape)
111-
log_mel_spec = log_mel_spec.squeeze(2) # mono
112-
return log_mel_spec
113-
114-
11537
@register_vocoder
11638
class DDSP(BaseVocoder):
11739
def __init__(self, device='cpu'):
@@ -149,8 +71,15 @@ def spec2wav_torch(self, mel, f0): # mel: [B, T, bins] f0: [B, T]
14971
print('Mismatch parameters: hparams[\'fmax\']=', hparams['fmax'], '!=', self.args.data.mel_fmax,
15072
'(vocoder)')
15173
with torch.no_grad():
152-
f0 = f0.unsqueeze(-1)
153-
signal, _, (s_h, s_n) = self.model(mel.to(self.device), f0.to(self.device))
74+
mel = mel.to(self.device)
75+
mel_base = hparams.get('mel_base', 10)
76+
if mel_base != 'e':
77+
assert mel_base in [10, '10'], "mel_base must be 'e', '10' or 10."
78+
else:
79+
# log mel to log10 mel
80+
mel = 0.434294 * mel
81+
f0 = f0.unsqueeze(-1).to(self.device)
82+
signal, _, (s_h, s_n) = self.model(mel, f0)
15483
signal = signal.view(-1)
15584
return signal
15685

@@ -178,38 +107,14 @@ def spec2wav(self, mel, f0):
178107
'(vocoder)')
179108
with torch.no_grad():
180109
mel = torch.FloatTensor(mel).unsqueeze(0).to(self.device)
110+
mel_base = hparams.get('mel_base', 10)
111+
if mel_base != 'e':
112+
assert mel_base in [10, '10'], "mel_base must be 'e', '10' or 10."
113+
else:
114+
# log mel to log10 mel
115+
mel = 0.434294 * mel
181116
f0 = torch.FloatTensor(f0).unsqueeze(0).unsqueeze(-1).to(self.device)
182-
signal, _, (s_h, s_n) = self.model(mel.to(self.device), f0.to(self.device))
117+
signal, _, (s_h, s_n) = self.model(mel, f0)
183118
signal = signal.view(-1)
184119
wav_out = signal.cpu().numpy()
185120
return wav_out
186-
187-
@staticmethod
188-
def wav2spec(inp_path, keyshift=0, speed=1, device=None):
189-
if device is None:
190-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
191-
sampling_rate = hparams['audio_sample_rate']
192-
n_mel_channels = hparams['audio_num_mel_bins']
193-
n_fft = hparams['fft_size']
194-
win_length = hparams['win_size']
195-
hop_length = hparams['hop_size']
196-
mel_fmin = hparams['fmin']
197-
mel_fmax = hparams['fmax']
198-
199-
# load input
200-
x, _ = librosa.load(inp_path, sr=sampling_rate)
201-
x_t = torch.from_numpy(x).float().to(device)
202-
x_t = x_t.unsqueeze(0).unsqueeze(0) # (T,) --> (1, 1, T)
203-
204-
# mel analysis
205-
mel_extractor = Audio2Mel(
206-
hop_length=hop_length,
207-
sampling_rate=sampling_rate,
208-
n_mel_channels=n_mel_channels,
209-
win_length=win_length,
210-
n_fft=n_fft,
211-
mel_fmin=mel_fmin,
212-
mel_fmax=mel_fmax).to(device)
213-
214-
mel = mel_extractor(x_t, keyshift=keyshift, speed=speed)
215-
return x, mel.squeeze(0).cpu().numpy()

0 commit comments

Comments
 (0)