Skip to content

Commit fc76beb

Browse files
fix slogdet bigtensor (#73706)
1 parent 8585b79 commit fc76beb

File tree

6 files changed

+48
-22
lines changed

6 files changed

+48
-22
lines changed

paddle/common/ddim.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,27 @@ DDim DDim::reshape(std::vector<int>& shape) const {
255255
return common::make_ddim(shape);
256256
}
257257

258+
DDim DDim::reshape(std::vector<int64_t>& shape) const {
259+
const DDim& in_dims = *this;
260+
261+
for (int i = 0; i < static_cast<int>(shape.size()); ++i) {
262+
if (shape[i] == 0) {
263+
shape[i] = static_cast<int64_t>(in_dims.at(i));
264+
}
265+
}
266+
267+
// Dim marked as "-1" must be inferred
268+
auto it = std::find(shape.begin(), shape.end(), -1);
269+
if (it != shape.end()) {
270+
int index = static_cast<int>(std::distance(shape.begin(), it));
271+
int64_t reshape_out_product =
272+
std::accumulate(shape.begin(), shape.end(), -1, std::multiplies<>());
273+
shape[index] = static_cast<int64_t>(product(in_dims)) / reshape_out_product;
274+
}
275+
276+
return common::make_ddim(shape);
277+
}
278+
258279
DDim DDim::transpose(const std::vector<int>& axis) const {
259280
const DDim& in_dims = *this;
260281

paddle/common/ddim.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ class TEST_API DDim {
128128

129129
DDim reshape(std::vector<int>& shape) const; // NOLINT
130130

131+
DDim reshape(std::vector<int64_t>& shape) const; // NOLINT
132+
131133
DDim transpose(const std::vector<int>& axis) const;
132134

133135
private:

paddle/phi/kernels/funcs/matrix_inverse.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ namespace funcs {
2727

2828
template <typename Context, typename T>
2929
struct MapMatrixInverseFunctor {
30-
void operator()(
31-
const Context& dev_ctx, const T* a_ptr, T* a_inv_ptr, int offset, int n) {
30+
void operator()(const Context& dev_ctx,
31+
const T* a_ptr,
32+
T* a_inv_ptr,
33+
int64_t offset,
34+
int64_t n) {
3235
using Matrix =
3336
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
3437
using EigenMatrixMap = Eigen::Map<Matrix>;
@@ -52,8 +55,8 @@ struct MapMatrixInverseFunctor<Context, phi::dtype::complex<T>> {
5255
void operator()(const Context& dev_ctx,
5356
const phi::dtype::complex<T>* a_ptr,
5457
phi::dtype::complex<T>* a_inv_ptr,
55-
int offset,
56-
int n) {
58+
int64_t offset,
59+
int64_t n) {
5760
using Matrix = Eigen::Matrix<std::complex<T>,
5861
Eigen::Dynamic,
5962
Eigen::Dynamic,
@@ -62,7 +65,7 @@ struct MapMatrixInverseFunctor<Context, phi::dtype::complex<T>> {
6265
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
6366
std::complex<T>* std_ptr = new std::complex<T>[n * n];
6467
std::complex<T>* std_inv_ptr = new std::complex<T>[n * n];
65-
for (int i = 0; i < n * n; i++) {
68+
for (int64_t i = 0; i < n * n; i++) {
6669
*(std_ptr + i) = static_cast<std::complex<T>>(*(a_ptr + offset + i));
6770
}
6871
ConstEigenMatrixMap mat(std_ptr, n, n);
@@ -75,7 +78,7 @@ struct MapMatrixInverseFunctor<Context, phi::dtype::complex<T>> {
7578
static_cast<std::complex<T>>(0),
7679
errors::InvalidArgument("Input is not invertible."));
7780
mat_inv.noalias() = lu.inverse();
78-
for (int i = 0; i < n * n; i++) {
81+
for (int64_t i = 0; i < n * n; i++) {
7982
*(a_inv_ptr + offset + i) =
8083
static_cast<phi::dtype::complex<T>>(*(std_inv_ptr + i));
8184
}
@@ -90,8 +93,8 @@ void ComputeInverseEigen(const Context& dev_ctx,
9093
DenseTensor* a_inv) {
9194
const auto& mat_dims = a.dims();
9295
const int rank = mat_dims.size();
93-
int n = mat_dims[rank - 1];
94-
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;
96+
int64_t n = mat_dims[rank - 1];
97+
int64_t batch_size = rank > 2 ? a.numel() / (n * n) : 1;
9598

9699
const T* a_ptr = a.data<T>();
97100
T* a_inv_ptr = dev_ctx.template Alloc<T>(a_inv);
@@ -100,7 +103,7 @@ void ComputeInverseEigen(const Context& dev_ctx,
100103
// it's not going to get the right result,
101104
// so we're going to convert it to std::complex and
102105
// then we're going to put it into eigen::matrix.
103-
for (int i = 0; i < batch_size; ++i) {
106+
for (int64_t i = 0; i < batch_size; ++i) {
104107
MapMatrixInverseFunctor<Context, T> functor;
105108
functor(dev_ctx, a_ptr, a_inv_ptr, i * n * n, n);
106109
}

paddle/phi/kernels/funcs/unsqueeze.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ inline DDim GetUnsqueezeShape(const std::vector<int64_t> unsqz_dims,
156156
inline const DenseTensor Unsqueeze(const DenseTensor& x, int axis = 0) {
157157
// don't copy data, only change the dims
158158
DenseTensor out(x);
159-
std::vector<int> out_shape = common::vectorize<int>(x.dims());
159+
std::vector<int64_t> out_shape = common::vectorize<int64_t>(x.dims());
160160
if (axis >= 0) {
161161
auto index = (out_shape.begin() + axis);
162162
out_shape.insert(index, 1);

paddle/phi/kernels/gpu/slogdeterminant_kernel.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ __global__ void GetSlogDetFromLUComplex(const T* lu_data,
8888
int64_t n,
8989
int64_t batch_size,
9090
T* out_data) {
91-
int idx = threadIdx.x + blockIdx.x * blockDim.x;
91+
int64_t idx = threadIdx.x + static_cast<int64_t>(blockIdx.x) * blockDim.x;
9292
if (idx < batch_size) {
93-
int offset_lu = idx * n * n;
94-
int offset_ipiv = idx * n;
93+
int64_t offset_lu = idx * n * n;
94+
int64_t offset_ipiv = idx * n;
9595
T det_val = T(1.0, 0.0);
9696
T negative = T(-1.0, 0.0);
97-
for (int i = 0; i < n; ++i) {
97+
for (int64_t i = 0; i < n; ++i) {
9898
det_val *= lu_data[offset_lu + i * n + i];
9999
if (ipiv[offset_ipiv + i] != i + 1) {
100100
det_val *= negative;
@@ -135,12 +135,12 @@ struct SlogDeterminantFunctor<phi::dtype::complex<T>, Context> {
135135
tmp_gpu_mat_data->ptr());
136136

137137
std::vector<const phi::dtype::complex<T>*> cpu_ptrs(batch_count);
138-
for (int i = 0; i < batch_count; ++i) {
138+
for (int64_t i = 0; i < batch_count; ++i) {
139139
cpu_ptrs[i] = gpu_mat + i * rank * rank;
140140
}
141141

142142
// num_ints is for pivot (rank * batch_count) and info (batch_count)
143-
int num_ints = batch_count * (rank + 1);
143+
int64_t num_ints = batch_count * (rank + 1);
144144
size_t total_bytes =
145145
batch_count * sizeof(phi::dtype::complex<T>*) + num_ints * sizeof(int);
146146
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc(
@@ -218,7 +218,7 @@ void SlogDeterminantKernel(const Context& dev_ctx,
218218
// shape [*, M, M], check whether it contains 0 in '*'.
219219
if (input_dim.size() > 2) {
220220
bool size_0 = false;
221-
std::vector<int> tmp_dim_vec(input_dim.begin(), input_dim.end() - 2);
221+
std::vector<int64_t> tmp_dim_vec(input_dim.begin(), input_dim.end() - 2);
222222
for (size_t i = 0; i < tmp_dim_vec.size(); ++i) {
223223
if (tmp_dim_vec[i] == 0) {
224224
size_0 = true;
@@ -234,7 +234,7 @@ void SlogDeterminantKernel(const Context& dev_ctx,
234234
}
235235
}
236236

237-
auto batch_count = detail::GetBatchCount(x.dims());
237+
int64_t batch_count = detail::GetBatchCount(x.dims());
238238
VLOG(2) << "input dim:" << x.dims();
239239
PADDLE_ENFORCE_GE(
240240
input_dim_size,
@@ -245,9 +245,9 @@ void SlogDeterminantKernel(const Context& dev_ctx,
245245
input_dim[input_dim_size - 1],
246246
input_dim[input_dim_size - 2],
247247
errors::InvalidArgument("the input matrix should be square matrix."));
248-
auto rank = input_dim[input_dim_size - 1]; // square matrix length
248+
int64_t rank = input_dim[input_dim_size - 1]; // square matrix length
249249
SlogDeterminantFunctor<T, Context>()(dev_ctx, x, rank, batch_count, out);
250-
std::vector<int> output_dim_vec(input_dim.begin(), input_dim.end() - 2);
250+
std::vector<int64_t> output_dim_vec(input_dim.begin(), input_dim.end() - 2);
251251
if (input_dim.size() == static_cast<size_t>(2)) {
252252
// when input is a two-dimension matrix, The det value is a number.
253253
output_dim_vec = {};

paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ void SlogDeterminantGradKernel(const Context& dev_ctx,
104104

105105
// remove useless first dimension
106106
int det_grad_size = det_grad.dims().size();
107-
std::vector<int> det_grad_vec;
108-
for (int i = 1; i < det_grad_size; ++i) {
107+
std::vector<int64_t> det_grad_vec;
108+
for (int64_t i = 1; i < det_grad_size; ++i) {
109109
det_grad_vec.emplace_back(det_grad.dims()[i]);
110110
}
111111
det_grad.Resize(det_grad.dims().reshape(det_grad_vec));

0 commit comments

Comments
 (0)