@@ -146,21 +146,33 @@ static const platform::Place PyObjectToPlace(const py::object &place_obj) {
146146 }
147147}
148148
149- static void InitTensorForVarBase (imperative::VarBase *self,
150- const py::array &array,
151- const platform::Place place,
152- bool persistable = false ,
153- bool zero_copy = false , std::string name = " " ,
154- int stop_gradient = -1 ) {
155- if (name == " " ) {
156- name =
157- imperative::GetCurrentTracer ()->GenerateUniqueName (" generated_tensor" );
158- }
159- VLOG (5 ) << " Init Tensor as: / name: " << name
160- << " / persistable: " << persistable << " / zero_copy: " << zero_copy
149+ // only initialize varbase, but not its tensor.
150+ static void InitVarBaseOnly (imperative::VarBase *self, const std::string &name,
151+ bool persistable = false , int stop_gradient = -1 ) {
152+ auto name_ = name == " "
153+ ? imperative::GetCurrentTracer ()->GenerateUniqueName (
154+ " generated_tensor" )
155+ : name;
156+
157+ VLOG (5 ) << " Init Tensor as: / name: " << name_
158+ << " / persistable: " << persistable
161159 << " / stop_gradient: " << stop_gradient;
162- new (self) imperative::VarBase (name);
160+ new (self) imperative::VarBase (name_);
161+ if (stop_gradient != -1 ) {
162+ self->SetOverridedStopGradient (stop_gradient);
163+ }
164+ self->SetPersistable (persistable);
165+ self->SetType (framework::proto::VarType::LOD_TENSOR);
166+ }
167+
168+ // initialize varbase and its tensor.
169+ static void InitVarBaseAndTensor (
170+ imperative::VarBase *self, const py::array &array,
171+ const platform::Place &place, const std::string &name,
172+ bool persistable = false , bool zero_copy = false , int stop_gradient = -1 ) {
173+ InitVarBaseOnly (self, name, persistable, stop_gradient);
163174 auto *tensor = self->MutableVar ()->GetMutable <framework::LoDTensor>();
175+ VLOG (4 ) << " zero_copy: " << zero_copy;
164176 if (platform::is_cpu_place (place)) {
165177 SetTensorFromPyArray<platform::CPUPlace>(
166178 tensor, array, BOOST_GET_CONST (platform::CPUPlace, place), zero_copy);
@@ -182,11 +194,6 @@ static void InitTensorForVarBase(imperative::VarBase *self,
182194 " Place should be one of "
183195 " CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace" ));
184196 }
185- if (stop_gradient != -1 ) {
186- self->SetOverridedStopGradient (stop_gradient);
187- }
188- self->SetPersistable (persistable);
189- self->SetType (framework::proto::VarType::LOD_TENSOR);
190197 self->SetDataType (tensor->type ());
191198}
192199
@@ -196,19 +203,25 @@ static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
196203 auto persistable = kwargs.contains (" persistable" )
197204 ? kwargs[" persistable" ].cast <bool >()
198205 : false ;
199- auto array = kwargs.contains (" value" ) ? kwargs[" value" ].cast <py::array>()
200- : py::array ();
201206 auto zero_copy =
202207 kwargs.contains (" zero_copy" ) ? kwargs[" zero_copy" ].cast <bool >() : false ;
203208 auto name = kwargs.contains (" name" ) ? kwargs[" name" ].cast <std::string>() : " " ;
204209 auto stop_gradient = kwargs.contains (" stop_gradient" )
205210 ? kwargs[" stop_gradient" ].cast <int >()
206211 : -1 ;
207212 auto default_place = imperative::GetCurrentTracer ()->ExpectedPlace ();
208- auto place = kwargs.contains (" place" ) ? PyObjectToPlace (kwargs[" place" ])
209- : default_place;
210- InitTensorForVarBase (self, array, place, persistable, zero_copy, name,
211- stop_gradient);
213+
214+ if (kwargs.contains (" value" )) {
215+ auto array = kwargs[" value" ].cast <py::array>();
216+ // place is only used when array is given, otherwise, it is meaningless and
217+ // ignored
218+ auto place = kwargs.contains (" place" ) ? PyObjectToPlace (kwargs[" place" ])
219+ : default_place;
220+ InitVarBaseAndTensor (self, array, place, name, persistable, zero_copy,
221+ stop_gradient);
222+ } else {
223+ InitVarBaseOnly (self, name, persistable, stop_gradient);
224+ }
212225}
213226
214227template <typename P>
@@ -243,7 +256,7 @@ static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
243256 const py::array &array) {
244257 auto place = imperative::GetCurrentTracer ()->ExpectedPlace ();
245258 VLOG (4 ) << " Init VarBase from numpy at " << place;
246- InitTensorForVarBase (self, array, place);
259+ InitVarBaseAndTensor (self, array, place, " " );
247260}
248261
249262static void InitVarBaseFromTensorWithArgDefault (
0 commit comments