@@ -162,38 +162,64 @@ template <DeviceType Device>
162162class CrossMapNormalFunc : public FunctionBase {
163163public:
164164 void init (const FuncConfig& config) override {
165+ // function arguments
165166 size_ = config.get <size_t >(" size" );
166167 scale_ = config.get <real>(" scale" );
167168 pow_ = config.get <real>(" pow" );
169+
170+ // number of inputs and outputs
171+ numInputs_ = 1 ;
172+ numOutputs_ = 2 ;
168173 }
169174
170175 void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
171- CHECK_EQ ((size_t )1 , inputs.size ());
172- CHECK_EQ ((size_t )2 , outputs.size ());
173-
174- CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
175- CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
176- CHECK (inputs[0 ].shape () == outputs[1 ].shape ());
177-
176+ check (inputs, outputs);
177+ // ArgType check still on here,
178+ // not sure whether it is better to put inside the check.
178179 CHECK_EQ (outputs[0 ].getArgType (), ASSIGN_TO);
179180 CHECK_EQ (outputs[1 ].getArgType (), ASSIGN_TO);
180- size_t samples = inputs[0 ].shape ()[0 ];
181- size_t channels = inputs[0 ].shape ()[1 ];
182- size_t height = inputs[0 ].shape ()[2 ];
183- size_t width = inputs[0 ].shape ()[3 ];
181+ size_t batchSize = inputs[0 ].shape ()[0 ];
182+ size_t maps = inputs[0 ].shape ()[1 ];
183+ size_t rows = inputs[0 ].shape ()[2 ];
184+ size_t columns = inputs[0 ].shape ()[3 ];
184185
185186 CrossMapNormal<Device>(outputs[0 ].data <real>(),
186187 outputs[1 ].data <real>(),
187188 inputs[0 ].data <real>(),
188- samples ,
189- channels ,
190- height ,
191- width ,
189+ batchSize ,
190+ maps ,
191+ rows ,
192+ columns ,
192193 size_,
193194 scale_,
194195 pow_);
195196 }
196197
198+ void check (const BufferArgs& inputs, const BufferArgs& outputs) override {
199+ CHECK_EQ (numInputs_, inputs.size ());
200+ CHECK_EQ (numOutputs_, outputs.size ());
201+
202+ CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
203+ CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
204+ CHECK (inputs[0 ].shape () == outputs[1 ].shape ());
205+ }
206+
207+ // Only need the shape of the input, can calculate the
208+ // floating-point operation.
209+ size_t ops (const BufferArgs& inputs, const BufferArgs& outputs) override {
210+ CHECK_EQ ((size_t )numInputs_, inputs.size ());
211+ size_t batchSize = inputs[0 ].shape ()[0 ];
212+ size_t maps = inputs[0 ].shape ()[1 ];
213+ size_t rows = inputs[0 ].shape ()[2 ];
214+ size_t columns = inputs[0 ].shape ()[3 ];
215+
216+ // number of floating-point operations
217+ // an approximate value
218+ size_t ops = batchSize * maps * rows * columns * (size_ * 2 + 3 );
219+
220+ return ops;
221+ }
222+
197223private:
198224 size_t size_;
199225 real scale_;
@@ -236,21 +262,18 @@ template <DeviceType Device>
236262class CrossMapNormalGradFunc : public FunctionBase {
237263public:
238264 void init (const FuncConfig& config) override {
265+ // function arguments
239266 size_ = config.get <size_t >(" size" );
240267 scale_ = config.get <real>(" scale" );
241268 pow_ = config.get <real>(" pow" );
269+
270+ // number of inputs and outputs
271+ numInputs_ = 4 ;
272+ numOutputs_ = 1 ;
242273 }
243274
244275 void calc (const BufferArgs& inputs, const BufferArgs& outputs) override {
245- CHECK_EQ ((size_t )4 , inputs.size ());
246- CHECK_EQ ((size_t )1 , outputs.size ());
247-
248- CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
249- CHECK (inputs[0 ].shape () == inputs[1 ].shape ());
250- CHECK (inputs[0 ].shape () == inputs[2 ].shape ());
251- CHECK (inputs[0 ].shape () == inputs[3 ].shape ());
252- CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
253-
276+ check (inputs, outputs);
254277 if (outputs[0 ].getArgType () != ADD_TO) {
255278 // Currently, some algorithm implementations are ASSIGN_TO mode,
256279 // if need to support the ADD_TO calculation, need to clear the output.
@@ -259,25 +282,52 @@ class CrossMapNormalGradFunc : public FunctionBase {
259282 tmp.zero ();
260283 }
261284
262- size_t samples = inputs[0 ].shape ()[0 ];
263- size_t channels = inputs[0 ].shape ()[1 ];
264- size_t height = inputs[0 ].shape ()[2 ];
265- size_t width = inputs[0 ].shape ()[3 ];
285+ size_t batchSize = inputs[0 ].shape ()[0 ];
286+ size_t maps = inputs[0 ].shape ()[1 ];
287+ size_t rows = inputs[0 ].shape ()[2 ];
288+ size_t columns = inputs[0 ].shape ()[3 ];
266289
267290 CrossMapNormalGrad<Device>(outputs[0 ].data <real>(),
268291 inputs[0 ].data <real>(),
269292 inputs[1 ].data <real>(),
270293 inputs[2 ].data <real>(),
271294 inputs[3 ].data <real>(),
272- samples ,
273- channels ,
274- height ,
275- width ,
295+ batchSize ,
296+ maps ,
297+ rows ,
298+ columns ,
276299 size_,
277300 scale_,
278301 pow_);
279302 }
280303
304+ void check (const BufferArgs& inputs, const BufferArgs& outputs) override {
305+ CHECK_EQ (numInputs_, inputs.size ());
306+ CHECK_EQ (numOutputs_, outputs.size ());
307+
308+ CHECK_EQ (inputs[0 ].shape ().ndims (), (size_t )4 );
309+ CHECK (inputs[0 ].shape () == inputs[1 ].shape ());
310+ CHECK (inputs[0 ].shape () == inputs[2 ].shape ());
311+ CHECK (inputs[0 ].shape () == inputs[3 ].shape ());
312+ CHECK (inputs[0 ].shape () == outputs[0 ].shape ());
313+ }
314+
315+ // Only need the shape of one input, can calculate the
316+ // floating-point operation.
317+ size_t ops (const BufferArgs& inputs, const BufferArgs& outputs) override {
318+ CHECK_LT ((size_t )1 , inputs.size ());
319+ size_t batchSize = inputs[0 ].shape ()[0 ];
320+ size_t maps = inputs[0 ].shape ()[1 ];
321+ size_t rows = inputs[0 ].shape ()[2 ];
322+ size_t columns = inputs[0 ].shape ()[3 ];
323+
324+ // number of floating-point operations
325+ // an approximate value
326+ size_t ops = batchSize * maps * rows * columns * (size_ * 4 + 2 );
327+
328+ return ops;
329+ }
330+
281331private:
282332 size_t size_;
283333 real scale_;
0 commit comments