|
| 1 | +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include <thrust/execution_policy.h> |
| 16 | +#include <thrust/random.h> |
| 17 | +#include <thrust/scan.h> |
| 18 | +#include <thrust/transform.h> |
| 19 | + |
| 20 | +#include "paddle/fluid/framework/eigen.h" |
| 21 | +#include "paddle/fluid/framework/op_registry.h" |
| 22 | +#include "paddle/fluid/framework/operator.h" |
| 23 | +#include "paddle/fluid/operators/multinomial_op.h" |
| 24 | +#include "paddle/fluid/platform/transform.h" |
| 25 | + |
| 26 | +namespace paddle { |
| 27 | +namespace operators { |
| 28 | + |
| 29 | +template <typename T> |
| 30 | +__global__ void NormalizeProbability(T* norm_probs, const T* in_data, |
| 31 | + T* sum_rows) { |
| 32 | + int id = threadIdx.x + blockIdx.x * blockDim.x + |
| 33 | + blockIdx.y * gridDim.x * blockDim.x; |
| 34 | + norm_probs[id] = in_data[id] / sum_rows[blockIdx.y]; |
| 35 | +} |
| 36 | + |
| 37 | +template <typename T> |
| 38 | +__global__ void GetCumulativeProbs(T* norm_probs_data, |
| 39 | + int64_t num_distributions, |
| 40 | + int64_t num_categories, |
| 41 | + T* cumulative_probs) { |
| 42 | + for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) { |
| 43 | + thrust::inclusive_scan(thrust::device, |
| 44 | + norm_probs_data + id * num_categories, |
| 45 | + norm_probs_data + (id + 1) * num_categories, |
| 46 | + cumulative_probs + id * num_categories); |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +template <typename T> |
| 51 | +struct RandomGeneratorCudaFunctor { |
| 52 | + unsigned int seed_; |
| 53 | + __host__ __device__ RandomGeneratorCudaFunctor(int seed) : seed_(seed) {} |
| 54 | + |
| 55 | + __host__ __device__ T operator()(const unsigned int n) const { |
| 56 | + thrust::minstd_rand rng; |
| 57 | + rng.seed(seed_); |
| 58 | + thrust::uniform_real_distribution<T> dist(0.0, 1.0); |
| 59 | + rng.discard(n); |
| 60 | + return dist(rng); |
| 61 | + } |
| 62 | +}; |
| 63 | + |
| 64 | +template <typename T> |
| 65 | +__device__ int binarySearchFunctor(T* cumulative_probs, T* norm_probs_data, |
| 66 | + int num_categories, T rng_number) { |
| 67 | + int left = 0; |
| 68 | + int right = num_categories; |
| 69 | + |
| 70 | + while (right - left > 0) { |
| 71 | + int mid = left + (right - left) / 2; |
| 72 | + |
| 73 | + T temp_prob = cumulative_probs[mid]; |
| 74 | + if (temp_prob < rng_number) { |
| 75 | + left = mid + 1; |
| 76 | + } else { |
| 77 | + right = mid; |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + if (left == num_categories) { |
| 82 | + left = num_categories - 1; |
| 83 | + } |
| 84 | + |
| 85 | + while (left >= 1 && norm_probs_data[left] == 0) left--; |
| 86 | + |
| 87 | + return left; |
| 88 | +} |
| 89 | + |
| 90 | +template <typename T> |
| 91 | +__global__ void sampleMultinomialWithReplacement( |
| 92 | + T* rng_data, const int64_t num_samples, int64_t* out_data, |
| 93 | + const int64_t num_distributions, const int64_t num_categories, |
| 94 | + T* cumulative_probs, T* norm_probs_data) { |
| 95 | + // use binary search to get the selected category sample id. |
| 96 | + // let cumulative_probs[id-1] < rng_data < cumulative_probs[id]. |
| 97 | + |
| 98 | + int idx = threadIdx.x + blockIdx.x * blockDim.x + |
| 99 | + blockIdx.y * gridDim.x * blockDim.x; |
| 100 | + |
| 101 | + // for every distribution |
| 102 | + for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) { |
| 103 | + // for every sample |
| 104 | + for (int sample = blockIdx.x * blockDim.x + threadIdx.x; |
| 105 | + sample < num_samples; sample += blockDim.x * gridDim.x) { |
| 106 | + T rng_number = rng_data[sample + dist * num_samples]; |
| 107 | + |
| 108 | + // Find the bucket that a uniform random number lies in |
| 109 | + int selected_category = binarySearchFunctor<T>( |
| 110 | + cumulative_probs + dist * num_categories, |
| 111 | + norm_probs_data + dist * num_categories, num_categories, rng_number); |
| 112 | + |
| 113 | + out_data[sample + dist * num_samples] = selected_category; |
| 114 | + } |
| 115 | + } |
| 116 | +} |
| 117 | + |
| 118 | +template <typename T> |
| 119 | +class MultinomialOpKernel<platform::CUDADeviceContext, T> |
| 120 | + : public framework::OpKernel<T> { |
| 121 | + public: |
| 122 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 123 | + const auto x = ctx.Input<framework::Tensor>("X"); |
| 124 | + auto out = ctx.Output<framework::Tensor>("Out"); |
| 125 | + |
| 126 | + const int64_t num_samples = ctx.Attr<int>("num_samples"); |
| 127 | + const bool replacement = ctx.Attr<bool>("replacement"); |
| 128 | + |
| 129 | + auto* in_data = x->data<T>(); |
| 130 | + int64_t* out_data = out->mutable_data<int64_t>(ctx.GetPlace()); |
| 131 | + |
| 132 | + auto in_dims = x->dims(); |
| 133 | + int64_t in_rank = in_dims.size(); |
| 134 | + const int64_t num_categories = in_dims[in_rank - 1]; |
| 135 | + const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1; |
| 136 | + |
| 137 | + // If replacement is False, it's not a replaceable sample. Every category |
| 138 | + // can |
| 139 | + // be used only once. So after every sample, probability of the distribution |
| 140 | + // will change. The implementation can't be parallelizable. Thus, call CPU |
| 141 | + // implementation ``MultinomialFunctor`` to sample the distribution. |
| 142 | + if (!replacement) { |
| 143 | + int64_t in_data_numel = x->numel(); |
| 144 | + int64_t out_data_numel = out->numel(); |
| 145 | + |
| 146 | + T* cpu_in_data = new T[in_data_numel]; |
| 147 | + int64_t* cpu_out_data = new int64_t[out_data_numel]; |
| 148 | + |
| 149 | + cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T), |
| 150 | + cudaMemcpyDeviceToHost); |
| 151 | + |
| 152 | + MultinomialFunctor<T>(cpu_out_data, cpu_in_data, num_samples, replacement, |
| 153 | + num_categories, num_distributions); |
| 154 | + cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(int64_t), |
| 155 | + cudaMemcpyHostToDevice); |
| 156 | + |
| 157 | + delete[] cpu_in_data; |
| 158 | + delete[] cpu_out_data; |
| 159 | + return; |
| 160 | + } |
| 161 | + |
| 162 | + // Sum of input may not be 1. To get probability in range [0, 1], calculate |
| 163 | + // sum of each row of input, and then use the sum to normalize the input. |
| 164 | + // sum_row_data: sum of each row |
| 165 | + framework::Tensor sum_rows_tensor; |
| 166 | + auto* sum_rows_data = |
| 167 | + sum_rows_tensor.mutable_data<T>({num_distributions}, ctx.GetPlace()); |
| 168 | + |
| 169 | + auto& place = *ctx.template device_context<platform::CUDADeviceContext>() |
| 170 | + .eigen_device(); |
| 171 | + |
| 172 | + if (num_distributions == 1) { |
| 173 | + auto eigen_input = framework::EigenVector<T>::Flatten(*x); |
| 174 | + auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor); |
| 175 | + eigen_sum_rows.device(place) = |
| 176 | + eigen_input.sum(Eigen::DSizes<int, 1>(1)) |
| 177 | + .eval() |
| 178 | + .reshape(Eigen::DSizes<int, 1>(sum_rows_tensor.dims()[0])); |
| 179 | + } else { |
| 180 | + auto eigen_input = framework::EigenMatrix<T>::From(*x); |
| 181 | + auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor); |
| 182 | + eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1)); |
| 183 | + } |
| 184 | + |
| 185 | + // Normalize row of each distribution to get the probability in range [0, |
| 186 | + // 1]. |
| 187 | + // norm_probs_data: probability of the distribution |
| 188 | + framework::Tensor norm_probs_tensor; |
| 189 | + auto* norm_probs_data = norm_probs_tensor.mutable_data<T>( |
| 190 | + {num_distributions, num_categories}, ctx.GetPlace()); |
| 191 | + |
| 192 | + // number of threads in a block is min(num_categories, 512) |
| 193 | + dim3 block_norm(num_categories < 512 ? num_categories : 512); |
| 194 | + dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions); |
| 195 | + NormalizeProbability< |
| 196 | + T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>( |
| 197 | + norm_probs_data, in_data, sum_rows_data); |
| 198 | + |
| 199 | + // Get cumulative probability of each distribution. It's the same function |
| 200 | + // of |
| 201 | + // ``cumsum`` op. |
| 202 | + framework::Tensor cumulative_probs_tensor; |
| 203 | + auto* cumulative_probs = cumulative_probs_tensor.mutable_data<T>( |
| 204 | + {num_distributions, num_categories}, ctx.GetPlace()); |
| 205 | + dim3 block_cumsum(1); |
| 206 | + dim3 grid_cumsum(num_distributions); |
| 207 | + GetCumulativeProbs<T><<<grid_cumsum, block_cumsum, 0, |
| 208 | + ctx.cuda_device_context().stream()>>>( |
| 209 | + norm_probs_data, num_distributions, num_categories, cumulative_probs); |
| 210 | + |
| 211 | + // Generate random number for each sample. |
| 212 | + std::random_device rd; |
| 213 | + auto seed = rd(); |
| 214 | + |
| 215 | + framework::Tensor rng_data_tensor; |
| 216 | + auto* rng_data = rng_data_tensor.mutable_data<T>( |
| 217 | + {num_distributions, num_samples}, ctx.GetPlace()); |
| 218 | + |
| 219 | + thrust::counting_iterator<unsigned int> index_sequence_begin(0); |
| 220 | + platform::Transform<platform::CUDADeviceContext> trans; |
| 221 | + auto* context = |
| 222 | + static_cast<const platform::CUDADeviceContext*>(&ctx.device_context()); |
| 223 | + trans(*context, index_sequence_begin, |
| 224 | + index_sequence_begin + num_distributions * num_samples, rng_data, |
| 225 | + RandomGeneratorCudaFunctor<T>(seed)); |
| 226 | + |
| 227 | + // Sample the multinomial distributions. |
| 228 | + dim3 block_sample(128); |
| 229 | + dim3 grid_sample((num_samples - 1) / block_sample.x + 1, num_distributions); |
| 230 | + sampleMultinomialWithReplacement<T><<<grid_sample, block_sample, 0, |
| 231 | + ctx.cuda_device_context().stream()>>>( |
| 232 | + rng_data, num_samples, out_data, num_distributions, num_categories, |
| 233 | + cumulative_probs, norm_probs_data); |
| 234 | + } |
| 235 | +}; |
| 236 | + |
| 237 | +} // namespace operators |
| 238 | +} // namespace paddle |
| 239 | + |
| 240 | +namespace ops = paddle::operators; |
| 241 | +namespace plat = paddle::platform; |
| 242 | + |
| 243 | +REGISTER_OP_CUDA_KERNEL( |
| 244 | + multinomial, ops::MultinomialOpKernel<plat::CUDADeviceContext, float>, |
| 245 | + ops::MultinomialOpKernel<plat::CUDADeviceContext, double>); |
0 commit comments