Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
880fa23
add Rprop
WintersMontagne10335 Nov 9, 2023
e87a153
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Nov 9, 2023
82174c1
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Nov 11, 2023
5b49639
modified: python/paddle/optimizer/__init__.py
WintersMontagne10335 Nov 11, 2023
5fa62c7
add CPU and GPU code
WintersMontagne10335 Nov 12, 2023
1a9779d
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Nov 12, 2023
ec692bb
add basic implementation of Rprop
WintersMontagne10335 Nov 13, 2023
14bb5b6
fix bugs
WintersMontagne10335 Nov 15, 2023
2e64f02
fix bugs
WintersMontagne10335 Nov 16, 2023
40ab781
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Nov 16, 2023
255bc40
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Nov 19, 2023
81a1b91
modify the type of some parameters
WintersMontagne10335 Nov 20, 2023
eff0754
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Nov 20, 2023
1b2894b
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Nov 23, 2023
2e975f8
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Dec 5, 2023
f2e6e1b
fix bugs
WintersMontagne10335 Dec 6, 2023
cd25d45
fix bugs
WintersMontagne10335 Dec 6, 2023
b29c40f
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Dec 6, 2023
6d4030f
fix bugs
WintersMontagne10335 Dec 7, 2023
4197a2d
fix bugs
WintersMontagne10335 Dec 7, 2023
3b7c123
verify parameters and add unit test cases
WintersMontagne10335 Dec 8, 2023
42ea643
fix bugs
WintersMontagne10335 Dec 8, 2023
a23e95e
add documents
WintersMontagne10335 Dec 11, 2023
2823805
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Dec 15, 2023
169de4d
update parameters description
WintersMontagne10335 Dec 15, 2023
7515047
add default value descriptions for learning_rate_range and etas
WintersMontagne10335 Dec 19, 2023
f9f5ac4
Merge remote-tracking branch 'upstream/develop' into winters019
WintersMontagne10335 Dec 19, 2023
07a1f95
fix bug
WintersMontagne10335 Dec 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,19 @@
inplace : (x -> out)
backward : round_grad

- op : rprop_
args : (Tensor param, Tensor grad, Tensor prev, Tensor learning_rate, Tensor master_param, Tensor learning_rate_range, Tensor etas, bool multi_precision=false)
output : Tensor(param_out), Tensor(prev_out), Tensor(learning_rate_out), Tensor(master_param_out)
infer_meta :
func : RpropInferMeta
kernel :
func : rprop
data_type : param
data_transform :
support_trans_dtype : learning_rate
optional : master_param, master_param_out
inplace : (param -> param_out), (prev -> prev_out), (learning_rate -> learning_rate_out), (master_param -> master_param_out)

- op : rsqrt
args : (Tensor x)
output : Tensor(out)
Expand Down
44 changes: 44 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3470,6 +3470,50 @@ void RnnInferMeta(const MetaTensor& x,
}
}

void RpropInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& prev,
const MetaTensor& learning_rate,
const MetaTensor& master_param,
const MetaTensor& learning_rate_range,
const MetaTensor& etas,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* prev_out,
MetaTensor* learning_rate_out,
MetaTensor* master_param_out) {
PADDLE_ENFORCE_NOT_NULL(
param_out,
phi::errors::InvalidArgument(
"Output(ParamOut) of RpropOp should not be null."));

PADDLE_ENFORCE_NOT_NULL(
prev_out,
phi::errors::InvalidArgument(
"Output(PrevOut) of RpropOp should not be null."));

PADDLE_ENFORCE_NOT_NULL(
learning_rate_out,
phi::errors::InvalidArgument(
"Output(LearningRateOut) of RpropOp should not be null."));

param_out->set_dims(param.dims());
param_out->set_dtype(param.dtype());
prev_out->set_dims(prev.dims());
prev_out->set_dtype(prev.dtype());
learning_rate_out->set_dims(learning_rate.dims());
learning_rate_out->set_dtype(learning_rate.dtype());
if (multi_precision) {
master_param_out->set_dims(master_param.dims());
if (DataType::FLOAT16 == master_param.dtype() ||
DataType::BFLOAT16 == master_param.dtype()) {
master_param_out->set_dtype(DataType::FLOAT32);
} else {
master_param_out->set_dtype(master_param.dtype());
}
}
}

void SgdInferMeta(const MetaTensor& param,
const MetaTensor& learning_rate,
const MetaTensor& grad,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,19 @@ void RnnInferMeta(const MetaTensor& x,
std::vector<MetaTensor*> state,
MetaTensor* reserve);

void RpropInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& prev,
const MetaTensor& learning_rate,
const MetaTensor& master_param,
const MetaTensor& learning_rate_range,
const MetaTensor& etas,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* prev_out,
MetaTensor* learning_rate_out,
MetaTensor* master_param_out);

void SendUERecvInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
Expand Down
143 changes: 143 additions & 0 deletions paddle/phi/kernels/cpu/rprop_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/rprop_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/jit/kernels.h"

namespace phi {

template <typename T, typename Context>
void RpropKernelCPUImpl(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& prev,
const DenseTensor& learning_rate,
const DenseTensor& learning_rate_range,
const DenseTensor& etas,
DenseTensor* param_out,
DenseTensor* prev_out,
DenseTensor* learning_rate_out) {
auto param_eigen = EigenVector<T>::Flatten(param);
auto prev_eigen = EigenVector<T>::Flatten(prev);
auto param_out_eigen = EigenVector<T>::Flatten(*param_out);
auto prev_out_eigen = EigenVector<T>::Flatten(*prev_out);
auto learning_rate_out_eigen = EigenVector<T>::Flatten(*learning_rate_out);
auto learning_rate_min = learning_rate_range.data<T>()[0];
auto learning_rate_max = learning_rate_range.data<T>()[1];
auto eta_negative = etas.data<T>()[0];
auto eta_positive = etas.data<T>()[1];

DenseTensor* grad_tensor = new DenseTensor();
grad_tensor->Resize(grad.dims());
dev_ctx.template Alloc<T>(grad_tensor);
phi::Copy<Context>(dev_ctx, grad, dev_ctx.GetPlace(), true, grad_tensor);
auto grad_eigen = EigenVector<T>::Flatten(*grad_tensor);

DenseTensor* product_tensor = new DenseTensor();
product_tensor->Resize(grad.dims());
dev_ctx.template Alloc<T>(product_tensor);
auto product_eigen = EigenVector<T>::Flatten(*product_tensor);

DenseTensor* learning_rate_tensor = new DenseTensor();
learning_rate_tensor->Resize(learning_rate.dims());
dev_ctx.template Alloc<T>(learning_rate_tensor);
phi::Copy<Context>(
dev_ctx, learning_rate, dev_ctx.GetPlace(), true, learning_rate_tensor);
auto learning_rate_eigen = EigenVector<T>::Flatten(*learning_rate_tensor);

DenseTensor* eta_tensor = new DenseTensor();
eta_tensor->Resize(learning_rate.dims());
dev_ctx.template Alloc<T>(eta_tensor);
auto eta_eigen = EigenVector<T>::Flatten(*eta_tensor);

product_eigen = grad_eigen * prev_eigen;
T* product_data = product_tensor->data<T>();
T* grad_data = grad_tensor->data<T>();
T* eta_data = eta_tensor->data<T>();
T zero = static_cast<T>(0);
T one = static_cast<T>(1);
for (int i = 0, n = product_tensor->numel(); i < n; i++) {
if (product_data[i] > zero) {
eta_data[i] = eta_positive;
} else if (product_data[i] == zero) {
eta_data[i] = one;
} else if (product_data[i] < zero) {
grad_data[i] = zero;
eta_data[i] = eta_negative;
}
}

learning_rate_eigen = learning_rate_eigen * eta_eigen;
T* learning_rate_data = learning_rate_tensor->data<T>();
for (int i = 0, n = learning_rate_tensor->numel(); i < n; i++) {
if (learning_rate_data[i] > learning_rate_max) {
learning_rate_data[i] = learning_rate_max;
} else if (learning_rate_data[i] < learning_rate_min) {
learning_rate_data[i] = learning_rate_min;
}
}

param_out_eigen = param_eigen - grad_eigen.sign() * learning_rate_eigen;
prev_out_eigen = grad_eigen;
learning_rate_out_eigen = learning_rate_eigen;
phi::Copy<Context>(dev_ctx, *grad_tensor, dev_ctx.GetPlace(), true, prev_out);
phi::Copy<Context>(dev_ctx,
*learning_rate_tensor,
dev_ctx.GetPlace(),
true,
learning_rate_out);
}

template <typename T, typename Context>
void RpropKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& prev,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param UNUSED,
const DenseTensor& learning_rate_range,
const DenseTensor& etas,
bool multi_precision UNUSED,
DenseTensor* param_out,
DenseTensor* prev_out,
DenseTensor* learning_rate_out,
DenseTensor* master_param_out UNUSED) {
dev_ctx.template Alloc<T>(param_out);
dev_ctx.template Alloc<T>(prev_out);
dev_ctx.template Alloc<T>(learning_rate_out);
RpropKernelCPUImpl<T, Context>(dev_ctx,
param,
grad,
prev,
learning_rate,
learning_rate_range,
etas,
param_out,
prev_out,
learning_rate_out);
}

} // namespace phi

PD_REGISTER_KERNEL(rprop,
CPU,
ALL_LAYOUT,
phi::RpropKernel,
phi::dtype::bfloat16,
float,
double) {}
157 changes: 157 additions & 0 deletions paddle/phi/kernels/gpu/rprop_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/rprop_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_helper.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/mixed_vector.h"

namespace phi {

template <typename T, typename MT>
__global__ void RpropKernelGPUImpl(const T* param,
const T* grad,
const T* prev,
const T* learning_rate,
const MT* master_param,
const T* learning_rate_range,
const T* etas,
int num,
T* param_out,
T* prev_out,
T* learning_rate_out,
MT* master_param_out) {
MT learning_rate_min_data = static_cast<MT>(learning_rate_range[0]);
MT learning_rate_max_data = static_cast<MT>(learning_rate_range[1]);
MT eta_negative_data = static_cast<MT>(etas[0]);
MT eta_positive_data = static_cast<MT>(etas[1]);
MT zero_data = static_cast<MT>(0);
MT one_data = static_cast<MT>(1);
MT negative_one_data = static_cast<MT>(-1);

CUDA_KERNEL_LOOP(i, num) {
MT param_data = master_param ? master_param[i] : static_cast<MT>(param[i]);
MT grad_data = static_cast<MT>(grad[i]);
MT prev_data = static_cast<MT>(prev[i]);
MT learning_rate_data = static_cast<MT>(learning_rate[i]);
MT product_data = grad_data * prev_data;

MT eta_data = one_data;
if (product_data > zero_data) {
eta_data = eta_positive_data;
} else if (product_data < zero_data) {
grad_data = zero_data;
eta_data = eta_negative_data;
}

learning_rate_data = learning_rate_data * eta_data;
if (learning_rate_data > learning_rate_max_data) {
learning_rate_data = learning_rate_max_data;
} else if (learning_rate_data < learning_rate_min_data) {
learning_rate_data = learning_rate_min_data;
}

MT grad_sign_data = zero_data;
if (grad_data > zero_data) {
grad_sign_data = one_data;
} else if (grad_data < zero_data) {
grad_sign_data = negative_one_data;
}

param_data = param_data - grad_sign_data * learning_rate_data;
prev_data = grad_data;

param_out[i] = static_cast<T>(param_data);
prev_out[i] = static_cast<T>(prev_data);
learning_rate_out[i] = static_cast<T>(learning_rate_data);
if (master_param_out) {
master_param_out[i] = param_data;
}
}
}

template <typename T, typename Context>
void RpropKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& prev,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
const DenseTensor& learning_rate_range,
const DenseTensor& etas,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* prev_out,
DenseTensor* learning_rate_out,
DenseTensor* master_param_out) {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPDType* master_in_data =
multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision ? dev_ctx.template Alloc<MPDType>(master_param_out)
: nullptr;

int block = 512;
int grid = (param.numel() + block - 1) / block;

RpropKernelGPUImpl<T, MPDType><<<grid, block, 0, dev_ctx.stream()>>>(
param.data<T>(),
grad.data<T>(),
prev.data<T>(),
learning_rate.data<T>(),
master_in_data,
learning_rate_range.data<T>(),
etas.data<T>(),
param.numel(),
dev_ctx.template Alloc<T>(param_out),
dev_ctx.template Alloc<T>(prev_out),
dev_ctx.template Alloc<T>(learning_rate_out),
master_out_data);
}

} // namespace phi

#ifdef PADDLE_WITH_CUDA
PD_REGISTER_KERNEL(rprop,
GPU,
ALL_LAYOUT,
phi::RpropKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
}
}
#endif

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(rprop,
GPU,
ALL_LAYOUT,
phi::RpropKernel,
phi::dtype::float16,
float,
double) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
}
}
#endif
Loading