Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions examples/speechlm2/salm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class SalmEvalConfig:
verbose: bool = True
use_normalizer: Optional[str] = "english" # "english", "basic", or "none" / "None"
device: str = "cuda"
dtype: str = "bfloat16"
extra_eos_tokens: Optional[list[str]] = None
system_prompt: Optional[str] = None
user_prompt: Optional[str] = None
Expand All @@ -56,10 +57,7 @@ class SalmEvalConfig:
def main(cfg: SalmEvalConfig):
logging.info(f'Hydra config:\n{OmegaConf.to_yaml(cfg)}')

with torch.device(cfg.device):
torch.set_default_dtype(torch.bfloat16)
model = SALM.from_pretrained(cfg.pretrained_name).eval().to(torch.bfloat16).to(cfg.device)
torch.set_default_dtype(torch.float32)
model = SALM.from_pretrained(cfg.pretrained_name).eval().to(getattr(torch, cfg.dtype)).to(cfg.device)

cuts = guess_parse_cutset(cfg.inputs).sort_by_duration()
dloader = torch.utils.data.DataLoader(
Expand Down
6 changes: 2 additions & 4 deletions examples/speechlm2/salm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SalmEvalConfig:
output_manifest: str = "generations.jsonl"
verbose: bool = True
device: str = "cuda"
dtype: str = "bfloat16"
extra_eos_tokens: Optional[list[str]] = None
system_prompt: Optional[str] = None
user_prompt: Optional[str] = None
Expand All @@ -51,10 +52,7 @@ class SalmEvalConfig:
def main(cfg: SalmEvalConfig):
logging.info(f"Hydra config:\n{OmegaConf.to_yaml(cfg)}")

with torch.device(cfg.device):
torch.set_default_dtype(torch.bfloat16)
model = SALM.from_pretrained(cfg.pretrained_name).eval().to(torch.bfloat16).to(cfg.device)
torch.set_default_dtype(torch.float32)
model = SALM.from_pretrained(cfg.pretrained_name).eval().to(getattr(torch, cfg.dtype)).to(cfg.device)

conversations = (
guess_parse_cutset(cfg.inputs)
Expand Down
4 changes: 4 additions & 0 deletions examples/speechlm2/to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class HfExportConfig:
# Path where we should save the HuggingFace Hub compatible checkpoint
output_dir: str

# Dtype used for stored parameters
dtype: str = "bfloat16"


def load_checkpoint(model: torch.nn.Module, checkpoint_path: str):
if Path(checkpoint_path).is_dir():
Expand All @@ -60,6 +63,7 @@ def main(cfg: HfExportConfig):
cls = import_class_by_path(cfg.class_path)
model = cls(OmegaConf.to_container(model_cfg, resolve=True))
load_checkpoint(model, cfg.ckpt_path)
model = model.to(getattr(torch, cfg.dtype))
model.save_pretrained(cfg.output_dir)


Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/speechlm2/models/duplex_s2s_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, cfg: dict) -> None:
maybe_install_lora(self)

# Load the pretrained ASR model.
setup_speech_encoder(self)
setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights)

self.embed_audio_tokens = torch.nn.ModuleList(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, cfg: dict) -> None:
maybe_install_lora(self)

# Load the pretrained streaming ASR model and copy its parameters into the audio perception module.
setup_speech_encoder(self)
setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights)

self.speech_generation = TransformerARSpeechDecoder(
speech_decoder_parms=OmegaConf.to_container(self.cfg.speech_decoder),
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/speechlm2/models/salm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(self, cfg) -> None:
maybe_install_lora(self)

# Load the pretrained streaming ASR model and copy its parameters into the audio perception module.
setup_speech_encoder(self)
setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights)

self._use_fsdp = False
self._use_tp = False
Expand Down
19 changes: 11 additions & 8 deletions nemo/collections/speechlm2/parts/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,19 @@ def setup_audio_codec(model: torch.nn.Module):
del model.audio_codec.discriminator # free up some memory


def setup_speech_encoder(model: torch.nn.Module):
def setup_speech_encoder(model: torch.nn.Module, pretrained_weights: bool = True):
"""
Sets up an ``AudioPerceptionModule``, initializing its ``encoder`` and ``preprocessor``
with a pretrained NeMo ``ASRModel``.
The result is assigned to ``model.perception`` attribute and is trainable.
"""
asr = load_pretrained_nemo(ASRModel, model.cfg.pretrained_asr).eval()
with open_dict(model.cfg):
model.cfg.perception.preprocessor = asr.cfg.preprocessor
model.cfg.perception.encoder = asr.cfg.encoder
model.cfg.perception.output_dim = model.llm.config.hidden_size
model.perception = AudioPerceptionModule(model.cfg.perception).train()
model.perception.load_state_dict(asr.state_dict(), strict=False)
if pretrained_weights:
asr = load_pretrained_nemo(ASRModel, model.cfg.pretrained_asr).eval()
with open_dict(model.cfg):
model.cfg.perception.preprocessor = asr.cfg.preprocessor
model.cfg.perception.encoder = asr.cfg.encoder
model.cfg.perception.output_dim = model.llm.config.hidden_size
model.perception = AudioPerceptionModule(model.cfg.perception).train()
model.perception.load_state_dict(asr.state_dict(), strict=False)
else:
model.perception = AudioPerceptionModule(model.cfg.perception).train()
51 changes: 43 additions & 8 deletions tests/collections/speechlm2/test_duplex_s2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,49 @@ def model():
"audio_loss_weight": 1,
"text_loss_weight": 3,
"perception": {
"_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
"modality_adapter": {
"target": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
"output_dim": 2048,
"encoder": {
"_target_": "nemo.collections.asr.modules.ConformerEncoder",
"feat_in": 512,
"att_context_size": [-1, -1],
"causal_downsampling": False,
"conv_context_size": None,
"conv_kernel_size": 9,
"conv_norm_type": "batch_norm",
"d_model": 1024,
"dropout": 0.1,
"dropout_att": 0.1,
"dropout_emb": 0.0,
"dropout_pre_encoder": 0.1,
"feat_in": 128,
"feat_out": -1,
"n_layers": 1,
"d_model": 512,
"subsampling_factor": 1,
"ff_expansion_factor": 4,
"n_heads": 8,
"n_layers": 2,
"pos_emb_max_len": 5000,
"self_attention_model": "rel_pos",
"subsampling": "dw_striding",
"subsampling_conv_channels": 256,
"subsampling_factor": 8,
},
"modality_adapter": {
"_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector",
"d_model": 1024,
},
"preprocessor": {
"_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor",
"dither": 1e-05,
"features": 128,
"frame_splicing": 1,
"log": True,
"n_fft": 512,
"normalize": "per_feature",
"pad_to": 0,
"pad_value": 0.0,
"sample_rate": 16000,
"window": "hann",
"window_size": 0.025,
"window_stride": 0.01,
},
},
"optimizer": {"_target_": "torch.optim.AdamW"},
Expand Down Expand Up @@ -177,13 +212,13 @@ def test_s2s_offline_generation(model):
assert isinstance(ans["text"][0], str)

gen_text = ans["tokens_text"]
assert gen_text.shape == (1, 14)
assert gen_text.shape == (1, 13)
assert gen_text.dtype == torch.long
assert (gen_text >= 0).all()
assert (gen_text < model.text_vocab_size).all()

gen_audio_codes = ans["tokens_audio"]
assert gen_audio_codes.shape == (1, 14, 8)
assert gen_audio_codes.shape == (1, 13, 8)
assert gen_audio_codes.dtype == torch.long
assert (gen_audio_codes >= 0).all()
assert (gen_audio_codes < model.speech_vocab_size).all()
Expand Down
51 changes: 43 additions & 8 deletions tests/collections/speechlm2/test_duplex_s2s_speech_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,49 @@ def model():
"audio_loss_weight": 1,
"text_loss_weight": 3,
"perception": {
"_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
"modality_adapter": {
"target": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
"output_dim": 2048,
"encoder": {
"_target_": "nemo.collections.asr.modules.ConformerEncoder",
"feat_in": 512,
"att_context_size": [-1, -1],
"causal_downsampling": False,
"conv_context_size": None,
"conv_kernel_size": 9,
"conv_norm_type": "batch_norm",
"d_model": 1024,
"dropout": 0.1,
"dropout_att": 0.1,
"dropout_emb": 0.0,
"dropout_pre_encoder": 0.1,
"feat_in": 128,
"feat_out": -1,
"n_layers": 1,
"d_model": 512,
"subsampling_factor": 1,
"ff_expansion_factor": 4,
"n_heads": 8,
"n_layers": 2,
"pos_emb_max_len": 5000,
"self_attention_model": "rel_pos",
"subsampling": "dw_striding",
"subsampling_conv_channels": 256,
"subsampling_factor": 8,
},
"modality_adapter": {
"_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector",
"d_model": 1024,
},
"preprocessor": {
"_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor",
"dither": 1e-05,
"features": 128,
"frame_splicing": 1,
"log": True,
"n_fft": 512,
"normalize": "per_feature",
"pad_to": 0,
"pad_value": 0.0,
"sample_rate": 16000,
"window": "hann",
"window_size": 0.025,
"window_stride": 0.01,
},
},
"speech_decoder": {
Expand Down Expand Up @@ -164,13 +199,13 @@ def test_s2s_speech_decoder_offline_generation(model):
assert isinstance(ans["text"][0], str)

gen_text = ans["tokens_text"]
assert gen_text.shape == (1, 14)
assert gen_text.shape == (1, 13)
assert gen_text.dtype == torch.long
assert (gen_text >= 0).all()
assert (gen_text < model.text_vocab_size).all()

gen_audio_codes = ans["tokens_audio"]
assert gen_audio_codes.shape == (1, 14, 8)
assert gen_audio_codes.shape == (1, 13, 8)
assert gen_audio_codes.dtype == torch.long
assert (gen_audio_codes >= 0).all()
assert (gen_audio_codes < model.speech_vocab_size).all()
Expand Down
40 changes: 39 additions & 1 deletion tests/collections/speechlm2/test_salm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,50 @@ def model():
"prompt_format": PROMPT,
"audio_locator_tag": AUDIO_LOCATOR_TAG,
"perception": {
"_target_": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
"target": "nemo.collections.speechlm2.modules.perception.AudioPerceptionModule",
"output_dim": 2048,
"encoder": {
"_target_": "nemo.collections.asr.modules.ConformerEncoder",
"att_context_size": [-1, -1],
"causal_downsampling": False,
"conv_context_size": None,
"conv_kernel_size": 9,
"conv_norm_type": "batch_norm",
"d_model": 1024,
"dropout": 0.1,
"dropout_att": 0.1,
"dropout_emb": 0.0,
"dropout_pre_encoder": 0.1,
"feat_in": 128,
"feat_out": -1,
"ff_expansion_factor": 4,
"n_heads": 8,
"n_layers": 2,
"pos_emb_max_len": 5000,
"self_attention_model": "rel_pos",
"subsampling": "dw_striding",
"subsampling_conv_channels": 256,
"subsampling_factor": 8,
},
"modality_adapter": {
"_target_": "nemo.collections.speechlm2.modules.perception.IdentityConnector",
"d_model": 1024,
},
"preprocessor": {
"_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor",
"dither": 1e-05,
"features": 128,
"frame_splicing": 1,
"log": True,
"n_fft": 512,
"normalize": "per_feature",
"pad_to": 0,
"pad_value": 0.0,
"sample_rate": 16000,
"window": "hann",
"window_size": 0.025,
"window_stride": 0.01,
},
},
"optimizer": {"_target_": "torch.optim.AdamW"},
}
Expand Down
Loading