Skip to content

Commit b0b5151

Browse files
pzelaskochtruong814
authored andcommitted
Add Nemotron-H prompt format, fix cut-to-conversation custom attr propagation (NVIDIA-NeMo#13963)
* Add Nemotron-H prompt format Signed-off-by: Piotr Żelasko <[email protected]> * Fix propagation of custom attr in cut_to_conversation Signed-off-by: Piotr Żelasko <[email protected]> * Fix CI Signed-off-by: Piotr Żelasko <[email protected]> * Unit test for the conversion fix Signed-off-by: Piotr Żelasko <[email protected]> --------- Signed-off-by: Piotr Żelasko <[email protected]> Co-authored-by: Charlie Truong <[email protected]> Signed-off-by: Amir Hussein <[email protected]>
1 parent cf6e859 commit b0b5151

File tree

10 files changed

+283
-6
lines changed

10 files changed

+283
-6
lines changed

nemo/collections/common/data/lhotse/cutset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ def cut_to_conversation(
478478
]
479479
if hasattr(cut, "context"):
480480
turns = [TextTurn(value=cut.context, role="user")] + turns
481+
if hasattr(cut, "system_prompt"):
482+
turns = [TextTurn(value=cut.system_prompt, role="system")] + turns
481483
return NeMoMultimodalConversation(
482484
id=cut.id,
483485
turns=turns,
@@ -489,6 +491,10 @@ def cut_to_conversation(
489491
@data_type_parser(["lhotse_as_conversation"])
490492
def read_lhotse_as_conversation(config) -> tuple[CutSet, bool]:
491493
cuts, is_tarred = read_cutset_from_config(config)
494+
# Attach extra tags to every utterance dynamically, if provided.
495+
# We need to attach them before cuts are converted to conversations.
496+
if (extra_tags := config.get("tags")) is not None:
497+
cuts = cuts.map(partial(attach_tags, tags=extra_tags), apply_fn=None)
492498
cuts = cuts.map(
493499
partial(
494500
cut_to_conversation,

nemo/collections/common/prompts/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from nemo.collections.common.prompts.gemma import GemmaPromptFormatter
1919
from nemo.collections.common.prompts.llama import Llama2PromptFormatter, Llama3PromptFormatter
2020
from nemo.collections.common.prompts.mistral import MistralPromptFormatter
21+
from nemo.collections.common.prompts.nemotron_h import NemotronHPromptFormatter
2122
from nemo.collections.common.prompts.phi2 import (
2223
Phi2ChatPromptFormatter,
2324
Phi2CodePromptFormatter,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# pylint: disable=missing-function-docstring,missing-class-docstring
15+
from lhotse.cut import Cut, MixedCut
16+
17+
from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn
18+
from nemo.collections.common.prompts.formatter import Modality, PromptFormatter
19+
20+
SYSTEM_BOS = "<SPECIAL_10>"
21+
TURN_BOS = "<SPECIAL_11>"
22+
23+
24+
class NemotronHPromptFormatter(PromptFormatter):
25+
NAME = "nemotron-h"
26+
OUTPUT_ROLE = "assistant"
27+
INFERENCE_PREFIX = f"\n{TURN_BOS}Assistant\n"
28+
TEMPLATE = {
29+
"system": {
30+
"template": f"{SYSTEM_BOS}System\n|message|",
31+
"slots": {
32+
"message": Modality.Text,
33+
},
34+
},
35+
"user": {
36+
"template": f"\n{TURN_BOS}User\n|message|",
37+
"slots": {
38+
"message": Modality.Text,
39+
},
40+
},
41+
OUTPUT_ROLE: {
42+
"template": f"{INFERENCE_PREFIX}|message|",
43+
"slots": {
44+
"message": Modality.Text,
45+
},
46+
},
47+
}
48+
49+
50+
@registered_prompt_format_fn(Cut, NemotronHPromptFormatter)
51+
def nemotron_h(cut: Cut, prompt: NemotronHPromptFormatter):
52+
if isinstance(cut, MixedCut):
53+
cut = cut.first_non_padding_cut
54+
55+
turns = []
56+
57+
system = ""
58+
if cut.has_custom("system_prompt"):
59+
system = cut.system_prompt
60+
turns.append({"role": "system", "content": system})
61+
62+
if cut.has_custom("context"):
63+
ctx = cut.context
64+
else:
65+
ctx = ""
66+
turns.append({"role": "user", "content": ctx})
67+
68+
if (answer := cut.supervisions[0].text) is not None:
69+
turns.append({"role": "assistant", "content": answer})
70+
71+
return prompt.encode_dialog(turns)

tests/collections/common/prompt_formatters/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
Feel free to add new tokens for your own tests!?
4242
But know that if you do so, you may need to update the token IDs in the existing tests!
4343
So, it might be a good idea to create a new tokenizer instead when adding new prompt formats.
44+
SYSTEM
4445
"""
4546

4647

@@ -58,7 +59,7 @@ def bpe_tokenizer(tmp_path_factory):
5859
remove_extra_whitespaces=True,
5960
bos=True,
6061
eos=True,
61-
user_defined_symbols=['\n', '<|im_start|>', '<|im_end|>'],
62+
user_defined_symbols=['\n', '<|im_start|>', '<|im_end|>', '<SPECIAL_10>', '<SPECIAL_11>'],
6263
)
6364
return SentencePieceTokenizer(str(tmpdir / "tokenizer.model"))
6465

tests/collections/common/prompt_formatters/test_canary_prompt_formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_canary_prompt_formatter_training(canary_tokenizer):
3737
assert canary_tokenizer.ids_to_text(ans["input_ids"].tolist()) == '<|startoftranscript|><|en|><|transcribe|><|en|><|pnc|> TEST<|endoftext|>'
3838
assert canary_tokenizer.ids_to_text(ans["context_ids"].tolist()) == '<|startoftranscript|><|en|><|transcribe|><|en|><|pnc|>'
3939
assert canary_tokenizer.ids_to_text(ans["answer_ids"].tolist()) == ' TEST<|endoftext|>'
40-
assert ans["mask"].tolist() == [False] * 5 + [True] * 5
40+
assert ans["mask"].shape[0] == ans["input_ids"].shape[0]
4141
# fmt: on
4242

4343

tests/collections/common/prompt_formatters/test_gemma_prompt_formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_gemma_prompt_formatter_training(bpe_tokenizer):
2828
assert bpe_tokenizer.ids_to_text(ans["input_ids"].tolist()) == '<start_of_turn>user\nTEST<end_of_turn>\n<start_of_turn>model\n TEST<end_of_turn>\n'
2929
assert bpe_tokenizer.ids_to_text(ans["context_ids"].tolist()) == '<start_of_turn>user\nTEST<end_of_turn>\n<start_of_turn>model\n'
3030
assert bpe_tokenizer.ids_to_text(ans["answer_ids"].tolist()) == 'TEST<end_of_turn>\n'
31-
assert ans["mask"].tolist() == [False] * 39 + [True] * 15
31+
assert ans["mask"].shape[0] == ans["input_ids"].shape[0]
3232
# fmt: on
3333

3434

tests/collections/common/prompt_formatters/test_llama2_prompt_formatter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_llama2_prompt_formatter_training(bpe_tokenizer):
2828
assert bpe_tokenizer.ids_to_text(ans["input_ids"].tolist()[1:-1]) == '[INST] TEST [/INST] TEST'
2929
assert bpe_tokenizer.ids_to_text(ans["context_ids"].tolist()[1:]) == '[INST] TEST [/INST]'
3030
assert bpe_tokenizer.ids_to_text(ans["answer_ids"].tolist()[:-1]) == 'TEST'
31-
assert ans["mask"].tolist() == [False] * 16 + [True] * 5
31+
assert ans["mask"].shape[0] == ans["input_ids"].shape[0]
3232
# fmt: on
3333

3434

@@ -59,7 +59,7 @@ def test_llama2_prompt_formatter_training_with_system(bpe_tokenizer):
5959
assert bpe_tokenizer.ids_to_text(ans["input_ids"].tolist()[1:-1]) == '[INST] <<SYS>>\nTEST\n<</SYS>>\n\nTEST [/INST] TEST'
6060
assert bpe_tokenizer.ids_to_text(ans["context_ids"].tolist()[1:]) == '[INST] <<SYS>>\nTEST\n<</SYS>>\n\nTEST [/INST]'
6161
assert bpe_tokenizer.ids_to_text(ans["answer_ids"].tolist()[:-1]) == 'TEST'
62-
assert ans["mask"].tolist() == [False] * 36 + [True] * 5
62+
assert ans["mask"].shape[0] == ans["input_ids"].shape[0]
6363
# fmt: on
6464

6565

tests/collections/common/prompt_formatters/test_mistral_prompt_formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_mistral_prompt_formatter_training(bpe_tokenizer):
2828
assert bpe_tokenizer.ids_to_text(ans["input_ids"].tolist()) == '<s> [INST] TEST [/INST] TEST</s>'
2929
assert bpe_tokenizer.ids_to_text(ans["context_ids"].tolist()) == '<s> [INST] TEST [/INST]'
3030
assert bpe_tokenizer.ids_to_text(ans["answer_ids"].tolist()) == 'TEST</s>'
31-
assert ans["mask"].tolist() == [False] * 18 + [True] * 8
31+
assert ans["mask"].shape[0] == ans["input_ids"].shape[0]
3232
# fmt: on
3333

3434

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from nemo.collections.common.prompts.nemotron_h import NemotronHPromptFormatter
16+
17+
18+
def test_nemotronh_prompt_formatter_training(bpe_tokenizer):
19+
formatter = NemotronHPromptFormatter(bpe_tokenizer)
20+
ans = formatter.encode_dialog(
21+
[
22+
{"role": "system", "slots": {"message": ""}},
23+
{"role": "user", "slots": {"message": "TEST"}},
24+
{"role": "assistant", "slots": {"message": "TEST"}},
25+
]
26+
)
27+
assert set(ans) == {"input_ids", "context_ids", "answer_ids", "mask"}
28+
# fmt: off
29+
assert bpe_tokenizer.ids_to_text(ans["input_ids"].tolist()) == '<SPECIAL_10>System\n \n<SPECIAL_11>User\nTEST \n<SPECIAL_11>Assistant\nTEST'
30+
assert bpe_tokenizer.ids_to_text(ans["context_ids"].tolist()) == '<SPECIAL_10>System\n \n<SPECIAL_11>User\nTEST'
31+
assert bpe_tokenizer.ids_to_text(ans["answer_ids"].tolist()) == '\n<SPECIAL_11>Assistant\nTEST'
32+
assert ans["mask"].shape[0] == ans["input_ids"].shape[0]
33+
# fmt: on
34+
35+
36+
def test_nemotronh_prompt_formatter_inference(bpe_tokenizer):
37+
formatter = NemotronHPromptFormatter(bpe_tokenizer)
38+
ans = formatter.encode_dialog(
39+
[
40+
{"role": "system", "slots": {"message": ""}},
41+
{"role": "user", "slots": {"message": "TEST"}},
42+
]
43+
)
44+
assert set(ans) == {"input_ids", "context_ids"}
45+
# fmt: off
46+
assert ans["input_ids"].tolist() == ans["context_ids"].tolist()
47+
assert bpe_tokenizer.ids_to_text(ans["input_ids"].tolist()[1:]) == '<SPECIAL_10>System\n \n<SPECIAL_11>User\nTEST \n<SPECIAL_11>Assistant\n'
48+
# fmt: on
49+
50+
51+
def test_nemotronh_prompt_formatter_training_with_system(bpe_tokenizer):
52+
formatter = NemotronHPromptFormatter(bpe_tokenizer)
53+
ans = formatter.encode_dialog(
54+
[
55+
{"role": "system", "slots": {"message": "SYSTEM"}},
56+
{"role": "user", "slots": {"message": "TEST"}},
57+
{"role": "assistant", "slots": {"message": "TEST"}},
58+
]
59+
)
60+
assert set(ans) == {"input_ids", "context_ids", "answer_ids", "mask"}
61+
# fmt: off
62+
assert bpe_tokenizer.ids_to_text(ans["input_ids"].tolist()) == '<SPECIAL_10>System\nSYSTEM \n<SPECIAL_11>User\nTEST \n<SPECIAL_11>Assistant\nTEST'
63+
assert bpe_tokenizer.ids_to_text(ans["context_ids"].tolist()) == '<SPECIAL_10>System\nSYSTEM \n<SPECIAL_11>User\nTEST'
64+
assert bpe_tokenizer.ids_to_text(ans["answer_ids"].tolist()) == '\n<SPECIAL_11>Assistant\nTEST'
65+
assert ans["mask"].shape[0] == ans["input_ids"].shape[0]
66+
# fmt: on
67+
68+
69+
def test_nemotronh_prompt_formatter_inference_with_system(bpe_tokenizer):
70+
formatter = NemotronHPromptFormatter(bpe_tokenizer)
71+
ans = formatter.encode_dialog(
72+
[
73+
{"role": "system", "slots": {"message": "SYSTEM"}},
74+
{"role": "user", "slots": {"message": "TEST"}},
75+
]
76+
)
77+
assert set(ans) == {"input_ids", "context_ids"}
78+
# fmt: off
79+
assert ans["input_ids"].tolist() == ans["context_ids"].tolist()
80+
assert bpe_tokenizer.ids_to_text(ans["input_ids"].tolist()[1:]) == '<SPECIAL_10>System\nSYSTEM \n<SPECIAL_11>User\nTEST \n<SPECIAL_11>Assistant\n'
81+
# fmt: on

tests/collections/common/test_lhotse_multimodal_dataloading.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from pathlib import Path
15+
1416
import lhotse
1517
import numpy as np
1618
import pytest
1719
import torch
20+
from lhotse import CutSet, SupervisionSegment
1821
from lhotse.testing.dummies import dummy_cut, dummy_recording
1922
from omegaconf import OmegaConf
2023

@@ -485,3 +488,117 @@ def test_multimodal_conversation_duration_filter():
485488
],
486489
)
487490
assert fltr(conv_s2s_7s) is False
491+
492+
493+
@pytest.fixture(scope="session")
494+
def cutset_path(tmp_path_factory) -> Path:
495+
"""3 utterances of lengths 1s, 2s, and 3s, with different context/system_prompt, as a Lhotse CutSet."""
496+
cuts = CutSet(
497+
[
498+
dummy_cut(
499+
0,
500+
duration=1.0,
501+
supervisions=[SupervisionSegment("e1", "e1", 0.0, 1.0, text="transcript")],
502+
with_data=True,
503+
),
504+
dummy_cut(
505+
1,
506+
duration=2.0,
507+
recording_duration=2.0,
508+
supervisions=[SupervisionSegment("e2", "e2", 0.0, 2.0, text="context and transcript")],
509+
with_data=True,
510+
),
511+
dummy_cut(
512+
2,
513+
duration=3.0,
514+
recording_duration=3.0,
515+
supervisions=[SupervisionSegment("e3", "e3", 0.0, 2.0, text="system context and transcript")],
516+
with_data=True,
517+
),
518+
]
519+
)
520+
cuts[1].context = "some prompt"
521+
cuts[2].context = "other prompt"
522+
cuts[2].system_prompt = "system prompt"
523+
524+
tmp_path = tmp_path_factory.mktemp("data")
525+
p = tmp_path / "cuts.jsonl.gz"
526+
pa = tmp_path / "audio"
527+
cuts.save_audios(pa).drop_in_memory_data().to_file(p)
528+
return p
529+
530+
531+
def test_cut_to_conversation_conversion(cutset_path, tokenizer):
532+
cuts = CutSet.from_file(cutset_path)
533+
config = OmegaConf.create(
534+
{
535+
"input_cfg": [
536+
{
537+
"type": "lhotse_as_conversation",
538+
"cuts_path": cutset_path,
539+
"audio_locator_tag": "[audio]",
540+
"tags": {"test_key": "test_value"},
541+
},
542+
],
543+
"token_equivalent_duration": 0.08,
544+
"prompt_format": "llama3",
545+
"force_finite": True,
546+
"num_workers": 0,
547+
"batch_size": 4,
548+
"seed": 0,
549+
"shard_seed": 0,
550+
}
551+
)
552+
dl = get_lhotse_dataloader_from_config(
553+
config=config, global_rank=0, world_size=1, dataset=Identity(), tokenizer=tokenizer
554+
)
555+
batches = [batch for batch in dl]
556+
assert len(batches) == 1
557+
558+
# Check the cut that has no 'context' or 'system_prompt'
559+
conv = batches[0][0]
560+
assert isinstance(conv, NeMoMultimodalConversation)
561+
assert conv.id == cuts[0].id
562+
assert len(conv.turns) == 2
563+
assert isinstance(conv.turns[0], AudioTurn)
564+
assert conv.turns[0].role == "user"
565+
assert isinstance(conv.turns[1], TextTurn)
566+
assert conv.turns[1].role == "assistant"
567+
assert conv.turns[1].value == "transcript"
568+
assert conv.custom["test_key"] == "test_value"
569+
assert conv.turns[0].cut.custom["test_key"] == "test_value"
570+
571+
# Check the cut that has only 'context' and no 'system_prompt'
572+
conv = batches[0][1]
573+
assert isinstance(conv, NeMoMultimodalConversation)
574+
assert conv.id == cuts[1].id
575+
assert len(conv.turns) == 3
576+
assert isinstance(conv.turns[0], TextTurn)
577+
assert conv.turns[0].role == "user"
578+
assert conv.turns[0].value == "some prompt"
579+
assert isinstance(conv.turns[1], AudioTurn)
580+
assert conv.turns[1].role == "user"
581+
assert isinstance(conv.turns[2], TextTurn)
582+
assert conv.turns[2].role == "assistant"
583+
assert conv.turns[2].value == "context and transcript"
584+
assert conv.custom["test_key"] == "test_value"
585+
assert conv.turns[1].cut.custom["test_key"] == "test_value"
586+
587+
# Check the cut that has both 'context' and 'system_prompt'
588+
conv = batches[0][2]
589+
assert isinstance(conv, NeMoMultimodalConversation)
590+
assert conv.id == cuts[2].id
591+
assert len(conv.turns) == 4
592+
assert isinstance(conv.turns[0], TextTurn)
593+
assert conv.turns[0].role == "system"
594+
assert conv.turns[0].value == "system prompt"
595+
assert isinstance(conv.turns[1], TextTurn)
596+
assert conv.turns[1].role == "user"
597+
assert conv.turns[1].value == "other prompt"
598+
assert isinstance(conv.turns[2], AudioTurn)
599+
assert conv.turns[2].role == "user"
600+
assert isinstance(conv.turns[3], TextTurn)
601+
assert conv.turns[3].role == "assistant"
602+
assert conv.turns[3].value == "system context and transcript"
603+
assert conv.custom["test_key"] == "test_value"
604+
assert conv.turns[2].cut.custom["test_key"] == "test_value"

0 commit comments

Comments
 (0)