-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Feature extraction #3210
Feature extraction #3210
Changes from 24 commits
daa5b4a
63736cf
37d5ccd
67216b5
d891038
18d5b70
40ec7a5
d51b193
ea237b5
f1ea4ac
8e45833
e14a97b
7dc12a1
8cc488d
62aa5e2
e0a301a
1aa4dc1
4bfedd9
e7195d6
c874c97
049c1e5
c02ac72
26a9e61
80dd166
b563393
c4b4ada
b35b2e7
c4798d3
dd0c184
faaf5f2
6d308e8
ab26484
4ff6d46
0efc31b
ca49ab2
b1c27ba
3b4455d
61ac0c2
fd0cbd5
199e1ce
3b97d6f
6a29bc5
0c5ff0c
df86dbd
fd01bcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,9 +12,12 @@ | |
| from turicreate.toolkits._internal_utils import _mac_ver | ||
| import tempfile | ||
| from . import util as test_util | ||
|
|
||
| from turicreate.toolkits._main import ToolkitError as _ToolkitError | ||
| from turicreate.toolkits.image_analysis.image_analysis import MODEL_TO_FEATURE_SIZE_MAPPING, get_deep_features | ||
|
|
||
| import coremltools | ||
|
||
| import numpy as np | ||
| from turicreate.toolkits._main import ToolkitError as _ToolkitError | ||
|
|
||
|
|
||
| def get_test_data(): | ||
|
|
@@ -63,19 +66,27 @@ def get_test_data(): | |
| ) | ||
| images.append(tc_image) | ||
|
|
||
| return tc.SFrame({"awesome_image": images}) | ||
| data_dict = {"awesome_image": images} | ||
| data = tc.SFrame(data_dict) | ||
|
|
||
| for model_name, feature_length in MODEL_TO_FEATURE_SIZE_MAPPING.items(): | ||
| if str(feature_length) in data_dict.keys(): | ||
| continue | ||
| data[str(feature_length)] = get_deep_features(data["awesome_image"], model_name) | ||
|
|
||
| return data | ||
TobyRoseman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| data = get_test_data() | ||
|
|
||
|
|
||
| class ImageSimilarityTest(unittest.TestCase): | ||
| @classmethod | ||
| def setUpClass(self, input_image_shape=(3, 224, 224), model="resnet-50"): | ||
| def setUpClass(self, input_image_shape=(3, 224, 224), model="resnet-50", feature="awesome_image"): | ||
| """ | ||
| The setup class method for the basic test case with all default values. | ||
| """ | ||
| self.feature = "awesome_image" | ||
| self.feature = feature | ||
| self.label = None | ||
| self.input_image_shape = input_image_shape | ||
| self.pre_trained_model = model | ||
|
|
@@ -286,11 +297,27 @@ def test_save_and_load(self): | |
| print("Export coreml passed") | ||
|
|
||
|
|
||
| class ImageSimilarityResnetTestWithDeepFeatures(ImageSimilarityTest): | ||
| @classmethod | ||
| def setUpClass(self): | ||
| super(ImageSimilaritySqueezeNetTest, self).setUpClass( | ||
| model="resnet-50", input_image_shape=(3, 227, 227), feature="2048" | ||
| ) | ||
|
|
||
|
|
||
| class ImageSimilaritySqueezeNetTest(ImageSimilarityTest): | ||
| @classmethod | ||
| def setUpClass(self): | ||
| super(ImageSimilaritySqueezeNetTest, self).setUpClass( | ||
| model="squeezenet_v1.1", input_image_shape=(3, 227, 227) | ||
| model="squeezenet_v1.1", input_image_shape=(3, 227, 227), feature="awesome_image" | ||
| ) | ||
|
|
||
|
|
||
| class ImageSimilaritySqueezeNetTestWithDeepFeatures(ImageSimilarityTest): | ||
| @classmethod | ||
| def setUpClass(self): | ||
| super(ImageSimilaritySqueezeNetTest, self).setUpClass( | ||
| model="squeezenet_v1.1", input_image_shape=(3, 227, 227), feature="1000" | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -301,10 +328,20 @@ class ImageSimilarityVisionFeaturePrintSceneTest(ImageSimilarityTest): | |
| @classmethod | ||
| def setUpClass(self): | ||
| super(ImageSimilarityVisionFeaturePrintSceneTest, self).setUpClass( | ||
| model="VisionFeaturePrint_Scene", input_image_shape=(3, 299, 299) | ||
| model="VisionFeaturePrint_Scene", input_image_shape=(3, 299, 299), feature="awesome_image" | ||
| ) | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| _mac_ver() < (10, 14), "VisionFeaturePrint_Scene only supported on macOS 10.14+" | ||
| ) | ||
| class ImageSimilarityVisionFeaturePrintSceneTestWithDeepFeatures(ImageSimilarityTest): | ||
| @classmethod | ||
| def setUpClass(self): | ||
| super(ImageSimilarityVisionFeaturePrintSceneTest, self).setUpClass( | ||
| model="VisionFeaturePrint_Scene", input_image_shape=(3, 299, 299), feature="2048" | ||
| ) | ||
|
|
||
| # A test to gaurantee that old code using the incorrect name still works. | ||
| @unittest.skipIf( | ||
| _mac_ver() < (10, 14), "VisionFeaturePrint_Scene only supported on macOS 10.14+" | ||
|
|
@@ -313,5 +350,16 @@ class ImageSimilarityVisionFeaturePrintSceneTest_bad_name(ImageSimilarityTest): | |
| @classmethod | ||
| def setUpClass(self): | ||
| super(ImageSimilarityVisionFeaturePrintSceneTest_bad_name, self).setUpClass( | ||
| model="VisionFeaturePrint_Screen", input_image_shape=(3, 299, 299) | ||
| model="VisionFeaturePrint_Screen", input_image_shape=(3, 299, 299), feature="awesome_image" | ||
| ) | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| _mac_ver() < (10, 14), "VisionFeaturePrint_Scene only supported on macOS 10.14+" | ||
| ) | ||
| class ImageSimilarityVisionFeaturePrintSceneTestWithDeepFeatures_bad_name(ImageSimilarityTest): | ||
| @classmethod | ||
| def setUpClass(self): | ||
| super(ImageSimilarityVisionFeaturePrintSceneTest_bad_name, self).setUpClass( | ||
| model="VisionFeaturePrint_Screen", input_image_shape=(3, 299, 299), feature="2048" | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.