@@ -12,6 +12,7 @@ function(op_library TARGET)
1212 set (hip_cc_srcs)
1313 set (xpu_cc_srcs)
1414 set (npu_cc_srcs)
15+ set (mlu_cc_srcs)
1516 set (cudnn_cu_cc_srcs)
1617 set (miopen_cu_cc_srcs)
1718 set (cudnn_cu_srcs)
@@ -24,6 +25,10 @@ function(op_library TARGET)
2425 if (WITH_ASCEND_CL)
2526 set (op_common_deps ${op_common_deps} npu_op_runner)
2627 endif ()
28+ if (WITH_MLU)
29+ set (op_common_deps ${op_common_deps} mlu_baseop)
30+ endif ()
31+
2732 # Option `UNITY` is used to specify that operator `TARGET` will compiles with Unity Build.
2833 set (options UNITY)
2934 set (oneValueArgs "" )
@@ -98,6 +103,12 @@ function(op_library TARGET)
98103 list (APPEND npu_cc_srcs ${NPU_FILE} .cc)
99104 endif ()
100105 endif ()
106+ if (WITH_MLU)
107+ string (REPLACE "_op" "_op_mlu" MLU_FILE "${TARGET} " )
108+ if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /${MLU_FILE} .cc)
109+ list (APPEND mlu_cc_srcs ${MLU_FILE} .cc)
110+ endif ()
111+ endif ()
101112 else ()
102113 foreach (src ${op_library_SRCS} )
103114 if (WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu$" )
@@ -122,6 +133,8 @@ function(op_library TARGET)
122133 list (APPEND xpu_cc_srcs ${src} )
123134 elseif (WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$" )
124135 list (APPEND npu_cc_srcs ${src} )
136+ elseif (WITH_MLU AND ${src} MATCHES ".*_op_mlu.cc$" )
137+ list (APPEND mlu_cc_srcs ${src} )
125138 elseif (${src} MATCHES ".*\\ .cc$" )
126139 list (APPEND cc_srcs ${src} )
127140 else ()
@@ -196,7 +209,7 @@ function(op_library TARGET)
196209 # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
197210 if (WITH_UNITY_BUILD AND op_library_UNITY)
198211 # Combine the cc source files.
199- compose_unity_target_sources (${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} )
212+ compose_unity_target_sources (${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} ${mlu_cc_srcs} )
200213 if (TARGET ${UNITY_TARGET} )
201214 # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
202215 target_sources (${UNITY_TARGET} PRIVATE ${unity_target_cc_sources} )
@@ -207,7 +220,7 @@ function(op_library TARGET)
207220 # Add alias library to handle dependencies.
208221 add_library (${TARGET} ALIAS ${UNITY_TARGET} )
209222 else ()
210- cc_library (${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} DEPS ${op_library_DEPS}
223+ cc_library (${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} ${mlu_cc_srcs} DEPS ${op_library_DEPS}
211224 ${op_common_deps} )
212225 endif ()
213226 endif ()
@@ -262,8 +275,10 @@ function(op_library TARGET)
262275 list (LENGTH xpu_cc_srcs xpu_cc_srcs_len)
263276 list (LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len)
264277 list (LENGTH npu_cc_srcs npu_cc_srcs_len)
278+ list (LENGTH mlu_cc_srcs mlu_cc_srcs_len)
265279 if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND
266- ${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND ${npu_cc_srcs_len} EQUAL 0)
280+ ${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND
281+ ${npu_cc_srcs_len} EQUAL 0 AND ${mlu_cc_srcs_len} EQUAL 0)
267282 file (APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET} );\n " )
268283 set (pybind_flag 1)
269284 endif ()
@@ -322,6 +337,24 @@ function(op_library TARGET)
322337 endif ()
323338 file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${NPU_TARGET} , NPU);\n " )
324339 endif ()
340+ if (WITH_MLU AND ${mlu_cc_srcs_len} GREATER 0)
341+ file (READ ${ORIGINAL_TARGET} _mlu.cc TARGET_MLU_CONTENT )
342+ # It is different from the logic above, becareful
343+ string (REGEX MATCH "REGISTER_OP_MLU_KERNEL\\ (.*" multi_mlu_register "${TARGET_MLU_CONTENT} " )
344+ # [ \t\r\n]* is used for blank characters
345+ string (REGEX MATCH "REGISTER_OP_MLU_KERNEL\\ ([ \t\r\n ]*[a-z0-9_]*," one_mlu_register "${multi_mlu_register} " )
346+
347+ if (one_mlu_register STREQUAL "" )
348+ string (REPLACE "_op" "" MLU_TARGET "${TARGET} " )
349+ else ()
350+ string (REPLACE "REGISTER_OP_MLU_KERNEL(" "" MLU_TARGET "${one_mlu_register} " )
351+ string (REPLACE "," "" MLU_TARGET "${MLU_TARGET} " )
352+ # [ \t\r\n]+ is used for blank characters.
353+ # Here we use '+' instead of '*' since it is a REPLACE operation.
354+ string (REGEX REPLACE "[ \t\r\n ]+" "" MLU_TARGET "${MLU_TARGET} " )
355+ endif ()
356+ file (APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${MLU_TARGET} , MLU);\n " )
357+ endif ()
325358
326359 # pybind USE_OP_DEVICE_KERNEL for MKLDNN
327360 if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
@@ -369,11 +402,11 @@ function(register_operators)
369402 set (multiValueArgs EXCLUDES DEPS)
370403 cmake_parse_arguments (register_operators "${options} " "${oneValueArgs} "
371404 "${multiValueArgs} " ${ARGN} )
372-
373405 file (GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR} " "*_op.cc" )
374406 string (REPLACE "_mkldnn" "" OPS "${OPS} " )
375407 string (REPLACE "_xpu" "" OPS "${OPS} " )
376408 string (REPLACE "_npu" "" OPS "${OPS} " )
409+ string (REPLACE "_mlu" "" OPS "${OPS} " )
377410 string (REPLACE ".cc" "" OPS "${OPS} " )
378411 list (REMOVE_DUPLICATES OPS)
379412 list (LENGTH register_operators_DEPS register_operators_DEPS_len)
0 commit comments