112112#define BACKEND_ARMNN " "
113113#endif
114114
115- #define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN
115+ #if USE_DML
116+ #define BACKEND_DML " -DML"
117+ #else
118+ #define BACKEND_DML " "
119+ #endif
120+
121+ #define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML
116122#include " core/session/onnxruntime_cxx_api.h"
117123#include " core/providers/providers.h"
118124#include " core/providers/cpu/cpu_execution_provider.h"
@@ -152,6 +158,9 @@ std::string nuphar_settings;
152158#ifdef USE_ARMNN
153159#include " core/providers/armnn/armnn_provider_factory.h"
154160#endif
161+ #ifdef USE_DML
162+ #include " core/providers/dml/dml_provider_factory.h"
163+ #endif
155164
156165#define PYBIND_UNREFERENCED_PARAMETER (parameter ) ((void )(parameter))
157166
@@ -169,6 +178,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar
169178std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISAI (const char * backend_type, int device_id);
170179std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ACL (int use_arena);
171180std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ArmNN (int use_arena);
181+ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML (int device_id);
172182} // namespace onnxruntime
173183
174184#if defined(_MSC_VER)
@@ -486,6 +496,10 @@ void RegisterExecutionProviders(InferenceSession* sess, const std::vector<std::s
486496 } else if (type == kArmNNExecutionProvider ) {
487497#ifdef USE_ARMNN
488498 RegisterExecutionProvider (sess, *onnxruntime::CreateExecutionProviderFactory_ArmNN (sess->GetSessionOptions ().enable_cpu_mem_arena ));
499+ #endif
500+ } else if (type == kDmlExecutionProvider ) {
501+ #ifdef USE_DML
502+ RegisterExecutionProvider (sess, *onnxruntime::CreateExecutionProviderFactory_DML (0 ));
489503#endif
490504 } else {
491505 // unknown provider
@@ -692,6 +706,9 @@ void addGlobalMethods(py::module& m, const Environment& env) {
692706#endif
693707#ifdef USE_ARMNN
694708 onnxruntime::CreateExecutionProviderFactory_ArmNN (0 )
709+ #endif
710+ #ifdef USE_DML
711+ onnxruntime::CreateExecutionProviderFactory_DML (0 )
695712#endif
696713 };
697714
0 commit comments