Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
211 commits
Select commit Hold shift + click to select a range
fd28881
[Metax_change_ut]
duqimeng Jul 23, 2025
a9d2aa7
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Jul 24, 2025
1695f36
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Jul 31, 2025
b931d38
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 1, 2025
bef21bf
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 8, 2025
f4e5004
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 13, 2025
55422eb
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 18, 2025
815a63a
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 19, 2025
1739a15
fix sum&collect_fpn_proposals op register
StareAtYou Aug 19, 2025
af0bae5
fix sum&collect_fpn_proposals op register
metax666 Aug 19, 2025
be61f06
modify profile
jxwangmetax Aug 20, 2025
0fc2dd1
modify profile
metax666 Aug 20, 2025
1ad95c5
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 20, 2025
f12b3e4
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 21, 2025
789c9fc
[Metax] fix paddle bug replace 'MoeGradDispatchKernel' to 'MoeGateDis…
StareAtYou Aug 21, 2025
a0116fb
[Metax] fix paddle bug
metax666 Aug 21, 2025
a2da5e0
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 22, 2025
f9e6d2c
[Metax] register bce_loss_grad & bce_loss & index_add_grad kernels
StareAtYou Aug 22, 2025
4b4f562
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Aug 22, 2025
662e22e
[Metax] con2d_grad use gpudnn
duqimeng Aug 22, 2025
3e8d6ce
Merge branch 'metax666:develop' into develop
StareAtYou Aug 25, 2025
9dae9b7
[Metax] register bce_loss_grad & bce_loss & index_add_grad kernels
metax666 Aug 25, 2025
47fef62
blas handle support
jxwangmetax Aug 25, 2025
266c0df
blas handle support
metax666 Aug 25, 2025
a0b340b
[Metax] register some kernels & update CMakeLists
StareAtYou Aug 25, 2025
aa9bd35
Merge branch 'metax666:develop' into develop
StareAtYou Aug 26, 2025
8c6ac05
[Metax] register some kernels & update CMakeLists
metax666 Aug 26, 2025
9510f7d
Merge branch 'metax666:develop' into develop
duqimeng Aug 26, 2025
fa7cc1a
[Metax] fix metax unittest fail
StareAtYou Aug 26, 2025
a907545
[Metax] fix metax unittest fail
metax666 Aug 26, 2025
7a6312e
[Metax] add group_norm & label_smooth kernel and update matmul kernel
StareAtYou Aug 26, 2025
90bb94e
[Metax] add group_norm & label_smooth kernel and update matmul kernel
metax666 Aug 27, 2025
9f130fe
[Metax] fix rmsprop kernel register and add meshgrid & meshgrid_grad …
StareAtYou Aug 27, 2025
ca38fb5
Merge branch 'metax666:develop' into develop
StareAtYou Aug 27, 2025
f0cc1e0
add test
zhang-chenyi Aug 27, 2025
8e8b732
add test
zhang-chenyi Aug 27, 2025
8d7efbd
Merge branch 'metax666:develop' into develop
zhang-chenyi Aug 27, 2025
28c992b
Merge branch 'develop' of https://github.com/zhang-chenyi/PaddleCusto…
zhang-chenyi Aug 27, 2025
d3470bb
[test] chang the logic of workspace_host in cholesky_kernel_register
zhang-chenyi Aug 27, 2025
db17ebf
Merge branch 'develop' of https://github.com/zhang-chenyi/PaddleCusto…
zhang-chenyi Aug 27, 2025
83bc87f
[Metax] fix compile fail
StareAtYou Aug 27, 2025
f1e8d0c
Revert "[Metax] fix compile fail"
StareAtYou Aug 27, 2025
a13daa8
[Metax] fix compile fail by 'conv_transpose_grad_kernel_impl.h'
StareAtYou Aug 27, 2025
95a179b
[Metax] fix bug & add some kernel register
metax666 Aug 28, 2025
4576ef4
[Metax]fix bug and add qr lstsq logsoftmax
duqimeng Aug 28, 2025
ca51a1e
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
7789e9b
[Metax] con2d_grad use gpudnn
duqimeng Aug 22, 2025
afd0863
[Metax]fix bug and add qr lstsq logsoftmax
duqimeng Aug 28, 2025
6da0f0d
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
e1e07ba
[Metax] change_patch
duqimeng Aug 28, 2025
046637c
[Metax] change_patch
metax666 Aug 28, 2025
c27b492
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 28, 2025
05ecd9d
[Metax] update unit test CMakeLists.txt
StareAtYou Aug 28, 2025
b1bf7e8
[Metax] update unit test CMakeLists.txt
StareAtYou Aug 28, 2025
f90d585
Merge branch 'metax666:develop' into develop
StareAtYou Aug 28, 2025
874d9b6
Merge branch 'metax666:develop' into develop
zhang-chenyi Aug 28, 2025
0ca02b9
[feature] add unique_consecutive kernel
zhang-chenyi Aug 28, 2025
40d8f21
[metax-feature] add kernel for test_math_op_patch_var_base
metax666 Aug 28, 2025
3e9b526
[metax] add some kernel
duqimeng Aug 28, 2025
8911576
[metax] add some kernel
duqimeng Aug 28, 2025
8471597
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
0758887
Merge branch 'metax666:develop' into develop
StareAtYou Aug 29, 2025
61be33d
[Metax] register baddbmm kernel & update blas api
StareAtYou Aug 29, 2025
2fe962e
[Metax] register baddbmm kernel & update blas api
StareAtYou Aug 29, 2025
531fedb
Merge branch 'metax666:develop' into develop
StareAtYou Aug 29, 2025
c0dcfff
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
StareAtYou Aug 29, 2025
bd65451
[feature] add add unique_consecutive kernel.cu
zhang-chenyi Aug 29, 2025
0def63d
[fix] fix some test case due to missing op register
zhang-chenyi Aug 29, 2025
e503c9e
[fix] fix some fail text
zhang-chenyi Aug 29, 2025
9844878
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
70b86e7
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
1e90757
add and fix some kernels
1184319564 Aug 30, 2025
f93307d
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
StareAtYou Aug 29, 2025
c4b0eb9
[Metax] fix conflict
StareAtYou Sep 1, 2025
06dda18
[Metax] fix conflict
StareAtYou Sep 1, 2025
dae6ce8
[Metax] adapt to paddle-cpu-20250901 & resolve the issue of 'test_ele…
StareAtYou Sep 1, 2025
b4a5c62
[Metax] update repeat_interleave kernel & ignore max op test
StareAtYou Sep 2, 2025
7cf4405
Merge branch 'metax666:develop' into develop
StareAtYou Sep 2, 2025
0015f2e
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
metax666 Sep 2, 2025
fc2c0f5
Merge branch 'metax666:develop' into develop
duqimeng Sep 2, 2025
829c3b6
Merge dev
duqimeng Sep 2, 2025
3104a9c
【metax】add and fix some kernels
metax666 Sep 2, 2025
175cca6
[metax]fix lu eigvalshsqueeze rnn kernel
metax666 Sep 2, 2025
c7db810
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
f5813ed
[metax] chang patch fix copy
duqimeng Sep 2, 2025
6f0b705
[metax] chang patch fix copy
duqimeng Sep 2, 2025
8f47f0e
[metax] chang patch fix copy
metax666 Sep 2, 2025
b420f97
[Metax] update metax_gpu unit test
StareAtYou Sep 2, 2025
c08533e
[Metax] update metax_gpu unit test
metax666 Sep 2, 2025
414715f
[Metax] fix test CMakeList.txt
StareAtYou Sep 2, 2025
aa6b5bf
[Metax] fix test CMakeList.txt
metax666 Sep 2, 2025
69f3721
[fix] fix fail test when backend is mack
zhang-chenyi Sep 4, 2025
e45d324
[Metax] fix fail test when backend is mack
metax666 Sep 4, 2025
ef9d554
Merge branch 'PaddlePaddle:develop' into develop
metax666 Sep 4, 2025
0bfc6e7
[metax]change_cupti_and_fix_softmax
duqimeng Sep 9, 2025
cb93f6a
[metax]change_cupti_and_fix_softmax
duqimeng Sep 9, 2025
2e99f62
[metax]change_patch
duqimeng Sep 9, 2025
026551a
[metax]change_patch
duqimeng Sep 9, 2025
a1530d2
[metax]change_cupti_and_fix_softmax (#7)
duqimeng Sep 9, 2025
352f02e
[Metax] fix dgc & mklml compile product path problem (#8)
StareAtYou Sep 9, 2025
b09babb
Merge branch 'metax666:develop' into develop
duqimeng Sep 9, 2025
8f13fae
[Metax] fix accuracy kernel & add test_accuracy_op_metax.py unit test…
StareAtYou Sep 11, 2025
8938293
[Metax] update metax_gpu CMakeLists.txt (#10)
StareAtYou Sep 11, 2025
31594f8
[metax] updata_qr_kernel
duqimeng Sep 11, 2025
4fb467c
[metax] updata_qr_kernel
duqimeng Sep 11, 2025
f54187f
[metax] updata_qr_kernel (#11)
duqimeng Sep 11, 2025
5dc60a3
Merge branch 'metax666:develop' into develop
duqimeng Sep 11, 2025
7964c35
Merge branch 'PaddlePaddle:develop' into develop
metax666 Sep 12, 2025
1e04216
[Metax] fix illegal address access error in test_momentum_op (#12)
StareAtYou Sep 15, 2025
e4fd192
Merge branch 'metax666:develop' into develop
duqimeng Sep 15, 2025
471b184
[Metax] fix cufft and fix some blas kernel apply
duqimeng Sep 15, 2025
aca80a4
[Metax] fix cufft and fix some blas kernel apply (#13)
duqimeng Sep 15, 2025
1c54010
Merge branch 'PaddlePaddle:develop' into develop
metax666 Sep 15, 2025
a0d237c
Merge branch 'metax666:develop' into develop
duqimeng Sep 15, 2025
4c86266
[metax] fix bug
duqimeng Sep 15, 2025
fb547db
[metax] add warpctc_warprnn (#14)
duqimeng Sep 15, 2025
8e98198
[Metax] update metax CI (#15)
StareAtYou Sep 15, 2025
528ec55
[Metax] update metax CI CMakeLists (#16)
StareAtYou Sep 16, 2025
a8b4696
[Metax] add github action
duqimeng Sep 16, 2025
5b31405
[Metax] add github action (#18)
duqimeng Sep 16, 2025
8dff471
[metax]chaneg build
duqimeng Sep 16, 2025
ee4eefd
[metax]chaneg build
duqimeng Sep 16, 2025
b93c971
[metax] chang build (#19)
duqimeng Sep 16, 2025
8a36c4c
[metax]chaneg build
duqimeng Sep 16, 2025
bd5ac4d
Merge branch 'develop' into develop
duqimeng Sep 16, 2025
656d684
[metax]chaneg build
duqimeng Sep 16, 2025
2c224ad
[metax]chaneg build
duqimeng Sep 16, 2025
4c65070
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Sep 16, 2025
6dbbe84
change_build (#20)
duqimeng Sep 16, 2025
a7f6ed7
[metax]chaneg build
duqimeng Sep 16, 2025
9bfec7e
Merge branch 'develop' into develop
metax666 Sep 16, 2025
ef1b28e
change_build (#21)
duqimeng Sep 16, 2025
00014e2
[metax]chaneg build
duqimeng Sep 16, 2025
25e76dc
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Sep 16, 2025
3737e48
change_build (#22)
duqimeng Sep 16, 2025
16f3584
【metax】modify cmake for warpctc and warprnnt (#17)
jxwangmetax Sep 16, 2025
e95cc2c
Merge branch 'metax666:develop' into develop
duqimeng Sep 16, 2025
ce54693
[metax]modify library to static library (#24)
jxwangmetax Sep 16, 2025
4cda637
[Metax] organize documents (#25)
StareAtYou Sep 16, 2025
a7f53dd
Merge branch 'metax666:develop' into develop
duqimeng Sep 16, 2025
6ada0e9
[metax]fix_code style and index_elementwise_put_kernel
duqimeng Sep 17, 2025
23fca59
[metax]fix_code style and index_elementwise_put_kernel (#27)
duqimeng Sep 17, 2025
a513aae
change_build_917 (#29)
duqimeng Sep 17, 2025
3834990
[metax]change_build
duqimeng Sep 17, 2025
4eb455e
chang_build (#30)
duqimeng Sep 17, 2025
77ebcb8
[metax]change_build
duqimeng Sep 17, 2025
19c9184
Merge branch 'develop' into develop
metax666 Sep 17, 2025
1773978
[metax]modify kernel (#31)
jxwangmetax Sep 17, 2025
4339ed4
Merge branch 'metax666:develop' into develop
duqimeng Sep 17, 2025
44532ba
change_metax_work
duqimeng Sep 17, 2025
69af381
change_metax_work (#32)
duqimeng Sep 17, 2025
02047f9
change_metax_work
duqimeng Sep 17, 2025
bda901e
change_metax_work
duqimeng Sep 17, 2025
7fe6f2d
change_build (#33)
duqimeng Sep 17, 2025
b22fc13
[metax] modify fused_bias_dropout_residual_layer_norm (#34)
jxwangmetax Sep 17, 2025
1c7d32a
change_metax_work
duqimeng Sep 17, 2025
ed8f128
Merge branch 'develop' into develop
metax666 Sep 17, 2025
c3d1444
change_build (#35)
duqimeng Sep 17, 2025
287691f
Merge branch 'metax666:develop' into develop
duqimeng Sep 17, 2025
569a867
change_build (#36)
duqimeng Sep 17, 2025
976ecec
change_metax_work
duqimeng Sep 17, 2025
0c6ebe2
change_warpctc.cmake
duqimeng Sep 18, 2025
0edc6f6
change_warpctc.cmake (#38)
duqimeng Sep 18, 2025
5e7a84b
change warpctc.cmake
duqimeng Sep 18, 2025
2688c86
change_warpctc.cmake (#39)
duqimeng Sep 18, 2025
6f031fe
test (#40)
duqimeng Sep 18, 2025
542efeb
test
duqimeng Sep 18, 2025
40daeb9
change_run_ut
duqimeng Sep 18, 2025
4c21a9c
Merge branch 'develop' into develop
metax666 Sep 18, 2025
e84d399
test_ut (#41)
duqimeng Sep 18, 2025
322dc15
remove_tets
duqimeng Sep 18, 2025
0e4b75d
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Sep 18, 2025
b5f2feb
tets (#43)
duqimeng Sep 18, 2025
bd106bd
Merge branch 'metax666:develop' into develop
duqimeng Sep 18, 2025
e20eca7
test (#44)
duqimeng Sep 18, 2025
7dbab02
test
duqimeng Sep 18, 2025
27ebafe
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Sep 18, 2025
e37f633
[metax] modify compile (#42)
jxwangmetax Sep 19, 2025
1af5148
[Metax] add log analysis script (#46)
StareAtYou Sep 19, 2025
bd39be9
Merge branch 'metax666:develop' into develop
duqimeng Sep 19, 2025
f79b1bd
add_generate_pb
duqimeng Sep 19, 2025
518bee8
add_generate_pb (#47)
duqimeng Sep 19, 2025
bc02549
modify blas (#51)
jxwangmetax Sep 22, 2025
1977ca8
[metax] modify tf32 (#52)
jxwangmetax Sep 22, 2025
f8a0cca
Merge branch 'metax666:develop' into develop
duqimeng Sep 22, 2025
1ae2618
[Metax] update metax backend CI test (#53)
StareAtYou Sep 22, 2025
37aa236
Merge branch 'metax666:develop' into develop
duqimeng Sep 22, 2025
76d5eb0
[Metax] fix log_analysis.py bug (#54)
StareAtYou Sep 23, 2025
6f925da
Merge branch 'metax666:develop' into develop
duqimeng Sep 23, 2025
9c17b6e
[Metax] update metax CI CMakeLists & scripts (#56)
StareAtYou Sep 23, 2025
e08b161
[metax]fix paddle bug
duqimeng Sep 23, 2025
51c98a2
[Metax] fix MatmulKernel problem (#57)
StareAtYou Sep 23, 2025
d113018
[metax]fix paddle bug" (#58)
duqimeng Sep 23, 2025
9404022
Merge branch 'metax666:develop' into develop
duqimeng Sep 23, 2025
1a0a84e
change_ut
duqimeng Sep 23, 2025
8991299
change—ut (#59)
duqimeng Sep 23, 2025
ece9f09
change_ut
duqimeng Sep 23, 2025
a770e6f
change_ut (#60)
duqimeng Sep 23, 2025
d1d25ad
change_ut
duqimeng Sep 24, 2025
902112b
change_ut (#63)
duqimeng Sep 24, 2025
9a88a09
Merge branch 'PaddlePaddle:develop' into develop
metax666 Sep 24, 2025
8ff82b6
Merge branch 'metax666:develop' into develop
duqimeng Sep 24, 2025
4ae65f7
Merge branch 'PaddlePaddle:develop' into develop
metax666 Sep 24, 2025
cfe44ce
[Metax] add keyword filter in CI CMakeLists.txt (#64)
StareAtYou Sep 25, 2025
78946fd
[metax] modify kernels (#67)
jxwangmetax Sep 26, 2025
ac78af2
Fix part of the missing kernel issues (#66)
Theendlessofhell Sep 26, 2025
bfdf3da
Merge branch 'metax666:develop' into develop
duqimeng Sep 26, 2025
4ce9fe6
[Metax] fix index_elementwise_get kernel (#68)
StareAtYou Sep 26, 2025
be4aeff
Merge branch 'metax666:develop' into develop
duqimeng Sep 26, 2025
d75ccc7
[metax]fix patch and fix missing kernel
duqimeng Sep 29, 2025
6ce9e13
Update Paddle submodule to latest develop
tianshuo78520a Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Paddle
Submodule Paddle updated 477 files
5 changes: 4 additions & 1 deletion backends/metax_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/increment_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu
Expand Down Expand Up @@ -535,6 +535,7 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/clip_by_norm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/uniform_random_batch_size_like_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/get_tensor_from_selected_rows_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/batch_norm_kernel.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/empty_kernel.cc
Expand Down Expand Up @@ -642,6 +643,8 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rms_norm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/lars_momentum_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/partial_sum_kernel.cu
# ############################################################################
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu
# kernels/kps
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
#include "paddle/phi/kernels/selected_rows/adam_kernel.h"

PD_CUSTOM_KERNEL_REGISTER(adam_dense_param_sparse_grad,
metax_gpu,
ALL_LAYOUT,
phi::sr::AdamDenseParamSparseGradKernel,
float,
double,
phi::float16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND);

if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ PD_CUSTOM_KERNEL_REGISTER(einsum,
phi::EinsumKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::float16,
phi::bfloat16,
phi::complex64,
phi::complex128) {}

PD_CUSTOM_KERNEL_REGISTER(einsum_infer,
metax_gpu,
ALL_LAYOUT,
phi::EinsumInferKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::float16,
phi::bfloat16,
phi::complex64,
phi::complex128) {}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/index_elementwise_get_kernel.h"
#include "paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu" // NOLINT

PD_CUSTOM_KERNEL_REGISTER(index_elementwise_get,
metax_gpu,
Expand All @@ -27,7 +27,7 @@ PD_CUSTOM_KERNEL_REGISTER(index_elementwise_get,
int64_t,
int16_t,
uint8_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::float16,
phi::bfloat16,
phi::complex64,
phi::complex128) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/lars_momentum_kernel.h"

PD_CUSTOM_KERNEL_REGISTER(lars_momentum,
metax_gpu,
ALL_LAYOUT,
phi::LarsMomentumKernel,
float,
double,
phi::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ PD_CUSTOM_KERNEL_REGISTER(multinomial,
phi::MultinomialKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float) {
float,
double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ PD_CUSTOM_KERNEL_REGISTER(nonzero,
int64_t,
int,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::float16,
phi::bfloat16,
bool,
float,
double) {
double,
phi::complex64,
phi::complex128) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ PD_CUSTOM_KERNEL_REGISTER(put_along_axis,
float,
double,
int64_t,
uint8_t,
int16_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::float16,
phi::bfloat16) {}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ PD_CUSTOM_KERNEL_REGISTER(take_along_axis,
int64_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
uint8_t, // 支持 uint8
int16_t // 支持 int16
) {}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ PD_REGISTER_PLUGIN_KERNEL(addmm,
ALL_LAYOUT,
phi::AddmmKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ PD_REGISTER_PLUGIN_KERNEL(layer_norm_grad,
ALL_LAYOUT,
phi::LayerNormGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
Expand Down
26 changes: 0 additions & 26 deletions backends/metax_gpu/kernels/metax_kernel/metax_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,6 @@
#include "kernels/metax_kernel/metax_context.h"

namespace phi {
const bool allow_tf32_cublas = []() -> bool {
const char* v = std::getenv("ALLOW_TF32_CUBLAS");
if (v) {
return std::atoi(v);
}
return true;
}();

const bool allow_tf32_cudnn = []() -> bool {
const char* v = std::getenv("ALLOW_TF32_CUDNN");
if (v) {
return std::atoi(v);
}
return false;
}();

bool AllowTF32Cublas() { return allow_tf32_cublas; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
void DnnWorkspaceHandle::RunFuncSync(
const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes,
Expand All @@ -42,19 +24,11 @@ void DnnWorkspaceHandle::RunFuncSync(
void* workspace_ptr = nullptr;
size_t size = ((required_workspace_bytes + 255) >> 8) << 8;
std::lock_guard<std::mutex> guard(*mtx_);
#ifdef PADDLE_WITH_HIP
auto status = hipMalloc(&workspace_ptr, size);
#else
auto status = cudaMalloc(&workspace_ptr, size);
#endif
if (status == gpuSuccess) {
cudnn_func(workspace_ptr);
phi::backends::gpu::GpuStreamSync(stream_);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipFree(workspace_ptr));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaFree(workspace_ptr));
#endif
return;
}
}
Expand Down
3 changes: 1 addition & 2 deletions backends/metax_gpu/kernels/metax_kernel/metax_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <mutex>

#include "kernels/funcs/blas/cublasLt.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
Expand All @@ -30,8 +31,6 @@
cublasLtHandle_t GetBlasLtHandle();

namespace phi {
bool AllowTF32Cublas();
bool AllowTF32Cudnn();
class DnnWorkspaceHandle {
public:
inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream)
Expand Down
65 changes: 0 additions & 65 deletions backends/metax_gpu/patch/paddle.patch
Original file line number Diff line number Diff line change
Expand Up @@ -869,19 +869,6 @@ index e838778952..83e805e75a 100644

namespace phi {
namespace fusion {
diff --git a/paddle/phi/kernels/gpu/correlation_kernel.cu b/paddle/phi/kernels/gpu/correlation_kernel.cu
index 4c93778bde..c7bdf8a2cc 100644
--- a/paddle/phi/kernels/gpu/correlation_kernel.cu
+++ b/paddle/phi/kernels/gpu/correlation_kernel.cu
@@ -103,7 +103,7 @@ void CorrelationCUDAKernel(const Context &dev_ctx,
int stride2,
int corr_type_multiply,
DenseTensor *out) {
- bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU;
+ bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM;
PADDLE_ENFORCE_EQ(
is_gpu_place,
true,
diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h
index f0cca0f701..02ea957240 100644
--- a/paddle/phi/kernels/gpu/depthwise_conv.h
Expand All @@ -897,19 +884,6 @@ index f0cca0f701..02ea957240 100644

namespace phi {
// To determine use cudnn or not.
diff --git a/paddle/phi/kernels/gpu/dgc_kernel.cu b/paddle/phi/kernels/gpu/dgc_kernel.cu
index c2ddfa1347..c6adf5a6de 100644
--- a/paddle/phi/kernels/gpu/dgc_kernel.cu
+++ b/paddle/phi/kernels/gpu/dgc_kernel.cu
@@ -188,7 +188,7 @@ void DGCKernel(const Context& dev_ctx,
int buf_size = paddle::communication::dgc::get_buffer_size(k);
phi::Allocator::AllocationPtr tmp_ious_data;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
- if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
+ if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) {
tmp_ious_data = phi::memory_utils::Alloc(
dev_ctx.GetPlace(),
buf_size,
diff --git a/paddle/phi/kernels/gpu/gelu_funcs.h b/paddle/phi/kernels/gpu/gelu_funcs.h
index 29fa252e96..4ae72b0935 100644
--- a/paddle/phi/kernels/gpu/gelu_funcs.h
Expand Down Expand Up @@ -974,19 +948,6 @@ index 1bdbe1564c..f753b54bc6 100644
#include "paddle/phi/kernels/impl/qr_kernel_impl.h"
#include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h"
#include "paddle/phi/kernels/lstsq_kernel.h"
diff --git a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu
index 05a977828f..5136608c41 100644
--- a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu
+++ b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu
@@ -58,7 +58,7 @@ void ShuffleBatchKernel(const Context& dev_ctx,
int64_t seed_int = 0;
if (seed.initialized()) {
const auto& seed_place = seed.place().GetType();
- bool is_gpu_place = seed_place == phi::AllocationType::GPU;
+ bool is_gpu_place = seed_place == phi::AllocationType::GPU || seed_place == phi::AllocationType::CUSTOM;
if (is_gpu_place) {
// NOTE: We have overwritten GetKernelTypeForVar, so seed_place would
// not be CUDAPlace in practice. This case would only happen in Python
diff --git a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
index 9bc5326c90..79b57a8203 100644
--- a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
Expand Down Expand Up @@ -1144,32 +1105,6 @@ index 6f03f76eeb..5fe2c3e7dc 100644
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"

diff --git a/paddle/phi/kernels/impl/merged_momentum_impl.h b/paddle/phi/kernels/impl/merged_momentum_impl.h
index 7b85903776..3f4b298807 100644
--- a/paddle/phi/kernels/impl/merged_momentum_impl.h
+++ b/paddle/phi/kernels/impl/merged_momentum_impl.h
@@ -297,7 +297,7 @@ void MergedMomentumInnerCompute(
params_out[idx],
velocities_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
- } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
+ } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) {
phi::funcs::ForRange<Context> for_range(
static_cast<const Context &>(dev_ctx), params[idx]->numel());
const auto grad_type = grads[idx]->dtype();
diff --git a/paddle/phi/kernels/impl/momentum_kernel_impl.h b/paddle/phi/kernels/impl/momentum_kernel_impl.h
index de5bcfc30b..eb2a9714f5 100644
--- a/paddle/phi/kernels/impl/momentum_kernel_impl.h
+++ b/paddle/phi/kernels/impl/momentum_kernel_impl.h
@@ -457,7 +457,7 @@ void MomentumDenseImpl(const Context& dev_ctx,
regularization_coeff,
param_out,
velocity_out);
- } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
+ } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) {
funcs::ForRange<Context> for_range(dev_ctx, param.numel());
const auto grad_type = grad.dtype();
#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \
diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h
index 4099d8b506..baef2cd643 100644
--- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h
Expand Down
Loading