Skip to content
Merged
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/triangular_solve_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,6 @@ PD_REGISTER_KERNEL(triangular_solve,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/blas/blas_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ struct CBlas<phi::dtype::complex<float>> {
const phi::dtype::complex<float> alpha,
const phi::dtype::complex<float> *A,
const int lda,
phi::dtype::complex<double> *B,
phi::dtype::complex<float> *B,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么改动了,有看过添加这个pr吗,为什么之前要用phi::dtype::complex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数调用的是cblas中用于解方程组的单精度复数版本的 cblas_ctrsm,
同时下面也有调用cblas中用于解方程组的双精度复数版本的cblas_ztrsm,
所以说A,B的类型应该是一样的。个人感觉这里应该是笔误。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

您好,可以再帮忙review一下吗~

const int ldb) {
cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class MatrixReduceSumFunctor<T, CPUContext> {

template class MatrixReduceSumFunctor<float, CPUContext>;
template class MatrixReduceSumFunctor<double, CPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<float>, CPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<double>, CPUContext>;

} // namespace funcs
} // namespace phi
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class MatrixReduceSumFunctor<T, GPUContext> {

template class MatrixReduceSumFunctor<float, GPUContext>;
template class MatrixReduceSumFunctor<double, GPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<float>, GPUContext>;
template class MatrixReduceSumFunctor<phi::dtype::complex<double>, GPUContext>;

} // namespace funcs
} // namespace phi
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,6 @@ PD_REGISTER_KERNEL(triangular_solve_grad,
ALL_LAYOUT,
phi::TriangularSolveGradKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/triangular_solve_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,6 @@ PD_REGISTER_KERNEL(triangular_solve,
ALL_LAYOUT,
phi::TriangularSolveKernel,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
14 changes: 10 additions & 4 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3186,9 +3186,9 @@ def triangular_solve(

Args:
x (Tensor): The input triangular coefficient matrix. Its shape should be `[*, M, M]`, where `*` is zero or
more batch dimensions. Its data type should be float32 or float64.
more batch dimensions. Its data type should be float32, float64, complex64, complex128.
y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is
zero or more batch dimensions. Its data type should be float32 or float64.
zero or more batch dimensions. Its data type should be float32, float64, complex64, complex128.
upper (bool, optional): Whether to solve the upper-triangular system of equations (default) or the lower-triangular
system of equations. Default: True.
transpose (bool, optional): whether `x` should be transposed before calculation. Default: False.
Expand Down Expand Up @@ -3227,10 +3227,16 @@ def triangular_solve(
inputs = {"X": [x], "Y": [y]}
helper = LayerHelper("triangular_solve", **locals())
check_variable_and_dtype(
x, 'x', ['float32', 'float64'], 'triangular_solve'
x,
'x',
['float32', 'float64', 'complex64', 'complex128'],
'triangular_solve',
)
check_variable_and_dtype(
y, 'y', ['float32', 'float64'], 'triangular_solve'
y,
'y',
['float32', 'float64', 'complex64', 'complex128'],
'triangular_solve',
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)

Expand Down
Loading