Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/framework/unused_var_check.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ static const std::unordered_set<std::string> &GetOpWithUnusedVarAllowSet() {
"fused_batch_norm_act", // 2
"fused_batch_norm_act_grad", // 2
"data_norm", // 0
"data_norm_grad", // 0);
"data_norm_grad", // 0
"update_loss_scaling", // 0
});
return *allow_set;
}
Expand Down
30 changes: 19 additions & 11 deletions paddle/fluid/operators/amp/check_finite_and_unscale_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License. */
#include <cuda.h>

#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
Expand All @@ -25,22 +27,25 @@ __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) {
*found_inf = false;
}

template <typename T>
__global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num,
template <typename T, typename MT>
__global__ void CheckFiniteAndUnscale(const T* in, const MT* scale, int num,
bool* found_inf, T* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;

if (idx < num) {
T val = in[idx] * (*scale);
out[idx] = val;
if (!isfinite(val)) {
MT val = static_cast<MT>(in[idx]) * (*scale);
T narrow_val = static_cast<T>(val);
out[idx] = narrow_val;
if (!isfinite(narrow_val)) {
*found_inf = true;
}
}
}

template <typename T>
class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;

public:
void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Expand All @@ -49,14 +54,15 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");

const T* scale_data = scale->data<T>();
const MPDType* scale_data = scale->data<MPDType>();
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());

framework::Tensor inverse_scale =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({1}, dev_ctx);
T* inverse_scale_v = inverse_scale.template data<T>();
ctx.AllocateTmpTensor<MPDType, platform::CUDADeviceContext>({1},
dev_ctx);
MPDType* inverse_scale_v = inverse_scale.template data<MPDType>();

InverseAndMemset<T><<<1, 1, 0, dev_ctx.stream()>>>(
InverseAndMemset<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
scale_data, inverse_scale_v, found_inf_data);

for (size_t i = 0; i < xs.size(); ++i) {
Expand All @@ -69,7 +75,7 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
int block = 1024;
int grid = (num + block - 1) / block;
VLOG(3) << "launch kernel";
CheckFiniteAndUnscale<T><<<grid, block, 0, dev_ctx.stream()>>>(
CheckFiniteAndUnscale<T, MPDType><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, inverse_scale_v, num, found_inf_data, out_data);
VLOG(3) << "finish kernel";
}
Expand All @@ -79,6 +85,8 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(check_finite_and_unscale,
ops::CheckFiniteAndUnscaleGpuKernel<float>,
ops::CheckFiniteAndUnscaleGpuKernel<double>);
ops::CheckFiniteAndUnscaleGpuKernel<double>,
ops::CheckFiniteAndUnscaleGpuKernel<plat::float16>);
37 changes: 37 additions & 0 deletions paddle/fluid/operators/amp/fp16_type_traits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
namespace details {

template <typename T>
class MPTypeTrait {
public:
using Type = T;
};

template <>
class MPTypeTrait<platform::float16> {
public:
using Type = float;
};

} // namespace details
} // namespace operators
} // namespace paddle
6 changes: 4 additions & 2 deletions paddle/fluid/operators/amp/update_loss_scaling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "PrevLossScaling"),
ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};

Expand Down Expand Up @@ -107,6 +106,9 @@ class UpdateLossScalingOpMaker : public framework::OpProtoAndCheckerMaker {
"the received is %f",
decr_ratio));
});
AddAttr<bool>("stop_update",
"Stop updating loss scaling, and just zero inputs.")
.SetDefault(false);
AddComment(R"DOC(
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/amp/update_loss_scaling_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -83,8 +84,10 @@ class LazyZeros<platform::CUDADeviceContext, T> {
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
using GPU = paddle::platform::CUDADeviceContext;

REGISTER_OP_CUDA_KERNEL(update_loss_scaling,
ops::UpdateLossScalingKernel<GPU, float>,
ops::UpdateLossScalingKernel<GPU, double>);
ops::UpdateLossScalingKernel<GPU, double>,
ops::UpdateLossScalingKernel<GPU, plat::float16>);
34 changes: 21 additions & 13 deletions paddle/fluid/operators/amp/update_loss_scaling_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
Expand Down Expand Up @@ -79,30 +80,38 @@ class LazyZeros {

template <typename DeviceContext, typename T>
class UpdateLossScalingKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;

public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<DeviceContext>();

const auto xs = ctx.MultiInput<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
const auto* found_inf = ctx.Input<Tensor>("FoundInfinite");
PADDLE_ENFORCE_EQ(found_inf->numel(), 1,
platform::errors::InvalidArgument(
"FoundInfinite must has only one element."));
const bool* found_inf_data = found_inf->data<bool>();

LazyZeros<DeviceContext, T>{}(dev_ctx, found_inf_data, xs, outs);
const bool stop_update = ctx.Attr<bool>("stop_update");
if (stop_update) {
return;
}

const auto* pre_loss_scaling = ctx.Input<Tensor>("PrevLossScaling");
const auto* good_in = ctx.Input<Tensor>("InGoodSteps");
const auto* bad_in = ctx.Input<Tensor>("InBadSteps");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* updated_loss_scaling = ctx.Output<Tensor>("LossScaling");
auto* good_out = ctx.Output<Tensor>("OutGoodSteps");
auto* bad_out = ctx.Output<Tensor>("OutBadSteps");

PADDLE_ENFORCE_EQ(found_inf->numel(), 1,
platform::errors::InvalidArgument(
"FoundInfinite must has only one element."));

const bool* found_inf_data = found_inf->data<bool>();
const T* pre_loss_scaling_data = pre_loss_scaling->data<T>();
const MPDType* pre_loss_scaling_data = pre_loss_scaling->data<MPDType>();
const int* good_in_data = good_in->data<int>();
const int* bad_in_data = bad_in->data<int>();

auto& dev_ctx = ctx.template device_context<DeviceContext>();
T* updated_loss_scaling_data =
updated_loss_scaling->mutable_data<T>(dev_ctx.GetPlace());
MPDType* updated_loss_scaling_data =
updated_loss_scaling->mutable_data<MPDType>(dev_ctx.GetPlace());
int* good_out_data = good_out->mutable_data<int>(dev_ctx.GetPlace());
int* bad_out_data = bad_out->mutable_data<int>(dev_ctx.GetPlace());

Expand All @@ -111,11 +120,10 @@ class UpdateLossScalingKernel : public framework::OpKernel<T> {
ctx.Attr<int>("decr_every_n_nan_or_inf");
const float incr_ratio = ctx.Attr<float>("incr_ratio");
const float decr_ratio = ctx.Attr<float>("decr_ratio");
UpdateLossScalingFunctor<DeviceContext, T>{}(
UpdateLossScalingFunctor<DeviceContext, MPDType>{}(
dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data,
bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio,
decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data);
LazyZeros<DeviceContext, T>{}(dev_ctx, found_inf_data, xs, outs);
}
};

Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/operators/optimizers/adam_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/optimizers/adam_op.h"
#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -150,12 +151,17 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
"as beta2, this has a higher priority than attr(beta2), the "
"shape of this tensor MUST BE [1].")
.AsDispensable();
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();

AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("Moment1Out", "(Tensor) Output first moment");
AddOutput("Moment2Out", "(Tensor) Output second moment");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();

AddAttr<float>("beta1",
"(float, default 0.9) "
Expand Down Expand Up @@ -183,6 +189,10 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
"inner_op_parallelism is larger then 0, sparse update "
"will run in multithread mode")
.SetDefault(1000);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);

AddComment(R"DOC(
Adam Optimizer.
Expand Down Expand Up @@ -213,3 +223,13 @@ REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker);
REGISTER_OP_CPU_KERNEL(
adam, ops::AdamOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdamOpKernel<paddle::platform::CPUDeviceContext, double>);

REGISTER_OP_VERSION(adam)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adam属于训练的op,其实没有必要设置op version,加了也没有什么影响

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,谢谢提醒。

.AddCheckpoint(
R"ROC(
Upgrade adam add 1 attribute [multi_precision].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"multi_precision",
"(bool) Whether to use multi-precision during weight updating.",
false));
Loading