From c55a0ad9af4bae5386535f68a3ab94cdc9b34a09 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Fri, 30 May 2025 17:24:37 +0200 Subject: [PATCH] Cast to int for ONNX tracing Signed-off-by: Jan Lasek --- nemo/collections/llm/gpt/model/hf_llama_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/hf_llama_embedding.py b/nemo/collections/llm/gpt/model/hf_llama_embedding.py index 3d64b2d32d37..db0e3229c708 100755 --- a/nemo/collections/llm/gpt/model/hf_llama_embedding.py +++ b/nemo/collections/llm/gpt/model/hf_llama_embedding.py @@ -237,7 +237,7 @@ def forward( fill_value = torch.tensor(float("-inf"), dtype=embeddings.dtype, device=embeddings.device) - clipped_dimensions = torch.clamp(dimensions, max=embeddings.shape[1]) + clipped_dimensions = torch.clamp(dimensions, max=int(embeddings.shape[1])) embeddings = embeddings.masked_fill( torch.arange(embeddings.shape[1], device=embeddings.device) >= clipped_dimensions.unsqueeze(-1),