-
Notifications
You must be signed in to change notification settings - Fork 161
Add model and architecture for Omnivore model #43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
b5edd05
47168af
5d15888
b9666f6
fa5b6a6
c76fa7f
146c5e9
579552e
092d339
f2ad2d1
d40b29c
058b729
96b9211
cd8d199
44786d8
d6077ff
8ab13bc
063a403
c8e384a
b2f010c
13f4605
b9fccab
c505b2e
e7869d7
fbcf807
a8a59a3
e707b96
c20e1cd
6cccb83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| library.jpg | ||
| imagenet_class_index.json |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,280 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "59ef9a30-174d-4ea3-b4a0-7c465e4c3f53", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import torch\n", | ||
| "import torchvision.transforms as T\n", | ||
| "import torchmultimodal.models.omnivore as omnivore\n", | ||
| "\n", | ||
| "from PIL import Image\n", | ||
| "import collections\n", | ||
| "import json\n", | ||
| "import matplotlib.pyplot as plt\n", | ||
| "import numpy as np\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "55e221ba-528e-4f09-9d49-d0f54a577bc0", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "def count_parameters(model):\n", | ||
| " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", | ||
| "\n", | ||
| "def custom_load_state_dict(model, pretrained_state_dict):\n", | ||
| " # Convert the pretrained_state_dict so it have the same keys as the model\n", | ||
| " # then load the value of the weight into the model\n", | ||
| " pretrained_keys = list(pretrained_state_dict.keys())\n", | ||
| " model_keys = list(model.state_dict().keys())\n", | ||
| " key_mapping = {pretrained_keys[i]: model_keys[i] for i in range(len(model_keys))}\n", | ||
| " updated_pretrained_state_dict = collections.OrderedDict({key_mapping[key]: val for key, val in pretrained_state_dict.items()})\n", | ||
| " model.load_state_dict(updated_pretrained_state_dict)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "67b998df-4c99-4f18-a408-a1feef5c483a", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Load model from torch_hub\n", | ||
| "mhub = torch.hub.load(\"facebookresearch/omnivore:main\", model=\"omnivore_swinT\")\n", | ||
| "mhub.eval()\n", | ||
| "print(count_parameters(mhub))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "85ea7454-f1c1-449f-a33b-004c54a4ae8b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "5d3de763-8be6-450d-958b-0a01d4dc8b3b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "m = omnivore.omnivore_swin_t()\n", | ||
| "\n", | ||
| "# Check that it have same number of parameter\n", | ||
| "print(count_parameters(m))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "a1168d33-0d11-46df-bef4-51f1973c2f95", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "custom_load_state_dict(m, mhub.state_dict())\n", | ||
| "m = m.eval()\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "f1d192ad-4902-49e1-baf6-77e7474a33bf", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Inference test" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "fe6345f3-97b5-46cd-b068-aea129aedde4", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Download imagenet class and image\n", | ||
| "# Uncomment to download\n", | ||
| "!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O imagenet_class_index.json\n", | ||
| "with open(\"imagenet_class_index.json\", \"r\") as f:\n", | ||
| " imagenet_classnames = json.load(f)\n", | ||
| "\n", | ||
| "# Create an id to label name mapping\n", | ||
| "imagenet_id_to_classname = {}\n", | ||
| "for k, v in imagenet_classnames.items():\n", | ||
| " imagenet_id_to_classname[k] = v[1] \n", | ||
| "\n", | ||
| "# Download the example image file\n", | ||
| "# Uncomment to download\n", | ||
| "!wget -O library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg\n", | ||
| "\n", | ||
| "image_path = \"library.jpg\"\n", | ||
| "image_pil = Image.open(image_path).convert(\"RGB\")\n", | ||
| "plt.figure(figsize=(6, 6))\n", | ||
| "plt.imshow(image_pil)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "19d91191-d461-4ae7-93cf-f8eb325133be", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "image_transform = T.Compose(\n", | ||
| " [\n", | ||
| " T.Resize(224),\n", | ||
| " T.CenterCrop(224),\n", | ||
| " T.ToTensor(),\n", | ||
| " T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", | ||
| " ]\n", | ||
| ")\n", | ||
| "image = image_transform(image_pil) # C H W\n", | ||
| "\n", | ||
| "# Adding batch and time (D) dimension\n", | ||
| "image = image.unsqueeze(0).unsqueeze(2) # B C D H W" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "22a4b65b-e9ce-48e3-9f1b-e79acb5a38f5", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "def infer(model):\n", | ||
| " with torch.no_grad():\n", | ||
| " prediction = model(image, input_type=\"image\")\n", | ||
| " pred_classes = prediction.topk(k=5).indices\n", | ||
| "\n", | ||
| " pred_class_names = [imagenet_id_to_classname[str(i.item())] for i in pred_classes[0]]\n", | ||
| " print(\"Top 5 predicted labels: %s\" % \", \".join(pred_class_names))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "b4305dd1-64ef-4a0f-b817-f5efe5f23980", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Test both model to infer the same image and make sure the output classes are the same\n", | ||
| "infer(m)\n", | ||
| "infer(mhub)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "779c2578-1cba-478c-8909-f47330e1b376", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Make sure the output of the trunk / encoder are the same" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "2d85f5d7-f77d-47d3-8bec-3bfa51687953", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "m_feature = m.encoder(image)\n", | ||
| "mhub_feature = mhub.trunk(image)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "b5caf4aa-b964-4969-b531-318b9721bf17", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# See the first 10 features are the same\n", | ||
| "m_feature.flatten()[:10], mhub_feature[0].flatten()[:10]" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "ebd36e03-c86a-4d4f-a83a-b8ff3bd757b9", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Make sure all the features are the same\n", | ||
| "np.all(np.array(m_feature == mhub_feature[0]))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "08adda52-d593-4d58-816a-4a6a6205ce3d", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Test on randomly generated input" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "1dcacbc6-5cbc-4e01-b0c9-8404a33f13da", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "mock_video = torch.randn(1, 3, 10, 112, 112)\n", | ||
| "\n", | ||
| "m_output = m(mock_video, input_type=\"video\")\n", | ||
| "mhub_output = mhub(mock_video, input_type=\"video\")\n", | ||
| "\n", | ||
| "np.all(np.array(m_output == mhub_output[0]))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "357fa8cd-a853-4b12-ad87-44dba43c133b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "mock_depth = torch.randn(1, 4, 1, 112, 112)\n", | ||
| "\n", | ||
| "m_output = m(mock_video, input_type=\"rgbd\")\n", | ||
| "mhub_output = mhub(mock_video, input_type=\"rgbd\")\n", | ||
| "\n", | ||
| "np.all(np.array(m_output == mhub_output[0]))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "6bc03203-430f-4ab9-86a0-acbfadc89f67", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "Python 3 (ipykernel)", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.9.10" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,33 @@ | ||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the BSD-style license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
|
|
||||
| import unittest | ||||
|
|
||||
| import torch | ||||
| import torchmultimodal.models.omnivore as omnivore | ||||
| from torchmultimodal.utils.common import get_current_device | ||||
|
|
||||
|
|
||||
| class TestCLIPModule(unittest.TestCase): | ||||
| def setUp(self): | ||||
| torch.manual_seed(42) | ||||
| self.device = get_current_device() | ||||
|
|
||||
| def test_omnivore_swin_t_forward(self): | ||||
| model = omnivore.omnivore_swin_t().to(self.device) | ||||
| self.assertTrue(isinstance(model, torch.nn.Module)) | ||||
|
|
||||
| image = torch.randn(1, 3, 1, 112, 112) # B C D H W | ||||
| image_score = model(image, input_type="image") | ||||
| self.assertEqual(image_score.size(), torch.Size((1, 1000))) | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use our test utility to assert on values and shapes. (we need to assert values too)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion! Will do. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the utiltiy for assertion: Line 74 in e857541
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah okay, will change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @langong347 after I read the utility |
||||
|
|
||||
| rgbd = torch.randn(1, 4, 1, 112, 112) | ||||
| rgbd_score = model(rgbd, input_type="rgbd") | ||||
| self.assertEqual(rgbd_score.size(), torch.Size((1, 19))) | ||||
|
|
||||
| video = torch.randn(1, 3, 4, 112, 112) | ||||
| video_score = model(video, input_type="video") | ||||
| self.assertEqual(video_score.size(), torch.Size((1, 400))) | ||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from torchmultimodal.modules.encoders.swin_transformer_3d_encoder import ( | ||
| SwinTransformer3dEncoder, | ||
| PatchEmbedOmnivore, | ||
| ) | ||
| from torchmultimodal.utils.common import get_current_device | ||
|
|
||
|
|
||
| class TestSwinTransformer3dEncoder(unittest.TestCase): | ||
| def setUp(self): | ||
| torch.manual_seed(42) | ||
| self.device = get_current_device() | ||
|
|
||
| # Setup Encoder to test | ||
| embed_dim = 96 | ||
| norm_layer = torch.nn.LayerNorm | ||
| patch_embed = PatchEmbedOmnivore( | ||
| patch_size=(2, 4, 4), | ||
| embed_dim=embed_dim, | ||
| norm_layer=norm_layer, | ||
| ) | ||
| self.encoder = SwinTransformer3dEncoder( | ||
| embed_dim=embed_dim, | ||
| depths=[2, 2, 6, 2], | ||
| num_heads=[3, 6, 12, 24], | ||
| window_size=(8, 7, 7), | ||
| stochastic_depth_prob=0.2, | ||
| norm_layer=norm_layer, | ||
| patch_embed=patch_embed, | ||
| ).to(self.device) | ||
|
|
||
| def test_swin_transformer_3d_encoder(self): | ||
| self.assertTrue(isinstance(self.encoder, torch.nn.Module)) | ||
YosuaMichael marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| image = torch.randn(1, 3, 1, 112, 112) # B C D H W | ||
|
|
||
| scores = self.encoder(image) | ||
| self.assertEqual(scores.size(), torch.Size([1, 768])) | ||
| self.assertAlmostEqual(scores.abs().sum().item(), 268.00668, 3) | ||
|
|
||
| def test_swin_transformer_3d_scripting(self): | ||
| torch.jit.script(self.encoder) | ||
Uh oh!
There was an error while loading. Please reload this page.