|
| 1 | +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include <thrust/execution_policy.h> |
| 16 | +#include <thrust/random.h> |
| 17 | +#include <thrust/scan.h> |
| 18 | +#include <thrust/transform.h> |
| 19 | + |
| 20 | +#include "paddle/fluid/framework/eigen.h" |
| 21 | +#include "paddle/fluid/framework/op_registry.h" |
| 22 | +#include "paddle/fluid/framework/operator.h" |
| 23 | +#include "paddle/fluid/operators/multinomial_op.h" |
| 24 | +#include "paddle/fluid/platform/transform.h" |
| 25 | + |
| 26 | +namespace paddle { |
| 27 | +namespace operators { |
| 28 | + |
| 29 | +/* |
| 30 | +template <typename T, int MajorType = Eigen::RowMajor, |
| 31 | + typename IndexType = Eigen::DenseIndex> |
| 32 | +using EigenVector = framework::EigenVector<T, MajorType, IndexType>; |
| 33 | +template <typename T, int MajorType = Eigen::RowMajor, |
| 34 | + typename IndexType = Eigen::DenseIndex> |
| 35 | +using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; |
| 36 | +*/ |
| 37 | + |
| 38 | +/* |
| 39 | +template <class T> |
| 40 | +__global__ void SumArrayCUDAKernel(T **in, T *out, size_t in_size) { |
| 41 | + int id = blockIdx.x * blockDim.x + threadIdx.x; |
| 42 | + // T total(read_dst ? out[id] : static_cast<T>(0)); |
| 43 | + T total(static_cast<T>(0)) |
| 44 | + for (int i = 0; i < in_size; ++i) { |
| 45 | + const T *tmp = in[i]; |
| 46 | + if (tmp) { |
| 47 | + total += tmp[id]; |
| 48 | + } |
| 49 | + } |
| 50 | + out[id] = total; |
| 51 | + id += blockDim.x * gridDim.x; |
| 52 | +}*/ |
| 53 | + |
| 54 | +/* |
| 55 | +template <typename T> |
| 56 | +__global__ void NormalizeProbability(T* probs, int64_t rows, int64_t cols) { |
| 57 | + extern __shared__ std::vector<T> sum_rows(rows); |
| 58 | + T val; |
| 59 | + for (int64_t i = blockId.x; i < rows; i += gridDim.x) { |
| 60 | + T sum = static_cast<T>(0); |
| 61 | + for (int64_t j = threadIdx.x; j < cols; j += blockDim.x) { |
| 62 | + val = probs[i * cols + j]; |
| 63 | + sum += val; |
| 64 | + } |
| 65 | +
|
| 66 | + } |
| 67 | +}*/ |
| 68 | + |
| 69 | +template <typename T> |
| 70 | +__global__ void NormalizeProbability(T* norm_probs, const T* in_data, |
| 71 | + T* sum_rows) { |
| 72 | + // int id = blockIdx.x * blockDim.x + threadIdx.x; |
| 73 | + int id = threadIdx.x; |
| 74 | + norm_probs[id] = in_data[id] / sum_rows[0]; |
| 75 | +} |
| 76 | + |
| 77 | +template <typename T> |
| 78 | +struct RandomGeneratorCudaFunctor { |
| 79 | + unsigned int seed_; |
| 80 | + __host__ __device__ RandomGeneratorCudaFunctor(int seed) : seed_(seed) {} |
| 81 | + |
| 82 | + __host__ __device__ T operator()(const unsigned int n) const { |
| 83 | + thrust::minstd_rand rng; |
| 84 | + rng.seed(seed_); |
| 85 | + thrust::uniform_real_distribution<T> dist(0.0, 1.0); |
| 86 | + rng.discard(n); |
| 87 | + return dist(rng); |
| 88 | + } |
| 89 | +}; |
| 90 | + |
| 91 | +/* |
| 92 | +template <typename T> |
| 93 | +class MultinomialCudaFunctor(T* out_data, const T* in_data, |
| 94 | + const int64_t num_samples, const bool replacement, |
| 95 | + const int64_t num_categories, |
| 96 | + const int64_t num_distributions) { |
| 97 | +
|
| 98 | +}*/ |
| 99 | + |
| 100 | +template <typename T> |
| 101 | +__device__ int binarySearchForMultinomial(T* cumdist, T* dist, int size, |
| 102 | + T val) { |
| 103 | + int start = 0; |
| 104 | + int end = size; |
| 105 | + // cumdist[size - 1] = 0 => all zero prob dist |
| 106 | + // CUDA_KERNEL_ASSERT(cumdist[size - 1] > static_cast<T>(0)); |
| 107 | + |
| 108 | + while (end - start > 0) { |
| 109 | + int mid = start + (end - start) / 2; |
| 110 | + |
| 111 | + T midVal = cumdist[mid]; |
| 112 | + if (midVal < val) { |
| 113 | + start = mid + 1; |
| 114 | + } else { |
| 115 | + end = mid; |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + if (start == size) { |
| 120 | + // No probability mass or precision problems; just return the |
| 121 | + // first non-zero element by setting start to size-1 here, |
| 122 | + // the code below will move it to the last non-zero probability |
| 123 | + // this actually can happen when the random number is 1 |
| 124 | + // (github pytorch issue #4858). |
| 125 | + start = size - 1; |
| 126 | + } |
| 127 | + |
| 128 | + while (start >= 1 && dist[start] == 0) start--; |
| 129 | + |
| 130 | + return start; |
| 131 | +} |
| 132 | + |
| 133 | +template <typename T> |
| 134 | +__global__ void sampleMultinomialWithReplacement( |
| 135 | + T* rng, const int64_t totalSamples, T* dest, const int64_t distributions, |
| 136 | + const int64_t categories, T* normDistPrefixSum, T* normDist) { |
| 137 | + // At the moment, each warp computes one sample value in the binary |
| 138 | + // search due to divergence. It seems possible to compute multiple |
| 139 | + // values and limit divergence though later on. |
| 140 | + |
| 141 | + // global index formula for 2D grid of 1D blocks |
| 142 | + // int idx = blockIdx.y * gridDim.x * blockDim.x + blockIdx.x * blockDim.x + |
| 143 | + // threadIdx.x; |
| 144 | + int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 145 | + |
| 146 | + for (int sample = blockIdx.x * blockDim.x + threadIdx.x; |
| 147 | + sample < totalSamples; sample += blockDim.x * gridDim.x) { |
| 148 | + // we are losing 3 out of 4 generated numbers but it's ok |
| 149 | + // this kernel is not very efficient anyway |
| 150 | + |
| 151 | + // T uniform_random = dist(rng); |
| 152 | + T uniform_random = rng[sample]; |
| 153 | + |
| 154 | + // Find the bucket that a uniform sample lies in |
| 155 | + int choice = binarySearchForMultinomial<T>(normDistPrefixSum, normDist, |
| 156 | + categories, uniform_random); |
| 157 | + |
| 158 | + dest[sample] = choice; |
| 159 | + } |
| 160 | +} |
| 161 | + |
| 162 | +template <typename T> |
| 163 | +class MultinomialOpKernel<platform::CUDADeviceContext, T> |
| 164 | + : public framework::OpKernel<T> { |
| 165 | + public: |
| 166 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 167 | + const auto x = ctx.Input<framework::Tensor>("X"); |
| 168 | + auto out = ctx.Output<framework::Tensor>("Out"); |
| 169 | + |
| 170 | + const int64_t num_samples = ctx.Attr<int>("num_samples"); |
| 171 | + const bool replacement = ctx.Attr<bool>("replacement"); |
| 172 | + |
| 173 | + auto* in_data = x->data<T>(); |
| 174 | + auto* out_data = out->mutable_data<T>(ctx.GetPlace()); |
| 175 | + |
| 176 | + auto in_dims = x->dims(); |
| 177 | + int64_t in_rank = in_dims.size(); |
| 178 | + const int64_t num_categories = in_dims[in_rank - 1]; |
| 179 | + const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1; |
| 180 | + |
| 181 | + // std::vector<T> sum_rows(num_distributions); |
| 182 | + // SumArrayCUDAKernel<T>(in_data, sum_rows,) |
| 183 | + |
| 184 | + VLOG(3) << "Print num_distributions " << num_distributions << "\n"; |
| 185 | + |
| 186 | + VLOG(3) << "Print num_categories " << num_categories << "\n"; |
| 187 | + |
| 188 | + VLOG(3) << "Print in_rank " << in_rank << "\n"; |
| 189 | + |
| 190 | + framework::Tensor sum_rows_t; |
| 191 | + auto* sum_rows_data = sum_rows_t.mutable_data<T>({1}, ctx.GetPlace()); |
| 192 | + // auto* sum_rows_data = |
| 193 | + // sum_rows_t->mutable_data<T>(framework::make_ddim({1}), ctx.GetPlace()); |
| 194 | + |
| 195 | + auto& place = *ctx.template device_context<platform::CUDADeviceContext>() |
| 196 | + .eigen_device(); |
| 197 | + |
| 198 | + auto eigen_input = framework::EigenVector<T>::Flatten(*x); |
| 199 | + // auto eigen_sum_rows = framework::EigenVector<T>::From(sum_rows_t); |
| 200 | + auto eigen_sum_rows = framework::EigenScalar<T>::From(sum_rows_t); |
| 201 | + eigen_sum_rows.device(place) = |
| 202 | + eigen_input.sum(Eigen::DSizes<int, 1>(0)) |
| 203 | + .eval() |
| 204 | + .reshape(Eigen::DSizes<int, 1>(sum_rows_t.dims()[0])); |
| 205 | + // eigen_sum_rows.device(place) = |
| 206 | + // eigen_input.sum().eval().reshape(Eigen::DSizes<int, 1>(1)); |
| 207 | + |
| 208 | + dim3 grid(num_distributions); |
| 209 | + dim3 block(num_categories); |
| 210 | + |
| 211 | + // std::vector<T> in_data_norm(num_categories); |
| 212 | + framework::Tensor norm_probs_t; |
| 213 | + auto* norm_probs_data = |
| 214 | + norm_probs_t.mutable_data<T>({num_categories}, ctx.GetPlace()); |
| 215 | + NormalizeProbability< |
| 216 | + T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>( |
| 217 | + norm_probs_data, in_data, sum_rows_data); |
| 218 | + |
| 219 | + // num_distributions can only be 1. |
| 220 | + // std::vector<T> cumulative_probs(num_categories); |
| 221 | + framework::Tensor cumulative_probs_t; |
| 222 | + auto* cumulative_probs = |
| 223 | + cumulative_probs_t.mutable_data<T>({num_categories}, ctx.GetPlace()); |
| 224 | + // T cumulative_probs[num_categories]; |
| 225 | + int64_t size = num_categories; |
| 226 | + thrust::inclusive_scan(thrust::device, norm_probs_data, |
| 227 | + norm_probs_data + num_categories, cumulative_probs); |
| 228 | + |
| 229 | + if (replacement) { |
| 230 | + dim3 block(128); |
| 231 | + // int grid_y = 1; |
| 232 | + dim3 grid((num_samples - 1) / block.x + 1); |
| 233 | + |
| 234 | + /* |
| 235 | + // std::vector<T> rng(num_samples); |
| 236 | + T rng[num_samples]; |
| 237 | + std::uniform_real_distribution<T> dist(0, 1); |
| 238 | + auto gen_ptr = framework::DefaultCPUGenerator(); |
| 239 | + auto engine = gen_ptr->GetCPUEngine(); |
| 240 | +
|
| 241 | + for (int s = 0; s < num_samples; s++) { |
| 242 | + rng[s] = dist(*engine); |
| 243 | + } |
| 244 | + */ |
| 245 | + |
| 246 | + std::random_device rd; |
| 247 | + auto seed = rd(); |
| 248 | + |
| 249 | + framework::Tensor rng_data_t; |
| 250 | + auto* rng_data = |
| 251 | + rng_data_t.mutable_data<T>({num_samples}, ctx.GetPlace()); |
| 252 | + |
| 253 | + thrust::counting_iterator<unsigned int> index_sequence_begin(0); |
| 254 | + platform::Transform<platform::CUDADeviceContext> trans; |
| 255 | + auto* context = static_cast<const platform::CUDADeviceContext*>( |
| 256 | + &ctx.device_context()); |
| 257 | + trans(*context, index_sequence_begin, index_sequence_begin + num_samples, |
| 258 | + rng_data, RandomGeneratorCudaFunctor<T>(seed)); |
| 259 | + |
| 260 | + VLOG(3) << "Print enter\n"; |
| 261 | + // VLOG(3) << "Print size in_data " << |
| 262 | + // sizeof(in_data)/sizeof(in_data[num_categories-1]) << "\n"; |
| 263 | + // VLOG(3) << "Print norm_probs_data0 " << |
| 264 | + // sizeof(norm_probs_data[num_categories-1]) << "\n"; |
| 265 | + |
| 266 | + sampleMultinomialWithReplacement< |
| 267 | + T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>( |
| 268 | + rng_data, num_samples, out_data, num_distributions, num_categories, |
| 269 | + cumulative_probs, norm_probs_data); |
| 270 | + } |
| 271 | + |
| 272 | + // MultinomialCudaFunctor<T>(out_data, in_data, num_samples, replacement, |
| 273 | + // num_categories, num_distributions); |
| 274 | + } |
| 275 | +}; |
| 276 | + |
| 277 | +} // namespace operators |
| 278 | +} // namespace paddle |
| 279 | + |
| 280 | +namespace ops = paddle::operators; |
| 281 | +namespace plat = paddle::platform; |
| 282 | + |
| 283 | +REGISTER_OP_CUDA_KERNEL( |
| 284 | + multinomial, ops::MultinomialOpKernel<plat::CUDADeviceContext, float>, |
| 285 | + ops::MultinomialOpKernel<plat::CUDADeviceContext, double>); |
0 commit comments