Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions paddle/fluid/operators/label_smooth_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,21 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
auto label_dim = in_t->dims()[in_t->dims().size() - 1];
out_t->mutable_data<T>(ctx.GetPlace());

auto epsilon = ctx.Attr<float>("epsilon");
auto out = framework::EigenVector<T>::Flatten(*out_t);
auto in = framework::EigenVector<T>::Flatten(*in_t);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
if (dist_t) {
auto dist = framework::EigenVector<T>::Flatten(*dist_t);
out.device(dev) =
static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon) *
dist.broadcast(Eigen::DSizes<int, 1>(in_t->numel() / label_dim));
} else {
out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon / label_dim);
if (label_dim != 0) {
auto epsilon = ctx.Attr<float>("epsilon");
auto out = framework::EigenVector<T>::Flatten(*out_t);
auto in = framework::EigenVector<T>::Flatten(*in_t);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
if (dist_t) {
auto dist = framework::EigenVector<T>::Flatten(*dist_t);
out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon) *
dist.broadcast(Eigen::DSizes<int, 1>(
in_t->numel() / label_dim));
} else {
out.device(dev) = static_cast<T>(1 - epsilon) * in +
static_cast<T>(epsilon / label_dim);
}
}
}
};
Expand All @@ -54,13 +55,15 @@ class LabelSmoothGradKernel : public framework::OpKernel<T> {
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_in_t = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_in_t->mutable_data<T>(ctx.GetPlace());
auto d_out_dim = d_out_t->dims()[d_out_t->dims().size() - 1];
if (d_out_dim != 0) {
auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
auto d_in = framework::EigenVector<T>::Flatten(*d_in_t);

auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
auto d_in = framework::EigenVector<T>::Flatten(*d_in_t);

auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out;
auto epsilon = ctx.Attr<float>("epsilon");
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
d_in.device(dev) = static_cast<T>(1 - epsilon) * d_out;
}
}
};
} // namespace operators
Expand Down