Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions python/paddle/audio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@

import paddle
from paddle import Tensor
from paddle.base.framework import Variable
from paddle.pir import Value


def hz_to_mel(
freq: Union[Tensor, float], htk: bool = False
) -> Union[Tensor, float]:
freq: Union[Tensor, Value, Variable, float], htk: bool = False
) -> Union[Tensor, Value, Variable, float]:
"""Convert Hz to Mels.

Args:
Expand All @@ -43,7 +45,7 @@ def hz_to_mel(
"""

if htk:
if isinstance(freq, Tensor):
if isinstance(freq, (Tensor, Variable, Value)):
return 2595.0 * paddle.log10(1.0 + freq / 700.0)
else:
return 2595.0 * math.log10(1.0 + freq / 700.0)
Expand All @@ -60,7 +62,7 @@ def hz_to_mel(
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = math.log(6.4) / 27.0 # step size for log region

if isinstance(freq, Tensor):
if isinstance(freq, (Tensor, Variable, Value)):
target = (
min_log_mel + paddle.log(freq / min_log_hz + 1e-10) / logstep
) # prevent nan with 1e-10
Expand All @@ -76,8 +78,8 @@ def hz_to_mel(


def mel_to_hz(
mel: Union[float, Tensor], htk: bool = False
) -> Union[float, Tensor]:
mel: Union[float, Tensor, Variable, Value], htk: bool = False
) -> Union[float, Tensor, Variable, Value]:
"""Convert mel bin numbers to frequencies.

Args:
Expand Down Expand Up @@ -108,7 +110,7 @@ def mel_to_hz(
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = math.log(6.4) / 27.0 # step size for log region
if isinstance(mel, Tensor):
if isinstance(mel, (Tensor, Variable, Value)):
target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel))
mask = (mel > min_log_mel).astype(mel.dtype)
freqs = target * mask + freqs * (
Expand All @@ -126,7 +128,7 @@ def mel_frequencies(
f_max: float = 11025.0,
htk: bool = False,
dtype: str = 'float32',
) -> Tensor:
) -> Union[Tensor, Variable, Value]:
"""Compute mel frequencies.

Args:
Expand Down Expand Up @@ -257,11 +259,11 @@ def compute_fbank_matrix(


def power_to_db(
spect: Tensor,
spect: Union[Tensor, Variable, Value],
ref_value: float = 1.0,
amin: float = 1e-10,
top_db: Optional[float] = 80.0,
) -> Tensor:
) -> Union[Tensor, Variable, Value]:
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way.

Args:
Expand Down
14 changes: 7 additions & 7 deletions python/paddle/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import _C_ops
from .base.data_feeder import check_variable_and_dtype
from .base.layer_helper import LayerHelper
from .framework import in_dynamic_mode
from .framework import in_dynamic_or_pir_mode
from .tensor.attribute import is_floating_point, is_integer
from .tensor.creation import _complex_to_real_dtype, _real_to_complex_dtype

Expand Down Expand Up @@ -1404,7 +1404,7 @@ def fft_c2c(x, n, axis, norm, forward, name):
s = [n]
x = _resize_fft_input(x, s, axes)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
out = _C_ops.fft_c2c(x, axes, norm, forward)
else:
op_type = 'fft_c2c'
Expand Down Expand Up @@ -1435,7 +1435,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name):
_check_fft_n(n)
s = [n]
x = _resize_fft_input(x, s, axes)
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
out = _C_ops.fft_r2c(x, axes, norm, forward, onesided)
else:
op_type = 'fft_r2c'
Expand Down Expand Up @@ -1478,7 +1478,7 @@ def fft_c2r(x, n, axis, norm, forward, name):
s = [n // 2 + 1]
x = _resize_fft_input(x, s, axes)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if n is not None:
out = _C_ops.fft_c2r(x, axes, norm, forward, n)
else:
Expand Down Expand Up @@ -1537,7 +1537,7 @@ def fftn_c2c(x, s, axes, norm, forward, name):
if s is not None:
x = _resize_fft_input(x, s, axes)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
out = _C_ops.fft_c2c(x, axes, norm, forward)
else:
op_type = 'fft_c2c'
Expand Down Expand Up @@ -1587,7 +1587,7 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name):
if s is not None:
x = _resize_fft_input(x, s, axes)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
out = _C_ops.fft_r2c(x, axes, norm, forward, onesided)
else:
op_type = 'fft_r2c'
Expand Down Expand Up @@ -1651,7 +1651,7 @@ def fftn_c2r(x, s, axes, norm, forward, name):
fft_input_shape[-1] = fft_input_shape[-1] // 2 + 1
x = _resize_fft_input(x, fft_input_shape, axes)

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if s is not None:
out = _C_ops.fft_c2r(x, axes, norm, forward, s[-1])
else:
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,7 @@ def mish(x, name=None):
Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
[-0.03357624, 0. , 4.99955177])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.mish(x, 20)
else:
check_variable_and_dtype(
Expand Down
10 changes: 6 additions & 4 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4662,12 +4662,12 @@ def init_shape(self):
self.shape = [10, 12]

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)


class TestMish_ZeroDim(TestMish):
Expand All @@ -4686,6 +4686,7 @@ def setUp(self):
else paddle.CPUPlace()
)

@test_with_pir_api
def test_static_api(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -4719,6 +4720,7 @@ def test_base_api(self):
out_ref = ref_mish(self.x_np)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)

@test_with_pir_api
def test_errors(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -4914,7 +4916,7 @@ def test_check_grad(self):
create_test_act_fp16_class(TestHardSigmoid, check_pir=True)
create_test_act_fp16_class(TestSwish)
create_test_act_fp16_class(TestHardSwish, check_prim=True, check_pir=True)
create_test_act_fp16_class(TestMish)
create_test_act_fp16_class(TestMish, check_pir=True)
create_test_act_fp16_class(
TestLeakyRelu, check_prim=True, enable_cinn=True, check_pir=True
)
Expand Down Expand Up @@ -5073,7 +5075,7 @@ def test_check_grad(self):
create_test_act_bf16_class(TestHardSigmoid, check_pir=True)
create_test_act_bf16_class(TestSwish)
create_test_act_bf16_class(TestHardSwish, check_prim=True, check_pir=True)
create_test_act_bf16_class(TestMish)
create_test_act_bf16_class(TestMish, check_pir=True)
create_test_act_bf16_class(TestLeakyRelu, check_prim=True, check_pir=True)
create_test_act_bf16_class(TestLeakyReluAlpha1, check_prim=True)
create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True)
Expand Down
83 changes: 82 additions & 1 deletion test/legacy_test/test_audio_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import paddle
import paddle.audio
from paddle.pir_utils import test_with_pir_api


def parameterize(*params):
Expand All @@ -29,6 +30,7 @@ def parameterize(*params):

class TestAudioFuncitons(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.initParmas()

def initParmas(self):
Expand All @@ -52,7 +54,9 @@ def get_wav_data(dtype: str, num_channels: int, num_frames: int):
self.dtype = "float32"
self.window_size = 1024
waveform_tensor = get_wav_data(
self.dtype, self.num_channels, num_frames=self.duration * self.sr
self.dtype,
self.num_channels,
num_frames=int(self.duration * self.sr),
)
self.waveform = waveform_tensor.numpy()

Expand Down Expand Up @@ -86,6 +90,56 @@ def test_audio_function(self, val: float, htk_flag: bool):
decibel_paddle.numpy(), decibel_paddle, decimal=5
)

@parameterize([1.0, 3.0, 9.0, 25.0], [True, False])
@test_with_pir_api
def test_audio_function_static(self, val: float, htk_flag: bool):
paddle.enable_static()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
mel_paddle_tensor = paddle.audio.functional.hz_to_mel(
paddle.to_tensor([val]), htk_flag
)

hz_paddle_tensor = paddle.audio.functional.mel_to_hz(
paddle.to_tensor([val]), htk_flag
)

decibel_paddle = paddle.audio.functional.power_to_db(
paddle.to_tensor([val])
)

exe = paddle.static.Executor()
(
mel_paddle_tensor_ret,
hz_paddle_tensor_ret,
decibel_paddle_ret,
) = exe.run(
main,
fetch_list=[
mel_paddle_tensor,
hz_paddle_tensor,
decibel_paddle,
],
)

mel_librosa = librosa.hz_to_mel(val, htk_flag)
np.testing.assert_almost_equal(
mel_paddle_tensor_ret, mel_librosa, decimal=4
)

hz_librosa = librosa.mel_to_hz(val, htk_flag)
np.testing.assert_almost_equal(
hz_paddle_tensor_ret, hz_librosa, decimal=4
)

decibel_librosa = librosa.power_to_db(val)
np.testing.assert_almost_equal(
decibel_paddle_ret, decibel_librosa, decimal=5
)

paddle.disable_static()

@parameterize(
[64, 128, 256], [0.0, 0.5, 1.0], [10000, 11025], [False, True]
)
Expand All @@ -102,6 +156,33 @@ def test_audio_function_mel(
paddle_mel_freq, librosa_mel_freq, decimal=3
)

@parameterize(
[64, 128, 256], [0.0, 0.5, 1.0], [10000, 11025], [False, True]
)
# TODO(MarioLulab) May cause precision error. Fix it soon
# @test_with_pir_api
def test_audio_function_mel_static(
self, n_mels: int, f_min: float, f_max: float, htk_flag: bool
):
paddle.enable_static()
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
paddle_mel_freq = paddle.audio.functional.mel_frequencies(
n_mels, f_min, f_max, htk_flag, 'float64'
)

exe = paddle.static.Executor()
(paddle_mel_freq_ret,) = exe.run(main, fetch_list=[paddle_mel_freq])
librosa_mel_freq = librosa.mel_frequencies(
n_mels, f_min, f_max, htk_flag
)
np.testing.assert_almost_equal(
paddle_mel_freq_ret, librosa_mel_freq, decimal=3
)

paddle.disable_static()

@parameterize([8000, 16000], [64, 128, 256])
def test_audio_function_fft(self, sr: int, n_fft: int):
librosa_fft = librosa.fft_frequencies(sr, n_fft)
Expand Down