Skip to content

Commit ed6a84c

Browse files
janeklnasretdinovr
authored andcommitted
Cast to int for ONNX tracing (NVIDIA-NeMo#13782)
Signed-off-by: Jan Lasek <[email protected]>
1 parent 85778c5 commit ed6a84c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

nemo/collections/llm/gpt/model/hf_llama_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def forward(
237237

238238
fill_value = torch.tensor(float("-inf"), dtype=embeddings.dtype, device=embeddings.device)
239239

240-
clipped_dimensions = torch.clamp(dimensions, max=embeddings.shape[1])
240+
clipped_dimensions = torch.clamp(dimensions, max=int(embeddings.shape[1]))
241241

242242
embeddings = embeddings.masked_fill(
243243
torch.arange(embeddings.shape[1], device=embeddings.device) >= clipped_dimensions.unsqueeze(-1),

0 commit comments

Comments
 (0)