@@ -100,7 +100,6 @@ function(kernel_library TARGET)
100100 set (xpu_srcs)
101101 set (gpudnn_srcs)
102102 set (kps_srcs)
103- set (selected_rows_srcs)
104103 # parse and save the deps kerenl targets
105104 set (all_srcs)
106105 set (kernel_deps)
@@ -111,6 +110,12 @@ function(kernel_library TARGET)
111110
112111 cmake_parse_arguments (kernel_library "${options} " "${oneValueArgs} "
113112 "${multiValueArgs} " ${ARGN} )
113+
114+ # used for cc_library selected_rows dir target
115+ set (target_suffix "" )
116+ if ("${kernel_library_SUB_DIR} " STREQUAL "selected_rows_kernel" )
117+ set (target_suffix "_sr" )
118+ endif ()
114119
115120 list (LENGTH kernel_library_SRCS kernel_library_SRCS_len)
116121 # one kernel only match one impl file in each backend
@@ -121,9 +126,6 @@ function(kernel_library TARGET)
121126 if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /cpu/${TARGET} .cc AND NOT WITH_XPU_KP)
122127 list (APPEND cpu_srcs ${CMAKE_CURRENT_SOURCE_DIR} /cpu/${TARGET} .cc)
123128 endif ()
124- if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /selected_rows/${TARGET} .cc)
125- list (APPEND selected_rows_srcs ${CMAKE_CURRENT_SOURCE_DIR} /selected_rows/${TARGET} .cc)
126- endif ()
127129 if (WITH_GPU OR WITH_ROCM)
128130 if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /gpu/${TARGET} .cu)
129131 list (APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR} /gpu/${TARGET} .cu)
@@ -169,119 +171,119 @@ function(kernel_library TARGET)
169171 list (APPEND all_srcs ${xpu_srcs} )
170172 list (APPEND all_srcs ${gpudnn_srcs} )
171173 list (APPEND all_srcs ${kps_srcs} )
174+
175+ set (all_include_kernels)
176+ set (all_kernel_name)
177+
172178 foreach (src ${all_srcs} )
173179 file (READ ${src} target_content )
180+ # "kernels/xxx"(DenseTensor Kernel) can only include each other, but can't include "SUB_DIR/xxx" (such as selected_rows Kernel)
174181 string (REGEX MATCHALL "#include \" paddle\/ phi\/ kernels\/ [a-z0-9_]+_kernel.h\" " include_kernels ${target_content} )
175- if ("${kernel_library_SUB_DIR} " STREQUAL "" )
176- string (REGEX MATCHALL "#include \" paddle\/ phi\/ kernels\/ [a-z0-9_]+_kernel.h\" " include_kernels ${target_content} )
177- else ()
182+ list (APPEND all_include_kernels ${include_kernels} )
183+
184+ # "SUB_DIR/xxx" can include "kernels/xx" and "SUB_DIR/xxx"
185+ if (NOT "${kernel_library_SUB_DIR} " STREQUAL "" )
178186 string (REGEX MATCHALL "#include \" paddle\/ phi\/ kernels\/ ${kernel_library_SUB_DIR} \/ [a-z0-9_]+_kernel.h\" " include_kernels ${target_content} )
187+ list (APPEND all_include_kernels ${include_kernels} )
179188 endif ()
180- foreach (include_kernel ${include_kernels} )
189+
190+ foreach (include_kernel ${all_include_kernels} )
181191 if ("${kernel_library_SUB_DIR} " STREQUAL "" )
182192 string (REGEX REPLACE "#include \" paddle\/ phi\/ kernels\/ " "" kernel_name ${include_kernel} )
193+ string (REGEX REPLACE ".h\" " "" kernel_name ${kernel_name} )
194+ list (APPEND all_kernel_name ${kernel_name} )
183195 else ()
184- string (REGEX REPLACE "#include \" paddle\/ phi\/ kernels\/ ${kernel_library_SUB_DIR} \/ " "" kernel_name ${include_kernel} )
196+ # NOTE(dev): we should firstly match kernel_library_SUB_DIR.
197+ if (${include_kernel} MATCHES "#include \" paddle\/ phi\/ kernels\/ ${kernel_library_SUB_DIR} \/ " )
198+ string (REGEX REPLACE "#include \" paddle\/ phi\/ kernels\/ ${kernel_library_SUB_DIR} \/ " "" kernel_name ${include_kernel} )
199+ # for selected_rows directory, add ${target_suffix}.
200+ string (REGEX REPLACE ".h\" " "${target_suffix} " kernel_name ${kernel_name} )
201+ list (APPEND all_kernel_name ${kernel_name} )
202+ else ()
203+ string (REGEX REPLACE "#include \" paddle\/ phi\/ kernels\/ " "" kernel_name ${include_kernel} )
204+ string (REGEX REPLACE ".h\" " "" kernel_name ${kernel_name} )
205+ list (APPEND all_kernel_name ${kernel_name} )
206+ endif ()
185207 endif ()
186- string (REGEX REPLACE ".h\" " "" kernel_name ${kernel_name} )
187- list (APPEND kernel_deps ${kernel_name} )
208+ list (APPEND kernel_deps ${all_kernel_name} )
188209 endforeach ()
189210 endforeach ()
190211 list (REMOVE_DUPLICATES kernel_deps)
191- list (REMOVE_ITEM kernel_deps ${TARGET} )
212+ list (REMOVE_ITEM kernel_deps ${TARGET}${target_suffix} )
192213
193214 list (LENGTH common_srcs common_srcs_len)
194215 list (LENGTH cpu_srcs cpu_srcs_len)
195216 list (LENGTH gpu_srcs gpu_srcs_len)
196217 list (LENGTH xpu_srcs xpu_srcs_len)
197218 list (LENGTH gpudnn_srcs gpudnn_srcs_len)
198219 list (LENGTH kps_srcs kps_srcs_len)
199- list (LENGTH selected_rows_srcs selected_rows_srcs_len)
200220
201221 # kernel source file level
202222 # level 1: base device kernel
203223 # - cpu_srcs / gpu_srcs / xpu_srcs / gpudnn_srcs / kps_srcs
204224 # level 2: device-independent kernel
205225 # - common_srcs
206- # level 3: Kernel implemented by reusing device-independent kernel
207- # - selected_rows_srcs
208226 set (base_device_kernels)
209227 set (device_independent_kernel)
210- set (high_level_kernels)
211228
212229 # 1. Base device kernel compile
213230 if (${cpu_srcs_len} GREATER 0)
214- cc_library (${TARGET} _cpu SRCS ${cpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
215- list (APPEND base_device_kernels ${TARGET} _cpu)
231+ cc_library (${TARGET} _cpu${target_suffix} SRCS ${cpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
232+ list (APPEND base_device_kernels ${TARGET} _cpu${target_suffix} )
216233 endif ()
217234 if (${gpu_srcs_len} GREATER 0)
218235 if (WITH_GPU)
219- nv_library (${TARGET} _gpu SRCS ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
236+ nv_library (${TARGET} _gpu${target_suffix} SRCS ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
220237 elseif (WITH_ROCM)
221- hip_library (${TARGET} _gpu SRCS ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
238+ hip_library (${TARGET} _gpu${target_suffix} SRCS ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
222239 endif ()
223- list (APPEND base_device_kernels ${TARGET} _gpu)
240+ list (APPEND base_device_kernels ${TARGET} _gpu${target_suffix} )
224241 endif ()
225242 if (${xpu_srcs_len} GREATER 0)
226- cc_library (${TARGET} _xpu SRCS ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
227- list (APPEND base_device_kernels ${TARGET} _xpu)
243+ cc_library (${TARGET} _xpu${target_suffix} SRCS ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
244+ list (APPEND base_device_kernels ${TARGET} _xpu${target_suffix} )
228245 endif ()
229246 if (${gpudnn_srcs_len} GREATER 0)
230247 if (WITH_GPU)
231- nv_library (${TARGET} _gpudnn SRCS ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
248+ nv_library (${TARGET} _gpudnn${target_suffix} SRCS ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
232249 elseif (WITH_ROCM)
233- hip_library (${TARGET} _gpudnn SRCS ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
250+ hip_library (${TARGET} _gpudnn${target_suffix} SRCS ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
234251 endif ()
235- list (APPEND base_device_kernels ${TARGET} _gpudnn)
252+ list (APPEND base_device_kernels ${TARGET} _gpudnn${target_suffix} )
236253 endif ()
237254 if (${kps_srcs_len} GREATER 0)
238255 # only when WITH_XPU_KP, the kps_srcs_len can be > 0
239- xpu_library (${TARGET} _kps SRCS ${kps_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
240- list (APPEND base_device_kernels ${TARGET} _kps)
256+ xpu_library (${TARGET} _kps${target_suffix} SRCS ${kps_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
257+ list (APPEND base_device_kernels ${TARGET} _kps${target_suffix} )
241258 endif ()
242259
243260 # 2. Device-independent kernel compile
244261 if (${common_srcs_len} GREATER 0)
245262 if (WITH_GPU)
246- nv_library (${TARGET} _common SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
263+ nv_library (${TARGET} _common${target_suffix} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
247264 elseif (WITH_ROCM)
248- hip_library (${TARGET} _common SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
265+ hip_library (${TARGET} _common${target_suffix} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
249266 elseif (WITH_XPU_KP)
250- xpu_library (${TARGET} _common SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
267+ xpu_library (${TARGET} _common${target_suffix} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
251268 else ()
252- cc_library (${TARGET} _common SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
269+ cc_library (${TARGET} _common${target_suffix} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
253270 endif ()
254- list (APPEND device_independent_kernel ${TARGET} _common)
271+ list (APPEND device_independent_kernel ${TARGET} _common${target_suffix} )
255272 endif ()
256273
257- # 3. Reusing kernel compile
258- if (${selected_rows_srcs_len} GREATER 0)
259- if (WITH_GPU)
260- nv_library (${TARGET} _sr SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} )
261- elseif (WITH_ROCM)
262- hip_library (${TARGET} _sr SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} )
263- elseif (WITH_XPU_KP)
264- xpu_library (${TARGET} _sr SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} )
265- else ()
266- cc_library (${TARGET} _sr SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} )
267- endif ()
268- list (APPEND high_level_kernels ${TARGET} _sr)
269- endif ()
270274
271- # 4 . Unify target compile
275+ # 3 . Unify target compile
272276 list (LENGTH base_device_kernels base_device_kernels_len)
273277 list (LENGTH device_independent_kernel device_independent_kernel_len)
274- list (LENGTH high_level_kernels high_level_kernels_len)
275- if (${base_device_kernels_len} GREATER 0 OR ${device_independent_kernel_len} GREATER 0 OR
276- ${high_level_kernels_len} GREATER 0)
278+ if (${base_device_kernels_len} GREATER 0 OR ${device_independent_kernel_len} GREATER 0)
277279 if (WITH_GPU)
278- nv_library (${TARGET} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} ${high_level_kernels } )
280+ nv_library (${TARGET}${target_suffix} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} )
279281 elseif (WITH_ROCM)
280- hip_library (${TARGET} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} ${high_level_kernels } )
282+ hip_library (${TARGET}${target_suffix} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} )
281283 elseif (WITH_XPU_KP)
282- xpu_library (${TARGET} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} ${high_level_kernels } )
284+ xpu_library (${TARGET}${target_suffix} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} )
283285 else ()
284- cc_library (${TARGET} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} ${high_level_kernels } )
286+ cc_library (${TARGET}${target_suffix} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel} )
285287 endif ()
286288 else ()
287289 set (target_build_flag 0)
@@ -290,10 +292,10 @@ function(kernel_library TARGET)
290292 if (${target_build_flag} EQUAL 1)
291293 if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
292294 ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0 OR
293- ${gpudnn_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0 )
295+ ${gpudnn_srcs_len} GREATER 0)
294296 # append target into PHI_KERNELS property
295297 get_property (phi_kernels GLOBAL PROPERTY PHI_KERNELS )
296- set (phi_kernels ${phi_kernels} ${TARGET} )
298+ set (phi_kernels ${phi_kernels} ${TARGET}${target_suffix} )
297299 set_property (GLOBAL PROPERTY PHI_KERNELS ${phi_kernels} )
298300 endif ()
299301
@@ -318,9 +320,6 @@ function(kernel_library TARGET)
318320 if (${kps_srcs_len} GREATER 0)
319321 kernel_declare (${kps_srcs} )
320322 endif ()
321- if (${selected_rows_srcs_len} GREATER 0)
322- kernel_declare (${selected_rows_srcs} )
323- endif ()
324323 endif ()
325324endfunction ()
326325
0 commit comments