diff --git a/paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h b/paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h index 934ea15d80a444..869494da59cbe3 100644 --- a/paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h @@ -82,8 +82,54 @@ void SlogDeterminantGradKernel(const Context& dev_ctx, inverse_A.Resize(x.dims()); dev_ctx.template Alloc(&inverse_A); - phi::funcs::MatrixInverseFunctor mat_inv; - mat_inv(dev_ctx, x, &inverse_A); + const auto& mat_dims = x.dims(); + const int rank = mat_dims.size(); + int n = mat_dims[rank - 1]; + int64_t total_batch_size = rank > 2 ? x.numel() / (n * n) : 1; + + // Divide the batch into chunks because of cublasMatInv limitation + if (total_batch_size <= 65536) { + phi::funcs::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, x, &inverse_A); + } else { + constexpr int64_t max_batch_size = 65536; + int64_t processed = 0; + + VLOG(3) << "Large batch size detected (" << total_batch_size + << "), processing in chunks of " << max_batch_size; + + while (processed < total_batch_size) { + int64_t current_batch = + std::min(max_batch_size, total_batch_size - processed); + + // Extract current batch data + DenseTensor x_batch; + x_batch.ShareDataWith(x); + x_batch.Resize({total_batch_size, n, n}); + x_batch = x_batch.Slice(processed, processed + current_batch); + x_batch.Resize({current_batch, n, n}); + + DenseTensor inverse_batch; + inverse_batch.Resize({current_batch, n, n}); + dev_ctx.template Alloc(&inverse_batch); + + // Compute the inverse matrix for the current batch + phi::funcs::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, x_batch, &inverse_batch); + + // Copy the result to the output tensor + DenseTensor output_slice; + output_slice.ShareDataWith(inverse_A); + output_slice.Resize({total_batch_size, n, n}); + output_slice = output_slice.Slice(processed, processed + current_batch); + output_slice.Resize({current_batch, n, n}); + + phi::Copy( + dev_ctx, inverse_batch, dev_ctx.GetPlace(), false, &output_slice); + + processed += current_batch; + } + } VLOG(3) << "inverse(A) dims: " << inverse_A.dims();