diff --git a/nemo/collections/common/prompts/canary2.py b/nemo/collections/common/prompts/canary2.py index 84b05dbfb253..c6ba2e2bc847 100644 --- a/nemo/collections/common/prompts/canary2.py +++ b/nemo/collections/common/prompts/canary2.py @@ -81,6 +81,16 @@ class Canary2PromptFormatter(PromptFormatter): "decodercontext": Modality.Text, }, }, + # User prompt. + # This role is used for injecting partial transcription for the current audio input. + # Use it as the last turn in the prompt to allow for resuming the transcription after a ceratin point. + # https://github.com/openai/whisper/discussions/117 + "user_prefix": { + "template": "|prefix|", + "slots": { + "prefix": Modality.Text, + }, + }, # System's reponse. OUTPUT_ROLE: { "template": f"|text|{CANARY_EOS}", diff --git a/tests/collections/common/prompt_formatters/conftest.py b/tests/collections/common/prompt_formatters/conftest.py index 0cbb729c424b..7d5126be8c01 100644 --- a/tests/collections/common/prompt_formatters/conftest.py +++ b/tests/collections/common/prompt_formatters/conftest.py @@ -74,3 +74,25 @@ def canary_tokenizer(bpe_tokenizer, tmp_path_factory): "en": bpe_tokenizer, } ) + + +@pytest.fixture(scope="session") +def canary2_tokenizer(bpe_tokenizer, tmp_path_factory): + tmpdir = tmp_path_factory.mktemp("spl_tokens_canary2") + spl_tokens = CanaryTokenizer.build_special_tokenizer( + [ + "startofcontext", + "en", + "emo:undefined", + "noitn", + "notimestamp", + "nodiarize", + ], + tmpdir, + ) + return CanaryTokenizer( + tokenizers={ + "spl_tokens": spl_tokens, + "en": bpe_tokenizer, + } + ) diff --git a/tests/collections/common/prompt_formatters/test_canary2_prompt_formatter.py b/tests/collections/common/prompt_formatters/test_canary2_prompt_formatter.py new file mode 100644 index 000000000000..ddec1ddd5b0f --- /dev/null +++ b/tests/collections/common/prompt_formatters/test_canary2_prompt_formatter.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.common.prompts.canary2 import Canary2PromptFormatter + + +def test_canary2_prompt_formatter_training(canary2_tokenizer): + formatter = Canary2PromptFormatter(canary2_tokenizer) + ans = formatter.encode_dialog( + [ + { + "role": "user", + "slots": { + "decodercontext": "", + "emotion": "<|emo:undefined|>", + "source_lang": "<|en|>", + "target_lang": "<|en|>", + "pnc": "<|pnc|>", + "itn": "<|noitn|>", + "timestamp": "<|notimestamp|>", + "diarize": "<|nodiarize|>", + "prompt_language": "spl_tokens", + }, + }, + {"role": "assistant", "slots": {"text": "TEST", "prompt_language": "en"}}, + ] + ) + assert set(ans) == {"input_ids", "context_ids", "answer_ids", "mask"} + # fmt: off + assert canary2_tokenizer.ids_to_text(ans["input_ids"].tolist()) == '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en|><|en|><|pnc|><|noitn|><|notimestamp|><|nodiarize|> TEST<|endoftext|>' + assert canary2_tokenizer.ids_to_text(ans["context_ids"].tolist()) == '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en|><|en|><|pnc|><|noitn|><|notimestamp|><|nodiarize|>' + assert canary2_tokenizer.ids_to_text(ans["answer_ids"].tolist()) == ' TEST<|endoftext|>' + assert ans["mask"].shape[0] == ans["input_ids"].shape[0] + # fmt: on + + +def test_canary2_prompt_formatter_inference(canary2_tokenizer): + formatter = Canary2PromptFormatter(canary2_tokenizer) + ans = formatter.encode_dialog( + [ + { + "role": "user", + "slots": { + "decodercontext": "", + "emotion": "<|emo:undefined|>", + "source_lang": "<|en|>", + "target_lang": "<|en|>", + "pnc": "<|pnc|>", + "itn": "<|noitn|>", + "timestamp": "<|notimestamp|>", + "diarize": "<|nodiarize|>", + "prompt_language": "spl_tokens", + }, + }, + ] + ) + assert set(ans) == {"input_ids", "context_ids"} + # fmt: off + assert ans["input_ids"].tolist() == ans["context_ids"].tolist() + assert canary2_tokenizer.ids_to_text(ans["input_ids"].tolist()) == '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en|><|en|><|pnc|><|noitn|><|notimestamp|><|nodiarize|>' + # fmt: on + + +def test_canary2_prompt_formatter_inference_prefix(canary2_tokenizer): + formatter = Canary2PromptFormatter(canary2_tokenizer) + ans = formatter.encode_dialog( + [ + { + "role": "user", + "slots": { + "decodercontext": "", + "emotion": "<|emo:undefined|>", + "source_lang": "<|en|>", + "target_lang": "<|en|>", + "pnc": "<|pnc|>", + "itn": "<|noitn|>", + "timestamp": "<|notimestamp|>", + "diarize": "<|nodiarize|>", + "prompt_language": "spl_tokens", + }, + }, + { + "role": "user_prefix", + "slots": { + "prefix": "TEST", + "prompt_language": "en", + }, + }, + ] + ) + assert set(ans) == {"input_ids", "context_ids"} + # fmt: off + assert ans["input_ids"].tolist() == ans["context_ids"].tolist() + assert canary2_tokenizer.ids_to_text(ans["input_ids"].tolist()) == '<|startofcontext|><|startoftranscript|><|emo:undefined|><|en|><|en|><|pnc|><|noitn|><|notimestamp|><|nodiarize|> TEST' + # fmt: on