-
Notifications
You must be signed in to change notification settings - Fork 160
Add model and architecture for Omnivore model #43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
YosuaMichael
wants to merge
29
commits into
facebookresearch:main
from
YosuaMichael:omnivore/add-model
Closed
Changes from 16 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
b5edd05
Add swin_transformer_3d
YosuaMichael 47168af
Add architecture and models for swin_t
YosuaMichael 5d15888
Fix architecture by adding super().__init__() and add examples for co…
YosuaMichael b9666f6
Fix typing issue
YosuaMichael fa5b6a6
Update swin_transformer_3d encoder to make sure it is jit scriptable …
YosuaMichael c76fa7f
Add test for omnivore and fix bug on the avgpool
YosuaMichael 146c5e9
Putting PatchEmbedOmnivore into omnivore model and resolves some comm…
YosuaMichael 579552e
Add license header
YosuaMichael 092d339
Update the test to follow the update on model API
YosuaMichael f2ad2d1
Improve omnivore test and resolve comment from Lan
YosuaMichael d40b29c
Add option to get encoder_only for omnivore model, cleanup swin_trans…
YosuaMichael 058b729
Update the test code for omnivore
YosuaMichael 96b9211
Add test on ShiftedWindowAttention3d when there is zero shift
YosuaMichael cd8d199
Update swin_transformer_3d_encoder to be easier to upstream
YosuaMichael 44786d8
Fix format and mypy
YosuaMichael d6077ff
Merge branch 'main' into omnivore/add-model
YosuaMichael 8ab13bc
Create separate encoder function and simplify architecture a bit
YosuaMichael 063a403
Fix mypy problem
YosuaMichael c8e384a
Upstream swin_transformer components from torchvision
YosuaMichael b2f010c
Use lowercase variable name, format with ufmt
YosuaMichael 13f4605
Fix format and add license
YosuaMichael b9fccab
Fix formatting and expected test result
YosuaMichael c505b2e
Fix formatting
YosuaMichael e7869d7
Loosen test comparison
YosuaMichael fbcf807
Remove lru_cache usage to make module jit scriptable
YosuaMichael a8a59a3
Merge branch 'main' into omnivore/add-model
YosuaMichael e707b96
Use absolute path for test_utils
YosuaMichael c20e1cd
Ufmt format
YosuaMichael 6cccb83
Merge branch 'main' into omnivore/add-model
YosuaMichael File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| library.jpg | ||
| imagenet_class_index.json |
276 changes: 276 additions & 0 deletions
276
examples/omnivore/LoadOriginalPretrainedWeightAndCompare.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,276 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "59ef9a30-174d-4ea3-b4a0-7c465e4c3f53", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Need to install einops and timm for original omnivore model, and matplotlib for visualization\n", | ||
| "! pip install einops timm matplotlib\n", | ||
| "\n", | ||
| "import torch\n", | ||
| "import torchvision.transforms as T\n", | ||
| "import torchmultimodal.models.omnivore as omnivore\n", | ||
| "\n", | ||
| "from PIL import Image\n", | ||
| "import collections\n", | ||
| "import json\n", | ||
| "import matplotlib.pyplot as plt\n", | ||
| "import numpy as np\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "55e221ba-528e-4f09-9d49-d0f54a577bc0", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "def count_parameters(model):\n", | ||
| " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", | ||
| "\n", | ||
| "def custom_load_state_dict(model, pretrained_state_dict):\n", | ||
| " # Convert the pretrained_state_dict so it have the same keys as the model\n", | ||
| " # then load the value of the weight into the model\n", | ||
| " pretrained_keys = list(pretrained_state_dict.keys())\n", | ||
| " model_keys = list(model.state_dict().keys())\n", | ||
| " key_mapping = {pretrained_keys[i]: model_keys[i] for i in range(len(model_keys))}\n", | ||
| " updated_pretrained_state_dict = collections.OrderedDict({key_mapping[key]: val for key, val in pretrained_state_dict.items()})\n", | ||
| " model.load_state_dict(updated_pretrained_state_dict)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "67b998df-4c99-4f18-a408-a1feef5c483a", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Load model from torch_hub\n", | ||
| "\n", | ||
| "mhub = torch.hub.load(\"facebookresearch/omnivore:main\", model=\"omnivore_swinT\")\n", | ||
| "mhub.eval()\n", | ||
| "print(count_parameters(mhub))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "5d3de763-8be6-450d-958b-0a01d4dc8b3b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "m = omnivore.omnivore_swin_t()\n", | ||
| "\n", | ||
| "# Check that it have same number of parameter\n", | ||
| "print(count_parameters(m))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "a1168d33-0d11-46df-bef4-51f1973c2f95", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "custom_load_state_dict(m, mhub.state_dict())\n", | ||
| "m = m.eval()\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "f1d192ad-4902-49e1-baf6-77e7474a33bf", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Inference test" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "fe6345f3-97b5-46cd-b068-aea129aedde4", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Download imagenet class and image\n", | ||
| "# Uncomment to download\n", | ||
| "!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O imagenet_class_index.json\n", | ||
| "with open(\"imagenet_class_index.json\", \"r\") as f:\n", | ||
| " imagenet_classnames = json.load(f)\n", | ||
| "\n", | ||
| "# Create an id to label name mapping\n", | ||
| "imagenet_id_to_classname = {}\n", | ||
| "for k, v in imagenet_classnames.items():\n", | ||
| " imagenet_id_to_classname[k] = v[1] \n", | ||
| "\n", | ||
| "# Download the example image file\n", | ||
| "# Uncomment to download\n", | ||
| "!wget -O library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg\n", | ||
| "\n", | ||
| "image_path = \"library.jpg\"\n", | ||
| "image_pil = Image.open(image_path).convert(\"RGB\")\n", | ||
| "plt.figure(figsize=(6, 6))\n", | ||
| "plt.imshow(image_pil)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "19d91191-d461-4ae7-93cf-f8eb325133be", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "image_transform = T.Compose(\n", | ||
| " [\n", | ||
| " T.Resize(224),\n", | ||
| " T.CenterCrop(224),\n", | ||
| " T.ToTensor(),\n", | ||
| " T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", | ||
| " ]\n", | ||
| ")\n", | ||
| "image = image_transform(image_pil) # C H W\n", | ||
| "\n", | ||
| "# Adding batch and time (D) dimension\n", | ||
| "image = image.unsqueeze(0).unsqueeze(2) # B C D H W" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "22a4b65b-e9ce-48e3-9f1b-e79acb5a38f5", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "def infer(model):\n", | ||
| " with torch.no_grad():\n", | ||
| " prediction = model(image, input_type=\"image\")\n", | ||
| " pred_classes = prediction.topk(k=5).indices\n", | ||
| "\n", | ||
| " pred_class_names = [imagenet_id_to_classname[str(i.item())] for i in pred_classes[0]]\n", | ||
| " print(\"Top 5 predicted labels: %s\" % \", \".join(pred_class_names))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "b4305dd1-64ef-4a0f-b817-f5efe5f23980", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Test both model to infer the same image and make sure the output classes are the same\n", | ||
| "infer(m)\n", | ||
| "infer(mhub)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "779c2578-1cba-478c-8909-f47330e1b376", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Make sure the output of the trunk / encoder are the same" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "2d85f5d7-f77d-47d3-8bec-3bfa51687953", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "m_feature = m.encoder(image)\n", | ||
| "mhub_feature = mhub.trunk(image)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "b5caf4aa-b964-4969-b531-318b9721bf17", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# See the first 10 features are the same\n", | ||
| "m_feature.flatten()[:10], mhub_feature[0].flatten()[:10]" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "ebd36e03-c86a-4d4f-a83a-b8ff3bd757b9", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Make sure all the features are the same\n", | ||
| "np.all(np.array(m_feature == mhub_feature[0]))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "08adda52-d593-4d58-816a-4a6a6205ce3d", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Test on randomly generated input" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "1dcacbc6-5cbc-4e01-b0c9-8404a33f13da", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "mock_video = torch.randn(1, 3, 10, 112, 112)\n", | ||
| "\n", | ||
| "m_output = m(mock_video, input_type=\"video\")\n", | ||
| "mhub_output = mhub(mock_video, input_type=\"video\")\n", | ||
| "\n", | ||
| "np.all(np.array(m_output == mhub_output[0]))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "357fa8cd-a853-4b12-ad87-44dba43c133b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "mock_depth = torch.randn(1, 4, 1, 112, 112)\n", | ||
| "\n", | ||
| "m_output = m(mock_video, input_type=\"rgbd\")\n", | ||
| "mhub_output = mhub(mock_video, input_type=\"rgbd\")\n", | ||
| "\n", | ||
| "np.all(np.array(m_output == mhub_output[0]))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "6bc03203-430f-4ab9-86a0-acbfadc89f67", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "Python 3 (ipykernel)", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.9.12" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| import torchmultimodal.models.omnivore as omnivore | ||
| from torchmultimodal.utils.common import get_current_device | ||
|
|
||
| from ..test_utils import set_rng_seed | ||
|
|
||
|
|
||
| class TestOmnivoreModel(unittest.TestCase): | ||
| def setUp(self): | ||
| set_rng_seed(42) | ||
| self.device = get_current_device() | ||
|
|
||
| def test_omnivore_swin_t_forward(self): | ||
| model = omnivore.omnivore_swin_t().to(self.device) | ||
| self.assertTrue(isinstance(model, torch.nn.Module)) | ||
|
|
||
| image = torch.randn(1, 3, 1, 112, 112) # B C D H W | ||
| image_score = model(image, input_type="image") | ||
| self.assertEqual(image_score.size(), torch.Size((1, 1000))) | ||
| self.assertAlmostEqual(image_score.abs().sum().item(), 178.92318, 3) | ||
|
|
||
| rgbd = torch.randn(1, 4, 1, 112, 112) | ||
| rgbd_score = model(rgbd, input_type="rgbd") | ||
| self.assertEqual(rgbd_score.size(), torch.Size((1, 19))) | ||
| self.assertAlmostEqual(rgbd_score.abs().sum().item(), 3.39016, 3) | ||
|
|
||
| video = torch.randn(1, 3, 4, 112, 112) | ||
| video_score = model(video, input_type="video") | ||
| self.assertEqual(video_score.size(), torch.Size((1, 400))) | ||
| self.assertAlmostEqual(video_score.abs().sum().item(), 102.76638, 3) | ||
|
|
||
| def test_omnivore_forward_wrong_input_type(self): | ||
| model = omnivore.omnivore_swin_t().to(self.device) | ||
|
|
||
| image = torch.randn(1, 3, 1, 112, 112) # B C D H W | ||
| with self.assertRaises(AssertionError) as cm: | ||
| _ = model(image, input_type="_WRONG_TYPE_") | ||
| self.assertEqual( | ||
| "Unsupported input_type: _WRONG_TYPE_, please use one of {'video', 'rgbd', 'image'}", | ||
| str(cm.exception), | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use our test utility to assert on values and shapes. (we need to assert values too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! Will do.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the utiltiy for assertion:
multimodal/test/test_utils.py
Line 74 in e857541
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah okay, will change
self.assertEqualwithassert_expectedthenThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@langong347 after I read the utility
assert_expectedit seems to compare two tensor with float type. In this case, I thinkassertEqualis better for comparing the size since it is integer?