Skip to content

Commit a690c42

Browse files
YuanRishengAnnaTrainingG
authored andcommitted
Add New OP: gumbel_softmax (PaddlePaddle#35506)
* Add New Op: gumbel_softmax * Add New Op: gumbel_softmax * Add New Op: gumbel_softmax (amend) * add __main__ function in unit test * fix bugs when test in windows ci * update en docs * delete reletive error in unit test * delete relative error in unit test * set hard=True in unit test
1 parent 57ba97b commit a690c42

File tree

6 files changed

+852
-0
lines changed

6 files changed

+852
-0
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/gumbel_softmax_op.h"
16+
#include <string>
17+
#include <unordered_map>
18+
#include "paddle/fluid/operators/common_infer_shape_functions.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
class GumbelSoftmaxOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
void InferShape(framework::InferShapeContext* ctx) const override {
27+
return UnaryOpUnchangedInferShapeCheckAxis(ctx);
28+
}
29+
30+
protected:
31+
framework::OpKernelType GetExpectedKernelType(
32+
const framework::ExecutionContext& ctx) const override {
33+
return framework::OpKernelType(
34+
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
35+
ctx.device_context());
36+
}
37+
};
38+
39+
class GumbelSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
40+
public:
41+
void Make() override {
42+
AddInput("X",
43+
"(Tensor) An N-D Tensor, N >= 1,"
44+
"The first N - 1 dimensions index into a batch of independent "
45+
"distributions "
46+
"and the last dimension represents a vector of probabilities for "
47+
"each class.");
48+
AddOutput("Out", "The sampled tensor with the same shape as X.");
49+
AddAttr<float>("temperature",
50+
"(float, default 1.0) non-negative scalar temperature.")
51+
.SetDefault(1.0);
52+
AddAttr<bool>(
53+
"hard",
54+
"(bool, default false) "
55+
"if True, the returned samples will be discretized as one-hot vectors, "
56+
"but will be differentiated as if it is the soft sample in autograd.")
57+
.SetDefault(false);
58+
AddAttr<int>("axis",
59+
"(int, default -1)"
60+
"The dimension index of Input(x) to perform gumbel_softmax.")
61+
.SetDefault(-1);
62+
AddComment(R"DOC(
63+
GumbelSoftmax Operator.
64+
65+
Samples from the Gumbel-Softmax distribution and optionally discretizes.
66+
67+
)DOC");
68+
}
69+
};
70+
71+
class GumbelSoftmaxGradOp : public framework::OperatorWithKernel {
72+
public:
73+
using framework::OperatorWithKernel::OperatorWithKernel;
74+
75+
void InferShape(framework::InferShapeContext* ctx) const override {
76+
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "gumbel_softmax_grad");
77+
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
78+
"Out@GRAD", "gumbel_softmax_grad");
79+
PADDLE_ENFORCE_EQ(
80+
ctx->GetInputDim("Out"),
81+
ctx->GetInputDim(framework::GradVarName("Out")),
82+
platform::errors::InvalidArgument("Input(Out) and its gradients "
83+
"should have the same shape."));
84+
85+
ctx->SetOutputDim(framework::GradVarName("X"),
86+
ctx->GetInputDim(framework::GradVarName("Out")));
87+
}
88+
};
89+
90+
template <typename T>
91+
class GumbelSoftmaxGradOpMaker : public framework::SingleGradOpMaker<T> {
92+
public:
93+
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
94+
95+
protected:
96+
void Apply(GradOpPtr<T> op) const override {
97+
op->SetType("gumbel_softmax_grad");
98+
op->SetInput("Out", this->Output("Out"));
99+
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
100+
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
101+
op->SetAttrMap(this->Attrs());
102+
}
103+
};
104+
105+
} // namespace operators
106+
} // namespace paddle
107+
108+
namespace ops = paddle::operators;
109+
110+
REGISTER_OPERATOR(gumbel_softmax, ops::GumbelSoftmaxOp,
111+
ops::GumbelSoftmaxOpMaker,
112+
ops::GumbelSoftmaxGradOpMaker<paddle::framework::OpDesc>,
113+
ops::GumbelSoftmaxGradOpMaker<paddle::imperative::OpBase>);
114+
REGISTER_OPERATOR(gumbel_softmax_grad, ops::GumbelSoftmaxGradOp);
115+
116+
REGISTER_OP_CPU_KERNEL(
117+
gumbel_softmax,
118+
ops::GumbelSoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
119+
ops::GumbelSoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
120+
REGISTER_OP_CPU_KERNEL(
121+
gumbel_softmax_grad,
122+
ops::GumbelSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
123+
ops::GumbelSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#pragma once
15+
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/framework/operator.h"
18+
#include "paddle/fluid/operators/gumbel_softmax_op.h"
19+
20+
#if defined(__NVCC__) || defined(__HIPCC__)
21+
#ifdef __NVCC__
22+
#include "cub/cub.cuh"
23+
#endif
24+
#ifdef __HIPCC__
25+
#include <hipcub/hipcub.hpp>
26+
namespace cub = hipcub;
27+
#endif
28+
29+
#include <thrust/device_vector.h>
30+
#include <thrust/host_vector.h>
31+
#include <thrust/random.h>
32+
#include <thrust/transform.h>
33+
#include "paddle/fluid/framework/generator.h"
34+
#include "paddle/fluid/memory/memcpy.h"
35+
36+
namespace paddle {
37+
namespace operators {
38+
39+
template <typename K, typename V>
40+
using KeyValuePair = cub::KeyValuePair<K, V>;
41+
42+
template <typename T>
43+
struct UniformCUDAGenerator {
44+
T min_, max_;
45+
unsigned int seed_;
46+
unsigned int offset_ = 0;
47+
HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed)
48+
: min_(min), max_(max), seed_(seed) {}
49+
HOSTDEVICE UniformCUDAGenerator(T min, T max, unsigned int seed,
50+
unsigned int offset)
51+
: min_(min), max_(max), seed_(seed), offset_(offset) {}
52+
53+
HOSTDEVICE T operator()(const unsigned int n) const {
54+
thrust::minstd_rand rng;
55+
rng.seed(seed_);
56+
thrust::uniform_real_distribution<T> dist(min_, max_);
57+
rng.discard(n + offset_);
58+
return dist(rng);
59+
}
60+
};
61+
62+
template <typename T, size_t BlockDim>
63+
__global__ void OneHotCUDAKernel(const int64_t height, const int64_t width,
64+
const int64_t size_out_axis, const T init,
65+
const T* in, T* out) {
66+
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
67+
__shared__ typename BlockReduce::TempStorage temp_storage;
68+
69+
for (int64_t idx = blockIdx.x; idx < height; idx += gridDim.x) {
70+
KeyValuePair<int, T> kv_pair = {-1, init};
71+
int h = idx / size_out_axis;
72+
int w = idx % size_out_axis;
73+
cub::ArgMax reducer;
74+
for (int k = threadIdx.x; k < width; k += blockDim.x) {
75+
kv_pair = reducer(
76+
{k, in[h * width * size_out_axis + k * size_out_axis + w]}, kv_pair);
77+
}
78+
kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, reducer);
79+
if (threadIdx.x == 0) {
80+
int index = static_cast<int>(kv_pair.key);
81+
out[h * width * size_out_axis + index * size_out_axis + w] = 1;
82+
}
83+
__syncthreads();
84+
}
85+
}
86+
87+
template <typename T>
88+
struct OneHotGenerator<platform::CUDADeviceContext, T> {
89+
static void Transform(const platform::CUDADeviceContext& context,
90+
const Tensor& X, Tensor* Out, int axis) {
91+
const int size_to_axis = SizeToAxis(axis, X.dims());
92+
const int size_from_axis = SizeFromAxis(axis, X.dims());
93+
const int size_out_axis = SizeOutAxis(axis, X.dims());
94+
constexpr int thread_size = 512;
95+
int64_t max_grid_dimx = context.GetCUDAMaxGridDimSize().x;
96+
int64_t height = size_to_axis * size_out_axis;
97+
int block_size = height < max_grid_dimx ? height : max_grid_dimx;
98+
99+
Tensor input_tensor;
100+
input_tensor.mutable_data<T>(Out->dims(), platform::CUDAPlace());
101+
TensorCopy(*Out, context.GetPlace(), &input_tensor);
102+
math::set_constant(context, Out, 0.0);
103+
OneHotCUDAKernel<
104+
T, thread_size><<<block_size, thread_size, 0, context.stream()>>>(
105+
height, size_from_axis / size_out_axis, size_out_axis,
106+
std::numeric_limits<T>::lowest(), input_tensor.data<T>(),
107+
Out->data<T>());
108+
}
109+
};
110+
111+
template <typename T>
112+
__global__ void AddGumbelNoiseCUDAKernel(const T* input_data, T* output_data,
113+
T* noise, const float temperature,
114+
int64_t n) {
115+
int index = threadIdx.x + blockIdx.x * blockDim.x;
116+
int step = blockDim.x * gridDim.x;
117+
for (int64_t i = index; i < n; i += step) {
118+
T gumbel_noise = -log(-log(noise[i]));
119+
output_data[i] = (gumbel_noise + input_data[i]) / temperature;
120+
}
121+
}
122+
123+
template <typename T>
124+
struct GumbleNoiseGenerator<platform::CUDADeviceContext, T> {
125+
static void Transform(const platform::CUDADeviceContext& context,
126+
const T* input_data, T* output_data, int size_to_axis,
127+
int size_from_axis, const float temperature) {
128+
Tensor random_tensor;
129+
int64_t size = size_to_axis * size_from_axis;
130+
T* random_data =
131+
random_tensor.mutable_data<T>({size}, platform::CUDAPlace());
132+
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
133+
const unsigned int seed = std::random_device()();
134+
135+
// generate gumbel noise
136+
int device_id =
137+
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
138+
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
139+
if (gen_cuda->GetIsInitPy()) {
140+
auto seed_offset = gen_cuda->IncrementOffset(1);
141+
int gen_offset = size * seed_offset.second;
142+
thrust::transform(
143+
index_sequence_begin, index_sequence_begin + size,
144+
thrust::device_ptr<T>(random_data),
145+
UniformCUDAGenerator<T>(0.00001, 1, seed_offset.first, gen_offset));
146+
} else {
147+
thrust::transform(index_sequence_begin, index_sequence_begin + size,
148+
thrust::device_ptr<T>(random_data),
149+
UniformCUDAGenerator<T>(0.00001, 1, seed));
150+
}
151+
152+
// add gumbel noise to X
153+
const int thread_size = 512;
154+
int64_t block_size = (size + thread_size) / thread_size;
155+
AddGumbelNoiseCUDAKernel<
156+
T><<<block_size, thread_size, 0, context.stream()>>>(
157+
input_data, output_data, random_data, temperature, size);
158+
}
159+
};
160+
161+
#endif
162+
} // namespace operators
163+
} // namespace paddle
164+
165+
namespace ops = paddle::operators;
166+
namespace plat = paddle::platform;
167+
REGISTER_OP_CUDA_KERNEL(
168+
gumbel_softmax, ops::GumbelSoftmaxKernel<plat::CUDADeviceContext, float>,
169+
ops::GumbelSoftmaxKernel<plat::CUDADeviceContext, double>);
170+
REGISTER_OP_CUDA_KERNEL(
171+
gumbel_softmax_grad,
172+
ops::GumbelSoftmaxGradKernel<plat::CUDADeviceContext, float>,
173+
ops::GumbelSoftmaxGradKernel<plat::CUDADeviceContext, double>);

0 commit comments

Comments
 (0)