Skip to content

Commit 4553b2e

Browse files
authored
Expose DirectML provider to python (conflicts resolved from #3359) (#4630)
1 parent c239ff0 commit 4553b2e

File tree

5 files changed

+37
-3
lines changed

5 files changed

+37
-3
lines changed

cmake/external/dml.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
2121
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
2222
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
2323
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/DirectML.3.0.0)
24+
set(DML_SHARED_LIB DirectML.dll)
2425

2526
# Restore nuget packages, which will pull down the DirectML redist package
2627
add_custom_command(

cmake/onnxruntime_python.cmake

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,15 @@ if (onnxruntime_USE_NUPHAR)
377377
)
378378
endif()
379379

380+
if (onnxruntime_USE_DML)
381+
add_custom_command(
382+
TARGET onnxruntime_pybind11_state POST_BUILD
383+
COMMAND ${CMAKE_COMMAND} -E copy
384+
${DML_PACKAGE_DIR}/bin/${onnxruntime_target_platform}/${DML_SHARED_LIB}
385+
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/capi/
386+
)
387+
endif()
388+
380389
if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
381390
include(onnxruntime_language_interop_ops.cmake)
382391
endif()

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,13 @@ struct OrtStatus {
119119
#define BACKEND_ARMNN ""
120120
#endif
121121

122-
#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN
122+
#if USE_DML
123+
#define BACKEND_DML "-DML"
124+
#else
125+
#define BACKEND_DML ""
126+
#endif
127+
128+
#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX BACKEND_ACL BACKEND_ARMNN BACKEND_DML
123129
#include "core/session/onnxruntime_cxx_api.h"
124130
#include "core/providers/providers.h"
125131
#include "core/providers/cpu/cpu_execution_provider.h"
@@ -159,6 +165,9 @@ std::string nuphar_settings;
159165
#ifdef USE_ARMNN
160166
#include "core/providers/armnn/armnn_provider_factory.h"
161167
#endif
168+
#ifdef USE_DML
169+
#include "core/providers/dml/dml_provider_factory.h"
170+
#endif
162171

163172
#define PYBIND_UNREFERENCED_PARAMETER(parameter) ((void)(parameter))
164173

@@ -176,6 +185,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar
176185
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISAI(const char* backend_type, int device_id);
177186
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ACL(int use_arena);
178187
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ArmNN(int use_arena);
188+
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(int device_id);
179189
} // namespace onnxruntime
180190

181191
#if defined(_MSC_VER)
@@ -374,7 +384,7 @@ const std::vector<std::string>& GetAllProviders() {
374384
static std::vector<std::string> all_providers = {kTensorrtExecutionProvider, kCudaExecutionProvider, kMIGraphXExecutionProvider,
375385
kNGraphExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider,
376386
kNupharExecutionProvider, kVitisAIExecutionProvider, kArmNNExecutionProvider,
377-
kAclExecutionProvider, kCpuExecutionProvider};
387+
kAclExecutionProvider, kDmlExecutionProvider, kCpuExecutionProvider};
378388
return all_providers;
379389
}
380390

@@ -572,6 +582,9 @@ void RegisterExecutionProviders(InferenceSession* sess, const std::vector<std::s
572582
sess, *onnxruntime::CreateExecutionProviderFactory_ArmNN(sess->GetSessionOptions().enable_cpu_mem_arena));
573583
#endif
574584
} else if (type == kDmlExecutionProvider) {
585+
#ifdef USE_DML
586+
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_DML(0));
587+
#endif
575588
} else {
576589
// unknown provider
577590
throw std::runtime_error("Unknown Provider Type: " + type);
@@ -721,6 +734,9 @@ void addGlobalMethods(py::module& m, const Environment& env) {
721734
#endif
722735
#ifdef USE_ARMNN
723736
onnxruntime::CreateExecutionProviderFactory_ArmNN(0)
737+
#endif
738+
#ifdef USE_DML
739+
onnxruntime::CreateExecutionProviderFactory_DML(0)
724740
#endif
725741
};
726742

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@
6565
elif '--use_armnn' in sys.argv:
6666
package_name = 'onnxruntime-armnn'
6767
sys.argv.remove('--use_armnn')
68+
elif '--use_dml' in sys.argv:
69+
package_name = 'onnxruntime-dml'
70+
sys.argv.remove('--use_dml')
6871

6972
# PEP 513 defined manylinux1_x86_64 and manylinux1_i686
7073
# PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686
@@ -188,6 +191,8 @@ def run(self):
188191
libs.extend(['onnxruntime_providers_tensorrt.dll'])
189192
# nGraph Libs
190193
libs.extend(['ngraph.dll', 'cpu_backend.dll', 'tbb.dll', 'mimalloc-override.dll', 'mimalloc-redirect.dll', 'mimalloc-redirect32.dll'])
194+
# DirectML Libs
195+
libs.extend(['directml.dll'])
191196
# Nuphar Libs
192197
libs.extend(['tvm.dll'])
193198
if nightly_build:

tools/ci_build/build.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1353,7 +1353,7 @@ def run_nodejs_tests(nodejs_binding_dir):
13531353

13541354
def build_python_wheel(
13551355
source_dir, build_dir, configs, use_cuda, use_ngraph, use_dnnl,
1356-
use_tensorrt, use_openvino, use_nuphar, use_vitisai, use_acl, use_armnn,
1356+
use_tensorrt, use_openvino, use_nuphar, use_vitisai, use_acl, use_armnn, use_dml,
13571357
wheel_name_suffix, enable_training, nightly_build=False, featurizers_build=False, use_ninja=False):
13581358
for config in configs:
13591359
cwd = get_config_build_dir(build_dir, config)
@@ -1402,6 +1402,8 @@ def build_python_wheel(
14021402
args.append('--use_acl')
14031403
elif use_armnn:
14041404
args.append('--use_armnn')
1405+
elif use_dml:
1406+
args.append('--use_dml')
14051407

14061408
run_subprocess(args, cwd=cwd)
14071409

@@ -1794,6 +1796,7 @@ def main():
17941796
args.use_vitisai,
17951797
args.use_acl,
17961798
args.use_armnn,
1799+
args.use_dml,
17971800
args.wheel_name_suffix,
17981801
args.enable_training,
17991802
nightly_build=nightly_build,

0 commit comments

Comments
 (0)