-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Adds CLIP to models exportable with ONNX #18515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
cf55224
845d2e2
8418aee
eae965e
118ca35
4a22e42
7204e1d
b108224
3a8d870
bfed078
2306fbc
19e0423
af3e2fc
32295b5
3737ec2
82d4a1b
1b30df9
933b12e
0f7a95a
3028908
7663f29
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -68,7 +68,8 @@ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: | |
|
|
||
| def clip_loss(similarity: torch.Tensor) -> torch.Tensor: | ||
| caption_loss = contrastive_loss(similarity) | ||
| image_loss = contrastive_loss(similarity.T) | ||
| # .T doesn't work while converting to onnx, aten::numpy_T operator is not supported yet | ||
| image_loss = contrastive_loss(similarity.t()) | ||
| return (caption_loss + image_loss) / 2.0 | ||
|
|
||
|
|
||
|
|
@@ -660,7 +661,10 @@ def forward( | |
|
|
||
| # text_embeds.shape = [batch_size, sequence_length, transformer.width] | ||
| # take features from the eot embedding (eot_token is the highest number in each sequence) | ||
| pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)] | ||
| # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 | ||
| pooled_output = last_hidden_state[ | ||
| torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add a comment here to say that the cast
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, added a comment and pushed
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually ONNX does support
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, my bad. updated the comment and pushed |
||
| ] | ||
|
|
||
| if not return_dict: | ||
| return (last_hidden_state, pooled_output) + encoder_outputs[1:] | ||
|
|
@@ -1050,7 +1054,8 @@ def forward( | |
| # cosine similarity as logits | ||
| logit_scale = self.logit_scale.exp() | ||
| logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale | ||
| logits_per_image = logits_per_text.T | ||
| # .T doesn't work while converting to onnx, aten::numpy_T operator is not supported yet | ||
| logits_per_image = logits_per_text.t() | ||
|
|
||
| loss = None | ||
| if return_loss: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.