Skip to content
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b5edd05
Add swin_transformer_3d
YosuaMichael May 12, 2022
47168af
Add architecture and models for swin_t
YosuaMichael May 12, 2022
5d15888
Fix architecture by adding super().__init__() and add examples for co…
YosuaMichael May 13, 2022
b9666f6
Fix typing issue
YosuaMichael May 13, 2022
fa5b6a6
Update swin_transformer_3d encoder to make sure it is jit scriptable …
YosuaMichael May 13, 2022
c76fa7f
Add test for omnivore and fix bug on the avgpool
YosuaMichael May 13, 2022
146c5e9
Putting PatchEmbedOmnivore into omnivore model and resolves some comm…
YosuaMichael May 17, 2022
579552e
Add license header
YosuaMichael May 17, 2022
092d339
Update the test to follow the update on model API
YosuaMichael May 17, 2022
f2ad2d1
Improve omnivore test and resolve comment from Lan
YosuaMichael May 18, 2022
d40b29c
Add option to get encoder_only for omnivore model, cleanup swin_trans…
YosuaMichael May 20, 2022
058b729
Update the test code for omnivore
YosuaMichael May 20, 2022
96b9211
Add test on ShiftedWindowAttention3d when there is zero shift
YosuaMichael May 20, 2022
cd8d199
Update swin_transformer_3d_encoder to be easier to upstream
YosuaMichael May 25, 2022
44786d8
Fix format and mypy
YosuaMichael May 25, 2022
d6077ff
Merge branch 'main' into omnivore/add-model
YosuaMichael Jun 15, 2022
8ab13bc
Create separate encoder function and simplify architecture a bit
YosuaMichael Jun 15, 2022
063a403
Fix mypy problem
YosuaMichael Jun 15, 2022
c8e384a
Upstream swin_transformer components from torchvision
YosuaMichael Jun 16, 2022
b2f010c
Use lowercase variable name, format with ufmt
YosuaMichael Jun 16, 2022
13f4605
Fix format and add license
YosuaMichael Jun 20, 2022
b9fccab
Fix formatting and expected test result
YosuaMichael Jun 22, 2022
c505b2e
Fix formatting
YosuaMichael Jun 22, 2022
e7869d7
Loosen test comparison
YosuaMichael Jun 22, 2022
fbcf807
Remove lru_cache usage to make module jit scriptable
YosuaMichael Jun 23, 2022
a8a59a3
Merge branch 'main' into omnivore/add-model
YosuaMichael Jun 23, 2022
e707b96
Use absolute path for test_utils
YosuaMichael Jun 23, 2022
c20e1cd
Ufmt format
YosuaMichael Jun 23, 2022
6cccb83
Merge branch 'main' into omnivore/add-model
YosuaMichael Jun 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/omnivore/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
library.jpg
imagenet_class_index.json
276 changes: 276 additions & 0 deletions examples/omnivore/LoadOriginalPretrainedWeightAndCompare.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "59ef9a30-174d-4ea3-b4a0-7c465e4c3f53",
"metadata": {},
"outputs": [],
"source": [
"# Need to install einops and timm for original omnivore model, and matplotlib for visualization\n",
"! pip install einops timm matplotlib\n",
"\n",
"import torch\n",
"import torchvision.transforms as T\n",
"import torchmultimodal.models.omnivore as omnivore\n",
"\n",
"from PIL import Image\n",
"import collections\n",
"import json\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55e221ba-528e-4f09-9d49-d0f54a577bc0",
"metadata": {},
"outputs": [],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"def custom_load_state_dict(model, pretrained_state_dict):\n",
" # Convert the pretrained_state_dict so it have the same keys as the model\n",
" # then load the value of the weight into the model\n",
" pretrained_keys = list(pretrained_state_dict.keys())\n",
" model_keys = list(model.state_dict().keys())\n",
" key_mapping = {pretrained_keys[i]: model_keys[i] for i in range(len(model_keys))}\n",
" updated_pretrained_state_dict = collections.OrderedDict({key_mapping[key]: val for key, val in pretrained_state_dict.items()})\n",
" model.load_state_dict(updated_pretrained_state_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67b998df-4c99-4f18-a408-a1feef5c483a",
"metadata": {},
"outputs": [],
"source": [
"# Load model from torch_hub\n",
"\n",
"mhub = torch.hub.load(\"facebookresearch/omnivore:main\", model=\"omnivore_swinT\")\n",
"mhub.eval()\n",
"print(count_parameters(mhub))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d3de763-8be6-450d-958b-0a01d4dc8b3b",
"metadata": {},
"outputs": [],
"source": [
"m = omnivore.omnivore_swin_t()\n",
"\n",
"# Check that it have same number of parameter\n",
"print(count_parameters(m))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1168d33-0d11-46df-bef4-51f1973c2f95",
"metadata": {},
"outputs": [],
"source": [
"custom_load_state_dict(m, mhub.state_dict())\n",
"m = m.eval()\n"
]
},
{
"cell_type": "markdown",
"id": "f1d192ad-4902-49e1-baf6-77e7474a33bf",
"metadata": {},
"source": [
"# Inference test"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe6345f3-97b5-46cd-b068-aea129aedde4",
"metadata": {},
"outputs": [],
"source": [
"# Download imagenet class and image\n",
"# Uncomment to download\n",
"!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O imagenet_class_index.json\n",
"with open(\"imagenet_class_index.json\", \"r\") as f:\n",
" imagenet_classnames = json.load(f)\n",
"\n",
"# Create an id to label name mapping\n",
"imagenet_id_to_classname = {}\n",
"for k, v in imagenet_classnames.items():\n",
" imagenet_id_to_classname[k] = v[1] \n",
"\n",
"# Download the example image file\n",
"# Uncomment to download\n",
"!wget -O library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg\n",
"\n",
"image_path = \"library.jpg\"\n",
"image_pil = Image.open(image_path).convert(\"RGB\")\n",
"plt.figure(figsize=(6, 6))\n",
"plt.imshow(image_pil)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19d91191-d461-4ae7-93cf-f8eb325133be",
"metadata": {},
"outputs": [],
"source": [
"image_transform = T.Compose(\n",
" [\n",
" T.Resize(224),\n",
" T.CenterCrop(224),\n",
" T.ToTensor(),\n",
" T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
" ]\n",
")\n",
"image = image_transform(image_pil) # C H W\n",
"\n",
"# Adding batch and time (D) dimension\n",
"image = image.unsqueeze(0).unsqueeze(2) # B C D H W"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22a4b65b-e9ce-48e3-9f1b-e79acb5a38f5",
"metadata": {},
"outputs": [],
"source": [
"def infer(model):\n",
" with torch.no_grad():\n",
" prediction = model(image, input_type=\"image\")\n",
" pred_classes = prediction.topk(k=5).indices\n",
"\n",
" pred_class_names = [imagenet_id_to_classname[str(i.item())] for i in pred_classes[0]]\n",
" print(\"Top 5 predicted labels: %s\" % \", \".join(pred_class_names))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4305dd1-64ef-4a0f-b817-f5efe5f23980",
"metadata": {},
"outputs": [],
"source": [
"# Test both model to infer the same image and make sure the output classes are the same\n",
"infer(m)\n",
"infer(mhub)"
]
},
{
"cell_type": "markdown",
"id": "779c2578-1cba-478c-8909-f47330e1b376",
"metadata": {},
"source": [
"# Make sure the output of the trunk / encoder are the same"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d85f5d7-f77d-47d3-8bec-3bfa51687953",
"metadata": {},
"outputs": [],
"source": [
"m_feature = m.encoder(image)\n",
"mhub_feature = mhub.trunk(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5caf4aa-b964-4969-b531-318b9721bf17",
"metadata": {},
"outputs": [],
"source": [
"# See the first 10 features are the same\n",
"m_feature.flatten()[:10], mhub_feature[0].flatten()[:10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ebd36e03-c86a-4d4f-a83a-b8ff3bd757b9",
"metadata": {},
"outputs": [],
"source": [
"# Make sure all the features are the same\n",
"np.all(np.array(m_feature == mhub_feature[0]))"
]
},
{
"cell_type": "markdown",
"id": "08adda52-d593-4d58-816a-4a6a6205ce3d",
"metadata": {},
"source": [
"# Test on randomly generated input"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1dcacbc6-5cbc-4e01-b0c9-8404a33f13da",
"metadata": {},
"outputs": [],
"source": [
"mock_video = torch.randn(1, 3, 10, 112, 112)\n",
"\n",
"m_output = m(mock_video, input_type=\"video\")\n",
"mhub_output = mhub(mock_video, input_type=\"video\")\n",
"\n",
"np.all(np.array(m_output == mhub_output[0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "357fa8cd-a853-4b12-ad87-44dba43c133b",
"metadata": {},
"outputs": [],
"source": [
"mock_depth = torch.randn(1, 4, 1, 112, 112)\n",
"\n",
"m_output = m(mock_video, input_type=\"rgbd\")\n",
"mhub_output = mhub(mock_video, input_type=\"rgbd\")\n",
"\n",
"np.all(np.array(m_output == mhub_output[0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6bc03203-430f-4ab9-86a0-acbfadc89f67",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
49 changes: 49 additions & 0 deletions test/models/test_omnivore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
import torchmultimodal.models.omnivore as omnivore
from torchmultimodal.utils.common import get_current_device

from ..test_utils import set_rng_seed


class TestOmnivoreModel(unittest.TestCase):
def setUp(self):
set_rng_seed(42)
self.device = get_current_device()

def test_omnivore_swin_t_forward(self):
model = omnivore.omnivore_swin_t().to(self.device)
self.assertTrue(isinstance(model, torch.nn.Module))

image = torch.randn(1, 3, 1, 112, 112) # B C D H W
image_score = model(image, input_type="image")
self.assertEqual(image_score.size(), torch.Size((1, 1000)))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use our test utility to assert on values and shapes. (we need to assert values too)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Will do.

Copy link

@langong347 langong347 May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the utiltiy for assertion:

def assert_expected(

Copy link
Contributor Author

@YosuaMichael YosuaMichael May 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okay, will change self.assertEqual with assert_expected then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@langong347 after I read the utility assert_expected it seems to compare two tensor with float type. In this case, I think assertEqual is better for comparing the size since it is integer?

self.assertAlmostEqual(image_score.abs().sum().item(), 178.92318, 3)

rgbd = torch.randn(1, 4, 1, 112, 112)
rgbd_score = model(rgbd, input_type="rgbd")
self.assertEqual(rgbd_score.size(), torch.Size((1, 19)))
self.assertAlmostEqual(rgbd_score.abs().sum().item(), 3.39016, 3)

video = torch.randn(1, 3, 4, 112, 112)
video_score = model(video, input_type="video")
self.assertEqual(video_score.size(), torch.Size((1, 400)))
self.assertAlmostEqual(video_score.abs().sum().item(), 102.76638, 3)

def test_omnivore_forward_wrong_input_type(self):
model = omnivore.omnivore_swin_t().to(self.device)

image = torch.randn(1, 3, 1, 112, 112) # B C D H W
with self.assertRaises(AssertionError) as cm:
_ = model(image, input_type="_WRONG_TYPE_")
self.assertEqual(
"Unsupported input_type: _WRONG_TYPE_, please use one of {'video', 'rgbd', 'image'}",
str(cm.exception),
)
Loading