diff --git a/python/paddle/audio/functional/functional.py b/python/paddle/audio/functional/functional.py index 930f412384523d..4d4ec21caddaa8 100644 --- a/python/paddle/audio/functional/functional.py +++ b/python/paddle/audio/functional/functional.py @@ -17,11 +17,13 @@ import paddle from paddle import Tensor +from paddle.base.framework import Variable +from paddle.pir import Value def hz_to_mel( - freq: Union[Tensor, float], htk: bool = False -) -> Union[Tensor, float]: + freq: Union[Tensor, Value, Variable, float], htk: bool = False +) -> Union[Tensor, Value, Variable, float]: """Convert Hz to Mels. Args: @@ -43,7 +45,7 @@ def hz_to_mel( """ if htk: - if isinstance(freq, Tensor): + if isinstance(freq, (Tensor, Variable, Value)): return 2595.0 * paddle.log10(1.0 + freq / 700.0) else: return 2595.0 * math.log10(1.0 + freq / 700.0) @@ -60,7 +62,7 @@ def hz_to_mel( min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) logstep = math.log(6.4) / 27.0 # step size for log region - if isinstance(freq, Tensor): + if isinstance(freq, (Tensor, Variable, Value)): target = ( min_log_mel + paddle.log(freq / min_log_hz + 1e-10) / logstep ) # prevent nan with 1e-10 @@ -76,8 +78,8 @@ def hz_to_mel( def mel_to_hz( - mel: Union[float, Tensor], htk: bool = False -) -> Union[float, Tensor]: + mel: Union[float, Tensor, Variable, Value], htk: bool = False +) -> Union[float, Tensor, Variable, Value]: """Convert mel bin numbers to frequencies. Args: @@ -108,7 +110,7 @@ def mel_to_hz( min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) logstep = math.log(6.4) / 27.0 # step size for log region - if isinstance(mel, Tensor): + if isinstance(mel, (Tensor, Variable, Value)): target = min_log_hz * paddle.exp(logstep * (mel - min_log_mel)) mask = (mel > min_log_mel).astype(mel.dtype) freqs = target * mask + freqs * ( @@ -126,7 +128,7 @@ def mel_frequencies( f_max: float = 11025.0, htk: bool = False, dtype: str = 'float32', -) -> Tensor: +) -> Union[Tensor, Variable, Value]: """Compute mel frequencies. Args: @@ -257,11 +259,11 @@ def compute_fbank_matrix( def power_to_db( - spect: Tensor, + spect: Union[Tensor, Variable, Value], ref_value: float = 1.0, amin: float = 1e-10, top_db: Optional[float] = 80.0, -) -> Tensor: +) -> Union[Tensor, Variable, Value]: """Convert a power spectrogram (amplitude squared) to decibel (dB) units. The function computes the scaling `10 * log10(x / ref)` in a numerically stable way. Args: diff --git a/python/paddle/fft.py b/python/paddle/fft.py index dc38b60ac5db5e..5f83985c6d273b 100644 --- a/python/paddle/fft.py +++ b/python/paddle/fft.py @@ -21,7 +21,7 @@ from . import _C_ops from .base.data_feeder import check_variable_and_dtype from .base.layer_helper import LayerHelper -from .framework import in_dynamic_mode +from .framework import in_dynamic_or_pir_mode from .tensor.attribute import is_floating_point, is_integer from .tensor.creation import _complex_to_real_dtype, _real_to_complex_dtype @@ -1404,7 +1404,7 @@ def fft_c2c(x, n, axis, norm, forward, name): s = [n] x = _resize_fft_input(x, s, axes) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.fft_c2c(x, axes, norm, forward) else: op_type = 'fft_c2c' @@ -1435,7 +1435,7 @@ def fft_r2c(x, n, axis, norm, forward, onesided, name): _check_fft_n(n) s = [n] x = _resize_fft_input(x, s, axes) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.fft_r2c(x, axes, norm, forward, onesided) else: op_type = 'fft_r2c' @@ -1478,7 +1478,7 @@ def fft_c2r(x, n, axis, norm, forward, name): s = [n // 2 + 1] x = _resize_fft_input(x, s, axes) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if n is not None: out = _C_ops.fft_c2r(x, axes, norm, forward, n) else: @@ -1537,7 +1537,7 @@ def fftn_c2c(x, s, axes, norm, forward, name): if s is not None: x = _resize_fft_input(x, s, axes) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.fft_c2c(x, axes, norm, forward) else: op_type = 'fft_c2c' @@ -1587,7 +1587,7 @@ def fftn_r2c(x, s, axes, norm, forward, onesided, name): if s is not None: x = _resize_fft_input(x, s, axes) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): out = _C_ops.fft_r2c(x, axes, norm, forward, onesided) else: op_type = 'fft_r2c' @@ -1651,7 +1651,7 @@ def fftn_c2r(x, s, axes, norm, forward, name): fft_input_shape[-1] = fft_input_shape[-1] // 2 + 1 x = _resize_fft_input(x, fft_input_shape, axes) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if s is not None: out = _C_ops.fft_c2r(x, axes, norm, forward, s[-1]) else: diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 0cda9a1e7480c8..f89885a7119583 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1494,7 +1494,7 @@ def mish(x, name=None): Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [-0.03357624, 0. , 4.99955177]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.mish(x, 20) else: check_variable_and_dtype( diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 5d8eb99da4d5a9..f25a9ce3a78dcc 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -4662,12 +4662,12 @@ def init_shape(self): self.shape = [10, 12] def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) class TestMish_ZeroDim(TestMish): @@ -4686,6 +4686,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -4719,6 +4720,7 @@ def test_base_api(self): out_ref = ref_mish(self.x_np) np.testing.assert_allclose(out_ref, res[0], rtol=1e-05) + @test_with_pir_api def test_errors(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -4914,7 +4916,7 @@ def test_check_grad(self): create_test_act_fp16_class(TestHardSigmoid, check_pir=True) create_test_act_fp16_class(TestSwish) create_test_act_fp16_class(TestHardSwish, check_prim=True, check_pir=True) -create_test_act_fp16_class(TestMish) +create_test_act_fp16_class(TestMish, check_pir=True) create_test_act_fp16_class( TestLeakyRelu, check_prim=True, enable_cinn=True, check_pir=True ) @@ -5073,7 +5075,7 @@ def test_check_grad(self): create_test_act_bf16_class(TestHardSigmoid, check_pir=True) create_test_act_bf16_class(TestSwish) create_test_act_bf16_class(TestHardSwish, check_prim=True, check_pir=True) -create_test_act_bf16_class(TestMish) +create_test_act_bf16_class(TestMish, check_pir=True) create_test_act_bf16_class(TestLeakyRelu, check_prim=True, check_pir=True) create_test_act_bf16_class(TestLeakyReluAlpha1, check_prim=True) create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True) diff --git a/test/legacy_test/test_audio_functions.py b/test/legacy_test/test_audio_functions.py index 47adbdd4905017..c2a9d2904e230e 100644 --- a/test/legacy_test/test_audio_functions.py +++ b/test/legacy_test/test_audio_functions.py @@ -21,6 +21,7 @@ import paddle import paddle.audio +from paddle.pir_utils import test_with_pir_api def parameterize(*params): @@ -29,6 +30,7 @@ def parameterize(*params): class TestAudioFuncitons(unittest.TestCase): def setUp(self): + paddle.disable_static() self.initParmas() def initParmas(self): @@ -52,7 +54,9 @@ def get_wav_data(dtype: str, num_channels: int, num_frames: int): self.dtype = "float32" self.window_size = 1024 waveform_tensor = get_wav_data( - self.dtype, self.num_channels, num_frames=self.duration * self.sr + self.dtype, + self.num_channels, + num_frames=int(self.duration * self.sr), ) self.waveform = waveform_tensor.numpy() @@ -86,6 +90,56 @@ def test_audio_function(self, val: float, htk_flag: bool): decibel_paddle.numpy(), decibel_paddle, decimal=5 ) + @parameterize([1.0, 3.0, 9.0, 25.0], [True, False]) + @test_with_pir_api + def test_audio_function_static(self, val: float, htk_flag: bool): + paddle.enable_static() + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + mel_paddle_tensor = paddle.audio.functional.hz_to_mel( + paddle.to_tensor([val]), htk_flag + ) + + hz_paddle_tensor = paddle.audio.functional.mel_to_hz( + paddle.to_tensor([val]), htk_flag + ) + + decibel_paddle = paddle.audio.functional.power_to_db( + paddle.to_tensor([val]) + ) + + exe = paddle.static.Executor() + ( + mel_paddle_tensor_ret, + hz_paddle_tensor_ret, + decibel_paddle_ret, + ) = exe.run( + main, + fetch_list=[ + mel_paddle_tensor, + hz_paddle_tensor, + decibel_paddle, + ], + ) + + mel_librosa = librosa.hz_to_mel(val, htk_flag) + np.testing.assert_almost_equal( + mel_paddle_tensor_ret, mel_librosa, decimal=4 + ) + + hz_librosa = librosa.mel_to_hz(val, htk_flag) + np.testing.assert_almost_equal( + hz_paddle_tensor_ret, hz_librosa, decimal=4 + ) + + decibel_librosa = librosa.power_to_db(val) + np.testing.assert_almost_equal( + decibel_paddle_ret, decibel_librosa, decimal=5 + ) + + paddle.disable_static() + @parameterize( [64, 128, 256], [0.0, 0.5, 1.0], [10000, 11025], [False, True] ) @@ -102,6 +156,33 @@ def test_audio_function_mel( paddle_mel_freq, librosa_mel_freq, decimal=3 ) + @parameterize( + [64, 128, 256], [0.0, 0.5, 1.0], [10000, 11025], [False, True] + ) + # TODO(MarioLulab) May cause precision error. Fix it soon + # @test_with_pir_api + def test_audio_function_mel_static( + self, n_mels: int, f_min: float, f_max: float, htk_flag: bool + ): + paddle.enable_static() + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + paddle_mel_freq = paddle.audio.functional.mel_frequencies( + n_mels, f_min, f_max, htk_flag, 'float64' + ) + + exe = paddle.static.Executor() + (paddle_mel_freq_ret,) = exe.run(main, fetch_list=[paddle_mel_freq]) + librosa_mel_freq = librosa.mel_frequencies( + n_mels, f_min, f_max, htk_flag + ) + np.testing.assert_almost_equal( + paddle_mel_freq_ret, librosa_mel_freq, decimal=3 + ) + + paddle.disable_static() + @parameterize([8000, 16000], [64, 128, 256]) def test_audio_function_fft(self, sr: int, n_fft: int): librosa_fft = librosa.fft_frequencies(sr, n_fft)