Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions nemo/export/tarutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,3 @@ def keys(self):
Returns an iterator over the keys in the store.
"""
return self._path.iterdir()


def unpack_tarball(archive: str, dest_dir: str):
"""
Unpacks a tarball into a destination directory.

Args:
archive (str): The path to the tarball.
dest_dir (str): The path to the destination directory.
"""
with tarfile.open(archive, mode="r") as tar:
tar.extractall(path=dest_dir)
4 changes: 1 addition & 3 deletions nemo/export/tensorrt_mm_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
extract_lora_ckpt,
)
from nemo.export.multimodal.run import MultimodalModelRunner, SpeechllmModelRunner
from nemo.export.tarutils import unpack_tarball

use_deploy = True
try:
Expand Down Expand Up @@ -152,8 +151,7 @@ def export(
if os.path.isdir(lora_checkpoint_path):
lora_dir = lora_checkpoint_path
else:
lora_dir = os.path.join(tmp_dir.name, "unpacked_lora")
unpack_tarball(lora_checkpoint_path, lora_dir)
raise ValueError("lora_checkpoint_path in nemo1 is not supported. It must be a directory")

llm_lora_path = [extract_lora_ckpt(lora_dir, tmp_dir.name)]
else:
Expand Down
49 changes: 24 additions & 25 deletions nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import glob
import itertools
import os
import subprocess
import warnings
Expand Down Expand Up @@ -78,42 +79,40 @@ def qnemo_to_tensorrt_llm(

speculative_decoding_mode = "medusa" if "Medusa" in config.architecture else None

build_cmd = "trtllm-build "
build_cmd += f"--checkpoint_dir {nemo_checkpoint_path} "
build_cmd += f"--log_level {log_level} "
build_cmd += f"--output_dir {engine_dir} "
build_cmd += f"--workers {num_build_workers} "
build_cmd += f"--max_batch_size {max_batch_size} "
build_cmd += f"--max_input_len {max_input_len} "
build_cmd += f"--max_beam_width {max_beam_width} "
build_cmd += f"--max_prompt_embedding_table_size {max_prompt_embedding_table_size} "
build_cmd += f"--paged_kv_cache {'enable' if paged_kv_cache else 'disable'} "
build_cmd += f"--use_paged_context_fmha {'enable' if paged_context_fmha else 'disable'} "
build_cmd += f"--remove_input_padding {'enable' if remove_input_padding else 'disable'} "
build_cmd += f"--multiple_profiles {'enable' if multiple_profiles else 'disable'} "
build_cmd += f"--reduce_fusion {'enable' if reduce_fusion else 'disable'} "
build_cmd += f"--use_fused_mlp {'enable' if use_fused_mlp else 'disable'} "
build_cmd = ["trtllm-build"]
build_cmd.extend(["--checkpoint_dir", nemo_checkpoint_path])
build_cmd.extend(["--log_level", log_level])
build_cmd.extend(["--output_dir", engine_dir])
build_cmd.extend(["--workers", str(num_build_workers)])
build_cmd.extend(["--max_batch_size", str(max_batch_size)])
build_cmd.extend(["--max_input_len", str(max_input_len)])
build_cmd.extend(["--max_beam_width", str(max_beam_width)])
build_cmd.extend(["--max_prompt_embedding_table_size", str(max_prompt_embedding_table_size)])
build_cmd.extend(["--paged_kv_cache", "enable" if paged_kv_cache else "disable"])
build_cmd.extend(["--use_paged_context_fmha", "enable" if paged_context_fmha else "disable"])
build_cmd.extend(["--remove_input_padding", "enable" if remove_input_padding else "disable"])
build_cmd.extend(["--multiple_profiles", "enable" if multiple_profiles else "disable"])
build_cmd.extend(["--reduce_fusion", "enable" if reduce_fusion else "disable"])
build_cmd.extend(["--use_fused_mlp", "enable" if use_fused_mlp else "disable"])

if not use_qdq:
build_cmd += "--gemm_plugin auto "
build_cmd.extend(["--gemm_plugin", "auto"])

if max_seq_len is not None:
build_cmd += f"--max_seq_len {max_seq_len} "
build_cmd.extend(["--max_seq_len", str(max_seq_len)])

if max_num_tokens is not None:
build_cmd += f"--max_num_tokens {max_num_tokens} "
build_cmd.extend(["--max_num_tokens", str(max_num_tokens)])
else:
build_cmd += f"--max_num_tokens {max_batch_size * max_input_len} "
build_cmd.extend(["--max_num_tokens", str(max_batch_size * max_input_len)])

if opt_num_tokens is not None:
build_cmd += f"--opt_num_tokens {opt_num_tokens} "
build_cmd.extend(["--opt_num_tokens", str(opt_num_tokens)])

if speculative_decoding_mode:
build_cmd += f"--speculative_decoding_mode {speculative_decoding_mode} "

build_cmd = build_cmd.replace("--", "\\\n --") # Separate parameters line by line
build_cmd.extend(["--speculative_decoding_mode", speculative_decoding_mode])

print("trtllm-build command:")
print(build_cmd)
print("".join(itertools.chain.from_iterable(zip(build_cmd, itertools.cycle(["\n ", " "])))).strip())

subprocess.run(build_cmd, shell=True, check=True)
subprocess.run(build_cmd, shell=False, check=True)
Loading