diff --git a/nemo/collections/llm/gpt/model/hf_llama_embedding.py b/nemo/collections/llm/gpt/model/hf_llama_embedding.py index 3d64b2d32d37..db0e3229c708 100755 --- a/nemo/collections/llm/gpt/model/hf_llama_embedding.py +++ b/nemo/collections/llm/gpt/model/hf_llama_embedding.py @@ -237,7 +237,7 @@ def forward( fill_value = torch.tensor(float("-inf"), dtype=embeddings.dtype, device=embeddings.device) - clipped_dimensions = torch.clamp(dimensions, max=embeddings.shape[1]) + clipped_dimensions = torch.clamp(dimensions, max=int(embeddings.shape[1])) embeddings = embeddings.masked_fill( torch.arange(embeddings.shape[1], device=embeddings.device) >= clipped_dimensions.unsqueeze(-1),