From 4096b874175ea15d7e4f48a12560f8ada7c92fde Mon Sep 17 00:00:00 2001 From: Rauf Date: Fri, 2 May 2025 07:43:30 -0700 Subject: [PATCH 1/5] Adding tests for schrodinger bridge model Signed-off-by: Rauf --- .../test_audio_models_schroedinger_bridge.py | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 tests/collections/audio/test_audio_models_schroedinger_bridge.py diff --git a/tests/collections/audio/test_audio_models_schroedinger_bridge.py b/tests/collections/audio/test_audio_models_schroedinger_bridge.py new file mode 100644 index 000000000000..7f35c052a884 --- /dev/null +++ b/tests/collections/audio/test_audio_models_schroedinger_bridge.py @@ -0,0 +1,163 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from omegaconf import DictConfig + +from nemo.collections.audio.models import SchroedingerBridgeAudioToAudioModel + + +@pytest.fixture() +def schroedinger_bridge_model_ncsn(): + + model = { + 'sample_rate': 16000, + 'num_outputs': 1, + 'normalize_input': True, + 'max_utts_evaluation_metrics': 50, + } + encoder = { + '_target_': 'nemo.collections.audio.modules.transforms.AudioToSpectrogram', + 'fft_length': 510, + 'hop_length': 128, + 'magnitude_power': 0.5, + 'scale': 0.33, + } + decoder = { + '_target_': 'nemo.collections.audio.modules.transforms.SpectrogramToAudio', + 'fft_length': encoder['fft_length'], + 'hop_length': encoder['hop_length'], + 'magnitude_power': encoder['magnitude_power'], + 'scale': encoder['scale'], + } + estimator = { + '_target_': 'nemo.collections.audio.parts.submodules.ncsnpp.SpectrogramNoiseConditionalScoreNetworkPlusPlus', + 'in_channels': 2, # single-channel noisy input + 'out_channels': 1, # single-channel estimate + 'conditioned_on_time': True, + 'num_res_blocks': 3, # increased number of res blocks + 'pad_time_to': 64, # pad to 64 frames for the time dimension + 'pad_dimension_to': 0, # no padding in the frequency dimension + } + + loss_encoded = { + '_target_': 'nemo.collections.audio.losses.MSELoss', # computed in the time domain + 'ndim': 4 + } + + loss_time = { + '_target_': 'nemo.collections.audio.losses.MAELoss' + } + + noise_schedule = { + '_target_': 'nemo.collections.audio.parts.submodules.schroedinger_bridge.SBNoiseScheduleVE', + 'k': 2.6, + 'c': 0.4, + 'time_min': 1e-4, + 'time_max': 1.0, + 'num_steps': 1000 # num steps for the forward process + } + + sampler = { + '_target_': 'nemo.collections.audio.parts.submodules.schroedinger_bridge.SBSampler', + 'time_min': 1e-4, + 'time_max': 1.0, + 'num_steps': 5 # num steps for the reverse process + } + + + + model_config = DictConfig( + { + 'sample_rate': model['sample_rate'], + 'num_outputs': model['num_outputs'], + 'normalize_input': model['normalize_input'], + 'max_utts_evaluation_metrics': model['max_utts_evaluation_metrics'], + 'encoder': DictConfig(encoder), + 'decoder': DictConfig(decoder), + 'estimator': DictConfig(estimator), + 'loss_encoded': DictConfig(loss_encoded), + 'loss_time': DictConfig(loss_time), + 'loss_time_weight': 0.001, + 'estimator_output': 'data_prediction', + 'noise_schedule': DictConfig(noise_schedule), + 'sampler': DictConfig(sampler), + + 'optim': { + 'optimizer': 'Adam', + 'lr': 0.001, + 'betas': (0.9, 0.98), + }, + } + ) + + model = SchroedingerBridgeAudioToAudioModel(cfg=model_config) + + return model + + +class TestSchroedingerBridgeModelNCSN: + """Test Schroedinger Bridge model with NCSN estimator.""" + + @pytest.mark.unit + def test_constructor(self, schroedinger_bridge_model_ncsn): + """Test that the model can be constructed from a config dict.""" + model = schroedinger_bridge_model_ncsn.train() + confdict = model.to_config_dict() + instance2 = SchroedingerBridgeAudioToAudioModel.from_config_dict(confdict) + assert isinstance(instance2, SchroedingerBridgeAudioToAudioModel) + + @pytest.mark.unit + @pytest.mark.parametrize( + "batch_size, sample_len", + [ + (4, 4), # Example 1 + (2, 8), # Example 2 + (1, 10), # Example 3 + ], + ) + def test_forward_infer(self, schroedinger_bridge_model_ncsn, batch_size, sample_len): + """Test that the model can run forward inference.""" + model = schroedinger_bridge_model_ncsn.eval() + confdict = model.to_config_dict() + sampling_rate = confdict['sample_rate'] + rng = torch.Generator() + rng.manual_seed(0) + input_signal = torch.randn(size=(batch_size, 1, sample_len * sampling_rate), generator=rng) + input_signal_length = (sample_len * sampling_rate) * torch.ones(batch_size, dtype=torch.int) + + with torch.no_grad(): + # batch size 1 + output_list = [] + output_length_list = [] + for i in range(input_signal.size(0)): + output, output_length = model.forward( + input_signal=input_signal[i : i + 1], input_length=input_signal_length[i : i + 1] + ) + output_list.append(output) + output_length_list.append(output_length) + output_instance = torch.cat(output_list, 0) + output_length_instance = torch.cat(output_length_list, 0) + + # batch size batch_size + output_batch, output_length_batch = model.forward( + input_signal=input_signal, input_length=input_signal_length + ) + + # It is generative model so we do not check the diffenence between output_instance and output_batch + + # Check that the output and output length are the same for the instance and batch + assert output_instance.shape == output_batch.shape + assert output_length_instance.shape == output_length_batch.shape \ No newline at end of file From 5c1ef74636708785e50e15a781d539b5f6bed9e0 Mon Sep 17 00:00:00 2001 From: nasretdinovr Date: Fri, 2 May 2025 15:00:11 +0000 Subject: [PATCH 2/5] Apply isort and black reformatting Signed-off-by: nasretdinovr --- .../test_audio_models_schroedinger_bridge.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/collections/audio/test_audio_models_schroedinger_bridge.py b/tests/collections/audio/test_audio_models_schroedinger_bridge.py index 7f35c052a884..af971d883b54 100644 --- a/tests/collections/audio/test_audio_models_schroedinger_bridge.py +++ b/tests/collections/audio/test_audio_models_schroedinger_bridge.py @@ -52,14 +52,9 @@ def schroedinger_bridge_model_ncsn(): 'pad_dimension_to': 0, # no padding in the frequency dimension } - loss_encoded = { - '_target_': 'nemo.collections.audio.losses.MSELoss', # computed in the time domain - 'ndim': 4 - } + loss_encoded = {'_target_': 'nemo.collections.audio.losses.MSELoss', 'ndim': 4} # computed in the time domain - loss_time = { - '_target_': 'nemo.collections.audio.losses.MAELoss' - } + loss_time = {'_target_': 'nemo.collections.audio.losses.MAELoss'} noise_schedule = { '_target_': 'nemo.collections.audio.parts.submodules.schroedinger_bridge.SBNoiseScheduleVE', @@ -67,18 +62,16 @@ def schroedinger_bridge_model_ncsn(): 'c': 0.4, 'time_min': 1e-4, 'time_max': 1.0, - 'num_steps': 1000 # num steps for the forward process + 'num_steps': 1000, # num steps for the forward process } sampler = { '_target_': 'nemo.collections.audio.parts.submodules.schroedinger_bridge.SBSampler', 'time_min': 1e-4, 'time_max': 1.0, - 'num_steps': 5 # num steps for the reverse process + 'num_steps': 5, # num steps for the reverse process } - - model_config = DictConfig( { 'sample_rate': model['sample_rate'], @@ -94,7 +87,6 @@ def schroedinger_bridge_model_ncsn(): 'estimator_output': 'data_prediction', 'noise_schedule': DictConfig(noise_schedule), 'sampler': DictConfig(sampler), - 'optim': { 'optimizer': 'Adam', 'lr': 0.001, @@ -157,7 +149,7 @@ def test_forward_infer(self, schroedinger_bridge_model_ncsn, batch_size, sample_ ) # It is generative model so we do not check the diffenence between output_instance and output_batch - + # Check that the output and output length are the same for the instance and batch assert output_instance.shape == output_batch.shape - assert output_length_instance.shape == output_length_batch.shape \ No newline at end of file + assert output_length_instance.shape == output_length_batch.shape From 079c3d2e48eafd7ade0d3cefd9777c6dfdc578eb Mon Sep 17 00:00:00 2001 From: Rauf Date: Fri, 2 May 2025 08:18:22 -0700 Subject: [PATCH 3/5] Updated the year Signed-off-by: Rauf --- .../collections/audio/test_audio_models_schroedinger_bridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collections/audio/test_audio_models_schroedinger_bridge.py b/tests/collections/audio/test_audio_models_schroedinger_bridge.py index 7f35c052a884..889080ce1469 100644 --- a/tests/collections/audio/test_audio_models_schroedinger_bridge.py +++ b/tests/collections/audio/test_audio_models_schroedinger_bridge.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 5b2d0fb56f2657a472bf7f33e22dbd2947604990 Mon Sep 17 00:00:00 2001 From: nasretdinovr Date: Fri, 2 May 2025 19:23:59 +0400 Subject: [PATCH 4/5] Updated the year Signed-off-by: nasretdinovr --- .../collections/audio/test_audio_models_schroedinger_bridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collections/audio/test_audio_models_schroedinger_bridge.py b/tests/collections/audio/test_audio_models_schroedinger_bridge.py index af971d883b54..701e516630aa 100644 --- a/tests/collections/audio/test_audio_models_schroedinger_bridge.py +++ b/tests/collections/audio/test_audio_models_schroedinger_bridge.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From eefba75520daf0786d47d66874e4dca7281eab9e Mon Sep 17 00:00:00 2001 From: Rauf Date: Tue, 6 May 2025 04:46:38 -0700 Subject: [PATCH 5/5] reduced model channels to make it smaller Signed-off-by: Rauf --- tests/collections/audio/test_audio_models_schroedinger_bridge.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/collections/audio/test_audio_models_schroedinger_bridge.py b/tests/collections/audio/test_audio_models_schroedinger_bridge.py index 701e516630aa..70c99135fafa 100644 --- a/tests/collections/audio/test_audio_models_schroedinger_bridge.py +++ b/tests/collections/audio/test_audio_models_schroedinger_bridge.py @@ -47,6 +47,7 @@ def schroedinger_bridge_model_ncsn(): 'in_channels': 2, # single-channel noisy input 'out_channels': 1, # single-channel estimate 'conditioned_on_time': True, + 'channels': [8, 8, 8, 8, 8], 'num_res_blocks': 3, # increased number of res blocks 'pad_time_to': 64, # pad to 64 frames for the time dimension 'pad_dimension_to': 0, # no padding in the frequency dimension