From 818dbfbcaa6a3179fe7136c5db00b0b5d64d49db Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Thu, 20 Aug 2020 17:19:56 +0800 Subject: [PATCH 1/7] add first generator fix test=develop --- .../fluid/memory/allocation/mmap_allocator.cc | 7 +- .../detection/generate_proposal_labels_op.cc | 9 +- .../detection/rpn_target_assign_op.cc | 10 +- .../operators/distributed/large_scale_kv.h | 15 +- paddle/fluid/operators/dropout_op.h | 9 +- paddle/fluid/operators/gaussian_random_op.cc | 26 +- paddle/fluid/operators/math/sampler.cc | 25 +- .../mkldnn/gaussian_random_mkldnn_op.cc | 28 +- paddle/fluid/operators/randint_op.cc | 23 +- paddle/fluid/operators/randperm_op.h | 16 +- paddle/fluid/operators/sampling_id_op.h | 5 +- .../operators/truncated_gaussian_random_op.cc | 26 +- .../fluid/tests/unittests/test_random_seed.py | 352 +++++++++++++++++- 13 files changed, 497 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/memory/allocation/mmap_allocator.cc b/paddle/fluid/memory/allocation/mmap_allocator.cc index 0ef084bafd0c9f..67e9b5c4a7345c 100644 --- a/paddle/fluid/memory/allocation/mmap_allocator.cc +++ b/paddle/fluid/memory/allocation/mmap_allocator.cc @@ -27,6 +27,8 @@ #include #include +#include "paddle/fluid/framework/generator.h" + namespace paddle { namespace memory { namespace allocation { @@ -60,7 +62,10 @@ std::string GetIPCName() { handle += std::to_string(getpid()); #endif handle += "_"; - handle += std::to_string(rd()); + handle += + framework::Generator::GetInstance()->is_init_py + ? std::to_string(framework::Generator::GetInstance()->Random64()) + : std::to_string(rd()); return std::move(handle); } diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index 884aa1f6f4e996..dd9e9fd2d0ca10 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/gather.h" @@ -190,7 +191,13 @@ std::vector> SampleFgBgGt( const int64_t fg_size = static_cast(fg_inds.size()); if (fg_size > fg_rois_per_this_image) { for (int64_t i = fg_rois_per_this_image; i < fg_size; ++i) { - int rng_ind = std::floor(uniform(engine) * i); + int rng_ind = + framework::Generator::GetInstance()->is_init_py + ? std::floor(uniform(framework::Generator::GetInstance() + ->GetCPUEngine()) * + i) + : std::floor(uniform(engine) * i); + // int rng_ind = std::floor(uniform(engine) * i); if (rng_ind < fg_rois_per_this_image) { std::iter_swap(fg_inds.begin() + rng_ind, fg_inds.begin() + i); std::iter_swap(mapped_gt_inds.begin() + rng_ind, diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index 2a16e20c2a7235..109da856ba21b5 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/math/math_function.h" @@ -159,7 +160,14 @@ void ReservoirSampling(const int num, std::vector* inds, if (len > static_cast(num)) { if (use_random) { for (size_t i = num; i < len; ++i) { - int rng_ind = std::floor(uniform(engine) * i); + int rng_ind = + framework::Generator::GetInstance()->is_init_py + ? std::floor( + uniform( + framework::Generator::GetInstance()->GetCPUEngine()) * + i) + : std::floor(uniform(engine) * i); + // int rng_ind = std::floor(uniform(engine) * i); if (rng_ind < num) std::iter_swap(inds->begin() + rng_ind, inds->begin() + i); } diff --git a/paddle/fluid/operators/distributed/large_scale_kv.h b/paddle/fluid/operators/distributed/large_scale_kv.h index fb7a0691154de7..b32210d4fffab2 100644 --- a/paddle/fluid/operators/distributed/large_scale_kv.h +++ b/paddle/fluid/operators/distributed/large_scale_kv.h @@ -28,6 +28,7 @@ #include // NOLINT #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/framework/selected_rows.h" @@ -96,7 +97,12 @@ class UniformInitializer : public Initializer { dist_ = std::uniform_real_distribution(min_, max_); } - float GetValue() override { return dist_(random_engine_); } + float GetValue() override { + return framework::Generator::GetInstance()->is_init_py + ? dist_(framework::Generator::GetInstance()->GetCPUEngine()) + : dist_(random_engine_); + // return dist_(random_engine_); + } private: float min_; @@ -141,7 +147,12 @@ class GaussianInitializer : public Initializer { dist_ = std::normal_distribution(mean_, std_); } - float GetValue() override { return dist_(random_engine_); } + float GetValue() override { + return framework::Generator::GetInstance()->is_init_py + ? dist_(framework::Generator::GetInstance()->GetCPUEngine()) + : dist_(random_engine_); + return dist_(random_engine_); + } private: float std_; diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 676361289e888a..bce4c7ca19a603 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -55,6 +56,8 @@ class CPUDropoutKernel : public framework::OpKernel { return; } + bool init_generator_py = framework::Generator::GetInstance()->is_init_py; + // NOTE: fixed seed should only be used in unittest or for debug. // Guarantee to use random seed in training. std::random_device rnd; @@ -71,7 +74,11 @@ class CPUDropoutKernel : public framework::OpKernel { std::uniform_real_distribution dist(0, 1); for (size_t i = 0; i < size; ++i) { - if (dist(engine) < dropout_prob) { + float cur_random = + init_generator_py + ? dist(framework::Generator::GetInstance()->GetCPUEngine()) + : dist(engine); + if (cur_random < dropout_prob) { mask_data[i] = 0; y_data[i] = 0; } else { diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 898c063afdd43c..111d4ad4490074 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/fill_constant_op.h" #ifdef PADDLE_WITH_MKLDNN @@ -31,23 +32,30 @@ class CPUGaussianRandomKernel : public framework::OpKernel { float mean = context.Attr("mean"); float std = context.Attr("std"); auto* tensor = context.Output("Out"); - unsigned int seed = static_cast(context.Attr("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); std::normal_distribution dist(mean, std); - const std::string op_type = "gaussian_random"; auto shape = GetShape(context, op_type); tensor->Resize(shape); int64_t size = tensor->numel(); T* data = tensor->mutable_data(context.GetPlace()); - for (int64_t i = 0; i < size; ++i) { - data[i] = dist(engine); + if (framework::Generator::GetInstance()->is_init_py) { + std::mt19937_64& gen_engine = + framework::Generator::GetInstance()->GetCPUEngine(); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(gen_engine); + } + } else { + unsigned int seed = static_cast(context.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } } } }; diff --git a/paddle/fluid/operators/math/sampler.cc b/paddle/fluid/operators/math/sampler.cc index 238d9f2905058d..86feaa72d5fa69 100644 --- a/paddle/fluid/operators/math/sampler.cc +++ b/paddle/fluid/operators/math/sampler.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/generator.h" namespace paddle { namespace operators { @@ -31,7 +32,12 @@ UniformSampler::UniformSampler(int64_t range, unsigned int seed) dist_ = std::make_shared>(0, range); } -int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); } +int64_t UniformSampler::Sample() const { + return framework::Generator::GetInstance()->is_init_py + ? (*dist_)(framework::Generator::GetInstance()->GetCPUEngine()) + : (*dist_)(*random_engine_); + // return (*dist_)(*random_engine_); +} float UniformSampler::Probability(int64_t value) const { return inv_range_; } @@ -46,8 +52,11 @@ int64_t LogUniformSampler::Sample() const { // inverse_transform_sampling method // More details: // https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ - const int64_t value = - static_cast(exp((*dist_)(*random_engine_) * log_range_)) - 1; + auto cur_random = + framework::Generator::GetInstance()->is_init_py + ? (*dist_)(framework::Generator::GetInstance()->GetCPUEngine()) + : (*dist_)(*random_engine_); + const int64_t value = static_cast(exp(cur_random * log_range_)) - 1; // Mathematically, value should be <= range_, but might not be due to some // floating point roundoff, so we mod by range_. return value % range_; @@ -75,8 +84,14 @@ CustomSampler::CustomSampler(int64_t range, const float *probabilities, } int64_t CustomSampler::Sample() const { - auto index = (*int_dist_)(*random_engine_); - auto p = (*real_dist_)(*random_engine_); + auto index = + framework::Generator::GetInstance()->is_init_py + ? (*int_dist_)(framework::Generator::GetInstance()->GetCPUEngine()) + : (*int_dist_)(*random_engine_); + auto p = + framework::Generator::GetInstance()->is_init_py + ? (*real_dist_)(framework::Generator::GetInstance()->GetCPUEngine()) + : (*real_dist_)(*random_engine_); if (p > alias_probs_[index]) { int alias = alias_[index]; diff --git a/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc index 37b6e3bb803a2b..58902831356d22 100644 --- a/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc @@ -28,21 +28,29 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel { float std = context.Attr("std"); auto* tensor = context.Output("Out"); - unsigned int seed = static_cast(context.Attr("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); - std::normal_distribution dist(mean, std); - const std::string op_type = "gaussian_random"; auto shape = GetShape(context, op_type); tensor->Resize(shape); T* data = tensor->mutable_data(context.GetPlace()); int64_t size = tensor->numel(); - for (int64_t i = 0; i < size; ++i) { - data[i] = dist(engine); + + if (framework::Generator::GetInstance()->is_init_py) { + std::mt19937_64& gen_engine = + framework::Generator::GetInstance()->GetCPUEngine(); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(gen_engine); + } + } else { + unsigned int seed = static_cast(context.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::normal_distribution dist(mean, std); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } } tensor->set_layout(DataLayout::kMKLDNN); diff --git a/paddle/fluid/operators/randint_op.cc b/paddle/fluid/operators/randint_op.cc index 11ce738e001517..cd309f1810a380 100644 --- a/paddle/fluid/operators/randint_op.cc +++ b/paddle/fluid/operators/randint_op.cc @@ -15,6 +15,7 @@ #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/uniform_random_op.h" @@ -43,15 +44,23 @@ class CPURandintKernel : public framework::OpKernel { T* data = out->mutable_data(ctx.GetPlace()); int64_t size = out->numel(); - unsigned int seed = static_cast(ctx.Attr("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); std::uniform_int_distribution dist(ctx.Attr("low"), ctx.Attr("high") - 1); - for (int64_t i = 0; i < size; ++i) data[i] = dist(engine); + + if (framework::Generator::GetInstance()->is_init_py) { + std::mt19937_64& gen_engine = + framework::Generator::GetInstance()->GetCPUEngine(); + for (int64_t i = 0; i < size; ++i) data[i] = dist(gen_engine); + } else { + unsigned int seed = static_cast(ctx.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + + for (int64_t i = 0; i < size; ++i) data[i] = dist(engine); + } } }; diff --git a/paddle/fluid/operators/randperm_op.h b/paddle/fluid/operators/randperm_op.h index 64ef1c771423f2..0eb028ad806848 100644 --- a/paddle/fluid/operators/randperm_op.h +++ b/paddle/fluid/operators/randperm_op.h @@ -19,6 +19,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/place.h" @@ -31,11 +32,17 @@ static inline void random_permate(T* data_ptr, int num, unsigned int seed) { for (int i = 0; i < num; ++i) { data_ptr[i] = static_cast(i); } - if (seed == 0) { - seed = std::random_device()(); + if (framework::Generator::GetInstance()->is_init_py) { + std::shuffle(data_ptr, data_ptr + num, + framework::Generator::GetInstance()->GetCPUEngine()); + + } else { + if (seed == 0) { + seed = std::random_device()(); + } + std::srand(seed); + std::random_shuffle(data_ptr, data_ptr + num); } - std::srand(seed); - std::random_shuffle(data_ptr, data_ptr + num); } template @@ -51,6 +58,7 @@ class RandpermKernel : public framework::OpKernel { if (platform::is_cpu_place(ctx.GetPlace())) { T* out_data = out_tensor->mutable_data(platform::CPUPlace()); random_permate(out_data, n, seed); + } else { framework::Tensor tmp_tensor; tmp_tensor.Resize(framework::make_ddim({n})); diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 5ec32c98f7f84a..a09220b1ccd136 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -21,6 +21,7 @@ #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -61,7 +62,9 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ids(batch_size); for (int i = 0; i < batch_size; ++i) { - T r = dist(engine); + T r = framework::Generator::GetInstance()->is_init_py + ? dist(framework::Generator::GetInstance()->GetCPUEngine()) + : dist(engine); int idx = width - 1; for (int j = 0; j < width; ++j) { if ((r -= ins_vector[i * width + j]) < 0) { diff --git a/paddle/fluid/operators/truncated_gaussian_random_op.cc b/paddle/fluid/operators/truncated_gaussian_random_op.cc index 9e158abba747d1..3aa9ff544af639 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { @@ -161,18 +162,27 @@ class CPUTruncatedGaussianRandomKernel : public framework::OpKernel { auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); - unsigned int seed = static_cast(context.Attr("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); std::uniform_real_distribution dist(std::numeric_limits::min(), 1.0); TruncatedNormal truncated_normal(mean, std); int64_t size = tensor->numel(); - for (int64_t i = 0; i < size; ++i) { - data[i] = truncated_normal(dist(engine)); + + if (framework::Generator::GetInstance()->is_init_py) { + std::mt19937_64& gen_engine = + framework::Generator::GetInstance()->GetCPUEngine(); + for (int64_t i = 0; i < size; ++i) { + data[i] = truncated_normal(dist(gen_engine)); + } + } else { + unsigned int seed = static_cast(context.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + for (int64_t i = 0; i < size; ++i) { + data[i] = truncated_normal(dist(engine)); + } } } }; diff --git a/python/paddle/fluid/tests/unittests/test_random_seed.py b/python/paddle/fluid/tests/unittests/test_random_seed.py index 31120a73042c98..2933abe46c1b87 100644 --- a/python/paddle/fluid/tests/unittests/test_random_seed.py +++ b/python/paddle/fluid/tests/unittests/test_random_seed.py @@ -92,6 +92,118 @@ def test_generator_uniform_random_static(self): self.assertTrue(np.allclose(out1_res2, out2_res2)) self.assertTrue(not np.allclose(out1_res2, out1_res1)) + def test_gen_dropout_dygraph(self): + gen = generator.Generator() + + fluid.enable_dygraph() + + gen.manual_seed(111111111) + st = gen.get_state() + # x = np.arange(1,101).reshape(2,50).astype("float32") + x = fluid.layers.uniform_random( + [2, 10], dtype="float32", min=0.0, max=1.0) + y = fluid.layers.dropout(x, 0.5) + gen.manual_seed(111111111) + #gen.set_state(st) + x1 = fluid.layers.uniform_random( + [2, 10], dtype="float32", min=0.0, max=1.0) + y1 = fluid.layers.dropout(x1, 0.5) + y_np = y.numpy() + y1_np = y1.numpy() + #print(y_np) + #print(y1_np) + if not core.is_compiled_with_cuda(): + print(">>>>>>> dropout dygraph >>>>>>>") + self.assertTrue(np.allclose(y_np, y1_np)) + + def test_gen_dropout_static(self): + fluid.disable_dygraph() + + gen = generator.Generator() + gen.manual_seed(123123143) + + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + x_1 = fluid.layers.uniform_random(shape=[2, 10]) + y_1 = fluid.layers.dropout(x_1, 0.5) + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + out1 = exe.run(train_program, feed={}, fetch_list=[y_1]) + #gen.set_state(cur_state) + gen.manual_seed(123123143) + out2 = exe.run(train_program, feed={}, fetch_list=[y_1]) + out1_np = np.array(out1[0]) + out2_np = np.array(out2[0]) + # print(out1_np) + # print(out2_np) + if not core.is_compiled_with_cuda(): + print(">>>>>>> dropout static >>>>>>>") + self.assertTrue(np.allclose(out1_np, out2_np)) + + def test_generator_gaussian_random_dygraph(self): + """Test Generator seed.""" + gen = generator.Generator() + + fluid.enable_dygraph() + + gen.manual_seed(12312321111) + x = fluid.layers.gaussian_random([10], dtype="float32") + st1 = gen.get_state() + x1 = fluid.layers.gaussian_random([10], dtype="float32") + gen.set_state(st1) + x2 = fluid.layers.gaussian_random([10], dtype="float32") + gen.manual_seed(12312321111) + x3 = fluid.layers.gaussian_random([10], dtype="float32") + x_np = x.numpy() + x1_np = x1.numpy() + x2_np = x2.numpy() + x3_np = x3.numpy() + + if not core.is_compiled_with_cuda(): + print(">>>>>>> gaussian random dygraph >>>>>>>") + self.assertTrue(np.allclose(x1_np, x2_np)) + self.assertTrue(np.allclose(x_np, x3_np)) + + def test_generator_gaussian_random_static(self): + + fluid.disable_dygraph() + + gen = generator.Generator() + gen.manual_seed(123123143) + + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + result_1 = fluid.layers.gaussian_random(shape=[3, 4]) + result_2 = fluid.layers.gaussian_random(shape=[3, 4]) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + out1 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + #gen.set_state(cur_state) + gen.manual_seed(123123143) + out2 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + + out1_res1 = np.array(out1[0]) + out1_res2 = np.array(out1[1]) + out2_res1 = np.array(out2[0]) + out2_res2 = np.array(out2[1]) + + if not core.is_compiled_with_cuda(): + print(">>>>>>> gaussian random static >>>>>>>") + self.assertTrue(np.allclose(out1_res1, out2_res1)) + self.assertTrue(np.allclose(out1_res2, out2_res2)) + self.assertTrue(not np.allclose(out1_res2, out1_res1)) + def test_generator_randint_dygraph(self): """Test Generator seed.""" gen = generator.Generator() @@ -99,21 +211,253 @@ def test_generator_randint_dygraph(self): fluid.enable_dygraph() gen.manual_seed(12312321111) - x = paddle.randint(low=1) + x = paddle.randint(low=10, shape=[10], dtype="int32") st1 = gen.get_state() - x1 = paddle.randint(low=1) + x1 = paddle.randint(low=10, shape=[10], dtype="int32") gen.set_state(st1) - x2 = paddle.randint(low=1) + x2 = paddle.randint(low=10, shape=[10], dtype="int32") gen.manual_seed(12312321111) - x3 = paddle.randint(low=1) + x3 = paddle.randint(low=10, shape=[10], dtype="int32") x_np = x.numpy() x1_np = x1.numpy() x2_np = x2.numpy() x3_np = x3.numpy() + if not core.is_compiled_with_cuda(): + print(">>>>>>> randint dygraph >>>>>>>") self.assertTrue(np.allclose(x1_np, x2_np)) self.assertTrue(np.allclose(x_np, x3_np)) + def test_generator_ranint_static(self): + + fluid.disable_dygraph() + + gen = generator.Generator() + gen.manual_seed(123123143) + + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + result_1 = paddle.randint(low=10, shape=[3, 4]) + result_2 = paddle.randint(low=10, shape=[3, 4]) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + out1 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + #gen.set_state(cur_state) + gen.manual_seed(123123143) + out2 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + + out1_res1 = np.array(out1[0]) + out1_res2 = np.array(out1[1]) + out2_res1 = np.array(out2[0]) + out2_res2 = np.array(out2[1]) + + if not core.is_compiled_with_cuda(): + print(">>>>>>> randint static >>>>>>>") + self.assertTrue(np.allclose(out1_res1, out2_res1)) + self.assertTrue(np.allclose(out1_res2, out2_res2)) + self.assertTrue(not np.allclose(out1_res2, out1_res1)) + + def test_generator_randperm_dygraph(self): + """Test Generator seed.""" + gen = generator.Generator() + + fluid.enable_dygraph() + + gen.manual_seed(12312321111) + x = paddle.randperm(10) + st1 = gen.get_state() + x1 = paddle.randperm(10) + gen.set_state(st1) + x2 = paddle.randperm(10) + gen.manual_seed(12312321111) + x3 = paddle.randperm(10) + x_np = x.numpy() + x1_np = x1.numpy() + x2_np = x2.numpy() + x3_np = x3.numpy() + + # print("## {}".format(x1_np)) + # print("## {}".format(x2_np)) + + if not core.is_compiled_with_cuda(): + print(">>>>>>> randperm dygraph >>>>>>>") + self.assertTrue(np.allclose(x1_np, x2_np)) + self.assertTrue(np.allclose(x_np, x3_np)) + + def test_generator_randperm_static(self): + + fluid.disable_dygraph() + + gen = generator.Generator() + gen.manual_seed(123123143) + + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + result_1 = paddle.randperm(10) + result_2 = paddle.randperm(10) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + out1 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + #gen.set_state(cur_state) + gen.manual_seed(123123143) + out2 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + + out1_res1 = np.array(out1[0]) + out1_res2 = np.array(out1[1]) + out2_res1 = np.array(out2[0]) + out2_res2 = np.array(out2[1]) + + if not core.is_compiled_with_cuda(): + print(">>>>>>> randperm static >>>>>>>") + self.assertTrue(np.allclose(out1_res1, out2_res1)) + self.assertTrue(np.allclose(out1_res2, out2_res2)) + self.assertTrue(not np.allclose(out1_res2, out1_res1)) + + def test_generator_sampling_id_dygraph(self): + """Test Generator seed.""" + gen = generator.Generator() + + fluid.enable_dygraph() + + gen.manual_seed(12312321111) + x = fluid.layers.uniform_random( + [10, 10], dtype="float32", min=0.0, max=1.0) + y = fluid.layers.sampling_id(x) + st1 = gen.get_state() + x1 = fluid.layers.uniform_random( + [10, 10], dtype="float32", min=0.0, max=1.0) + y1 = fluid.layers.sampling_id(x) + gen.set_state(st1) + x2 = fluid.layers.uniform_random( + [10, 10], dtype="float32", min=0.0, max=1.0) + y2 = fluid.layers.sampling_id(x) + gen.manual_seed(12312321111) + x3 = fluid.layers.uniform_random( + [10, 10], dtype="float32", min=0.0, max=1.0) + y3 = fluid.layers.sampling_id(x) + + x_np = y.numpy() + x1_np = y1.numpy() + x2_np = y2.numpy() + x3_np = y3.numpy() + + print("## {}".format(x1_np)) + print("## {}".format(x2_np)) + + if not core.is_compiled_with_cuda(): + print(">>>>>>> sampling id dygraph >>>>>>>") + self.assertTrue(np.allclose(x1_np, x2_np)) + self.assertTrue(np.allclose(x_np, x3_np)) + + def test_generator_randperm_static(self): + + fluid.disable_dygraph() + + gen = generator.Generator() + gen.manual_seed(123123143) + + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + x = fluid.layers.uniform_random(shape=[10, 10]) + result_1 = fluid.layers.sampling_id(x) + result_2 = fluid.layers.sampling_id(x) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + out1 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + #gen.set_state(cur_state) + gen.manual_seed(123123143) + out2 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + + out1_res1 = np.array(out1[0]) + out1_res2 = np.array(out1[1]) + out2_res1 = np.array(out2[0]) + out2_res2 = np.array(out2[1]) + + if not core.is_compiled_with_cuda(): + print(">>>>>>> sampling id static >>>>>>>") + self.assertTrue(np.allclose(out1_res1, out2_res1)) + self.assertTrue(np.allclose(out1_res2, out2_res2)) + self.assertTrue(not np.allclose(out1_res2, out1_res1)) + + def test_gen_TruncatedNormal_initializer(self): + fluid.disable_dygraph() + + gen = generator.Generator() + gen.manual_seed(123123143) + cur_state = gen.get_state() + + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + x = fluid.layers.uniform_random(shape=[2, 10]) + result_1 = fluid.layers.fc( + input=x, + size=10, + param_attr=fluid.initializer.TruncatedNormal( + loc=0.0, scale=2.0)) + result_2 = fluid.layers.fc( + input=x, + size=10, + param_attr=fluid.initializer.TruncatedNormal( + loc=0.0, scale=2.0)) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + out1 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + #gen.set_state(cur_state) + + #gen.set_state(cur_state) + gen.manual_seed(123123143) + with fluid.program_guard(train_program, startup_program): + exe.run(startup_program) + out2 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + + out1_res1 = np.array(out1[0]) + out1_res2 = np.array(out1[1]) + out2_res1 = np.array(out2[0]) + out2_res2 = np.array(out2[1]) + + print(out1_res1) + print(out1_res2) + print(out2_res1) + print(out2_res2) + + if not core.is_compiled_with_cuda(): + print(">>>>>>> sampling id static >>>>>>>") + self.assertTrue(np.allclose(out1_res1, out2_res1)) + self.assertTrue(np.allclose(out1_res2, out2_res2)) + self.assertTrue(not np.allclose(out1_res2, out1_res1)) + if __name__ == "__main__": unittest.main() From e62450ac79c79d0eb4f7124958f23c70c72acfcc Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Thu, 20 Aug 2020 17:47:27 +0800 Subject: [PATCH 2/7] fix default init seed test=develop --- paddle/fluid/framework/generator.h | 6 ++++-- python/paddle/fluid/generator.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/generator.h b/paddle/fluid/framework/generator.h index 17870782ba72a3..38c21ba3a553ac 100644 --- a/paddle/fluid/framework/generator.h +++ b/paddle/fluid/framework/generator.h @@ -37,8 +37,10 @@ struct Generator { Generator() { GeneratorState default_gen_state_cpu; default_gen_state_cpu.device = -1; - default_gen_state_cpu.current_seed = 34342423252; - std::seed_seq seq({34342423252}); + std::random_device rnd; + uint64_t init_seed = (((uint64_t)rnd()) << 32) + rnd(); + default_gen_state_cpu.current_seed = init_seed; + std::seed_seq seq({init_seed}); default_gen_state_cpu.cpu_engine = std::mt19937_64(seq); this->state_ = std::make_shared(default_gen_state_cpu); } diff --git a/python/paddle/fluid/generator.py b/python/paddle/fluid/generator.py index 24262e3f5666ab..e11b2e484dce1d 100644 --- a/python/paddle/fluid/generator.py +++ b/python/paddle/fluid/generator.py @@ -29,7 +29,7 @@ def __init__(self, device="CPU"): seed_in = default_rng_seed_val if self.device == "CPU": self.generator = core.Generator() - self.generator.manual_seed(seed_in) + # self.generator.manual_seed(seed_in) else: raise ValueError( "generator class with device %s does not exist, currently only support generator with device 'CPU' " From ef81b1ef39172c40b2ccb28bd609733c24fb588a Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Thu, 20 Aug 2020 19:12:50 +0800 Subject: [PATCH 3/7] fix gaussian random mkldnn test=develop --- paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc index 58902831356d22..d0ecca78ae8b27 100644 --- a/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/operators/fill_constant_op.h" #include "paddle/fluid/operators/mean_op.h" @@ -33,6 +34,7 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel { tensor->Resize(shape); T* data = tensor->mutable_data(context.GetPlace()); int64_t size = tensor->numel(); + std::normal_distribution dist(mean, std); if (framework::Generator::GetInstance()->is_init_py) { std::mt19937_64& gen_engine = @@ -47,7 +49,6 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel { seed = std::random_device()(); } engine.seed(seed); - std::normal_distribution dist(mean, std); for (int64_t i = 0; i < size; ++i) { data[i] = dist(engine); } From 1992e7d2fbf181b1c6948061b449684d3c3f8043 Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Fri, 21 Aug 2020 14:46:41 +0800 Subject: [PATCH 4/7] fix mmap allocator build error test=develop --- paddle/fluid/memory/allocation/mmap_allocator.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/memory/allocation/mmap_allocator.cc b/paddle/fluid/memory/allocation/mmap_allocator.cc index 67e9b5c4a7345c..9a9648720ae7cc 100644 --- a/paddle/fluid/memory/allocation/mmap_allocator.cc +++ b/paddle/fluid/memory/allocation/mmap_allocator.cc @@ -27,8 +27,6 @@ #include #include -#include "paddle/fluid/framework/generator.h" - namespace paddle { namespace memory { namespace allocation { @@ -62,10 +60,13 @@ std::string GetIPCName() { handle += std::to_string(getpid()); #endif handle += "_"; + handle += std::to_string(rd()); + /* handle += framework::Generator::GetInstance()->is_init_py ? std::to_string(framework::Generator::GetInstance()->Random64()) : std::to_string(rd()); + */ return std::move(handle); } From 650858434cb1330d2a3231e972542685dee88d75 Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Fri, 21 Aug 2020 17:28:25 +0800 Subject: [PATCH 5/7] fix mac build communicator test=develop --- paddle/fluid/operators/distributed/CMakeLists.txt | 2 +- paddle/fluid/operators/distributed/large_scale_kv.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index cff3993a068cee..a033611f478f9e 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -61,7 +61,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) -cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool parameter_send parameter_recv) +cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool parameter_send parameter_recv generator) cc_test(communicator_test SRCS communicator_test.cc DEPS communicator) if(WITH_GPU) cc_test(collective_server_test SRCS collective_server_test.cc diff --git a/paddle/fluid/operators/distributed/large_scale_kv.h b/paddle/fluid/operators/distributed/large_scale_kv.h index b32210d4fffab2..0d7032e286caab 100644 --- a/paddle/fluid/operators/distributed/large_scale_kv.h +++ b/paddle/fluid/operators/distributed/large_scale_kv.h @@ -151,7 +151,7 @@ class GaussianInitializer : public Initializer { return framework::Generator::GetInstance()->is_init_py ? dist_(framework::Generator::GetInstance()->GetCPUEngine()) : dist_(random_engine_); - return dist_(random_engine_); + // return dist_(random_engine_); } private: From 8bfc90b335cafed51b9e5a6259d43439ec1b53d5 Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Sat, 22 Aug 2020 16:09:11 +0800 Subject: [PATCH 6/7] fix ut fail test=develop --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/randint_op.cc | 4 ++- .../test_generate_proposal_labels_op.py | 28 ++++++++++++++++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 48d1ec9461a880..6e8ff52ed4a884 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -123,7 +123,7 @@ cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_t cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) -nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) +nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor generator) if (WITH_GPU) nv_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3) else() diff --git a/paddle/fluid/operators/randint_op.cc b/paddle/fluid/operators/randint_op.cc index cd309f1810a380..662fe3bcb3b3b2 100644 --- a/paddle/fluid/operators/randint_op.cc +++ b/paddle/fluid/operators/randint_op.cc @@ -59,7 +59,9 @@ class CPURandintKernel : public framework::OpKernel { } engine.seed(seed); - for (int64_t i = 0; i < size; ++i) data[i] = dist(engine); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } } } }; diff --git a/python/paddle/fluid/tests/unittests/test_generate_proposal_labels_op.py b/python/paddle/fluid/tests/unittests/test_generate_proposal_labels_op.py index a5d36203b0ad56..5054256ca72477 100644 --- a/python/paddle/fluid/tests/unittests/test_generate_proposal_labels_op.py +++ b/python/paddle/fluid/tests/unittests/test_generate_proposal_labels_op.py @@ -224,7 +224,8 @@ def _expand_bbox_targets(bbox_targets_input, class_nums, is_cls_agnostic): class TestGenerateProposalLabelsOp(OpTest): def set_data(self): - self.use_random = False + #self.use_random = False + self.init_use_random() self.init_test_cascade() self.init_test_params() self.init_test_input() @@ -267,6 +268,9 @@ def setUp(self): def init_test_cascade(self, ): self.is_cascade_rcnn = False + def init_use_random(self): + self.use_random = False + def init_test_params(self): self.batch_size_per_im = 512 self.fg_fraction = 0.25 @@ -329,6 +333,28 @@ def init_test_cascade(self): self.is_cascade_rcnn = True +class TestUseRandom(TestGenerateProposalLabelsOp): + def init_use_random(self): + self.use_random = True + self.is_cascade_rcnn = False + + def test_check_output(self): + self.check_output_customized(self.verify_out) + + def verify_out(self, outs): + print("skip") + + def init_test_params(self): + self.batch_size_per_im = 512 + self.fg_fraction = 0.025 + self.fg_thresh = 0.5 + self.bg_thresh_hi = 0.5 + self.bg_thresh_lo = 0.0 + self.bbox_reg_weights = [0.1, 0.1, 0.2, 0.2] + self.is_cls_agnostic = False + self.class_nums = 2 if self.is_cls_agnostic else 81 + + class TestClsAgnostic(TestCascade): def init_test_params(self): self.batch_size_per_im = 512 From 0133f230977faa76efa71c2abbf3e67d87e45827 Mon Sep 17 00:00:00 2001 From: yaoxuefeng6 Date: Mon, 24 Aug 2020 04:45:48 +0800 Subject: [PATCH 7/7] fix ut test=develop --- paddle/fluid/framework/generator.h | 6 ++---- paddle/fluid/memory/allocation/mmap_allocator.cc | 6 ------ .../operators/detection/generate_proposal_labels_op.cc | 9 +-------- .../fluid/operators/detection/rpn_target_assign_op.cc | 10 +--------- 4 files changed, 4 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/framework/generator.h b/paddle/fluid/framework/generator.h index 38c21ba3a553ac..17870782ba72a3 100644 --- a/paddle/fluid/framework/generator.h +++ b/paddle/fluid/framework/generator.h @@ -37,10 +37,8 @@ struct Generator { Generator() { GeneratorState default_gen_state_cpu; default_gen_state_cpu.device = -1; - std::random_device rnd; - uint64_t init_seed = (((uint64_t)rnd()) << 32) + rnd(); - default_gen_state_cpu.current_seed = init_seed; - std::seed_seq seq({init_seed}); + default_gen_state_cpu.current_seed = 34342423252; + std::seed_seq seq({34342423252}); default_gen_state_cpu.cpu_engine = std::mt19937_64(seq); this->state_ = std::make_shared(default_gen_state_cpu); } diff --git a/paddle/fluid/memory/allocation/mmap_allocator.cc b/paddle/fluid/memory/allocation/mmap_allocator.cc index 9a9648720ae7cc..0ef084bafd0c9f 100644 --- a/paddle/fluid/memory/allocation/mmap_allocator.cc +++ b/paddle/fluid/memory/allocation/mmap_allocator.cc @@ -61,12 +61,6 @@ std::string GetIPCName() { #endif handle += "_"; handle += std::to_string(rd()); - /* - handle += - framework::Generator::GetInstance()->is_init_py - ? std::to_string(framework::Generator::GetInstance()->Random64()) - : std::to_string(rd()); - */ return std::move(handle); } diff --git a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc index dd9e9fd2d0ca10..884aa1f6f4e996 100644 --- a/paddle/fluid/operators/detection/generate_proposal_labels_op.cc +++ b/paddle/fluid/operators/detection/generate_proposal_labels_op.cc @@ -13,7 +13,6 @@ limitations under the License. */ #include #include #include -#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/gather.h" @@ -191,13 +190,7 @@ std::vector> SampleFgBgGt( const int64_t fg_size = static_cast(fg_inds.size()); if (fg_size > fg_rois_per_this_image) { for (int64_t i = fg_rois_per_this_image; i < fg_size; ++i) { - int rng_ind = - framework::Generator::GetInstance()->is_init_py - ? std::floor(uniform(framework::Generator::GetInstance() - ->GetCPUEngine()) * - i) - : std::floor(uniform(engine) * i); - // int rng_ind = std::floor(uniform(engine) * i); + int rng_ind = std::floor(uniform(engine) * i); if (rng_ind < fg_rois_per_this_image) { std::iter_swap(fg_inds.begin() + rng_ind, fg_inds.begin() + i); std::iter_swap(mapped_gt_inds.begin() + rng_ind, diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc index 109da856ba21b5..2a16e20c2a7235 100644 --- a/paddle/fluid/operators/detection/rpn_target_assign_op.cc +++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/math/math_function.h" @@ -160,14 +159,7 @@ void ReservoirSampling(const int num, std::vector* inds, if (len > static_cast(num)) { if (use_random) { for (size_t i = num; i < len; ++i) { - int rng_ind = - framework::Generator::GetInstance()->is_init_py - ? std::floor( - uniform( - framework::Generator::GetInstance()->GetCPUEngine()) * - i) - : std::floor(uniform(engine) * i); - // int rng_ind = std::floor(uniform(engine) * i); + int rng_ind = std::floor(uniform(engine) * i); if (rng_ind < num) std::iter_swap(inds->begin() + rng_ind, inds->begin() + i); }