Skip to content

Commit 9db3d7a

Browse files
committed
Ported @RobinKa's changes to expose DirectML to python (microsoft#3359) against latest master
1 parent d5b98a1 commit 9db3d7a

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,13 @@
112112
#define BACKEND_ARMNN ""
113113
#endif
114114

115-
#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN
115+
#if USE_DML
116+
#define BACKEND_DML "-DML"
117+
#else
118+
#define BACKEND_DML ""
119+
#endif
120+
121+
#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML
116122
#include "core/session/onnxruntime_cxx_api.h"
117123
#include "core/providers/providers.h"
118124
#include "core/providers/cpu/cpu_execution_provider.h"
@@ -152,6 +158,9 @@ std::string nuphar_settings;
152158
#ifdef USE_ARMNN
153159
#include "core/providers/armnn/armnn_provider_factory.h"
154160
#endif
161+
#ifdef USE_DML
162+
#include "core/providers/dml/dml_provider_factory.h"
163+
#endif
155164

156165
#define PYBIND_UNREFERENCED_PARAMETER(parameter) ((void)(parameter))
157166

@@ -169,6 +178,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar
169178
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISAI(const char* backend_type, int device_id);
170179
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ACL(int use_arena);
171180
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ArmNN(int use_arena);
181+
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(int device_id);
172182
} // namespace onnxruntime
173183

174184
#if defined(_MSC_VER)
@@ -486,6 +496,10 @@ void RegisterExecutionProviders(InferenceSession* sess, const std::vector<std::s
486496
} else if (type == kArmNNExecutionProvider) {
487497
#ifdef USE_ARMNN
488498
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_ArmNN(sess->GetSessionOptions().enable_cpu_mem_arena));
499+
#endif
500+
} else if (type == kDmlExecutionProvider) {
501+
#ifdef USE_DML
502+
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_DML(0));
489503
#endif
490504
} else {
491505
// unknown provider
@@ -692,6 +706,9 @@ void addGlobalMethods(py::module& m, const Environment& env) {
692706
#endif
693707
#ifdef USE_ARMNN
694708
onnxruntime::CreateExecutionProviderFactory_ArmNN(0)
709+
#endif
710+
#ifdef USE_DML
711+
onnxruntime::CreateExecutionProviderFactory_DML(0)
695712
#endif
696713
};
697714

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@
6565
elif '--use_armnn' in sys.argv:
6666
package_name = 'onnxruntime-armnn'
6767
sys.argv.remove('--use_armnn')
68+
elif '--use_dml' in sys.argv:
69+
package_name = 'onnxruntime-dml'
70+
sys.argv.remove('--use_dml')
6871

6972
# PEP 513 defined manylinux1_x86_64 and manylinux1_i686
7073
# PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686

tools/ci_build/build.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1337,7 +1337,7 @@ def run_nodejs_tests(nodejs_binding_dir):
13371337

13381338
def build_python_wheel(
13391339
source_dir, build_dir, configs, use_cuda, use_ngraph, use_dnnl,
1340-
use_tensorrt, use_openvino, use_nuphar, use_vitisai, use_acl, use_armnn,
1340+
use_tensorrt, use_openvino, use_nuphar, use_vitisai, use_acl, use_armnn, use_dml,
13411341
wheel_name_suffix, nightly_build=False, featurizers_build=False, use_ninja=False):
13421342
for config in configs:
13431343
cwd = get_config_build_dir(build_dir, config)
@@ -1384,6 +1384,8 @@ def build_python_wheel(
13841384
args.append('--use_acl')
13851385
elif use_armnn:
13861386
args.append('--use_armnn')
1387+
elif use_dml:
1388+
args.append('--use_dml')
13871389

13881390
run_subprocess(args, cwd=cwd)
13891391

@@ -1675,6 +1677,7 @@ def main():
16751677
args.use_vitisai,
16761678
args.use_acl,
16771679
args.use_armnn,
1680+
args.use_dml,
16781681
args.wheel_name_suffix,
16791682
nightly_build=nightly_build,
16801683
featurizers_build=args.use_featurizers,

0 commit comments

Comments
 (0)