Skip to content

Commit 7cd2c13

Browse files
authored
add multinomial op (#27219)
* add multinomial cpu kernel * fix C++ notype error * fix windows ci array len error * let array len be const * change array to vector * add cuda kernrl with num_distribution is 1, and not support replacement=False * add multinomial python api * support num_distribution different multinomial distributions * add multinomial python api unittest * change output dtype to int64 * fix coverage prob * optimize format * fix dtype of output error, should be int64_t
1 parent d2369dd commit 7cd2c13

File tree

7 files changed

+722
-0
lines changed

7 files changed

+722
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include "paddle/fluid/operators/multinomial_op.h"
15+
16+
#include <algorithm>
17+
#include <string>
18+
#include <vector>
19+
20+
#include "paddle/fluid/framework/generator.h"
21+
#include "paddle/fluid/framework/op_registry.h"
22+
#include "paddle/fluid/framework/operator.h"
23+
#include "paddle/fluid/operators/common_infer_shape_functions.h"
24+
25+
namespace paddle {
26+
namespace operators {
27+
28+
class MultinomialOpMaker : public framework::OpProtoAndCheckerMaker {
29+
public:
30+
void Make() override {
31+
AddInput("X", "A tensor contains probabilities of categories");
32+
AddOutput("Out", "The output tensor of multinomial op");
33+
AddAttr<int>("num_samples", "number of the generated samples")
34+
.SetDefault(1);
35+
AddAttr<bool>("replacement", "can a category be sampled more than once")
36+
.SetDefault(false);
37+
AddComment(R"DOC(
38+
This OP returns a Tensor filled with the sampled categoris according to Multinomial probabilities.
39+
40+
Out ~ Multinomial(X)
41+
42+
)DOC");
43+
}
44+
};
45+
46+
class MultinomialOp : public framework::OperatorWithKernel {
47+
public:
48+
using framework::OperatorWithKernel::OperatorWithKernel;
49+
50+
void InferShape(framework::InferShapeContext *ctx) const override {
51+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Multinomial");
52+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multinomial");
53+
54+
auto x_dim = ctx->GetInputDim("X");
55+
int64_t x_rank = x_dim.size();
56+
std::vector<int64_t> out_dims(x_rank);
57+
for (int64_t i = 0; i < x_rank - 1; i++) {
58+
out_dims[i] = x_dim[i];
59+
}
60+
61+
int64_t num_samples = ctx->Attrs().Get<int>("num_samples");
62+
out_dims[x_rank - 1] = num_samples;
63+
64+
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
65+
}
66+
};
67+
68+
template <typename T>
69+
class MultinomialOpKernel<platform::CPUDeviceContext, T>
70+
: public framework::OpKernel<T> {
71+
public:
72+
void Compute(const framework::ExecutionContext &ctx) const override {
73+
const auto x = ctx.Input<framework::Tensor>("X");
74+
auto out = ctx.Output<framework::Tensor>("Out");
75+
const int64_t num_samples = ctx.Attr<int>("num_samples");
76+
const bool replacement = ctx.Attr<bool>("replacement");
77+
78+
auto *in_data = x->data<T>();
79+
int64_t *out_data = out->mutable_data<int64_t>(ctx.GetPlace());
80+
81+
auto in_dims = x->dims();
82+
int64_t in_rank = in_dims.size();
83+
const int64_t num_categories = in_dims[in_rank - 1];
84+
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
85+
86+
MultinomialFunctor<T>(out_data, in_data, num_samples, replacement,
87+
num_categories, num_distributions);
88+
}
89+
};
90+
91+
} // namespace operators
92+
} // namespace paddle
93+
94+
namespace ops = paddle::operators;
95+
namespace plat = paddle::platform;
96+
REGISTER_OPERATOR(
97+
multinomial, ops::MultinomialOp, ops::MultinomialOpMaker,
98+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
99+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
100+
101+
REGISTER_OP_CPU_KERNEL(
102+
multinomial, ops::MultinomialOpKernel<plat::CPUDeviceContext, float>,
103+
ops::MultinomialOpKernel<plat::CPUDeviceContext, double>);
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <thrust/execution_policy.h>
16+
#include <thrust/random.h>
17+
#include <thrust/scan.h>
18+
#include <thrust/transform.h>
19+
20+
#include "paddle/fluid/framework/eigen.h"
21+
#include "paddle/fluid/framework/op_registry.h"
22+
#include "paddle/fluid/framework/operator.h"
23+
#include "paddle/fluid/operators/multinomial_op.h"
24+
#include "paddle/fluid/platform/transform.h"
25+
26+
namespace paddle {
27+
namespace operators {
28+
29+
template <typename T>
30+
__global__ void NormalizeProbability(T* norm_probs, const T* in_data,
31+
T* sum_rows) {
32+
int id = threadIdx.x + blockIdx.x * blockDim.x +
33+
blockIdx.y * gridDim.x * blockDim.x;
34+
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
35+
}
36+
37+
template <typename T>
38+
__global__ void GetCumulativeProbs(T* norm_probs_data,
39+
int64_t num_distributions,
40+
int64_t num_categories,
41+
T* cumulative_probs) {
42+
for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) {
43+
thrust::inclusive_scan(thrust::device,
44+
norm_probs_data + id * num_categories,
45+
norm_probs_data + (id + 1) * num_categories,
46+
cumulative_probs + id * num_categories);
47+
}
48+
}
49+
50+
template <typename T>
51+
struct RandomGeneratorCudaFunctor {
52+
unsigned int seed_;
53+
__host__ __device__ RandomGeneratorCudaFunctor(int seed) : seed_(seed) {}
54+
55+
__host__ __device__ T operator()(const unsigned int n) const {
56+
thrust::minstd_rand rng;
57+
rng.seed(seed_);
58+
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
59+
rng.discard(n);
60+
return dist(rng);
61+
}
62+
};
63+
64+
template <typename T>
65+
__device__ int binarySearchFunctor(T* cumulative_probs, T* norm_probs_data,
66+
int num_categories, T rng_number) {
67+
int left = 0;
68+
int right = num_categories;
69+
70+
while (right - left > 0) {
71+
int mid = left + (right - left) / 2;
72+
73+
T temp_prob = cumulative_probs[mid];
74+
if (temp_prob < rng_number) {
75+
left = mid + 1;
76+
} else {
77+
right = mid;
78+
}
79+
}
80+
81+
if (left == num_categories) {
82+
left = num_categories - 1;
83+
}
84+
85+
while (left >= 1 && norm_probs_data[left] == 0) left--;
86+
87+
return left;
88+
}
89+
90+
template <typename T>
91+
__global__ void sampleMultinomialWithReplacement(
92+
T* rng_data, const int64_t num_samples, int64_t* out_data,
93+
const int64_t num_distributions, const int64_t num_categories,
94+
T* cumulative_probs, T* norm_probs_data) {
95+
// use binary search to get the selected category sample id.
96+
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
97+
98+
int idx = threadIdx.x + blockIdx.x * blockDim.x +
99+
blockIdx.y * gridDim.x * blockDim.x;
100+
101+
// for every distribution
102+
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
103+
// for every sample
104+
for (int sample = blockIdx.x * blockDim.x + threadIdx.x;
105+
sample < num_samples; sample += blockDim.x * gridDim.x) {
106+
T rng_number = rng_data[sample + dist * num_samples];
107+
108+
// Find the bucket that a uniform random number lies in
109+
int selected_category = binarySearchFunctor<T>(
110+
cumulative_probs + dist * num_categories,
111+
norm_probs_data + dist * num_categories, num_categories, rng_number);
112+
113+
out_data[sample + dist * num_samples] = selected_category;
114+
}
115+
}
116+
}
117+
118+
template <typename T>
119+
class MultinomialOpKernel<platform::CUDADeviceContext, T>
120+
: public framework::OpKernel<T> {
121+
public:
122+
void Compute(const framework::ExecutionContext& ctx) const override {
123+
const auto x = ctx.Input<framework::Tensor>("X");
124+
auto out = ctx.Output<framework::Tensor>("Out");
125+
126+
const int64_t num_samples = ctx.Attr<int>("num_samples");
127+
const bool replacement = ctx.Attr<bool>("replacement");
128+
129+
auto* in_data = x->data<T>();
130+
int64_t* out_data = out->mutable_data<int64_t>(ctx.GetPlace());
131+
132+
auto in_dims = x->dims();
133+
int64_t in_rank = in_dims.size();
134+
const int64_t num_categories = in_dims[in_rank - 1];
135+
const int64_t num_distributions = in_rank > 1 ? in_dims[in_rank - 2] : 1;
136+
137+
// If replacement is False, it's not a replaceable sample. Every category
138+
// can
139+
// be used only once. So after every sample, probability of the distribution
140+
// will change. The implementation can't be parallelizable. Thus, call CPU
141+
// implementation ``MultinomialFunctor`` to sample the distribution.
142+
if (!replacement) {
143+
int64_t in_data_numel = x->numel();
144+
int64_t out_data_numel = out->numel();
145+
146+
T* cpu_in_data = new T[in_data_numel];
147+
int64_t* cpu_out_data = new int64_t[out_data_numel];
148+
149+
cudaMemcpy(cpu_in_data, in_data, in_data_numel * sizeof(T),
150+
cudaMemcpyDeviceToHost);
151+
152+
MultinomialFunctor<T>(cpu_out_data, cpu_in_data, num_samples, replacement,
153+
num_categories, num_distributions);
154+
cudaMemcpy(out_data, cpu_out_data, out_data_numel * sizeof(int64_t),
155+
cudaMemcpyHostToDevice);
156+
157+
delete[] cpu_in_data;
158+
delete[] cpu_out_data;
159+
return;
160+
}
161+
162+
// Sum of input may not be 1. To get probability in range [0, 1], calculate
163+
// sum of each row of input, and then use the sum to normalize the input.
164+
// sum_row_data: sum of each row
165+
framework::Tensor sum_rows_tensor;
166+
auto* sum_rows_data =
167+
sum_rows_tensor.mutable_data<T>({num_distributions}, ctx.GetPlace());
168+
169+
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
170+
.eigen_device();
171+
172+
if (num_distributions == 1) {
173+
auto eigen_input = framework::EigenVector<T>::Flatten(*x);
174+
auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor);
175+
eigen_sum_rows.device(place) =
176+
eigen_input.sum(Eigen::DSizes<int, 1>(1))
177+
.eval()
178+
.reshape(Eigen::DSizes<int, 1>(sum_rows_tensor.dims()[0]));
179+
} else {
180+
auto eigen_input = framework::EigenMatrix<T>::From(*x);
181+
auto eigen_sum_rows = framework::EigenVector<T>::Flatten(sum_rows_tensor);
182+
eigen_sum_rows.device(place) = eigen_input.sum(Eigen::DSizes<int, 1>(1));
183+
}
184+
185+
// Normalize row of each distribution to get the probability in range [0,
186+
// 1].
187+
// norm_probs_data: probability of the distribution
188+
framework::Tensor norm_probs_tensor;
189+
auto* norm_probs_data = norm_probs_tensor.mutable_data<T>(
190+
{num_distributions, num_categories}, ctx.GetPlace());
191+
192+
// number of threads in a block is min(num_categories, 512)
193+
dim3 block_norm(num_categories < 512 ? num_categories : 512);
194+
dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions);
195+
NormalizeProbability<
196+
T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>(
197+
norm_probs_data, in_data, sum_rows_data);
198+
199+
// Get cumulative probability of each distribution. It's the same function
200+
// of
201+
// ``cumsum`` op.
202+
framework::Tensor cumulative_probs_tensor;
203+
auto* cumulative_probs = cumulative_probs_tensor.mutable_data<T>(
204+
{num_distributions, num_categories}, ctx.GetPlace());
205+
dim3 block_cumsum(1);
206+
dim3 grid_cumsum(num_distributions);
207+
GetCumulativeProbs<T><<<grid_cumsum, block_cumsum, 0,
208+
ctx.cuda_device_context().stream()>>>(
209+
norm_probs_data, num_distributions, num_categories, cumulative_probs);
210+
211+
// Generate random number for each sample.
212+
std::random_device rd;
213+
auto seed = rd();
214+
215+
framework::Tensor rng_data_tensor;
216+
auto* rng_data = rng_data_tensor.mutable_data<T>(
217+
{num_distributions, num_samples}, ctx.GetPlace());
218+
219+
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
220+
platform::Transform<platform::CUDADeviceContext> trans;
221+
auto* context =
222+
static_cast<const platform::CUDADeviceContext*>(&ctx.device_context());
223+
trans(*context, index_sequence_begin,
224+
index_sequence_begin + num_distributions * num_samples, rng_data,
225+
RandomGeneratorCudaFunctor<T>(seed));
226+
227+
// Sample the multinomial distributions.
228+
dim3 block_sample(128);
229+
dim3 grid_sample((num_samples - 1) / block_sample.x + 1, num_distributions);
230+
sampleMultinomialWithReplacement<T><<<grid_sample, block_sample, 0,
231+
ctx.cuda_device_context().stream()>>>(
232+
rng_data, num_samples, out_data, num_distributions, num_categories,
233+
cumulative_probs, norm_probs_data);
234+
}
235+
};
236+
237+
} // namespace operators
238+
} // namespace paddle
239+
240+
namespace ops = paddle::operators;
241+
namespace plat = paddle::platform;
242+
243+
REGISTER_OP_CUDA_KERNEL(
244+
multinomial, ops::MultinomialOpKernel<plat::CUDADeviceContext, float>,
245+
ops::MultinomialOpKernel<plat::CUDADeviceContext, double>);

0 commit comments

Comments
 (0)