File tree Expand file tree Collapse file tree 3 files changed +31
-171
lines changed
Expand file tree Collapse file tree 3 files changed +31
-171
lines changed Original file line number Diff line number Diff line change @@ -151,8 +151,17 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
151151
152152 protected:
153153 void Apply (GradOpPtr<T> op) const override {
154- op->SetType (" set_value_grad" );
154+ if (this ->HasInput (" ValueTensor" )) {
155+ op->SetType (" set_value_grad" );
156+ op->SetInput (" ValueTensor" , this ->Input (" ValueTensor" ));
157+ op->SetOutput (framework::GradVarName (" ValueTensor" ),
158+ this ->InputGrad (" ValueTensor" ));
159+ } else {
160+ op->SetType (" set_value_with_scalar_grad" );
161+ }
162+
155163 op->SetInput (framework::GradVarName (" Out" ), this ->OutputGrad (" Out" ));
164+
156165 if (this ->HasInput (" StartsTensorList" )) {
157166 op->SetInput (" StartsTensorList" , this ->Input (" StartsTensorList" ));
158167 }
@@ -166,12 +175,6 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
166175 op->SetAttrMap (this ->Attrs ());
167176
168177 op->SetOutput (framework::GradVarName (" Input" ), this ->InputGrad (" Input" ));
169-
170- if (this ->HasInput (" ValueTensor" )) {
171- op->SetInput (" ValueTensor" , this ->Input (" ValueTensor" ));
172- op->SetOutput (framework::GradVarName (" ValueTensor" ),
173- this ->InputGrad (" ValueTensor" ));
174- }
175178 }
176179};
177180
Original file line number Diff line number Diff line change @@ -351,87 +351,16 @@ void SetValueWithScalarGradKernel(const Context& dev_ctx,
351351 const std::vector<int64_t >& decrease_axes,
352352 const std::vector<int64_t >& none_axes,
353353 DenseTensor* x_grad) {
354- const int rank = out_grad.dims ().size ();
355-
356- switch (rank) {
357- case 1 :
358- SetValueGradImpl<T, Context, 1 >(dev_ctx,
359- out_grad,
360- starts,
361- ends,
362- steps,
363- axes,
364- decrease_axes,
365- none_axes,
366- x_grad,
367- nullptr );
368- break ;
369- case 2 :
370- SetValueGradImpl<T, Context, 2 >(dev_ctx,
371- out_grad,
372- starts,
373- ends,
374- steps,
375- axes,
376- decrease_axes,
377- none_axes,
378- x_grad,
379- nullptr );
380- break ;
381- case 3 :
382- SetValueGradImpl<T, Context, 3 >(dev_ctx,
383- out_grad,
384- starts,
385- ends,
386- steps,
387- axes,
388- decrease_axes,
389- none_axes,
390- x_grad,
391- nullptr );
392- break ;
393- case 4 :
394- SetValueGradImpl<T, Context, 4 >(dev_ctx,
395- out_grad,
396- starts,
397- ends,
398- steps,
399- axes,
400- decrease_axes,
401- none_axes,
402- x_grad,
403- nullptr );
404- break ;
405- case 5 :
406- SetValueGradImpl<T, Context, 5 >(dev_ctx,
407- out_grad,
408- starts,
409- ends,
410- steps,
411- axes,
412- decrease_axes,
413- none_axes,
414- x_grad,
415- nullptr );
416- break ;
417- case 6 :
418- SetValueGradImpl<T, Context, 6 >(dev_ctx,
419- out_grad,
420- starts,
421- ends,
422- steps,
423- axes,
424- decrease_axes,
425- none_axes,
426- x_grad,
427- nullptr );
428- break ;
429- default :
430- PADDLE_THROW (phi::errors::InvalidArgument (
431- " The rank of set_value_with_scalar_grad's input should be less than "
432- " 7, but "
433- " received %d." ,
434- rank));
435- }
354+ SetValueGradKernel<T, Context>(dev_ctx,
355+ out_grad,
356+ starts,
357+ ends,
358+ steps,
359+ axes,
360+ decrease_axes,
361+ none_axes,
362+ x_grad,
363+ nullptr );
436364}
365+
437366} // namespace phi
Original file line number Diff line number Diff line change @@ -407,88 +407,16 @@ void SetValueWithScalarGradKernel(const Context& dev_ctx,
407407 const std::vector<int64_t >& decrease_axes,
408408 const std::vector<int64_t >& none_axes,
409409 DenseTensor* x_grad) {
410- const int rank = out_grad.dims ().size ();
411-
412- switch (rank) {
413- case 1 :
414- SetValueGradImpl<T, Context, 1 >(dev_ctx,
415- out_grad,
416- starts,
417- ends,
418- steps,
419- axes,
420- decrease_axes,
421- none_axes,
422- x_grad,
423- nullptr );
424- break ;
425- case 2 :
426- SetValueGradImpl<T, Context, 2 >(dev_ctx,
427- out_grad,
428- starts,
429- ends,
430- steps,
431- axes,
432- decrease_axes,
433- none_axes,
434- x_grad,
435- nullptr );
436- break ;
437- case 3 :
438- SetValueGradImpl<T, Context, 3 >(dev_ctx,
439- out_grad,
440- starts,
441- ends,
442- steps,
443- axes,
444- decrease_axes,
445- none_axes,
446- x_grad,
447- nullptr );
448- break ;
449- case 4 :
450- SetValueGradImpl<T, Context, 4 >(dev_ctx,
451- out_grad,
452- starts,
453- ends,
454- steps,
455- axes,
456- decrease_axes,
457- none_axes,
458- x_grad,
459- nullptr );
460- break ;
461- case 5 :
462- SetValueGradImpl<T, Context, 5 >(dev_ctx,
463- out_grad,
464- starts,
465- ends,
466- steps,
467- axes,
468- decrease_axes,
469- none_axes,
470- x_grad,
471- nullptr );
472- break ;
473- case 6 :
474- SetValueGradImpl<T, Context, 6 >(dev_ctx,
475- out_grad,
476- starts,
477- ends,
478- steps,
479- axes,
480- decrease_axes,
481- none_axes,
482- x_grad,
483- nullptr );
484- break ;
485- default :
486- PADDLE_THROW (phi::errors::InvalidArgument (
487- " The rank of set_value_with_scalar_grad's input should be less than "
488- " 7, but "
489- " received %d." ,
490- rank));
491- }
410+ SetValueGradKernel<T, Context>(dev_ctx,
411+ out_grad,
412+ starts,
413+ ends,
414+ steps,
415+ axes,
416+ decrease_axes,
417+ none_axes,
418+ x_grad,
419+ nullptr );
492420}
493421
494422} // namespace phi
You can’t perform that action at this time.
0 commit comments