2020
2121#include < coremltools/mlmodel/src/Format.hpp>
2222
23+ #ifdef HAS_MACOS_10_15
24+ #import < ml/neural_net/mps_device_manager.h>
25+ #endif
26+
2327namespace turi {
2428namespace image_deep_feature_extractor {
2529
@@ -158,6 +162,27 @@ void build_vision_feature_print_scene_spec(const std::string& model_path) {
158162
159163}
160164
165+ API_AVAILABLE (macos(10.13 ), ios(11.0 ))
166+ MLModel* create_model (NSURL * url, NSError * _Nullable* error) {
167+ #if defined(HAS_MACOS_10_15) && !defined(TC_BUILD_IOS)
168+ if (@available (macos 10.15 , ios 10.13 , *)) {
169+ MLModelConfiguration* config = [[MLModelConfiguration alloc ] init ];
170+ config.preferredMetalDevice = [TCMPSDeviceManager sharedInstance ].preferredDevice ;
171+
172+ if (config.preferredMetalDevice != nil ) {
173+ logprogress_stream << " Using GPU (" << config.preferredMetalDevice .name .UTF8String
174+ << " ) to extract features." ;
175+ } else {
176+ // Assume that CoreML will fall back to CPU if no Metal devices are available.
177+ logprogress_stream << " Using CPU to extract features." ;
178+ }
179+
180+ return [MLModel modelWithContentsOfURL: url configuration: config error: error];
181+ }
182+ #endif // defined(HAS_MACOS_10_15) && !defined(TC_BUILD_IOS)
183+ return [MLModel modelWithContentsOfURL: url error: error];
184+ }
185+
161186API_AVAILABLE (macos(10.13 ),ios(11.0 ))
162187static MLModel *create_model (const std::string& download_path,
163188 const std::string& model_name) {
@@ -173,7 +198,7 @@ void build_vision_feature_print_scene_spec(const std::string& model_path) {
173198 if (boost::filesystem::exists (compiled_modified_model_path)) {
174199
175200 NSError * error = nil ;
176- result = [MLModel modelWithContentsOfURL: compiledModelURL error: &error] ;
201+ result = create_model ( compiledModelURL, &error) ;
177202
178203 if (error || !result) {
179204
@@ -240,7 +265,7 @@ void build_vision_feature_print_scene_spec(const std::string& model_path) {
240265
241266 // Load the compiled modified model
242267 NSError * error = nil ;
243- result = [MLModel modelWithContentsOfURL: compiledModelURL error: &error] ;
268+ result = create_model ( compiledModelURL, &error) ;
244269 checkNSError (error);
245270 }
246271
0 commit comments