diff --git a/optimum/habana/transformers/models/clip/modeling_clip.py b/optimum/habana/transformers/models/clip/modeling_clip.py index 0d9a16c375..6f045669c4 100644 --- a/optimum/habana/transformers/models/clip/modeling_clip.py +++ b/optimum/habana/transformers/models/clip/modeling_clip.py @@ -132,10 +132,8 @@ def forward( attn_output = self.bmm2(attn_weights, values) - # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() attn_output = self.out_proj(attn_output) if not output_attentions: