Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.

Commit b35829e

Browse files
author
Yang Yang(Tony)
authored
remove vardesc in variable (#26)
1 parent e424cab commit b35829e

6 files changed

Lines changed: 63 additions & 127 deletions

File tree

src/function.h

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,41 @@ class Fill {
4040
: initializer_(initializer), attrs_(attrs) {}
4141

4242
void operator()(VariableHandle var) {
43-
get_global_tape().AddOp(initializer_, {}, {{"Out", {var}}}, attrs_);
43+
if (initializer_ == "fill_constant") {
44+
// fill_constant is not OperatorWithKernel, so we can't add it to the tape
45+
framework::OpDesc op_desc =
46+
CreateOpDesc(initializer_, {}, {{"Out", {var}}}, attrs_);
47+
ScopeWrapper scope({}, {{"Out", {var}}});
48+
framework::OpRegistry::CreateOp(op_desc)->Run(scope,
49+
platform::CPUPlace());
50+
} else {
51+
get_global_tape().AddOp(initializer_, {}, {{"Out", {var}}}, attrs_);
52+
}
4453
}
4554

4655
private:
4756
const std::string initializer_;
4857
const framework::AttributeMap attrs_;
4958
};
5059

60+
void init_params(VariableHandle v,
61+
const std::string &initializer,
62+
const framework::AttributeMap &attrs) {
63+
if (initializer == "fill_constant") {
64+
// fill_constant is not OperatorWithKernel, so we can't add it to the tape
65+
framework::OpDesc op_desc =
66+
CreateOpDesc(initializer, {}, {{"Out", {v}}}, attrs);
67+
ScopeWrapper scope({}, {{"Out", {v}}});
68+
framework::OpRegistry::CreateOp(op_desc)->Run(scope, platform::CPUPlace());
69+
} else {
70+
Tape init_tape;
71+
init_tape.AddOp(initializer, {}, {{"Out", {v}}}, attrs);
72+
init_tape.Forward();
73+
}
74+
}
75+
76+
// TODO(tonyyang-svail): change this to a function
77+
// https://github.com/PaddlePaddle/tape/issues/23
5178
class Mean {
5279
public:
5380
VariableHandle operator()(VariableHandle var) {
@@ -82,8 +109,6 @@ class Linear {
82109
: w_(new Variable("LinearWeight")),
83110
b_(new Variable("LinearBias")),
84111
act_(act) {
85-
Tape init_tape;
86-
87112
// Use Xavier to initialize Weight
88113
float limit = sqrt(6.0 / static_cast<float>(in_dim + out_dim));
89114
framework::AttributeMap attrs;
@@ -92,15 +117,13 @@ class Linear {
92117
attrs["min"] = -limit;
93118
attrs["max"] = limit;
94119
attrs["seed"] = RandomSeed::GetRandomSeed();
95-
init_tape.AddOp("uniform_random", {}, {{"Out", {w_}}}, attrs);
120+
init_params(w_, "uniform_random", attrs);
96121

97122
// Use fill zero to initialize Bias
98123
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
99124
attrs["shape"] = std::vector<int>{out_dim};
100125
attrs["value"] = 0.0f;
101-
init_tape.AddOp("fill_constant", {}, {{"Out", {b_}}}, attrs);
102-
103-
init_tape.Forward();
126+
init_params(b_, "fill_constant", attrs);
104127
}
105128

106129
VariableHandle operator()(VariableHandle input) {
@@ -134,8 +157,6 @@ class Convolution2D {
134157
: w_(new Variable("ConvolutionWeight")),
135158
b_(new Variable("ConvolutionBias")),
136159
act_(act) {
137-
Tape init_tape;
138-
139160
// Use Xavier to initialize Weight
140161
float fan_in = c_in * f * f, fan_out = c_out * f * f;
141162
float limit = sqrt(6.0 / (fan_in + fan_out));
@@ -145,15 +166,13 @@ class Convolution2D {
145166
attrs["min"] = -limit;
146167
attrs["max"] = limit;
147168
attrs["seed"] = RandomSeed::GetRandomSeed();
148-
init_tape.AddOp("uniform_random", {}, {{"Out", {w_}}}, attrs);
169+
init_params(w_, "uniform_random", attrs);
149170

150171
// Use fill zero to initialize Bias
151172
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
152173
attrs["shape"] = std::vector<int>{c_out};
153174
attrs["value"] = 0.0f;
154-
init_tape.AddOp("fill_constant", {}, {{"Out", {b_}}}, attrs);
155-
156-
init_tape.Forward();
175+
init_params(b_, "fill_constant", attrs);
157176
}
158177

159178
VariableHandle operator()(VariableHandle input) {
@@ -190,16 +209,12 @@ class Convolution2D {
190209
class SGD {
191210
public:
192211
explicit SGD(float learning_rate) : learning_rate_(new Variable("sgd")) {
193-
Tape init_tape;
194-
195212
std::string initializer = "fill_constant";
196213
framework::AttributeMap attrs;
197214
attrs["dtype"] = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
198215
attrs["shape"] = std::vector<int>{1};
199216
attrs["value"] = learning_rate;
200-
init_tape.AddOp(initializer, {}, {{"Out", {learning_rate_}}}, attrs);
201-
202-
init_tape.Forward();
217+
init_params(learning_rate_, initializer, attrs);
203218
}
204219

205220
void Update(VariableHandle input) {
@@ -224,7 +239,6 @@ VariableHandle CreateRecordioFileReader(std::string filename,
224239
std::vector<int> ranks,
225240
std::vector<int> lod_levels) {
226241
VariableHandle reader(new paddle::tape::Variable("reader"));
227-
reader->MutableDesc()->SetType(paddle::framework::proto::VarType::READER);
228242

229243
framework::OpDesc op_desc = CreateOpDesc("create_recordio_file_reader",
230244
{},
@@ -240,10 +254,7 @@ VariableHandle CreateRecordioFileReader(std::string filename,
240254
}
241255

242256
void ReadNext(VariableHandle reader, VariableHandle data_holder) {
243-
PADDLE_ENFORCE_EQ(reader->Desc().GetType(),
244-
paddle::framework::proto::VarType::READER);
245-
PADDLE_ENFORCE_EQ(data_holder->Desc().GetType(),
246-
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY);
257+
PADDLE_ENFORCE(reader->Var().IsType<framework::ReaderHolder>());
247258

248259
reader->GetMutable<paddle::framework::ReaderHolder>()->ReadNext(
249260
data_holder->GetMutable<paddle::framework::LoDTensorArray>());

src/tape.cc

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ std::string to_string(const std::string &type,
5858
ss << type << " ";
5959
for (auto &param_name : in_vars) {
6060
for (auto &var : param_name.second) {
61-
ss << param_name.first << ":(" << var->Desc() << ") ";
61+
ss << param_name.first << ":(" << var << ") ";
6262
}
6363
}
6464
for (auto &param_name : out_vars) {
6565
for (auto &var : param_name.second) {
66-
ss << param_name.first << ":(" << var->Desc() << ") ";
66+
ss << param_name.first << ":(" << var << ") ";
6767
}
6868
}
6969
return ss.str();
@@ -85,41 +85,33 @@ framework::OpDesc CreateOpDesc(const std::string &type,
8585
outputs[param_name.first].emplace_back(var->Name());
8686
}
8787
}
88-
return framework::OpDesc(type, inputs, outputs, attrs);
88+
framework::OpDesc op_desc(type, inputs, outputs, attrs);
89+
op_desc.CheckAttrs();
90+
return op_desc;
8991
}
9092

9193
void InferShapeAndVarType(const std::string &type,
9294
const VariableHandleMap &in_vars,
9395
VariableHandleMap *out_vars,
9496
const framework::AttributeMap &attrs) {
95-
framework::OpDesc op_desc = CreateOpDesc(type, in_vars, *out_vars, attrs);
96-
op_desc.CheckAttrs();
97-
98-
// Create a temporary block for compile-time
99-
framework::ProgramDesc program_desc;
100-
framework::BlockDesc *block_desc = program_desc.MutableBlock(0);
101-
PADDLE_ENFORCE(block_desc);
102-
103-
for (auto &param_name : in_vars) {
104-
for (auto &var : param_name.second) {
105-
*block_desc->Var(var->Name())->Proto() = *var->MutableDesc()->Proto();
106-
}
107-
}
108-
for (auto &param_name : *out_vars) {
109-
for (auto &var : param_name.second) {
110-
*block_desc->Var(var->Name())->Proto() = *var->MutableDesc()->Proto();
97+
// Tape only supports LoDTensor
98+
for (auto &param2var : *out_vars) {
99+
for (auto &var : param2var.second) {
100+
var->GetMutable<framework::LoDTensor>();
111101
}
112102
}
113103

114-
LOG(INFO) << "- " << to_string(type, in_vars, *out_vars, attrs);
115-
op_desc.InferShape(*block_desc);
116-
op_desc.InferVarType(block_desc);
117-
for (auto &param_name : *out_vars) {
118-
for (auto &var : param_name.second) {
119-
*var->MutableDesc()->Proto() = *block_desc->Var(var->Name())->Proto();
120-
}
121-
}
122-
LOG(INFO) << "+ " << to_string(type, in_vars, *out_vars, attrs);
104+
framework::OpDesc op_desc = CreateOpDesc(type, in_vars, *out_vars, attrs);
105+
ScopeWrapper scope(in_vars, *out_vars);
106+
107+
// Tape only supports OperatorWithKernel
108+
auto op = framework::OpRegistry::CreateOp(op_desc);
109+
auto *op_with_kernel =
110+
dynamic_cast<framework::OperatorWithKernel *>(op.get());
111+
PADDLE_ENFORCE_NOT_NULL(op_with_kernel, "%s doesn't have kernel", type);
112+
paddle::framework::RuntimeInferShapeContext infer_shape_ctx(*op_with_kernel,
113+
scope);
114+
op_with_kernel->InferShape(&infer_shape_ctx);
123115
}
124116

125117
void Tape::AddOp(const std::string &type,
@@ -135,14 +127,6 @@ void Tape::Forward() {
135127
PADDLE_ENFORCE(!has_been_backwarded_);
136128
while (current_position_ < tape_.size()) {
137129
OpHandle &op = tape_[current_position_];
138-
139-
// Create Output Tensor, this is only necessary for OpWithKernel
140-
for (auto &param2var : op.outputs_) {
141-
for (auto &var : param2var.second) {
142-
var->InitializeVariable();
143-
}
144-
}
145-
146130
framework::OpDesc op_desc =
147131
CreateOpDesc(op.type_, op.inputs_, op.outputs_, op.attrs_);
148132
ScopeWrapper scope(op.inputs_, op.outputs_);
@@ -161,14 +145,9 @@ void Tape::Backward(VariableHandle target) {
161145
// TODO(tonyyang-svail): check output of last op is target
162146
backward_tape_.reset(new Tape());
163147

164-
framework::AttributeMap attrs;
165-
166148
// FIXME(tonyyang-svail): Need to infer_data_type
167-
attrs["dtype"] = framework::proto::VarType::Type::VarType_Type_FP32;
168-
attrs["shape"] = std::vector<int>{1};
169-
attrs["value"] = 1.0f;
170149
backward_tape_->AddOp(
171-
"fill_constant", {}, {{"Out", {target->Grad()}}}, attrs);
150+
"fill_ones_like", {{"X", {target}}}, {{"Out", {target->Grad()}}}, {});
172151

173152
for (auto it = tape_.rbegin(); it != tape_.rend(); ++it) {
174153
framework::OpDesc op_desc =

src/test_tape.cc

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,46 +31,12 @@ using paddle::tape::ReadNext;
3131

3232
TEST(Tape, TestReader) {
3333
VariableHandle data_label(new paddle::tape::Variable("data_label"));
34-
data_label->MutableDesc()->SetType(
35-
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY);
36-
3734
VariableHandle reader = CreateRecordioFileReader(
3835
"/tape/src/data/mnist.recordio", {32, 1, 28, 28, 32, 1}, {4, 2}, {0, 0});
39-
4036
ReadNext(reader, data_label);
4137
LOG(INFO) << *data_label;
4238
}
4339

44-
TEST(Tape, TestSoftmax) {
45-
std::string data_initializer = "uniform_random";
46-
paddle::framework::AttributeMap data_attrs;
47-
data_attrs["min"] = -1.0f;
48-
data_attrs["max"] = 1.0f;
49-
data_attrs["dtype"] =
50-
paddle::framework::proto::VarType::Type::VarType_Type_FP32;
51-
data_attrs["shape"] = std::vector<int>{10, 10};
52-
data_attrs["seed"] = 123;
53-
Fill data_filler(data_initializer, data_attrs);
54-
55-
std::string label_initializer = "fill_constant";
56-
paddle::framework::AttributeMap label_attrs;
57-
label_attrs["dtype"] =
58-
paddle::framework::proto::VarType::Type::VarType_Type_INT64;
59-
label_attrs["shape"] = std::vector<int>{10, 1};
60-
label_attrs["value"] = 1.0f;
61-
Fill label_filler(label_initializer, label_attrs);
62-
63-
VariableHandle input(new Variable("input"));
64-
data_filler(input);
65-
VariableHandle label(new Variable("input"));
66-
label_filler(label);
67-
68-
auto loss = cross_entropy(softmax(input), label);
69-
70-
LOG(INFO) << input->Value();
71-
LOG(INFO) << loss->Value();
72-
}
73-
7440
TEST(Tape, TestRelu) {
7541
std::string initializer = "uniform_random";
7642
paddle::framework::AttributeMap attrs;

src/variable.cc

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@ namespace tape {
2222

2323
std::ostream& operator<<(std::ostream& os, const Variable& var) {
2424
LOG(INFO) << "Printing " << var.Name();
25-
framework::proto::VarType::Type var_type = var.Desc().GetType();
26-
if (var_type == framework::proto::VarType::LOD_TENSOR) {
25+
if (var.Var().IsType<framework::LoDTensor>()) {
2726
os << var.Var().Get<framework::LoDTensor>();
28-
} else if (var_type = framework::proto::VarType::LOD_TENSOR_ARRAY) {
27+
} else if (var.Var().IsType<framework::LoDTensorArray>()) {
2928
framework::LoDTensorArray array =
3029
var.Var().Get<framework::LoDTensorArray>();
3130
for (size_t i = 0; i < array.size(); ++i) {
@@ -39,19 +38,6 @@ std::ostream& operator<<(std::ostream& os, const Variable& var) {
3938
return os;
4039
}
4140

42-
void Variable::InitializeVariable() {
43-
LOG(INFO) << "Initialzing " << desc_.Name() << " as " << desc_.GetType();
44-
framework::proto::VarType::Type var_type = desc_.GetType();
45-
if (var_type == framework::proto::VarType::LOD_TENSOR) {
46-
var_.GetMutable<framework::LoDTensor>();
47-
} else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
48-
var_.GetMutable<framework::SelectedRows>();
49-
} else {
50-
PADDLE_THROW("Variable type %d is not in [LOD_TENSOR, SELECTED_ROWS]",
51-
var_type);
52-
}
53-
}
54-
5541
const Variable& Variable::Value() {
5642
get_global_tape().Forward();
5743
return *this;

src/variable.h

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,17 @@ std::ostream& operator<<(std::ostream&, const Variable&);
3636
class Variable {
3737
public:
3838
explicit Variable(const std::string pre_fix)
39-
: desc_(pre_fix + std::to_string(count())) {}
39+
: name_(pre_fix + std::to_string(count())) {}
4040

4141
Variable(const std::string pre_fix, bool is_grad)
42-
: desc_(pre_fix + (is_grad ? framework::kGradVarSuffix
42+
: name_(pre_fix + (is_grad ? framework::kGradVarSuffix
4343
: std::to_string(count()))) {}
4444

4545
~Variable() { LOG(INFO) << "Deleting " << Name(); }
4646

47-
// Instantiate LoDTensor/SelectedRow
48-
void InitializeVariable();
49-
5047
VariableHandle Grad() {
5148
if (grad_.expired()) {
52-
VariableHandle new_grad(new Variable(desc_.Name(), true));
49+
VariableHandle new_grad(new Variable(name_, true));
5350
grad_ = new_grad;
5451
return new_grad;
5552
} else {
@@ -66,11 +63,8 @@ class Variable {
6663
// Evaluate a variable by running Forward() on the global tape
6764
const Variable& Value();
6865

69-
const framework::VarDesc& Desc() const { return desc_; }
70-
framework::VarDesc* MutableDesc() { return &desc_; }
71-
7266
// TODO(tonyyang-svail): No need to expose name
73-
std::string Name() const { return desc_.Name(); }
67+
std::string Name() const { return name_; }
7468

7569
const framework::Variable& Var() const { return var_; }
7670
framework::Variable* MutableVar() { return &var_; }
@@ -91,7 +85,7 @@ class Variable {
9185
return counter++;
9286
}
9387

94-
framework::VarDesc desc_;
88+
std::string name_;
9589
framework::Variable var_;
9690

9791
// Not own

0 commit comments

Comments
 (0)