@@ -119,7 +119,13 @@ struct OrtStatus {
119119#define BACKEND_ARMNN " "
120120#endif
121121
122- #define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN
122+ #if USE_DML
123+ #define BACKEND_DML " -DML"
124+ #else
125+ #define BACKEND_DML " "
126+ #endif
127+
128+ #define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML
123129#include " core/session/onnxruntime_cxx_api.h"
124130#include " core/providers/providers.h"
125131#include " core/providers/cpu/cpu_execution_provider.h"
@@ -159,6 +165,9 @@ std::string nuphar_settings;
159165#ifdef USE_ARMNN
160166#include " core/providers/armnn/armnn_provider_factory.h"
161167#endif
168+ #ifdef USE_DML
169+ #include " core/providers/dml/dml_provider_factory.h"
170+ #endif
162171
163172#define PYBIND_UNREFERENCED_PARAMETER (parameter ) ((void )(parameter))
164173
@@ -176,6 +185,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar
176185std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISAI (const char * backend_type, int device_id);
177186std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ACL (int use_arena);
178187std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ArmNN (int use_arena);
188+ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML (int device_id);
179189} // namespace onnxruntime
180190
181191#if defined(_MSC_VER)
@@ -374,7 +384,7 @@ const std::vector<std::string>& GetAllProviders() {
374384 static std::vector<std::string> all_providers = {kTensorrtExecutionProvider , kCudaExecutionProvider , kMIGraphXExecutionProvider ,
375385 kNGraphExecutionProvider , kOpenVINOExecutionProvider , kDnnlExecutionProvider ,
376386 kNupharExecutionProvider , kVitisAIExecutionProvider , kArmNNExecutionProvider ,
377- kAclExecutionProvider , kCpuExecutionProvider };
387+ kAclExecutionProvider , kDmlExecutionProvider , kCpuExecutionProvider };
378388 return all_providers;
379389}
380390
@@ -572,6 +582,9 @@ void RegisterExecutionProviders(InferenceSession* sess, const std::vector<std::s
572582 sess, *onnxruntime::CreateExecutionProviderFactory_ArmNN (sess->GetSessionOptions ().enable_cpu_mem_arena ));
573583#endif
574584 } else if (type == kDmlExecutionProvider ) {
585+ #ifdef USE_DML
586+ RegisterExecutionProvider (sess, *onnxruntime::CreateExecutionProviderFactory_DML (0 ));
587+ #endif
575588 } else {
576589 // unknown provider
577590 throw std::runtime_error (" Unknown Provider Type: " + type);
@@ -721,6 +734,9 @@ void addGlobalMethods(py::module& m, const Environment& env) {
721734#endif
722735#ifdef USE_ARMNN
723736 onnxruntime::CreateExecutionProviderFactory_ArmNN (0 )
737+ #endif
738+ #ifdef USE_DML
739+ onnxruntime::CreateExecutionProviderFactory_DML (0 )
724740#endif
725741 };
726742
0 commit comments