diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index e00e6c0c052585..1f8a7d158325ed 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2197,6 +2197,19 @@ inplace : (x -> out) backward : round_grad +- op : rprop_ + args : (Tensor param, Tensor grad, Tensor prev, Tensor learning_rate, Tensor master_param, Tensor learning_rate_range, Tensor etas, bool multi_precision=false) + output : Tensor(param_out), Tensor(prev_out), Tensor(learning_rate_out), Tensor(master_param_out) + infer_meta : + func : RpropInferMeta + kernel : + func : rprop + data_type : param + data_transform : + support_trans_dtype : learning_rate + optional : master_param, master_param_out + inplace : (param -> param_out), (prev -> prev_out), (learning_rate -> learning_rate_out), (master_param -> master_param_out) + - op : rsqrt args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index eee92aa1380449..0b2ef29389137c 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3470,6 +3470,50 @@ void RnnInferMeta(const MetaTensor& x, } } +void RpropInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& prev, + const MetaTensor& learning_rate, + const MetaTensor& master_param, + const MetaTensor& learning_rate_range, + const MetaTensor& etas, + bool multi_precision, + MetaTensor* param_out, + MetaTensor* prev_out, + MetaTensor* learning_rate_out, + MetaTensor* master_param_out) { + PADDLE_ENFORCE_NOT_NULL( + param_out, + phi::errors::InvalidArgument( + "Output(ParamOut) of RpropOp should not be null.")); + + PADDLE_ENFORCE_NOT_NULL( + prev_out, + phi::errors::InvalidArgument( + "Output(PrevOut) of RpropOp should not be null.")); + + PADDLE_ENFORCE_NOT_NULL( + learning_rate_out, + phi::errors::InvalidArgument( + "Output(LearningRateOut) of RpropOp should not be null.")); + + param_out->set_dims(param.dims()); + param_out->set_dtype(param.dtype()); + prev_out->set_dims(prev.dims()); + prev_out->set_dtype(prev.dtype()); + learning_rate_out->set_dims(learning_rate.dims()); + learning_rate_out->set_dtype(learning_rate.dtype()); + if (multi_precision) { + master_param_out->set_dims(master_param.dims()); + if (DataType::FLOAT16 == master_param.dtype() || + DataType::BFLOAT16 == master_param.dtype()) { + master_param_out->set_dtype(DataType::FLOAT32); + } else { + master_param_out->set_dtype(master_param.dtype()); + } + } +} + void SgdInferMeta(const MetaTensor& param, const MetaTensor& learning_rate, const MetaTensor& grad, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 8a5c9263acc9ad..be3f1fba94a800 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -627,6 +627,19 @@ void RnnInferMeta(const MetaTensor& x, std::vector state, MetaTensor* reserve); +void RpropInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& prev, + const MetaTensor& learning_rate, + const MetaTensor& master_param, + const MetaTensor& learning_rate_range, + const MetaTensor& etas, + bool multi_precision, + MetaTensor* param_out, + MetaTensor* prev_out, + MetaTensor* learning_rate_out, + MetaTensor* master_param_out); + void SendUERecvInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& src_index, diff --git a/paddle/phi/kernels/cpu/rprop_kernel.cc b/paddle/phi/kernels/cpu/rprop_kernel.cc new file mode 100644 index 00000000000000..e9950b6d986189 --- /dev/null +++ b/paddle/phi/kernels/cpu/rprop_kernel.cc @@ -0,0 +1,143 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rprop_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/jit/kernels.h" + +namespace phi { + +template +void RpropKernelCPUImpl(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& prev, + const DenseTensor& learning_rate, + const DenseTensor& learning_rate_range, + const DenseTensor& etas, + DenseTensor* param_out, + DenseTensor* prev_out, + DenseTensor* learning_rate_out) { + auto param_eigen = EigenVector::Flatten(param); + auto prev_eigen = EigenVector::Flatten(prev); + auto param_out_eigen = EigenVector::Flatten(*param_out); + auto prev_out_eigen = EigenVector::Flatten(*prev_out); + auto learning_rate_out_eigen = EigenVector::Flatten(*learning_rate_out); + auto learning_rate_min = learning_rate_range.data()[0]; + auto learning_rate_max = learning_rate_range.data()[1]; + auto eta_negative = etas.data()[0]; + auto eta_positive = etas.data()[1]; + + DenseTensor* grad_tensor = new DenseTensor(); + grad_tensor->Resize(grad.dims()); + dev_ctx.template Alloc(grad_tensor); + phi::Copy(dev_ctx, grad, dev_ctx.GetPlace(), true, grad_tensor); + auto grad_eigen = EigenVector::Flatten(*grad_tensor); + + DenseTensor* product_tensor = new DenseTensor(); + product_tensor->Resize(grad.dims()); + dev_ctx.template Alloc(product_tensor); + auto product_eigen = EigenVector::Flatten(*product_tensor); + + DenseTensor* learning_rate_tensor = new DenseTensor(); + learning_rate_tensor->Resize(learning_rate.dims()); + dev_ctx.template Alloc(learning_rate_tensor); + phi::Copy( + dev_ctx, learning_rate, dev_ctx.GetPlace(), true, learning_rate_tensor); + auto learning_rate_eigen = EigenVector::Flatten(*learning_rate_tensor); + + DenseTensor* eta_tensor = new DenseTensor(); + eta_tensor->Resize(learning_rate.dims()); + dev_ctx.template Alloc(eta_tensor); + auto eta_eigen = EigenVector::Flatten(*eta_tensor); + + product_eigen = grad_eigen * prev_eigen; + T* product_data = product_tensor->data(); + T* grad_data = grad_tensor->data(); + T* eta_data = eta_tensor->data(); + T zero = static_cast(0); + T one = static_cast(1); + for (int i = 0, n = product_tensor->numel(); i < n; i++) { + if (product_data[i] > zero) { + eta_data[i] = eta_positive; + } else if (product_data[i] == zero) { + eta_data[i] = one; + } else if (product_data[i] < zero) { + grad_data[i] = zero; + eta_data[i] = eta_negative; + } + } + + learning_rate_eigen = learning_rate_eigen * eta_eigen; + T* learning_rate_data = learning_rate_tensor->data(); + for (int i = 0, n = learning_rate_tensor->numel(); i < n; i++) { + if (learning_rate_data[i] > learning_rate_max) { + learning_rate_data[i] = learning_rate_max; + } else if (learning_rate_data[i] < learning_rate_min) { + learning_rate_data[i] = learning_rate_min; + } + } + + param_out_eigen = param_eigen - grad_eigen.sign() * learning_rate_eigen; + prev_out_eigen = grad_eigen; + learning_rate_out_eigen = learning_rate_eigen; + phi::Copy(dev_ctx, *grad_tensor, dev_ctx.GetPlace(), true, prev_out); + phi::Copy(dev_ctx, + *learning_rate_tensor, + dev_ctx.GetPlace(), + true, + learning_rate_out); +} + +template +void RpropKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& prev, + const DenseTensor& learning_rate, + const paddle::optional& master_param UNUSED, + const DenseTensor& learning_rate_range, + const DenseTensor& etas, + bool multi_precision UNUSED, + DenseTensor* param_out, + DenseTensor* prev_out, + DenseTensor* learning_rate_out, + DenseTensor* master_param_out UNUSED) { + dev_ctx.template Alloc(param_out); + dev_ctx.template Alloc(prev_out); + dev_ctx.template Alloc(learning_rate_out); + RpropKernelCPUImpl(dev_ctx, + param, + grad, + prev, + learning_rate, + learning_rate_range, + etas, + param_out, + prev_out, + learning_rate_out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(rprop, + CPU, + ALL_LAYOUT, + phi::RpropKernel, + phi::dtype::bfloat16, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/rprop_kernel.cu b/paddle/phi/kernels/gpu/rprop_kernel.cu new file mode 100644 index 00000000000000..4ae95c16898417 --- /dev/null +++ b/paddle/phi/kernels/gpu/rprop_kernel.cu @@ -0,0 +1,157 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/rprop_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_helper.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/mixed_vector.h" + +namespace phi { + +template +__global__ void RpropKernelGPUImpl(const T* param, + const T* grad, + const T* prev, + const T* learning_rate, + const MT* master_param, + const T* learning_rate_range, + const T* etas, + int num, + T* param_out, + T* prev_out, + T* learning_rate_out, + MT* master_param_out) { + MT learning_rate_min_data = static_cast(learning_rate_range[0]); + MT learning_rate_max_data = static_cast(learning_rate_range[1]); + MT eta_negative_data = static_cast(etas[0]); + MT eta_positive_data = static_cast(etas[1]); + MT zero_data = static_cast(0); + MT one_data = static_cast(1); + MT negative_one_data = static_cast(-1); + + CUDA_KERNEL_LOOP(i, num) { + MT param_data = master_param ? master_param[i] : static_cast(param[i]); + MT grad_data = static_cast(grad[i]); + MT prev_data = static_cast(prev[i]); + MT learning_rate_data = static_cast(learning_rate[i]); + MT product_data = grad_data * prev_data; + + MT eta_data = one_data; + if (product_data > zero_data) { + eta_data = eta_positive_data; + } else if (product_data < zero_data) { + grad_data = zero_data; + eta_data = eta_negative_data; + } + + learning_rate_data = learning_rate_data * eta_data; + if (learning_rate_data > learning_rate_max_data) { + learning_rate_data = learning_rate_max_data; + } else if (learning_rate_data < learning_rate_min_data) { + learning_rate_data = learning_rate_min_data; + } + + MT grad_sign_data = zero_data; + if (grad_data > zero_data) { + grad_sign_data = one_data; + } else if (grad_data < zero_data) { + grad_sign_data = negative_one_data; + } + + param_data = param_data - grad_sign_data * learning_rate_data; + prev_data = grad_data; + + param_out[i] = static_cast(param_data); + prev_out[i] = static_cast(prev_data); + learning_rate_out[i] = static_cast(learning_rate_data); + if (master_param_out) { + master_param_out[i] = param_data; + } + } +} + +template +void RpropKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& prev, + const DenseTensor& learning_rate, + const paddle::optional& master_param, + const DenseTensor& learning_rate_range, + const DenseTensor& etas, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* prev_out, + DenseTensor* learning_rate_out, + DenseTensor* master_param_out) { + using MPDType = typename phi::dtype::MPTypeTrait::Type; + const MPDType* master_in_data = + multi_precision ? master_param->data() : nullptr; + MPDType* master_out_data = + multi_precision ? dev_ctx.template Alloc(master_param_out) + : nullptr; + + int block = 512; + int grid = (param.numel() + block - 1) / block; + + RpropKernelGPUImpl<<>>( + param.data(), + grad.data(), + prev.data(), + learning_rate.data(), + master_in_data, + learning_rate_range.data(), + etas.data(), + param.numel(), + dev_ctx.template Alloc(param_out), + dev_ctx.template Alloc(prev_out), + dev_ctx.template Alloc(learning_rate_out), + master_out_data); +} + +} // namespace phi + +#ifdef PADDLE_WITH_CUDA +PD_REGISTER_KERNEL(rprop, + GPU, + ALL_LAYOUT, + phi::RpropKernel, + phi::dtype::float16, + phi::dtype::bfloat16, + float, + double) { + if (kernel_key.dtype() == phi::DataType::FLOAT16 || + kernel_key.dtype() == phi::DataType::BFLOAT16) { + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + } +} +#endif + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(rprop, + GPU, + ALL_LAYOUT, + phi::RpropKernel, + phi::dtype::float16, + float, + double) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + } +} +#endif diff --git a/paddle/phi/kernels/rprop_kernel.h b/paddle/phi/kernels/rprop_kernel.h new file mode 100644 index 00000000000000..adeefebcd46012 --- /dev/null +++ b/paddle/phi/kernels/rprop_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void RpropKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& prev, + const DenseTensor& learning_rate, + const paddle::optional& master_param, + const DenseTensor& learning_rate_range, + const DenseTensor& etas, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* prev_out, + DenseTensor* learning_rate_out, + DenseTensor* master_param_out); + +} // namespace phi diff --git a/python/paddle/optimizer/__init__.py b/python/paddle/optimizer/__init__.py index bf8d63b2171237..516779cd924f6a 100644 --- a/python/paddle/optimizer/__init__.py +++ b/python/paddle/optimizer/__init__.py @@ -23,6 +23,7 @@ from .momentum import Momentum from .optimizer import Optimizer from .rmsprop import RMSProp +from .rprop import Rprop from .sgd import SGD __all__ = [ @@ -34,6 +35,7 @@ 'RMSProp', 'Adadelta', 'SGD', + 'Rprop', 'Momentum', 'Lamb', 'LBFGS', diff --git a/python/paddle/optimizer/rprop.py b/python/paddle/optimizer/rprop.py new file mode 100644 index 00000000000000..25b4be7170be20 --- /dev/null +++ b/python/paddle/optimizer/rprop.py @@ -0,0 +1,267 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +from paddle import _C_ops +from paddle.tensor.creation import to_tensor + +from ..base import framework +from ..base.dygraph import no_grad +from ..base.framework import in_dynamic_or_pir_mode +from .optimizer import Optimizer + +__all__ = [] + + +class Rprop(Optimizer): + r""" + **Notes: This optimizer is only applicable to full-batch training.** + Optimizer of the Rprop algorithm.Please refer to this for details: + `A direct adaptive method for faster backpropagation learning : The RPROP algorithm `_. + + .. math:: + + \begin{aligned} + &\hspace{0mm} For\ all\ weights\ and\ biases\{ \\ + &\hspace{5mm} \textbf{if} \: (\frac{\partial E}{\partial w_{ij}}(t-1)*\frac{\partial E}{\partial w_{ij}}(t)> 0)\ \textbf{then} \: \{ \\ + &\hspace{10mm} learning\_rate_{ij}(t)=\mathrm{minimum}(learning\_rate_{ij}(t-1)*\eta^{+},learning\_rate_{max}) \\ + &\hspace{10mm} \Delta w_{ij}(t)=-sign(\frac{\partial E}{\partial w_{ij}}(t))*learning\_rate_{ij}(t) \\ + &\hspace{10mm} w_{ij}(t+1)=w_{ij}(t)+\Delta w_{ij}(t) \\ + &\hspace{5mm} \} \\ + &\hspace{5mm} \textbf{else if} \: (\frac{\partial E}{\partial w_{ij}}(t-1)*\frac{\partial E}{\partial w_{ij}}(t)< 0)\ \textbf{then} \: \{ \\ + &\hspace{10mm} learning\_rate_{ij}(t)=\mathrm{maximum}(learning\_rate_{ij}(t-1)*\eta^{-},learning\_rate_{min}) \\ + &\hspace{10mm} w_{ij}(t+1)=w_{ij}(t) \\ + &\hspace{10mm} \frac{\partial E}{\partial w_{ij}}(t)=0 \\ + &\hspace{5mm} \} \\ + &\hspace{5mm} \textbf{else if} \: (\frac{\partial E}{\partial w_{ij}}(t-1)*\frac{\partial E}{\partial w_{ij}}(t)= 0)\ \textbf{then} \: \{ \\ + &\hspace{10mm} \Delta w_{ij}(t)=-sign(\frac{\partial E}{\partial w_{ij}}(t))*learning\_rate_{ij}(t) \\ + &\hspace{10mm} w_{ij}(t+1)=w_{ij}(t)+\Delta w_{ij}(t) \\ + &\hspace{5mm} \} \\ + &\hspace{0mm} \} \\ + \end{aligned} + + Parameters: + learning_rate (float|Tensor|LearningRateDecay, optional): The initial learning rate used to update ``Parameter``. + It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. + learning_rate_range (tuple, optional): The range of learning rate. + Learning rate cannot be smaller than the first element of the tuple; + learning rate cannot be larger than the second element of the tuple. + The default value is (1e-5, 50). + parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. + This parameter is required in dygraph mode. + The default value is None in static graph mode, at this time all parameters will be updated. + etas (tuple, optional): Tuple used to update learning rate. + The first element of the tuple is the multiplicative decrease factor; + the second element of the tuple is the multiplicative increase factor. + The default value is (0.5, 1.2). + grad_clip (GradientClipBase, optional): Gradient clipping strategy, it's an instance of some derived class of ``GradientClipBase`` . + There are three clipping strategies ( :ref:`api_paddle_nn_ClipGradByGlobalNorm` , :ref:`api_paddle_nn_ClipGradByNorm` , :ref:`api_paddle_nn_ClipGradByValue` ). + Default None, meaning there is no gradient clipping. + multi_precision (bool, optional): In mixed precision training scenarios based on GPU, + this parameter is mainly used to ensure the numerical stability of gradient updates. + When it is set to True, the optimizer will save a backup of FP32 type parameters with an equal value for FP16 type parameters. + When updating gradients, first increase the gradient type to FP32, and then assign it to the FP32 type parameter backup. + Finally, the updated FP32 type value will be converted to FP16 type first, + and then assigned to the actual FP16 type parameters participating in the calculation. + The default value is False. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Examples: + .. code-block:: python + + >>> import paddle + + >>> inp = paddle.uniform(min=-0.1, max=0.1, shape=[1, 100], dtype='float32') + >>> linear = paddle.nn.Linear(100, 10) + >>> inp = paddle.to_tensor(inp) + >>> out = linear(inp) + >>> loss = paddle.mean(out) + >>> rprop = paddle.optimizer.Rprop(learning_rate=0.001, learning_rate_range=(0.0001,0.1), parameters=linear.parameters(), etas=(0.5,1.2)) + >>> out.backward() + >>> rprop.step() + >>> rprop.clear_grad() + """ + _prevs_acc_str = "prevs" + _learning_rates_acc_str = "learning_rates" + + def __init__( + self, + learning_rate=0.001, + learning_rate_range=(1e-5, 50), + parameters=None, + etas=(0.5, 1.2), + grad_clip=None, + multi_precision=False, + name=None, + ): + if learning_rate is None: + raise ValueError("learning_rate is not set") + if ( + not 0.0 + < learning_rate_range[0] + <= learning_rate + <= learning_rate_range[1] + ): + raise ValueError( + "'0.0 < learning_rate_range[0] <= learning_rate <= learning_rate_range[1]' must be true" + ) + if not 0.0 < etas[0] < 1.0 < etas[1]: + raise ValueError("'0.0 < etas[0] < 1.0 < etas[1]' must be true") + super().__init__( + learning_rate=learning_rate, + parameters=parameters, + weight_decay=0.0, + grad_clip=grad_clip, + name=name, + ) + self.type = "rprop" + self._initial_learning_rate = learning_rate + self._multi_precision = multi_precision + self._master_weights = {} + self._learning_rate_range = [learning_rate_range] + self._etas = [etas] + self._sign = True + + def _to_tensor(self, block, dtype): + assert isinstance(block, framework.Block) + self._learning_rate_range = to_tensor( + self._learning_rate_range, dtype=dtype + ) + self._etas = to_tensor(self._etas, dtype=dtype) + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + if isinstance(parameters, dict): + parameters = self._update_param_group(parameters) + + # Create accumulator tensors for first and second moments + for p in parameters: + if p.name in self._already_create_accumulater: + continue + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): + master_p = self._create_master_weight(p) + self._add_accumulator( + self._prevs_acc_str, + master_p, + p.dtype, + 0, + ) + self._add_accumulator( + self._learning_rates_acc_str, + master_p, + p.dtype, + self._initial_learning_rate, + ) + self._already_create_accumulater.add(p.name) + continue + if ( + self._is_dtype_fp16_or_bf16(p.dtype) + and not self._multi_precision + ): + warnings.warn( + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." + "Consider using multi_precision=True option of the Adam optimizer." + ) + self._add_accumulator( + self._prevs_acc_str, + p, + p.dtype, + 0, + ) + self._add_accumulator( + self._learning_rates_acc_str, + p, + p.dtype, + fill_value=self._initial_learning_rate, + ) + self._already_create_accumulater.add(p.name) + + @no_grad + def _append_optimize_op(self, block, param_and_grad): + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + + if self._sign: + self._to_tensor(block, param_and_grad[0][0].dtype) + self._sign = False + + prevs = self._get_accumulator_master( + self._prevs_acc_str, param_and_grad[0] + ) + + learning_rates = self._get_accumulator_master( + self._learning_rates_acc_str, param_and_grad[0] + ) + + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype + ) + master_weight = ( + self._master_weights[param_and_grad[0].name] + if find_master + else None + ) + + if in_dynamic_or_pir_mode(): + _C_ops.rprop_( + param_and_grad[0], + param_and_grad[1], + prevs, + learning_rates, + master_weight, + self._learning_rate_range, + self._etas, + find_master, + ) + + return None + else: + assert isinstance(block, framework.Block) + # create the optimize op + inputs = { + "param": param_and_grad[0], + "grad": param_and_grad[1], + "prev": prevs, + "learning_rate": learning_rates, + "learning_rate_range": self._learning_rate_range, + "etas": self._etas, + } + + outputs = { + "param_out": param_and_grad[0], + "prev_out": prevs, + "learning_rate_out": learning_rates, + } + + attrs = {"multi_precision": find_master} + + if find_master: + inputs["master_param"] = master_weight + outputs["master_param_out"] = master_weight + + rprop_op = block.append_op( + type=self.type, + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True, + ) + + return rprop_op + + def _update_param_group(self, parameters): + parameters = parameters.get('params') + return parameters diff --git a/test/legacy_test/test_rprop_op.py b/test/legacy_test/test_rprop_op.py new file mode 100644 index 00000000000000..e92203897746d8 --- /dev/null +++ b/test/legacy_test/test_rprop_op.py @@ -0,0 +1,455 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import ( + OpTest, + convert_float_to_uint16, +) +from utils import dygraph_guard + +import paddle +from paddle.base import core + +paddle.enable_static() + + +def rprop_wrapper( + param, + grad, + prev, + learning_rate, + master_param=None, + learning_rate_range=np.array((1e-5, 50)).astype("float32"), + etas=np.array((0.5, 1.2)).astype("float32"), + multi_precision=False, +): + paddle._C_ops.rprop_( + param, + grad, + prev, + learning_rate, + master_param, + learning_rate_range, + etas, + multi_precision, + ) + + +class TestRpropOp(OpTest): + def setUp(self): + self.op_type = "rprop" + self.python_api = rprop_wrapper + self.python_out_sig = ['Out'] + self.conf() + params = np.random.random((self.h, self.w)).astype("float32") + grads = np.random.random((self.h, self.w)).astype("float32") + prevs = np.random.random((self.h, self.w)).astype("float32") + learning_rates = np.random.random((self.h, self.w)).astype("float32") + + scale = 0.01 + np.subtract(params, 0.5, out=params) + np.multiply(params, scale, out=params) + np.subtract(grads, 0.5, out=grads) + np.multiply(grads, scale, out=grads) + np.subtract(prevs, 0.5, out=prevs) + np.multiply(prevs, scale, out=prevs) + np.multiply(learning_rates, scale, out=learning_rates) + + learning_rate_min = 0.1 * scale + learning_rate_max = 0.9 * scale + eta_negative = 0.5 + eta_positive = 1.2 + + param_outs = params.copy() + prev_outs = prevs.copy() + learning_rate_outs = learning_rates.copy() + + for i, param in enumerate(params): + grad = grads[i] + prev = prevs[i] + lr = learning_rate_outs[i] + param_out = param_outs[i] + prev_out = prev_outs[i] + + sign = np.sign(np.multiply(grad, prev)) + sign[np.greater(sign, 0)] = eta_positive + sign[np.less(sign, 0)] = eta_negative + sign[np.equal(sign, 0)] = 1 + np.multiply(lr, sign, out=lr) + lr[np.less(lr, learning_rate_min)] = learning_rate_min + lr[np.greater(lr, learning_rate_max)] = learning_rate_max + + grad = grad.copy() + grad[np.equal(sign, eta_negative)] = 0 + + learning_rate_outs[i] = lr + param_outs[i] = np.subtract( + param_out, np.multiply(np.sign(grad), lr) + ) + prev_outs[i] = grad.copy() + + self.inputs = { + "param": params, + "grad": grads, + "prev": prevs, + "learning_rate": learning_rates, + "learning_rate_range": np.array( + (learning_rate_min, learning_rate_max) + ).astype("float32"), + "etas": np.array((0.5, 1.2)).astype("float32"), + } + + self.outputs = { + "param_out": param_outs, + "prev_out": prev_outs, + "learning_rate_out": learning_rate_outs, + } + + def conf(self): + self.h = 102 + self.w = 105 + + def test_check_output(self): + self.check_output(check_pir=True) + + +class TestRpropOpCase8X(TestRpropOp): + def conf(self): + self.h = 10 + self.w = 64 + + +class TestRpropV2(unittest.TestCase): + def test_rprop_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(1, 26).astype("float32") + a = paddle.to_tensor(value) + linear = paddle.nn.Linear(26, 5) + + rprop = paddle.optimizer.Rprop( + learning_rate=0.01, + parameters=linear.parameters(), + ) + out = linear(a) + out.backward() + rprop.step() + rprop.clear_gradients() + + def test_raise_error(self): + self.assertRaises( + ValueError, paddle.optimizer.Rprop, learning_rate=None + ) + self.assertRaises( + ValueError, + paddle.optimizer.Rprop, + learning_rate=1e-3, + learning_rate_range=np.array((1e-2, 1e-1)).astype("float32"), + ) + self.assertRaises( + ValueError, + paddle.optimizer.Rprop, + learning_rate=1e-3, + etas=np.array((-0.1, 1.1)).astype("float32"), + ) + + def test_rprop_group_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(1, 26).astype("float32") + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(26, 5) + linear_2 = paddle.nn.Linear(5, 3) + rprop = paddle.optimizer.Rprop( + learning_rate=0.01, + parameters=[ + {'params': linear_1.parameters()}, + { + 'params': linear_2.parameters(), + 'learning_rate': 0.1, + }, + ], + ) + out = linear_1(a) + out = linear_2(out) + out.backward() + rprop.step() + rprop.clear_gradients() + + +class TestRpropMultiPrecision2_0(unittest.TestCase): + def dygraph_rprop_mp(self, mp): + paddle.disable_static() + paddle.seed(10) + paddle.set_device('gpu') + input = paddle.randn((2, 2)) + model = paddle.nn.Linear(2, 2) + optimizer = paddle.optimizer.Rprop( + parameters=model.parameters(), multi_precision=mp + ) + if mp: + model = paddle.amp.decorate(models=model, level='O2') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + for idx in range(5): + if mp: + with paddle.amp.auto_cast(level='O2'): + output = model(input) + loss = paddle.mean(output) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) + optimizer.clear_grad() + else: + output = model(input) + loss = paddle.mean(output) + optimizer.step() + optimizer.clear_grad() + + return output, model.parameters() + + def static_rprop_mp(self, mp): + paddle.enable_static() + paddle.seed(10) + np.random.seed(10) + exe = paddle.static.Executor('gpu') + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + optimizer = paddle.optimizer.Rprop(multi_precision=mp) + + if mp: + optimizer = paddle.static.amp.decorate( + optimizer, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + use_pure_fp16=True, + use_fp16_guard=False, + ) + with paddle.static.program_guard(train_program, startup_program): + if mp: + data = paddle.static.data( + shape=[2, 2], name='X', dtype='float16' + ) + else: + data = paddle.static.data( + shape=[2, 2], name='X', dtype='float32' + ) + hidden = paddle.static.nn.fc(x=data, size=10) + loss = paddle.mean(hidden) + optimizer.minimize(loss) + exe.run(startup_program) + + if mp: + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) + x = np.random.random(size=(2, 2)).astype('float16') + else: + x = np.random.random(size=(2, 2)).astype('float32') + out = [] + for idx in range(5): + (loss_data,) = exe.run( + train_program, feed={"X": x}, fetch_list=[loss.name] + ) + out.append(loss_data) + return out + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + "Test dygraph mode" + output1_dy, params1_dy = self.dygraph_rprop_mp(mp=True) + output2_dy, params2_dy = self.dygraph_rprop_mp(mp=False) + np.testing.assert_allclose( + output1_dy.astype('float32').numpy(), + output2_dy.astype('float32').numpy(), + rtol=1e-05, + atol=0.1, + ) + for idx in range(len(params1_dy)): + np.testing.assert_allclose( + params1_dy[idx].astype('float32').numpy(), + params2_dy[idx].astype('float32').numpy(), + rtol=1e-05, + atol=0.1, + ) + "Test static graph mode" + output1_st = self.static_rprop_mp(mp=True) + output2_st = self.static_rprop_mp(mp=False) + for idx in range(len(output1_st)): + np.testing.assert_allclose( + output1_st[idx].astype('float32'), + output2_st[idx].astype('float32'), + rtol=1e-05, + atol=0.1, + ) + + +class TestRpropSimple(unittest.TestCase): + def setUp(self) -> None: + self.data = np.random.random(size=(2, 2)).astype('float32') + + def run_static(self): + with paddle.pir_utils.IrGuard(): + paddle.seed(10) + np.random.seed(10) + + exe = paddle.static.Executor('gpu') + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(train_program, startup_program): + input = paddle.static.data( + shape=[2, 2], name='input', dtype='float32' + ) + model = paddle.nn.Linear(2, 2) + output = model(input) + loss = paddle.mean(output) + + optimizer = paddle.optimizer.Rprop() + optimizer.minimize(loss) + + exe.run(startup_program) + + out = [] + for _ in range(5): + (loss_data,) = exe.run( + train_program, feed={"input": self.data}, fetch_list=[loss] + ) + out.append(loss_data) + return out + + def run_dygraph(self): + with dygraph_guard(): + paddle.seed(10) + np.random.seed(10) + + out = [] + model = paddle.nn.Linear(2, 2) + optimizer = paddle.optimizer.Rprop(parameters=model.parameters()) + for _ in range(5): + output = model(paddle.to_tensor(self.data)) + loss = paddle.mean(output) + out.append(loss.numpy()) + loss.backward() + optimizer.step() + optimizer.clear_grad() + + return out + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + out1 = self.run_dygraph() + out2 = self.run_static() + np.testing.assert_allclose(out1, out2) + + +@unittest.skipIf( + not core.supports_bfloat16(), 'place does not support BF16 evaluation' +) +class TestRpropOpBF16(OpTest): + def setUp(self): + self.op_type = "rprop" + self.dtype = np.uint16 + self.use_mkldnn = True + self.conf() + params = np.random.random((self.h, self.w)).astype("float32") + grads = np.random.random((self.h, self.w)).astype("float32") + prevs = np.random.random((self.h, self.w)).astype("float32") + learning_rates = np.random.random((self.h, self.w)).astype("float32") + + scale = 0.01 + np.subtract(params, 0.5, out=params) + np.multiply(params, scale, out=params) + np.subtract(grads, 0.5, out=grads) + np.multiply(grads, scale, out=grads) + np.subtract(prevs, 0.5, out=prevs) + np.multiply(prevs, scale, out=prevs) + np.multiply(learning_rates, scale, out=learning_rates) + + learning_rate_min = 0.1 * scale + learning_rate_max = 0.9 * scale + eta_negative = 0.5 + eta_positive = 1.2 + + param_outs = params.copy() + prev_outs = prevs.copy() + learning_rate_outs = learning_rates.copy() + + for i, param in enumerate(params): + grad = grads[i] + prev = prevs[i] + lr = learning_rate_outs[i] + param_out = param_outs[i] + prev_out = prev_outs[i] + + sign = np.sign(np.multiply(grad, prev)) + sign[np.greater(sign, 0)] = eta_positive + sign[np.less(sign, 0)] = eta_negative + sign[np.equal(sign, 0)] = 1 + np.multiply(lr, sign, out=lr) + lr[np.less(lr, learning_rate_min)] = learning_rate_min + lr[np.greater(lr, learning_rate_max)] = learning_rate_max + + grad = grad.copy() + grad[np.equal(sign, eta_negative)] = 0 + + learning_rate_outs[i] = lr + param_outs[i] = np.subtract( + param_out, np.multiply(np.sign(grad), lr) + ) + prev_outs[i] = grad.copy() + + learning_rate_range = np.array( + (learning_rate_min, learning_rate_max) + ).astype("float32") + etas = np.array((0.5, 1.2)).astype("float32") + + params_bf16 = convert_float_to_uint16(params) + grads_bf16 = convert_float_to_uint16(grads) + prevs_bf16 = convert_float_to_uint16(prevs) + learning_rates_bf16 = convert_float_to_uint16(learning_rates) + learning_rate_range_bf16 = convert_float_to_uint16(learning_rate_range) + etas_bf16 = convert_float_to_uint16(etas) + + param_outs_bf16 = convert_float_to_uint16(param_outs) + prev_outs_bf16 = convert_float_to_uint16(prev_outs) + learning_rate_outs_bf16 = convert_float_to_uint16(learning_rate_outs) + + self.inputs = { + "param": params_bf16, + "grad": grads_bf16, + "prev": prevs_bf16, + "learning_rate": learning_rates_bf16, + "learning_rate_range": learning_rate_range_bf16, + "etas": etas_bf16, + } + + self.outputs = { + "param_out": param_outs_bf16, + "prev_out": prev_outs_bf16, + "learning_rate_out": learning_rate_outs_bf16, + } + + def conf(self): + self.h = 102 + self.w = 105 + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace(), check_dygraph=False) + + +if __name__ == "__main__": + unittest.main()