Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/table/common_dense_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ void CommonDenseTable::create_initializer(const std::string& attr,
initializers_[name] = new FillConstantInitializer(slices);
} else if (slices[0] == "uniform_random") {
initializers_[name] = new UniformInitializer(slices);
} else if (slices[0] == "truncated_gaussian_random") {
initializers_[name] = new TruncatedGaussianInitializer(slices);
} else {
PADDLE_THROW(
platform::errors::InvalidArgument("%s can not be supported", name));
Expand Down
37 changes: 37 additions & 0 deletions paddle/fluid/distributed/table/depends/initializers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

#include <functional>
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>
#include "gflags/gflags.h"

#include "paddle/fluid/framework/generator.h"

#include "paddle/fluid/operators/truncated_gaussian_random_op.h"

namespace paddle {
namespace distributed {

Expand Down Expand Up @@ -108,6 +111,40 @@ class GaussianInitializer : public Initializer {
std::normal_distribution<float> dist_;
};

class TruncatedGaussianInitializer : public Initializer {
public:
explicit TruncatedGaussianInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]);

std::uniform_real_distribution<float> dist_(
std::numeric_limits<float>::min(), 1.0);
random_engine_ = framework::GetCPURandomEngine(seed_);
}

float GetValue() override {
paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_);
float value = truncated_normal(dist_(*random_engine_));
return value;
}

void GetValue(float *value, int numel) {
paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_);
for (int x = 0; x < numel; ++x) {
value[x] = truncated_normal(dist_(*random_engine_));
}
}

private:
float std_;
float mean_;

std::shared_ptr<std::mt19937_64> random_engine_;
std::uniform_real_distribution<float> dist_;
};

class FillConstantInitializer : public Initializer {
public:
explicit FillConstantInitializer(const std::vector<std::string> &attrs) {
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/distributed/table/depends/large_scale_kv.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ class ValueBlock {
} else if (slices[0] == "uniform_random") {
initializers_.emplace_back(
std::make_shared<UniformInitializer>(slices));
} else if (slices[0] == "truncated_gaussian_random") {
initializers_.emplace_back(
std::make_shared<TruncatedGaussianInitializer>(slices));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s can not be supported", attr));
Expand Down