Skip to content

Commit c833770

Browse files
committed
polish code
1 parent 8a5f39f commit c833770

File tree

3 files changed

+31
-171
lines changed

3 files changed

+31
-171
lines changed

paddle/fluid/operators/set_value_op.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,17 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
151151

152152
protected:
153153
void Apply(GradOpPtr<T> op) const override {
154-
op->SetType("set_value_grad");
154+
if (this->HasInput("ValueTensor")) {
155+
op->SetType("set_value_grad");
156+
op->SetInput("ValueTensor", this->Input("ValueTensor"));
157+
op->SetOutput(framework::GradVarName("ValueTensor"),
158+
this->InputGrad("ValueTensor"));
159+
} else {
160+
op->SetType("set_value_with_scalar_grad");
161+
}
162+
155163
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
164+
156165
if (this->HasInput("StartsTensorList")) {
157166
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
158167
}
@@ -166,12 +175,6 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
166175
op->SetAttrMap(this->Attrs());
167176

168177
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
169-
170-
if (this->HasInput("ValueTensor")) {
171-
op->SetInput("ValueTensor", this->Input("ValueTensor"));
172-
op->SetOutput(framework::GradVarName("ValueTensor"),
173-
this->InputGrad("ValueTensor"));
174-
}
175178
}
176179
};
177180

paddle/phi/kernels/impl/set_value_grad_kernel_impl.h

Lines changed: 11 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -351,87 +351,16 @@ void SetValueWithScalarGradKernel(const Context& dev_ctx,
351351
const std::vector<int64_t>& decrease_axes,
352352
const std::vector<int64_t>& none_axes,
353353
DenseTensor* x_grad) {
354-
const int rank = out_grad.dims().size();
355-
356-
switch (rank) {
357-
case 1:
358-
SetValueGradImpl<T, Context, 1>(dev_ctx,
359-
out_grad,
360-
starts,
361-
ends,
362-
steps,
363-
axes,
364-
decrease_axes,
365-
none_axes,
366-
x_grad,
367-
nullptr);
368-
break;
369-
case 2:
370-
SetValueGradImpl<T, Context, 2>(dev_ctx,
371-
out_grad,
372-
starts,
373-
ends,
374-
steps,
375-
axes,
376-
decrease_axes,
377-
none_axes,
378-
x_grad,
379-
nullptr);
380-
break;
381-
case 3:
382-
SetValueGradImpl<T, Context, 3>(dev_ctx,
383-
out_grad,
384-
starts,
385-
ends,
386-
steps,
387-
axes,
388-
decrease_axes,
389-
none_axes,
390-
x_grad,
391-
nullptr);
392-
break;
393-
case 4:
394-
SetValueGradImpl<T, Context, 4>(dev_ctx,
395-
out_grad,
396-
starts,
397-
ends,
398-
steps,
399-
axes,
400-
decrease_axes,
401-
none_axes,
402-
x_grad,
403-
nullptr);
404-
break;
405-
case 5:
406-
SetValueGradImpl<T, Context, 5>(dev_ctx,
407-
out_grad,
408-
starts,
409-
ends,
410-
steps,
411-
axes,
412-
decrease_axes,
413-
none_axes,
414-
x_grad,
415-
nullptr);
416-
break;
417-
case 6:
418-
SetValueGradImpl<T, Context, 6>(dev_ctx,
419-
out_grad,
420-
starts,
421-
ends,
422-
steps,
423-
axes,
424-
decrease_axes,
425-
none_axes,
426-
x_grad,
427-
nullptr);
428-
break;
429-
default:
430-
PADDLE_THROW(phi::errors::InvalidArgument(
431-
"The rank of set_value_with_scalar_grad's input should be less than "
432-
"7, but "
433-
"received %d.",
434-
rank));
435-
}
354+
SetValueGradKernel<T, Context>(dev_ctx,
355+
out_grad,
356+
starts,
357+
ends,
358+
steps,
359+
axes,
360+
decrease_axes,
361+
none_axes,
362+
x_grad,
363+
nullptr);
436364
}
365+
437366
} // namespace phi

paddle/phi/kernels/xpu/set_value_grad_kernel.cc

Lines changed: 10 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -407,88 +407,16 @@ void SetValueWithScalarGradKernel(const Context& dev_ctx,
407407
const std::vector<int64_t>& decrease_axes,
408408
const std::vector<int64_t>& none_axes,
409409
DenseTensor* x_grad) {
410-
const int rank = out_grad.dims().size();
411-
412-
switch (rank) {
413-
case 1:
414-
SetValueGradImpl<T, Context, 1>(dev_ctx,
415-
out_grad,
416-
starts,
417-
ends,
418-
steps,
419-
axes,
420-
decrease_axes,
421-
none_axes,
422-
x_grad,
423-
nullptr);
424-
break;
425-
case 2:
426-
SetValueGradImpl<T, Context, 2>(dev_ctx,
427-
out_grad,
428-
starts,
429-
ends,
430-
steps,
431-
axes,
432-
decrease_axes,
433-
none_axes,
434-
x_grad,
435-
nullptr);
436-
break;
437-
case 3:
438-
SetValueGradImpl<T, Context, 3>(dev_ctx,
439-
out_grad,
440-
starts,
441-
ends,
442-
steps,
443-
axes,
444-
decrease_axes,
445-
none_axes,
446-
x_grad,
447-
nullptr);
448-
break;
449-
case 4:
450-
SetValueGradImpl<T, Context, 4>(dev_ctx,
451-
out_grad,
452-
starts,
453-
ends,
454-
steps,
455-
axes,
456-
decrease_axes,
457-
none_axes,
458-
x_grad,
459-
nullptr);
460-
break;
461-
case 5:
462-
SetValueGradImpl<T, Context, 5>(dev_ctx,
463-
out_grad,
464-
starts,
465-
ends,
466-
steps,
467-
axes,
468-
decrease_axes,
469-
none_axes,
470-
x_grad,
471-
nullptr);
472-
break;
473-
case 6:
474-
SetValueGradImpl<T, Context, 6>(dev_ctx,
475-
out_grad,
476-
starts,
477-
ends,
478-
steps,
479-
axes,
480-
decrease_axes,
481-
none_axes,
482-
x_grad,
483-
nullptr);
484-
break;
485-
default:
486-
PADDLE_THROW(phi::errors::InvalidArgument(
487-
"The rank of set_value_with_scalar_grad's input should be less than "
488-
"7, but "
489-
"received %d.",
490-
rank));
491-
}
410+
SetValueGradKernel<T, Context>(dev_ctx,
411+
out_grad,
412+
starts,
413+
ends,
414+
steps,
415+
axes,
416+
decrease_axes,
417+
none_axes,
418+
x_grad,
419+
nullptr);
492420
}
493421

494422
} // namespace phi

0 commit comments

Comments
 (0)