-
Notifications
You must be signed in to change notification settings - Fork 161
Add model and architecture for Omnivore model #43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
b5edd05
47168af
5d15888
b9666f6
fa5b6a6
c76fa7f
146c5e9
579552e
092d339
f2ad2d1
d40b29c
058b729
96b9211
cd8d199
44786d8
d6077ff
8ab13bc
063a403
c8e384a
b2f010c
13f4605
b9fccab
c505b2e
e7869d7
fbcf807
a8a59a3
e707b96
c20e1cd
6cccb83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| library.jpg | ||
| imagenet_class_index.json |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,276 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "59ef9a30-174d-4ea3-b4a0-7c465e4c3f53", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Need to install einops and timm for original omnivore model, and matplotlib for visualization\n", | ||
| "! pip install einops timm matplotlib\n", | ||
| "\n", | ||
| "import torch\n", | ||
| "import torchvision.transforms as T\n", | ||
| "import torchmultimodal.models.omnivore as omnivore\n", | ||
| "\n", | ||
| "from PIL import Image\n", | ||
| "import collections\n", | ||
| "import json\n", | ||
| "import matplotlib.pyplot as plt\n", | ||
| "import numpy as np\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "55e221ba-528e-4f09-9d49-d0f54a577bc0", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "def count_parameters(model):\n", | ||
| " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", | ||
| "\n", | ||
| "def custom_load_state_dict(model, pretrained_state_dict):\n", | ||
| " # Convert the pretrained_state_dict so it have the same keys as the model\n", | ||
| " # then load the value of the weight into the model\n", | ||
| " pretrained_keys = list(pretrained_state_dict.keys())\n", | ||
| " model_keys = list(model.state_dict().keys())\n", | ||
| " key_mapping = {pretrained_keys[i]: model_keys[i] for i in range(len(model_keys))}\n", | ||
| " updated_pretrained_state_dict = collections.OrderedDict({key_mapping[key]: val for key, val in pretrained_state_dict.items()})\n", | ||
| " model.load_state_dict(updated_pretrained_state_dict)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "67b998df-4c99-4f18-a408-a1feef5c483a", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Load model from torch_hub\n", | ||
| "\n", | ||
| "mhub = torch.hub.load(\"facebookresearch/omnivore:main\", model=\"omnivore_swinT\")\n", | ||
| "mhub.eval()\n", | ||
| "print(count_parameters(mhub))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "5d3de763-8be6-450d-958b-0a01d4dc8b3b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "m = omnivore.omnivore_swin_t()\n", | ||
| "\n", | ||
| "# Check that it have same number of parameter\n", | ||
| "print(count_parameters(m))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "a1168d33-0d11-46df-bef4-51f1973c2f95", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "custom_load_state_dict(m, mhub.state_dict())\n", | ||
| "m = m.eval()\n" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "f1d192ad-4902-49e1-baf6-77e7474a33bf", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Inference test" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "fe6345f3-97b5-46cd-b068-aea129aedde4", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Download imagenet class and image\n", | ||
| "# Uncomment to download\n", | ||
| "!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O imagenet_class_index.json\n", | ||
| "with open(\"imagenet_class_index.json\", \"r\") as f:\n", | ||
| " imagenet_classnames = json.load(f)\n", | ||
| "\n", | ||
| "# Create an id to label name mapping\n", | ||
| "imagenet_id_to_classname = {}\n", | ||
| "for k, v in imagenet_classnames.items():\n", | ||
| " imagenet_id_to_classname[k] = v[1] \n", | ||
| "\n", | ||
| "# Download the example image file\n", | ||
| "# Uncomment to download\n", | ||
| "!wget -O library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg\n", | ||
| "\n", | ||
| "image_path = \"library.jpg\"\n", | ||
| "image_pil = Image.open(image_path).convert(\"RGB\")\n", | ||
| "plt.figure(figsize=(6, 6))\n", | ||
| "plt.imshow(image_pil)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "19d91191-d461-4ae7-93cf-f8eb325133be", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "image_transform = T.Compose(\n", | ||
| " [\n", | ||
| " T.Resize(224),\n", | ||
| " T.CenterCrop(224),\n", | ||
| " T.ToTensor(),\n", | ||
| " T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", | ||
| " ]\n", | ||
| ")\n", | ||
| "image = image_transform(image_pil) # C H W\n", | ||
| "\n", | ||
| "# Adding batch and time (D) dimension\n", | ||
| "image = image.unsqueeze(0).unsqueeze(2) # B C D H W" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "22a4b65b-e9ce-48e3-9f1b-e79acb5a38f5", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "def infer(model):\n", | ||
| " with torch.no_grad():\n", | ||
| " prediction = model(image, input_type=\"image\")\n", | ||
| " pred_classes = prediction.topk(k=5).indices\n", | ||
| "\n", | ||
| " pred_class_names = [imagenet_id_to_classname[str(i.item())] for i in pred_classes[0]]\n", | ||
| " print(\"Top 5 predicted labels: %s\" % \", \".join(pred_class_names))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "b4305dd1-64ef-4a0f-b817-f5efe5f23980", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Test both model to infer the same image and make sure the output classes are the same\n", | ||
| "infer(m)\n", | ||
| "infer(mhub)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "779c2578-1cba-478c-8909-f47330e1b376", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Make sure the output of the trunk / encoder are the same" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "2d85f5d7-f77d-47d3-8bec-3bfa51687953", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "m_feature = m.encoder(image)\n", | ||
| "mhub_feature = mhub.trunk(image)" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "b5caf4aa-b964-4969-b531-318b9721bf17", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# See the first 10 features are the same\n", | ||
| "m_feature.flatten()[:10], mhub_feature[0].flatten()[:10]" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "ebd36e03-c86a-4d4f-a83a-b8ff3bd757b9", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Make sure all the features are the same\n", | ||
| "np.all(np.array(m_feature == mhub_feature[0]))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "08adda52-d593-4d58-816a-4a6a6205ce3d", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Test on randomly generated input" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "1dcacbc6-5cbc-4e01-b0c9-8404a33f13da", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "mock_video = torch.randn(1, 3, 10, 112, 112)\n", | ||
| "\n", | ||
| "m_output = m(mock_video, input_type=\"video\")\n", | ||
| "mhub_output = mhub(mock_video, input_type=\"video\")\n", | ||
| "\n", | ||
| "np.all(np.array(m_output == mhub_output[0]))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "357fa8cd-a853-4b12-ad87-44dba43c133b", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "mock_depth = torch.randn(1, 4, 1, 112, 112)\n", | ||
| "\n", | ||
| "m_output = m(mock_video, input_type=\"rgbd\")\n", | ||
| "mhub_output = mhub(mock_video, input_type=\"rgbd\")\n", | ||
| "\n", | ||
| "np.all(np.array(m_output == mhub_output[0]))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "6bc03203-430f-4ab9-86a0-acbfadc89f67", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "Python 3 (ipykernel)", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.9.12" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| import torchmultimodal.models.omnivore as omnivore | ||
| from torchmultimodal.utils.common import get_current_device | ||
|
|
||
|
|
||
| class TestOmnivoreModel(unittest.TestCase): | ||
| def setUp(self): | ||
| torch.manual_seed(42) | ||
| self.device = get_current_device() | ||
|
|
||
| def test_omnivore_swin_t_forward(self): | ||
| model = omnivore.omnivore_swin_t().to(self.device) | ||
| self.assertTrue(isinstance(model, torch.nn.Module)) | ||
|
|
||
| image = torch.randn(1, 3, 1, 112, 112) # B C D H W | ||
| image_score = model(image, input_type="image") | ||
| self.assertEqual(image_score.size(), torch.Size((1, 1000))) | ||
|
|
||
| rgbd = torch.randn(1, 4, 1, 112, 112) | ||
| rgbd_score = model(rgbd, input_type="rgbd") | ||
| self.assertEqual(rgbd_score.size(), torch.Size((1, 19))) | ||
|
|
||
| video = torch.randn(1, 3, 4, 112, 112) | ||
| video_score = model(video, input_type="video") | ||
| self.assertEqual(video_score.size(), torch.Size((1, 400))) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| from torchmultimodal.modules.encoders.swin_transformer_3d_encoder import ( | ||
| PatchEmbed3d, | ||
| SwinTransformer3dEncoder, | ||
| ) | ||
| from torchmultimodal.utils.common import get_current_device | ||
|
|
||
|
|
||
| class TestSwinTransformer3dEncoder(unittest.TestCase): | ||
| def setUp(self): | ||
| torch.manual_seed(42) | ||
| self.device = get_current_device() | ||
|
|
||
| # Setup Encoder to test | ||
| self.encoder = SwinTransformer3dEncoder( | ||
| patch_size=(2, 4, 4), | ||
| embed_dim=96, | ||
| depths=[2, 2, 6, 2], | ||
| num_heads=[3, 6, 12, 24], | ||
| window_size=(8, 7, 7), | ||
| stochastic_depth_prob=0.2, | ||
| norm_layer=torch.nn.LayerNorm, | ||
| patch_embed=PatchEmbed3d, | ||
| ).to(self.device) | ||
|
|
||
| def test_swin_transformer_3d_encoder(self): | ||
| self.assertTrue(isinstance(self.encoder, torch.nn.Module)) | ||
YosuaMichael marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| image = torch.randn(1, 3, 1, 112, 112) # B C D H W | ||
|
|
||
| scores = self.encoder(image) | ||
| self.assertEqual(scores.size(), torch.Size([1, 768])) | ||
| self.assertAlmostEqual(scores.abs().sum().item(), 277.638336, 3) | ||
|
|
||
| def test_swin_transformer_3d_scripting(self): | ||
| torch.jit.script(self.encoder) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
|
|
||
| class OmnivoreArchitecture(nn.Module): | ||
YosuaMichael marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Omnivore is a model that accept multiple vision modality. | ||
|
|
||
| Omnivore (https://arxiv.org/abs/2201.08377) is a single model that able to do classification | ||
| on images, videos, and single-view 3D data using the same shared parameters of the encoder. | ||
|
|
||
| Args: encoder (nn.Module): Instantiated encoder. | ||
| See SwinTransformer3dEncoder class. | ||
| heads (Optinal[nn.ModuleDict]): Dictionary of multiple heads for each dataset type | ||
|
|
||
| Inputs: x (Tensor): 5 Dimensional batched video tensor with format of B C D H W | ||
| where B is batch, C is channel, D is time, H is height, and W is width. | ||
| input_type (str): The dataset type of the input, this will used to choose | ||
| the correct head. | ||
| """ | ||
|
|
||
| def __init__(self, encoder: nn.Module, heads: nn.ModuleDict): | ||
| super().__init__() | ||
| self.encoder = encoder | ||
| self.heads = heads | ||
| self.input_types = set(heads.keys()) | ||
|
|
||
| def forward(self, x: torch.Tensor, input_type: str): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idiom of passing the Speaking offline with @YosuaMichael there are some concerns on whether the whole approach will work well on a distributed setup.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @datumbox, currently I plan to try on doing the training first and see if there are any particular problem with the architectures (whether the current or multiple model with shared encoder). In particular, I need to check the behaviour of having 2 models
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an interesting comment. confirming, does the ckpt include all the heads?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ankitade yes, the original checkpoint include all the heads There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ankitade @YosuaMichael It's worth keeping in mind that this pattern wont support FX. FX does tracing which means the flow of the execution of the model should not depend on the input. By adding the string FX won't know what to execute. Even though you might not be interested right now to make it FX traceable, on the future you might want to adopt FX quantization or other FX-based utils. This is the reason my advice is to split this module in submodules depending on the head. It's a safer idiom that would be forward compatible with future core expansions. It might require some massaging on the original weights to fix but I think that's straightforward to do and worth it. Up to you. :) |
||
| x = self.encoder(x) | ||
| assert ( | ||
YosuaMichael marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| input_type in self.input_types | ||
YosuaMichael marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ), f"Unsupported input_type: {input_type}, please use one of {self.input_types}" | ||
| x = self.heads[input_type](x) | ||
| return x | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use our test utility to assert on values and shapes. (we need to assert values too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! Will do.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the utiltiy for assertion:
multimodal/test/test_utils.py
Line 74 in e857541
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah okay, will change
self.assertEqualwithassert_expectedthenThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@langong347 after I read the utility
assert_expectedit seems to compare two tensor with float type. In this case, I thinkassertEqualis better for comparing the size since it is integer?