Skip to content

Commit e6b93cd

Browse files
committed
refine code, reuse commom function
1 parent c0345be commit e6b93cd

File tree

1 file changed

+38
-25
lines changed

1 file changed

+38
-25
lines changed

paddle/fluid/pybind/imperative.cc

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -146,21 +146,33 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) {
146146
}
147147
}
148148

149-
static void InitTensorForVarBase(imperative::VarBase *self,
150-
const py::array &array,
151-
const platform::Place place,
152-
bool persistable = false,
153-
bool zero_copy = false, std::string name = "",
154-
int stop_gradient = -1) {
155-
if (name == "") {
156-
name =
157-
imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
158-
}
159-
VLOG(5) << "Init Tensor as: / name: " << name
160-
<< " / persistable: " << persistable << " / zero_copy: " << zero_copy
149+
// only initialize varbase, but not its tensor.
150+
static void InitVarBaseOnly(imperative::VarBase *self, const std::string &name,
151+
bool persistable = false, int stop_gradient = -1) {
152+
auto name_ = name == ""
153+
? imperative::GetCurrentTracer()->GenerateUniqueName(
154+
"generated_tensor")
155+
: name;
156+
157+
VLOG(5) << "Init Tensor as: / name: " << name_
158+
<< " / persistable: " << persistable
161159
<< " / stop_gradient: " << stop_gradient;
162-
new (self) imperative::VarBase(name);
160+
new (self) imperative::VarBase(name_);
161+
if (stop_gradient != -1) {
162+
self->SetOverridedStopGradient(stop_gradient);
163+
}
164+
self->SetPersistable(persistable);
165+
self->SetType(framework::proto::VarType::LOD_TENSOR);
166+
}
167+
168+
// initialize varbase and its tensor.
169+
static void InitVarBaseAndTensor(
170+
imperative::VarBase *self, const py::array &array,
171+
const platform::Place &place, const std::string &name,
172+
bool persistable = false, bool zero_copy = false, int stop_gradient = -1) {
173+
InitVarBaseOnly(self, name, persistable, stop_gradient);
163174
auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
175+
VLOG(4) << "zero_copy: " << zero_copy;
164176
if (platform::is_cpu_place(place)) {
165177
SetTensorFromPyArray<platform::CPUPlace>(
166178
tensor, array, BOOST_GET_CONST(platform::CPUPlace, place), zero_copy);
@@ -182,11 +194,6 @@ static void InitTensorForVarBase(imperative::VarBase *self,
182194
"Place should be one of "
183195
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace"));
184196
}
185-
if (stop_gradient != -1) {
186-
self->SetOverridedStopGradient(stop_gradient);
187-
}
188-
self->SetPersistable(persistable);
189-
self->SetType(framework::proto::VarType::LOD_TENSOR);
190197
self->SetDataType(tensor->type());
191198
}
192199

@@ -196,19 +203,25 @@ static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
196203
auto persistable = kwargs.contains("persistable")
197204
? kwargs["persistable"].cast<bool>()
198205
: false;
199-
auto array = kwargs.contains("value") ? kwargs["value"].cast<py::array>()
200-
: py::array();
201206
auto zero_copy =
202207
kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast<bool>() : false;
203208
auto name = kwargs.contains("name") ? kwargs["name"].cast<std::string>() : "";
204209
auto stop_gradient = kwargs.contains("stop_gradient")
205210
? kwargs["stop_gradient"].cast<int>()
206211
: -1;
207212
auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
208-
auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
209-
: default_place;
210-
InitTensorForVarBase(self, array, place, persistable, zero_copy, name,
211-
stop_gradient);
213+
214+
if (kwargs.contains("value")) {
215+
auto array = kwargs["value"].cast<py::array>();
216+
// place is only used when array is given, otherwise, it is meaningless and
217+
// ignored
218+
auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
219+
: default_place;
220+
InitVarBaseAndTensor(self, array, place, name, persistable, zero_copy,
221+
stop_gradient);
222+
} else {
223+
InitVarBaseOnly(self, name, persistable, stop_gradient);
224+
}
212225
}
213226

214227
template <typename P>
@@ -243,7 +256,7 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
243256
const py::array &array) {
244257
auto place = imperative::GetCurrentTracer()->ExpectedPlace();
245258
VLOG(4) << "Init VarBase from numpy at " << place;
246-
InitTensorForVarBase(self, array, place);
259+
InitVarBaseAndTensor(self, array, place, "");
247260
}
248261

249262
static void InitVarBaseFromTensorWithArgDefault(

0 commit comments

Comments
 (0)