Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion paddle/fluid/operators/clip_by_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,123 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/clip_by_norm_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Tx, typename Ty = Tx>
struct SquareTransformer {
HOSTDEVICE explicit inline SquareTransformer(int n) {}

HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x) * static_cast<Ty>(x);
}

HOSTDEVICE inline Ty operator()(const Tx* x) const {
return static_cast<Ty>(x[0]) * static_cast<Ty>(x[0]);
}
};

template <typename Tx, typename Ty = Tx>
struct SquareSum {
using Transformer = SquareTransformer<Tx, Ty>;

inline Ty initial() { return static_cast<Ty>(0.0f); }

__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};

template <>
class ClipByNormKernel<platform::CUDADeviceContext, platform::float16>
: public framework::OpKernel<platform::float16> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max_norm = context.Attr<float>("max_norm");
auto in_var = context.InputVar("X");
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();

Tensor* output = nullptr;
const Tensor* input = nullptr;
if (in_var->IsType<framework::LoDTensor>()) {
input = context.Input<Tensor>("X");

output = context.Output<Tensor>("Out");
output->mutable_data<platform::float16>(context.GetPlace());
} else if (in_var->IsType<SelectedRows>()) {
auto* x = context.Input<SelectedRows>("X");

// merge ids in selected rows first
math::scatter::MergeAdd<platform::CUDADeviceContext, platform::float16>
merge_func;
SelectedRows* merged_input =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不太理解为什么要将临时的merged_input放到scope中,看起来只是在kernel中使用的临时变量

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,后面再优化下这部分代码

const_cast<framework::Scope&>(context.scope())
.Var()
->GetMutable<SelectedRows>();
merge_func(context.template device_context<platform::CUDADeviceContext>(),
*x, merged_input);
input = &(merged_input->value());

SelectedRows* output_selected_rows = context.Output<SelectedRows>("Out");
output_selected_rows->set_rows(merged_input->rows());
output_selected_rows->set_height(merged_input->height());
output = output_selected_rows->mutable_value();
output->Resize(merged_input->value().dims());
output->mutable_data<platform::float16>(context.GetPlace());
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Invalid input variable type, only support LodTensor and "
"SelectedRows types, but got type is %s.",
framework::ToTypeName(in_var->Type())));
}

PADDLE_ENFORCE_NOT_NULL(input,
platform::errors::InvalidArgument(
"Input(X) of ClipByNormOp should not be null. "
"Please check if it is created correctly."));
std::vector<int> reduce_dims;
reduce_dims.resize(input->dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) {
reduce_dims[i] = i;
}
Tensor tmp = context.AllocateTmpTensor<float, platform::CUDADeviceContext>(
{1}, dev_ctx);
TensorReduceFunctorImpl<platform::float16, float, SquareSum>(
*input, &tmp, reduce_dims, dev_ctx.stream());
auto tmp_eigen = EigenVector<float>::Flatten(tmp);
auto x_norm = tmp_eigen.sqrt();

auto x = EigenVector<platform::float16>::Flatten(*input);
auto out = EigenVector<platform::float16>::Flatten(*output);

auto& place =
*context.template device_context<platform::CUDADeviceContext>()
.eigen_device();

auto temp = (x_norm <= max_norm).template cast<float>();
auto epsilon =
((x_norm <= static_cast<float>(1e-30)).all().template cast<float>()) *
static_cast<float>(1e-6);

auto scaling =
(temp + (static_cast<float>(1) - temp) * max_norm / (x_norm + epsilon))
.template cast<platform::float16>();
Eigen::array<int, 1> one_dim{{1}};
Eigen::DSizes<int, 1> m_dsize(input->numel());

out.device(place) = x * scaling.reshape(one_dim).broadcast(m_dsize);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
clip_by_norm,
ops::ClipByNormKernel<paddle::platform::CUDADeviceContext, float>);
ops::ClipByNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::ClipByNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
35 changes: 34 additions & 1 deletion python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
class TestClipByNormOp(OpTest):
def setUp(self):
self.max_relative_error = 0.006
self.init_dtype()
self.initTestCase()
input = np.random.random(self.shape).astype("float32")
input = np.random.random(self.shape).astype(self.dtype)
input[np.abs(input) < self.max_relative_error] = 0.5
self.op_type = "clip_by_norm"
self.inputs = {'X': input, }
Expand All @@ -46,6 +47,9 @@ def initTestCase(self):
self.shape = (100, )
self.max_norm = 1.0

def init_dtype(self):
self.dtype = np.float32


class TestCase1(TestClipByNormOp):
def initTestCase(self):
Expand All @@ -65,6 +69,35 @@ def initTestCase(self):
self.max_norm = 1.0


class TestClipByNormOpFp16(TestClipByNormOp):
def init_dtype(self):
self.dtype = np.float16

def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=0.001)


class TestClipByNormOpFp16Case1(TestClipByNormOpFp16):
def initTestCase(self):
self.shape = (100, )
self.max_norm = 1e20


class TestClipByNormOpFp16Case2(TestClipByNormOpFp16):
def initTestCase(self):
self.shape = (16, 16)
self.max_norm = 0.1


class TestClipByNormOpFp16Case3(TestClipByNormOpFp16):
def initTestCase(self):
self.shape = (4, 8, 16)
self.max_norm = 1.0


class TestClipByNormOpWithSelectedRows(unittest.TestCase):
def check_with_place(self, place):
self.config_test_case()
Expand Down