Skip to content

Commit aa33bc3

Browse files
committed
[hybrid] static model parallel dropout support deterministic RandomSeedGenerator (PaddlePaddle#36228)
1 parent 17b2d71 commit aa33bc3

File tree

13 files changed

+354
-32
lines changed

13 files changed

+354
-32
lines changed

paddle/fluid/framework/generator.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,43 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() {
6363
return default_cpu_generator;
6464
}
6565

66+
using RNGMap = std::unordered_map<std::string, std::shared_ptr<Generator>>;
67+
68+
static RNGMap& GetRandomSeedGeneratorMap() {
69+
static auto random_seed_generator_map = RNGMap();
70+
return random_seed_generator_map;
71+
}
72+
73+
const std::shared_ptr<Generator>& SetRandomSeedGenerator(
74+
const std::string& name, uint64_t seed) {
75+
auto& rng_map = GetRandomSeedGeneratorMap();
76+
auto iter = rng_map.find(name);
77+
PADDLE_ENFORCE_EQ(iter == rng_map.end(), true,
78+
platform::errors::AlreadyExists(
79+
"%s RandomSeedGenerator is already exist", name));
80+
81+
auto generator = std::make_shared<Generator>(seed);
82+
bool emplace_success = rng_map.emplace(name, generator).second;
83+
PADDLE_ENFORCE_EQ(
84+
emplace_success, true,
85+
platform::errors::PermissionDenied(
86+
"SetRandomSeedGenerator cannot emplace %s RandomSeedGenerator",
87+
name));
88+
return rng_map[name];
89+
}
90+
91+
const std::shared_ptr<Generator>& GetRandomSeedGenerator(
92+
const std::string& name) {
93+
auto& rng_map = GetRandomSeedGeneratorMap();
94+
auto iter = rng_map.find(name);
95+
PADDLE_ENFORCE_EQ(iter != rng_map.end(), true,
96+
platform::errors::NotFound(
97+
"%s RandomSeedGenerator is not found, please "
98+
"use `set_random_seed_generator` to set rng first",
99+
name));
100+
return iter->second;
101+
}
102+
66103
std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine() {
67104
static auto op_default_cpu_engine = std::make_shared<std::mt19937_64>();
68105
return op_default_cpu_engine;

paddle/fluid/framework/generator.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,11 @@ std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
126126
const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
127127
int64_t device_id = -1);
128128

129+
const std::shared_ptr<Generator>& SetRandomSeedGenerator(
130+
const std::string& name, uint64_t seed);
131+
132+
const std::shared_ptr<Generator>& GetRandomSeedGenerator(
133+
const std::string& name);
134+
129135
} // namespace framework
130136
} // namespace paddle

paddle/fluid/operators/dropout_impl_util.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
2929
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
3030
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
3131

32-
if ((seed) && platform::is_gpu_place(seed->place())) {
32+
if (seed) {
3333
framework::Tensor seed_cpu_tensor;
3434
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
3535
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
@@ -39,12 +39,8 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
3939
*seed_data = seed_offset.first;
4040
*increment = seed_offset.second;
4141
} else {
42-
if (seed) {
43-
*seed_data = *(seed->data<int>());
44-
} else {
45-
std::random_device rnd;
46-
*seed_data = is_fix_seed ? seed_val : rnd();
47-
}
42+
std::random_device rnd;
43+
*seed_data = is_fix_seed ? seed_val : rnd();
4844
*increment = offset;
4945
}
5046
}

paddle/fluid/operators/seed_op.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ class SeedOpMaker : public framework::OpProtoAndCheckerMaker {
3939
void Make() override {
4040
AddOutput("Out", "The output of seed op.");
4141
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
42+
AddAttr<bool>("deterministic",
43+
"(bool, default false) Whether to use deterministic "
44+
"RandomSeedGenerator which "
45+
"generate by `set_random_seed_generator`")
46+
.SetDefault(false)
47+
.AsExtra();
48+
AddAttr<std::string>(
49+
"rng_name",
50+
"use deterministic RandomSeedGenerator which name is `rng_name`")
51+
.SetDefault("")
52+
.AsExtra();
4253
AddAttr<bool>("force_cpu",
4354
"(bool, default false) Force fill output variable to cpu "
4455
"memory. Otherwise, fill output variable to the running "

paddle/fluid/operators/seed_op.cu

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,9 @@ class GPUSeedKernel : public framework::OpKernel<T> {
2323
public:
2424
void Compute(const framework::ExecutionContext &context) const override {
2525
auto *out = context.Output<Tensor>("Out");
26-
int user_seed = context.Attr<int>("seed");
27-
auto force_cpu = context.Attr<bool>("force_cpu");
28-
std::random_device rnd;
29-
int seed;
30-
if (user_seed != 0) {
31-
seed = user_seed;
32-
} else {
33-
seed = rnd();
34-
}
26+
int seed = get_seed(context);
3527

28+
auto force_cpu = context.Attr<bool>("force_cpu");
3629
bool cpu_place = force_cpu || context.GetPlace() == platform::CPUPlace();
3730
if (cpu_place) {
3831
platform::DeviceContextPool &pool =

paddle/fluid/operators/seed_op.h

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,45 @@
1313
// limitations under the License.
1414
#pragma once
1515

16+
#include "paddle/fluid/framework/generator.h"
1617
#include "paddle/fluid/framework/op_registry.h"
1718
#include "paddle/fluid/framework/op_version_registry.h"
1819

1920
namespace paddle {
2021
namespace operators {
2122
using Tensor = framework::Tensor;
2223

23-
template <typename DeviceContext, typename T>
24-
class CPUSeedKernel : public framework::OpKernel<T> {
25-
public:
26-
void Compute(const framework::ExecutionContext& context) const override {
27-
auto* out = context.Output<Tensor>("Out");
28-
auto* out_data = out->mutable_data<T>(context.GetPlace());
29-
int user_seed = context.Attr<int>("seed");
24+
static int get_seed(const framework::ExecutionContext& context) {
25+
int user_seed = context.Attr<int>("seed");
26+
bool deterministic = context.Attr<bool>("deterministic");
3027

28+
int seed = 0;
29+
if (!deterministic) {
3130
// NOTE: fixed seed should only be used in unittest or for debug.
3231
// Guarantee to use random seed in training.
33-
std::random_device rnd;
34-
int seed;
3532
if (user_seed != 0) {
3633
seed = user_seed;
3734
} else {
35+
std::random_device rnd;
3836
seed = rnd();
3937
}
40-
out_data[0] = seed;
38+
} else {
39+
std::string name = context.Attr<std::string>("rng_name");
40+
auto rng = framework::GetRandomSeedGenerator(name);
41+
do { // NOTE(wangxi): cpu dropout will use random seed if seed == 0
42+
seed = static_cast<int>(rng->Random64());
43+
} while (seed == 0);
44+
}
45+
return seed;
46+
}
47+
48+
template <typename DeviceContext, typename T>
49+
class CPUSeedKernel : public framework::OpKernel<T> {
50+
public:
51+
void Compute(const framework::ExecutionContext& context) const override {
52+
auto* out = context.Output<Tensor>("Out");
53+
auto* out_data = out->mutable_data<T>(context.GetPlace());
54+
out_data[0] = get_seed(context);
4155
}
4256
};
4357

paddle/fluid/pybind/generator_py.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ void BindGenerator(py::module* m_ptr) {
6060
&framework::Generator::SetIsInitPy);
6161
m.def("default_cpu_generator", &framework::DefaultCPUGenerator);
6262
m.def("default_cuda_generator", &framework::GetDefaultCUDAGenerator);
63+
m.def("set_random_seed_generator", &framework::SetRandomSeedGenerator);
64+
m.def("get_random_seed_generator", &framework::GetRandomSeedGenerator);
6365
}
6466
} // namespace pybind
6567
} // namespace paddle

python/paddle/distributed/fleet/meta_parallel/parallel_layers/random.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
import paddle
1616
import contextlib
1717
import numpy as np
18+
from paddle import _C_ops
19+
from paddle.fluid import core
20+
from paddle.fluid.data_feeder import check_variable_and_dtype
21+
from paddle.fluid.framework import in_dygraph_mode, default_main_program
22+
from paddle.fluid.layer_helper import LayerHelper
1823

1924
__all__ = []
2025

@@ -93,3 +98,135 @@ def model_parallel_random_seed(seed=None):
9398
RNG_STATE_TRACKER.reset()
9499
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
95100
paddle.seed(global_seed)
101+
102+
103+
def determinate_seed(rng_name):
104+
assert rng_name is not None and rng_name != ""
105+
helper = LayerHelper('seed', **locals())
106+
out = helper.create_variable_for_type_inference(dtype=paddle.int32)
107+
# set force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
108+
helper.append_op(
109+
type='seed',
110+
outputs={'Out': out},
111+
attrs={'deterministic': True,
112+
'rng_name': rng_name,
113+
'force_cpu': True})
114+
return out
115+
116+
117+
def dropout(x,
118+
p=0.5,
119+
axis=None,
120+
rng_name=None,
121+
training=True,
122+
mode="upscale_in_train",
123+
name=None):
124+
"""
125+
Dropout is a regularization technique for reducing overfitting by preventing
126+
neuron co-adaption during training. The dropout operator randomly sets the
127+
outputs of some units to zero, while upscale others according to the given
128+
dropout probability.
129+
130+
Args:
131+
x (Tensor): The input tensor. The data type is float32 or float64.
132+
p (float|int): Probability of setting units to zero. Default 0.5.
133+
axis (int|list|tuple): The axis along which the dropout is performed. Default None.
134+
rng_name (str): The random seed generator name, which used to obtain deterministic results.
135+
training (bool): A flag indicating whether it is in train phrase or not. Default True.
136+
mode(str): ['upscale_in_train'(default) | 'downscale_in_infer'].
137+
138+
1. upscale_in_train(default), upscale the output at training time
139+
140+
- train: out = input * mask / ( 1.0 - dropout_prob )
141+
- inference: out = input
142+
143+
2. downscale_in_infer, downscale the output at inference
144+
145+
- train: out = input * mask
146+
- inference: out = input * (1.0 - dropout_prob)
147+
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
148+
149+
Returns:
150+
A Tensor representing the dropout, has same shape and data type as `x` .
151+
152+
153+
Examples:
154+
We use ``p=0.5`` in the following description for simplicity.
155+
156+
1. When ``axis=None`` , this is commonly used dropout, which dropout each element of x randomly.
157+
158+
.. code-block:: text
159+
160+
Let's see a simple case when x is a 2d tensor with shape 2*3:
161+
[[1 2 3]
162+
[4 5 6]]
163+
we generate mask with the same shape as x, which is 2*3. The value of mask is
164+
sampled from a Bernoulli distribution randomly. For example, we may get such mask:
165+
[[0 1 0]
166+
[1 0 1]]
167+
So the output is obtained from elementwise multiply of x and mask:
168+
[[0 2 0]
169+
[4 0 6]]
170+
Using default setting, i.e. ``mode='upscale_in_train'`` ,
171+
if in training phase, the final upscale output is:
172+
[[0 4 0 ]
173+
[8 0 12]]
174+
if in test phase, the output is the same as input:
175+
[[1 2 3]
176+
[4 5 6]]
177+
we can also set ``mode='downscale_in_infer'`` , then
178+
if in training phase, the final output is:
179+
[[0 2 0]
180+
[4 0 6]]
181+
if in test phase, the scale output is:
182+
[[0.5 1. 1.5]
183+
[2. 2.5 3. ]]
184+
185+
"""
186+
if rng_name is None:
187+
return paddle.nn.functional.dropout(x, p, axis, training, mode, name)
188+
189+
# fast return for p == 0
190+
if p == 0: return x
191+
192+
assert isinstance(p, (float, int)), \
193+
TypeError("p argument should be a number")
194+
assert 0 <= p <= 1, ValueError("p argument should between 0 and 1")
195+
assert mode in ('downscale_in_infer', 'upscale_in_train'), \
196+
ValueError(
197+
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'")
198+
199+
assert axis is None, \
200+
TypeError("unsupport axis when using random seed generator")
201+
202+
mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer
203+
204+
# dygraph using tracker, doesn't need determinate seed
205+
if in_dygraph_mode():
206+
out, mask = _C_ops.dropout(x, 'dropout_prob', p, 'is_test',
207+
not training, 'fix_seed', False, 'seed', 0,
208+
'dropout_implementation', mode)
209+
return out
210+
211+
seed = determinate_seed(rng_name)
212+
213+
helper = LayerHelper('dropout', **locals())
214+
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
215+
'dropout')
216+
217+
out = helper.create_variable_for_type_inference(dtype=x.dtype)
218+
mask = helper.create_variable_for_type_inference(
219+
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
220+
221+
helper.append_op(
222+
type='dropout',
223+
inputs={'X': [x],
224+
'Seed': seed},
225+
outputs={'Out': [out],
226+
'Mask': [mask]},
227+
attrs={
228+
'dropout_prob': p,
229+
'is_test': not training,
230+
'dropout_implementation': mode,
231+
})
232+
return out

python/paddle/fluid/backward.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,15 @@ def modify_forward_desc_for_recompute(self):
174174
return
175175

176176
op_idx = 0
177-
while (op_idx < len(self.ops)):
177+
while op_idx < len(self.ops):
178178
op = self.ops[op_idx]
179179
if op.desc.type() != "dropout":
180180
op_idx += 1
181181
continue
182+
# already insert seed op before dropout
183+
if op.input('Seed') is not None and len(op.input('Seed')) == 1:
184+
op_idx += 1
185+
continue
182186
# add a seed op so that the two dropout op can generate same output
183187
op_unique_name = unique_name.generate("seed")
184188
var_unique_name = unique_name.generate_with_ignorable_key(".".join(

0 commit comments

Comments
 (0)