diff --git a/examples/language_model/gpt-3/README.md b/examples/language_model/gpt-3/README.md index fee001f9a4ac..d58d2a15a16a 100644 --- a/examples/language_model/gpt-3/README.md +++ b/examples/language_model/gpt-3/README.md @@ -10,6 +10,18 @@ GPT-[3](https://arxiv.org/pdf/2005.14165.pdf) 是以[Transformer](https://arxiv. ## 使用方法 +### 环境依赖 + +- regex +- sentencepiece +- tqdm +- visualdl +- paddlepaddle-gpu >= 2.2rc + +安装命令 `pip install regex sentencepiece tqdm visualdl`。 +注:需要PaddlePaddle版本大于等于2.2rc,或者使用最新develop版本,安装方法请参见Paddle[官网](https://www.paddlepaddle.org.cn)。 + + ```shell cd static # 或者 cd dygraph # 下载样例数据 diff --git a/examples/language_model/gpt-3/static/run_pretrain_static.py b/examples/language_model/gpt-3/static/run_pretrain_static.py index 4415796521e3..6a14fa83a5ee 100644 --- a/examples/language_model/gpt-3/static/run_pretrain_static.py +++ b/examples/language_model/gpt-3/static/run_pretrain_static.py @@ -294,33 +294,17 @@ def do_train(args): p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] - # TODO @ZHUI Use paddle.optimizer.AdamW - if ops.optimizer._jit_compile(): - logger.info("Using paddlenlp custom AdamW optimizer.") - optimizer = ops.optimizer.AdamwOptimizer( - learning_rate=lr_scheduler, - beta1=args.adam_beta1, - beta2=args.adam_beta2, - epsilon=args.adam_epsilon, - grad_clip=clip, - weight_decay=args.weight_decay, - apply_decay_param_fun=lambda x: x in decay_param) - else: - if args.sharding_degree > 1: - raise ValueError( - "The paddle.optimizer.AdamW not compatible with Sharding!" - ) - logger.info("Using paddle.optimizer.AdamW.") - optimizer = paddle.optimizer.AdamW( - learning_rate=lr_scheduler, - beta1=args.adam_beta1, - beta2=args.adam_beta2, - epsilon=args.adam_epsilon, - grad_clip=clip, - weight_decay=args.weight_decay, - apply_decay_param_fun=lambda x: x in decay_param) - # alias - optimizer.apply_optimize = optimizer._apply_optimize + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=args.adam_beta1, + beta2=args.adam_beta2, + epsilon=args.adam_epsilon, + grad_clip=clip, + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_param) + # alias + optimizer.apply_optimize = optimizer._apply_optimize if args.use_recompute: dist_strategy.recompute = True @@ -341,12 +325,12 @@ def do_train(args): if not os.path.isdir(program_desc_dir): os.mkdir(program_desc_dir) - with open(program_desc_dir + "/main_program.txt.%d" % - (int(os.environ.get('FLAGS_selected_gpus', 0))), 'w') as f: + with open(program_desc_dir + "/main_program.txt.%d" % worker_index, + 'w') as f: f.write(str(main_program)) - with open(program_desc_dir + "/startup_program.txt.%d" % - (int(os.environ.get('FLAGS_selected_gpus', 0))), 'w') as f: + with open(program_desc_dir + "/startup_program.txt.%d" % worker_index, + 'w') as f: f.write(str(startup_program)) # Define the Executor for running the static model diff --git a/examples/language_model/gpt/README.md b/examples/language_model/gpt/README.md index 3b7496df212b..4bb3ec5cd872 100644 --- a/examples/language_model/gpt/README.md +++ b/examples/language_model/gpt/README.md @@ -26,11 +26,15 @@ GPT-[2](https://cdn.openai.com/better-language-models/language_models_are_unsupe ## 快速开始 ### 环境依赖 + - regex - sentencepiece - tqdm - visualdl -安装命令 `pip install regex sentencepiece tqdm visualdl` +- paddlepaddle-gpu >= 2.2rc + +安装命令 `pip install regex sentencepiece tqdm visualdl`。 +注:需要PaddlePaddle版本大于等于2.2rc,或者使用最新develop版本,安装方法请参见Paddle[官网](https://www.paddlepaddle.org.cn)。 ### 数据准备 diff --git a/examples/language_model/gpt/run_pretrain_static.py b/examples/language_model/gpt/run_pretrain_static.py index ba96d1d8a28f..27649c9405d4 100644 --- a/examples/language_model/gpt/run_pretrain_static.py +++ b/examples/language_model/gpt/run_pretrain_static.py @@ -294,33 +294,18 @@ def do_train(args): p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ] - # TODO @ZHUI Use paddle.optimizer.AdamW - if ops.optimizer._jit_compile(): - logger.info("Using paddlenlp custom AdamW optimizer.") - optimizer = ops.optimizer.AdamwOptimizer( - learning_rate=lr_scheduler, - beta1=args.adam_beta1, - beta2=args.adam_beta2, - epsilon=args.adam_epsilon, - grad_clip=clip, - weight_decay=args.weight_decay, - apply_decay_param_fun=lambda x: x in decay_param) - else: - if args.sharding_degree > 1: - raise ValueError( - "The paddle.optimizer.AdamW not compatible with Sharding!" - ) - logger.info("Using paddle.optimizer.AdamW.") - optimizer = paddle.optimizer.AdamW( - learning_rate=lr_scheduler, - beta1=args.adam_beta1, - beta2=args.adam_beta2, - epsilon=args.adam_epsilon, - grad_clip=clip, - weight_decay=args.weight_decay, - apply_decay_param_fun=lambda x: x in decay_param) - # alias - optimizer.apply_optimize = optimizer._apply_optimize + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=args.adam_beta1, + beta2=args.adam_beta2, + epsilon=args.adam_epsilon, + grad_clip=clip, + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_param) + + # alias + optimizer.apply_optimize = optimizer._apply_optimize if args.use_recompute: dist_strategy.recompute = True @@ -341,12 +326,12 @@ def do_train(args): if not os.path.isdir(program_desc_dir): os.mkdir(program_desc_dir) - with open(program_desc_dir + "/main_program.txt.%d" % - (int(os.environ.get('FLAGS_selected_gpus', 0))), 'w') as f: + with open(program_desc_dir + "/main_program.txt.%d" % worker_index, + 'w') as f: f.write(str(main_program)) - with open(program_desc_dir + "/startup_program.txt.%d" % - (int(os.environ.get('FLAGS_selected_gpus', 0))), 'w') as f: + with open(program_desc_dir + "/startup_program.txt.%d" % worker_index, + 'w') as f: f.write(str(startup_program)) # Define the Executor for running the static model diff --git a/paddlenlp/ops/optimizer/AdamwOptimizer.py b/paddlenlp/ops/optimizer/AdamwOptimizer.py deleted file mode 100644 index 7592323087f5..000000000000 --- a/paddlenlp/ops/optimizer/AdamwOptimizer.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -try: - from collections.abc import Callable # noqa -except ImportError: - from collections import Callable # noqa - -import paddle -from paddle.fluid.framework import Variable -from paddle.fluid import framework -from paddle.fluid import layers -from paddle.fluid.layers import ops -from paddle.fluid import core -from paddle.fluid.optimizer import Optimizer - -__all__ = ['AdamwOptimizer', ] - - -class AdamwOptimizer(Optimizer): - r""" - The AdamW optimizer is implemented based on the AdamW Optimization - in paper `DECOUPLED WEIGHT DECAY REGULARIZATION `_. - it can resolves the problem of L2 regularization failure in the Adam optimizer. - .. math:: - t & = t + 1 - moment\_1\_out & = {\\beta}_1 * moment\_1 + (1 - {\\beta}_1) * grad - moemnt\_2\_out & = {\\beta}_2 * moment\_2 + (1 - {\\beta}_2) * grad * grad - learning\_rate & = learning\_rate * \\ - \\frac{\sqrt{1 - {\\beta}_2^t}}{1 - {beta}_1^t} - param\_out & = param - learning\_rate * (\\frac{moment\_1}{\sqrt{moment\_2} + \epsilon} + \lambda * param) - - Args: - learning_rate (float|Variable, optional): The learning rate used to update ``Parameter``. - It can be a float value or a ``Variable`` with a float type. The default value is 0.001. - beta1 (float|Variable, optional): The exponential decay rate for the 1st moment estimates. - It should be a float number or a Variable with shape [1] and data type as float32. - The default value is 0.9. - beta2 (float|Variable, optional): The exponential decay rate for the 2nd moment estimates. - It should be a float number or a Variable with shape [1] and data type as float32. - The default value is 0.999. - epsilon (float, optional): A small float value for numerical stability. - The default value is 1e-08. - parameter_list (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ - The default value is None in static mode, at this time all parameters will be updated. - weight_decay (float, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01. - apply_decay_param_fun (function|None, optional): If it is not None, - only tensors that makes apply_decay_param_fun(Tensor.name)==True - will be updated. It only works when we want to specify tensors. - Default: None. - regularization (WeightDecayRegularizer, optional): The strategy of regularization. There are two method: \ - :ref:`api_fluid_regularizer_L1Decay` , :ref:`api_fluid_regularizer_L2Decay` . If a parameter has set \ - regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ - ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ - Default None, meaning there is no regularization. - grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of - some derived class of ``GradientClipBase`` . There are three cliping strategies - ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , - :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. - name (str, optional): Normally there is no need for user to set this property. - For more information, please refer to :ref:`api_guide_Name`. - The default value is None. - lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators. - The accumulators are updated at every step. Every element of the two moving-average - is updated in both dense mode and sparse mode. If the size of parameter is very large, - then the update may be very slow. The lazy mode only update the element that has - gradient in current mini-batch, so it will be much more faster. But this mode has - different semantics with the original Adam algorithm and may lead to different result. - The default value is False. - - """ - _moment1_acc_str = "moment1" - _moment2_acc_str = "moment2" - _beta1_pow_acc_str = "beta1_pow_acc" - _beta2_pow_acc_str = "beta2_pow_acc" - - def __init__(self, - learning_rate=0.001, - beta1=0.9, - beta2=0.999, - epsilon=1e-8, - parameter_list=None, - regularization=None, - grad_clip=None, - weight_decay=None, - lr_ratio=None, - apply_decay_param_fun=None, - name=None, - lazy_mode=False): - assert learning_rate is not None - assert beta1 is not None - assert beta2 is not None - assert epsilon is not None - super(AdamwOptimizer, self).__init__( - learning_rate=learning_rate, - parameter_list=parameter_list, - regularization=regularization, - grad_clip=grad_clip, - name=name) - self.type = "adamw" - self._beta1 = beta1 - self._beta2 = beta2 - self._epsilon = epsilon - self._lazy_mode = lazy_mode - self._weight_decay = weight_decay - self._apply_decay_param_fun = apply_decay_param_fun - if lr_ratio is not None: - assert isinstance(lr_ratio, Callable) - self._lr_ratio = lr_ratio - - def _create_accumulators(self, block, parameters): - assert isinstance(block, framework.Block) - - # Create accumulator tensors for first and second moments - for p in parameters: - self._add_accumulator(self._moment1_acc_str, p) - self._add_accumulator(self._moment2_acc_str, p) - self._add_accumulator( - name=self._beta1_pow_acc_str, - param=p, - fill_value=0.9 if isinstance(self._beta1, Variable) \ - else self._beta1, - shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') - self._add_accumulator( - name=self._beta2_pow_acc_str, - param=p, - fill_value=0.999 if isinstance(self._beta2, Variable) \ - else self._beta2, - shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') - - def _append_optimize_op(self, block, param_and_grad): - assert isinstance(block, framework.Block) - - moment1 = self._get_accumulator(self._moment1_acc_str, - param_and_grad[0]) - moment2 = self._get_accumulator(self._moment2_acc_str, - param_and_grad[0]) - beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, - param_and_grad[0]) - beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, - param_and_grad[0]) - lr = self._create_param_lr(param_and_grad) - - # create the adam optimize op - if self._apply_decay_param_fun is not None \ - and not self._apply_decay_param_fun(param_and_grad[0].name): - weight_decay = 0.0 - else: - weight_decay = self._weight_decay - - if framework.in_dygraph_mode(): - _beta1 = self._beta1 if not isinstance( - self._beta1, Variable) else self._beta1.numpy().item(0) - _beta2 = self._beta2 if not isinstance( - self._beta2, Variable) else self._beta2.numpy().item(0) - ins = { - 'Param': param_and_grad[0], - 'Grad': param_and_grad[1], - 'LearningRate': lr, - 'Moment1': moment1, - 'Moment2': moment2, - 'Beta1Pow': beta1_pow_acc, - 'Beta2Pow': beta2_pow_acc, - } - attrs = { - 'beta1': _beta1, - 'beta2': _beta2, - 'epsilon': self._epsilon, - 'lazy_mode': self._lazy_mode, - 'min_row_size_to_use_multithread': 1000, - 'multi_precision': False, - 'weight_decay': weight_decay, - 'lr_ratio': 1.0 - } - outs = { - 'ParamOut': param_and_grad[0], - 'Moment1Out': moment1, - 'Moment2Out': moment2, - 'Beta1PowOut': beta1_pow_acc, - 'Beta2PowOut': beta2_pow_acc, - } - - framework._dygraph_tracer().trace_op( - type="adamw", inputs=ins, outputs=outs, attrs=attrs) - - return None - - inputs = { - "Param": [param_and_grad[0]], - "Grad": [param_and_grad[1]], - "LearningRate": [lr], - "Moment1": [moment1], - "Moment2": [moment2], - "Beta1Pow": [beta1_pow_acc], - "Beta2Pow": [beta2_pow_acc] - } - outputs = { - "ParamOut": [param_and_grad[0]], - "Moment1Out": [moment1], - "Moment2Out": [moment2], - "Beta1PowOut": [beta1_pow_acc], - "Beta2PowOut": [beta2_pow_acc], - } - attrs = { - "epsilon": self._epsilon, - "lazy_mode": self._lazy_mode, - "min_row_size_to_use_multithread": 1000, - "weight_decay": weight_decay, - "lr_ratio": 1. - if self._lr_ratio is None else self._lr_ratio(param_and_grad[0]) - } - - if isinstance(self._beta1, Variable): - inputs['Beta1Tensor'] = self._beta1 - else: - attrs['beta1'] = self._beta1 - if isinstance(self._beta2, Variable): - inputs['Beta2Tensor'] = self._beta2 - else: - attrs['beta2'] = self._beta2 - - for name in ["Beta1Tensor", "Beta2Tensor", "MasterParam"]: - if name in inputs: - raise ValueError("Custom Adam should NOT have input: {}".format( - name)) - - adam_op = block.append_op( - type=self.type, - inputs=inputs, - outputs=outputs, - attrs=attrs, - stop_gradient=True) - - return adam_op diff --git a/paddlenlp/ops/optimizer/__init__.py b/paddlenlp/ops/optimizer/__init__.py index 92f3b0fb601c..b6a41f8147d5 100644 --- a/paddlenlp/ops/optimizer/__init__.py +++ b/paddlenlp/ops/optimizer/__init__.py @@ -12,37 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from paddle.utils.cpp_extension import load -from .adamw import AdamW -from .AdamwOptimizer import AdamwOptimizer from .adamwdl import AdamWDL - -def _jit_compile(): - try: - load( - name="custom_jit_ops", - sources=[ - os.path.join(os.path.dirname(__file__), x) - for x in [ - "adamw.cc", - "adamw.cu", - ] - ]) - return True - except RuntimeError as e: - import sys - sys.stderr.write(str(e) + "\n\n") - sys.stderr.write( - '''Warning with compile custom ops: compile custom adamw op failed. \nIf you do not use custom ops, please ignore this warning! \n\n''' - ) - return False - - -__all__ = [ - '_jit_compile', - 'AdamW', - 'AdamwOptimizer', - 'AdamWDL', -] +__all__ = ['AdamWDL', ] diff --git a/paddlenlp/ops/optimizer/adamw.cc b/paddlenlp/ops/optimizer/adamw.cc deleted file mode 100644 index 418c0c2d6c42..000000000000 --- a/paddlenlp/ops/optimizer/adamw.cc +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "paddle/extension.h" - -std::vector adam_cuda_forward( - // Tensor inputs - const paddle::Tensor& Param, - const paddle::Tensor& Grad, - const paddle::Tensor& LearningRate, - const paddle::Tensor& Moment1, - const paddle::Tensor& Moment2, - const paddle::Tensor& Beta1Pow, - const paddle::Tensor& Beta2Pow, - // const paddle::Tensor& Beta1Tensor, - // const paddle::Tensor& Beta2Tensor, - // const paddle::Tensor& MasterParam, - - // Attrs inputs - float beta1, - float beta2, - float epsilon, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - float weight_decay, - float lr_ratio); - -std::vector AdamForward( - // Tensor inputs - const paddle::Tensor& Param, - const paddle::Tensor& Grad, - const paddle::Tensor& LearningRate, - const paddle::Tensor& Moment1, - const paddle::Tensor& Moment2, - const paddle::Tensor& Beta1Pow, - const paddle::Tensor& Beta2Pow, - // const paddle::Tensor& Beta1Tensor, - // const paddle::Tensor& Beta2Tensor, - // const paddle::Tensor& MasterParam, - - // Attrs inputs - float beta1, - float beta2, - float epsilon, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - float weight_decay, - float lr_ratio) { - // TODO: Check Input - if (Param.place() == paddle::PlaceType::kCPU) { - PD_THROW("Not implemented."); - } else if (Param.place() == paddle::PlaceType::kGPU) { - return adam_cuda_forward(Param, - Grad, - LearningRate, - Moment1, - Moment2, - Beta1Pow, - Beta2Pow, - beta1, - beta2, - epsilon, - lazy_mode, - min_row_size_to_use_multithread, - multi_precision, - weight_decay, - lr_ratio); - } else { - PD_THROW("Not implemented."); - } -} - - -std::vector> AdamInferShape( - std::vector param_shape, - std::vector grad_shape, - std::vector lr_shape, - std::vector m1_shape, - std::vector m2_shape, - std::vector b1_shape, - std::vector b2_shape) { - return {param_shape, m1_shape, m2_shape, b1_shape, b2_shape}; -} - -std::vector AdamInferDtype(paddle::DataType param_dtype, - paddle::DataType grad_dtype, - paddle::DataType lr_dtype, - paddle::DataType m1_dtype, - paddle::DataType m2_dtype, - paddle::DataType b1_dtype, - paddle::DataType b2_dtype) { - return {param_dtype, m1_dtype, m2_dtype, b1_dtype, b2_dtype}; -} - - -PD_BUILD_OP(adamw) - .Inputs({ - "Param", // "(Tensor) Input parameter" - "Grad", // "(Tensor) Input gradient" - "LearningRate", // "(Tensor) Learning rate" - "Moment1", // "(Tensor) Input first moment" - "Moment2", // "(Tensor) Input second moment" - "Beta1Pow", // "(Tensor) Input beta1 power accumulator" - "Beta2Pow", // "(Tensor) Input beta2 power accumulator" - // "Beta1Tensor", // "(Tensor, optional) If provided, Adam - // will use this as beta1, this has a higher priority than attr(beta1), - // the shape of this tensor MUST BE [1].").AsDispensable(); - // "Beta2Tensor", // "(Tensor, optional) If provided, Adam - // will use this as beta2, this has a higher priority than attr(beta2), - // the shape of this tensor MUST BE [1].").AsDispensable(); - // "MasterParam", // "FP32 master weight for AMP.").AsDispensable() - }) - .Outputs({ - "ParamOut", // "(Tensor) Output parameter"); - "Moment1Out", // "(Tensor) Output first moment"); - "Moment2Out", // "(Tensor) Output second moment"); - "Beta1PowOut", // "(Tensor) Output beta1 power accumulator"); - "Beta2PowOut", // "(Tensor) Output beta2 power accumulator"); - // "MasterParamOut" // "The updated FP32 master weight for AMP. It - // shared memory with Input(MasterParam).").AsDispensable(); - }) - .Attrs({ - "beta1: float", // "(float, default 0.9) " "Exponential decay rate for - // the ""first moment estimates.").SetDefault(0.9f); - "beta2: float", // "(float, default 0.999) ""exponential decay rate for - // the ""second moment estimates.").SetDefault(0.999f); - "epsilon: float", // "(float, default 1.0e-8) ""Constant for numerical - // stability").SetDefault(1.0e-8f); - "lazy_mode: bool", // "(bool, default false) ""only update the - // parameter that has gradient in sparse - // update").SetDefault(false); - "min_row_size_to_use_multithread: int64_t", // "(int64_t, default 0) - // ""when not zero, if - // param row size is larger - // then - // ""min_row_size_to_use_multithread - // and - // ""inner_op_parallelism - // is larger then 0, sparse - // update ""will run in - // multithread - // mode").SetDefault(1000); - "multi_precision: bool", // "(bool, default false) ""Whether to use - // multi-precision during weight - // updating.").SetDefault(false); - "weight_decay: float", // "(float, default 0.0) ""Weight decay - // rate.").SetDefault(0.0f); - "lr_ratio: float", // "(float, default 1.0) ""Weight decay - // rate.").SetDefault(1.0f); - }) - .SetKernelFn(PD_KERNEL(AdamForward)) - .SetInferShapeFn(PD_INFER_SHAPE(AdamInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(AdamInferDtype)); diff --git a/paddlenlp/ops/optimizer/adamw.cu b/paddlenlp/ops/optimizer/adamw.cu deleted file mode 100644 index 1b635b164270..000000000000 --- a/paddlenlp/ops/optimizer/adamw.cu +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/extension.h" - - -template -__global__ void AdamKernelREG(MT beta1, - MT beta2, - MT epsilon, - MT beta1_pow_, - MT beta2_pow_, - const MT* moment1, - MT* moment1_out, - const MT* moment2, - MT* moment2_out, - const MT* lr_, - MT weight_decay, - MT lr_ratio, - const T* grad, - const T* param, - T* param_out, - const MT* master_param, - MT* master_param_out, - int ndim) { - MT lr = *lr_ * lr_ratio; - MT lr_orig = lr; - MT beta1_pow = beta1_pow_; - MT beta2_pow = beta2_pow_; - - lr *= sqrt(static_cast(1.0) - beta2_pow) / - (static_cast(1.0) - beta1_pow); - - int id = blockIdx.x * blockDim.x + threadIdx.x; - - for (; id < ndim; id += gridDim.x * blockDim.x) { - MT p = master_param ? master_param[id] : static_cast(param[id]); - MT g = static_cast(grad[id]); - MT mom1 = moment1[id]; - MT mom2 = moment2[id]; - mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; - mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p -= lr_orig * weight_decay * p; - p -= lr * (mom1 / - (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); - - moment1_out[id] = mom1; - moment2_out[id] = mom2; - param_out[id] = static_cast(p); - if (master_param_out) { - master_param_out[id] = p; - } - } -} - -template -__global__ void AdamKernelMEM(MT beta1, - MT beta2, - MT epsilon, - const MT* beta1_pow_, - const MT* beta2_pow_, - const MT* moment1, - MT* moment1_out, - const MT* moment2, - MT* moment2_out, - const MT* lr_, - MT weight_decay, - MT lr_ratio, - const T* grad, - const T* param, - T* param_out, - const MT* master_param, - MT* master_param_out, - int ndim) { - MT lr = *lr_ * lr_ratio; - MT lr_orig = lr; - MT beta1_pow = *beta1_pow_; - MT beta2_pow = *beta2_pow_; - - lr *= sqrt(static_cast(1.0) - beta2_pow) / - (static_cast(1.0) - beta1_pow); - - int id = blockIdx.x * blockDim.x + threadIdx.x; - - for (; id < ndim; id += gridDim.x * blockDim.x) { - MT p = master_param ? master_param[id] : static_cast(param[id]); - MT g = static_cast(grad[id]); - MT mom1 = static_cast(moment1[id]); - MT mom2 = static_cast(moment2[id]); - mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; - mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p -= lr_orig * weight_decay * p; - p -= lr * (mom1 / - (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); - - moment1_out[id] = mom1; - moment2_out[id] = mom2; - param_out[id] = static_cast(p); - if (master_param_out) { - master_param_out[id] = p; - } - } -} - -template -__global__ void UpdateBetaPow(T beta1, - T beta2, - const T* beta1_pow_, - const T* beta2_pow_, - T* beta1_pow_out, - T* beta2_pow_out) { - *beta1_pow_out = beta1 * beta1_pow_[0]; - *beta2_pow_out = beta2 * beta2_pow_[0]; -} - - -std::vector adam_cuda_forward( - // Tensor inputs - const paddle::Tensor& Param, - const paddle::Tensor& Grad, - const paddle::Tensor& LearningRate, - const paddle::Tensor& Moment1, - const paddle::Tensor& Moment2, - const paddle::Tensor& Beta1Pow, - const paddle::Tensor& Beta2Pow, - // const paddle::Tensor& Beta1Tensor, - // const paddle::Tensor& Beta2Tensor, - // const paddle::Tensor& MasterParam, - - // Attrs inputs - float beta1, - float beta2, - float epsilon, - bool lazy_mode, - int64_t min_row_size_to_use_multithread, - bool multi_precision, - float weight_decay, - float lr_ratio) { - auto ParamOut = paddle::Tensor(paddle::PlaceType::kGPU); - auto Moment1Out = paddle::Tensor(paddle::PlaceType::kGPU); - auto Moment2Out = paddle::Tensor(paddle::PlaceType::kGPU); - auto Beta1PowOut = paddle::Tensor(Beta1Pow.place()); - auto Beta2PowOut = paddle::Tensor(Beta2Pow.place()); - // auto MasterParamOut = paddle::Tensor(paddle::PlaceType::kGPU); - - ParamOut.reshape(Param.shape()); - Moment1Out.reshape(Moment1.shape()); - Moment2Out.reshape(Moment2.shape()); - Beta1PowOut.reshape(Beta1Pow.shape()); - Beta2PowOut.reshape(Beta2Pow.shape()); - - PD_CHECK(Beta1PowOut.size() == 1, - "beta1 pow output size should be 1, but received " - "value is:", - Beta1PowOut.size()); - PD_CHECK(Beta2PowOut.size() == 1, - "beta2 pow output size should be 1, but received " - "value is:", - Beta2PowOut.size()); - - PD_CHECK(Param.type() == paddle::DataType::FLOAT32, - "Custom adam support fp32 for now."); - - using T = float; - auto place = Param.place(); - T beta1_t = static_cast(beta1); - T beta2_t = static_cast(beta2); - T epsilon_t = static_cast(epsilon); - T weight_decay_t = static_cast(weight_decay); - T lr_ratio_t = static_cast(lr_ratio); - - int threads = 512; - int blocks = (Param.size() + threads - 1) / threads; - - auto Moment1Out_data = Moment1Out.mutable_data(place); - auto Moment2Out_data = Moment2Out.mutable_data(place); - auto ParamOut_data = ParamOut.mutable_data(place); - - if (Beta1Pow.place() == paddle::PlaceType::kCPU && - Beta2Pow.place() == paddle::PlaceType::kCPU) { - // Compute with betapow in REG - AdamKernelREG<<>>( - beta1_t, - beta2_t, - epsilon_t, - *Beta1Pow.data(), - *Beta2Pow.data(), - Moment1.data(), - Moment1Out_data, - Moment2.data(), - Moment2Out_data, - LearningRate.data(), - weight_decay_t, - lr_ratio_t, - Grad.data(), - Param.data(), - ParamOut_data, - nullptr, - nullptr, - Param.size()); - // Cpu update - Beta1PowOut.mutable_data(Beta1Pow.place())[0] = - beta1_t * Beta1Pow.data()[0]; - Beta2PowOut.mutable_data(Beta2Pow.place())[0] = - beta2_t * Beta2Pow.data()[0]; - } else { - // Compute with betapow in MEM - AdamKernelMEM<<>>( - beta1_t, - beta2_t, - epsilon_t, - Beta1Pow.data(), - Beta2Pow.data(), - Moment1.data(), - Moment1Out_data, - Moment2.data(), - Moment2Out_data, - LearningRate.data(), - weight_decay_t, - lr_ratio_t, - Grad.data(), - Param.data(), - ParamOut_data, - nullptr, - nullptr, - int(Param.size())); - // Update with gpu - UpdateBetaPow<<<1, 32, 0, Param.stream()>>>( - beta1_t, - beta2_t, - Beta1Pow.data(), - Beta2Pow.data(), - Beta1PowOut.mutable_data(place), - Beta2PowOut.mutable_data(place)); - } - - return {ParamOut, Moment1Out, Moment2Out, Beta1PowOut, Beta2PowOut}; -} diff --git a/paddlenlp/ops/optimizer/adamw.py b/paddlenlp/ops/optimizer/adamw.py deleted file mode 100644 index fd343172dd68..000000000000 --- a/paddlenlp/ops/optimizer/adamw.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -import os -import paddle -from paddle.optimizer.optimizer import Optimizer -from paddle.fluid import core -from paddle.fluid import framework -from paddle.fluid.framework import Variable -from paddle.fluid import layers -from paddle.fluid import unique_name -from paddle.fluid.framework import in_dygraph_mode, _dygraph_tracer -from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.dygraph import base as imperative_base - -__all__ = ["AdamW", ] - - -class AdamW(Optimizer): - r""" - The AdamW optimizer is implemented based on the AdamW Optimization - in paper `DECOUPLED WEIGHT DECAY REGULARIZATION `_. - it can resolves the problem of L2 regularization failure in the Adam optimizer. - .. math:: - t & = t + 1 - moment\_1\_out & = {\\beta}_1 * moment\_1 + (1 - {\\beta}_1) * grad - moemnt\_2\_out & = {\\beta}_2 * moment\_2 + (1 - {\\beta}_2) * grad * grad - learning\_rate & = learning\_rate * \\ - \\frac{\sqrt{1 - {\\beta}_2^t}}{1 - {beta}_1^t} - param\_out & = param - learning\_rate * (\\frac{moment\_1}{\sqrt{moment\_2} + \epsilon} + \lambda * param) - - Args: - learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``. - It can be a float value or a LRScheduler. The default value is 0.001. - beta1 (float, optional): The exponential decay rate for the 1st moment estimates. - It should be a float number or a Tensor with shape [1] and data type as float32. - The default value is 0.9. - beta2 (float, optional): The exponential decay rate for the 2nd moment estimates. - It should be a float number or a Tensor with shape [1] and data type as float32. - The default value is 0.999. - epsilon (float, optional): A small float value for numerical stability. - It should be a float number or a Tensor with shape [1] and data type as float32. - The default value is 1e-08. - parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ - The default value is None in static mode, at this time all parameters will be updated. - weight_decay (float, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01. - apply_decay_param_fun (function|None, optional): If it is not None, - only tensors that makes apply_decay_param_fun(Tensor.name)==True - will be updated. It only works when we want to specify tensors. - Default: None. - grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of - some derived class of ``GradientClipBase`` . There are three cliping strategies - ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , - :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. - lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators. - The accumulators are updated at every step. Every element of the two moving-average - is updated in both dense mode and sparse mode. If the size of parameter is very large, - then the update may be very slow. The lazy mode only update the element that has - gradient in current mini-batch, so it will be much more faster. But this mode has - different semantics with the original Adam algorithm and may lead to different result. - The default value is False. - multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. - name (str, optional): Normally there is no need for user to set this property. - For more information, please refer to :ref:`api_guide_Name`. - The default value is None. - - Examples: - .. code-block:: python - - import paddle - import paddlenlp - - linear = paddle.nn.Linear(10, 10) - inp = paddle.rand([10,10], dtype="float32") - out = linear(inp) - loss = paddle.mean(out) - adamw = paddlenlp.ops.optimizer.Adam(learning_rate=0.1, - parameters=linear.parameters()) - out.backward() - adamw.step() - adamw.clear_grad() - - """ - _moment1_acc_str = "moment1" - _moment2_acc_str = "moment2" - _beta1_pow_acc_str = "beta1_pow_acc" - _beta2_pow_acc_str = "beta2_pow_acc" - - def __init__(self, - learning_rate=0.001, - beta1=0.9, - beta2=0.999, - epsilon=1e-8, - parameters=None, - weight_decay=0.0, - grad_clip=None, - lazy_mode=False, - multi_precision=False, - apply_decay_param_fun=None, - name=None): - assert learning_rate is not None - assert beta1 is not None - assert beta2 is not None - assert epsilon is not None - if not isinstance(beta1, Variable): - if not 0 <= beta1 < 1: - raise ValueError( - "Invaild value of beta1, expect beta1 in [0,1).") - if not isinstance(beta2, Variable): - if not 0 <= beta2 < 1: - raise ValueError( - "Invaild value of beta2, expect beta2 in [0,1).") - if not isinstance(epsilon, Variable): - if not 0 <= epsilon: - raise ValueError( - "Invaild value of epsilon, expect epsilon >= 0.") - super(AdamW, self).__init__( - learning_rate=learning_rate, - parameters=parameters, - weight_decay=None, - grad_clip=grad_clip, - name=name) - self.type = "adamw" - self._beta1 = beta1 - self._beta2 = beta2 - self._epsilon = epsilon - self._lazy_mode = lazy_mode - self._multi_precision = multi_precision - self._weight_decay = weight_decay - self._apply_decay_param_fun = apply_decay_param_fun - self._master_weights = {} - - def _create_master_weight(self, param): - assert isinstance(self.helper, LayerHelper) - - var_name = param.name + "_fp32_master" - var_name = unique_name.generate(var_name) - var = layers.create_global_var( - name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True) - block = self.helper.startup_program.global_block() - block.append_op( - type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32 - }) - self._master_weights[param.name] = var - return var - - def _get_accumulator(self, name, param): - """Utility function to fetch an accumulator for a parameter - Args: - name: name of the accumulator - param: parameter variable for which accumulator is to be fetched - Returns: - accumulator variable for the parameter - """ - if self._name is not None: - name = self._name + "_" + name - find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - target_param = self._master_weights[ - param.name] if find_master else param - target_name = target_param.name - if (name not in self._accumulators or - target_name not in self._accumulators[name]): - raise Exception("Accumulator {} does not exist for parameter {}". - format(name, target_name)) - return self._accumulators[name][target_name] - - def _add_moments_pows(self, p): - acc_dtype = p.dtype - if acc_dtype == core.VarDesc.VarType.FP16: - acc_dtype = core.VarDesc.VarType.FP32 - self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) - self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) - self._add_accumulator( - name=self._beta1_pow_acc_str, - param=p, - dtype=acc_dtype, - fill_value=0.9 if isinstance(self._beta1, Variable) \ - else self._beta1, - shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') - self._add_accumulator( - name=self._beta2_pow_acc_str, - param=p, - dtype=acc_dtype, - fill_value=0.999 if isinstance(self._beta2, Variable) \ - else self._beta2, - shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') - - def _create_accumulators(self, block, parameters): - assert isinstance(block, framework.Block) - - # Create accumulator tensors for first and second moments - for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: - master_p = self._create_master_weight(p) - self._add_moments_pows(master_p) - continue - if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision: - warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." - "Consider using multi_precision=True option of the Adam optimizer." - ) - self._add_moments_pows(p) - - def _append_optimize_op(self, block, param_and_grad): - assert isinstance(block, framework.Block) - - moment1 = self._get_accumulator(self._moment1_acc_str, - param_and_grad[0]) - moment2 = self._get_accumulator(self._moment2_acc_str, - param_and_grad[0]) - beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, - param_and_grad[0]) - beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, - param_and_grad[0]) - find_master = self._multi_precision and param_and_grad[ - 0].dtype == core.VarDesc.VarType.FP16 - master_weight = (self._master_weights[param_and_grad[0].name] - if find_master else None) - lr = self._create_param_lr(param_and_grad) - - # create the adam optimize op - if self._apply_decay_param_fun is not None \ - and not self._apply_decay_param_fun(param_and_grad[0].name): - weight_decay = 0.0 - else: - weight_decay = self._weight_decay - - if framework.in_dygraph_mode(): - _beta1 = self._beta1 if not isinstance( - self._beta1, Variable) else self._beta1.numpy().item(0) - _beta2 = self._beta2 if not isinstance( - self._beta2, Variable) else self._beta2.numpy().item(0) - - ins = { - 'Param': param_and_grad[0], - 'Grad': param_and_grad[1], - 'LearningRate': lr, - 'Moment1': moment1, - 'Moment2': moment2, - 'Beta1Pow': beta1_pow_acc, - 'Beta2Pow': beta2_pow_acc, - } - attrs = { - 'beta1': _beta1, - 'beta2': _beta2, - 'epsilon': self._epsilon, - 'lazy_mode': self._lazy_mode, - 'min_row_size_to_use_multithread': 1000, - 'multi_precision': False, - 'weight_decay': weight_decay, - 'lr_ratio': 1.0 - } - outs = { - 'ParamOut': param_and_grad[0], - 'Moment1Out': moment1, - 'Moment2Out': moment2, - 'Beta1PowOut': beta1_pow_acc, - 'Beta2PowOut': beta2_pow_acc, - } - - framework._dygraph_tracer().trace_op( - type="adamw", inputs=ins, outputs=outs, attrs=attrs) - - return None - - inputs = { - "Param": [param_and_grad[0]], - "Grad": [param_and_grad[1]], - "LearningRate": [lr], - "Moment1": [moment1], - "Moment2": [moment2], - "Beta1Pow": [beta1_pow_acc], - "Beta2Pow": [beta2_pow_acc] - } - outputs = { - "ParamOut": [param_and_grad[0]], - "Moment1Out": [moment1], - "Moment2Out": [moment2], - "Beta1PowOut": [beta1_pow_acc], - "Beta2PowOut": [beta2_pow_acc], - } - attrs = { - "lazy_mode": self._lazy_mode, - "min_row_size_to_use_multithread": 1000, - "multi_precision": find_master, - 'weight_decay': weight_decay, - 'lr_ratio': 1.0 - } - - if isinstance(self._beta1, Variable): - inputs['Beta1Tensor'] = self._beta1 - else: - attrs['beta1'] = self._beta1 - if isinstance(self._beta2, Variable): - inputs['Beta2Tensor'] = self._beta2 - else: - attrs['beta2'] = self._beta2 - if isinstance(self._epsilon, Variable): - inputs['EpsilonTensor'] = self._epsilon - else: - attrs['epsilon'] = self._epsilon - - if find_master: - inputs["MasterParam"] = master_weight - outputs["MasterParamOut"] = master_weight - - for name in ["Beta1Tensor", "Beta2Tensor", "MasterParam"]: - if name in inputs: - raise ValueError( - "Custom AdamW should NOT have input: {}".format(name)) - - adam_op = block.append_op( - type=self.type, - inputs=inputs, - outputs=outputs, - attrs=attrs, - stop_gradient=True) - - return adam_op - - @imperative_base.no_grad - @framework.dygraph_only - def step(self): - params_grads = [] - for param in self._parameter_list: - if param.stop_gradient: - continue - if param._grad_ivar() is not None: - grad_var = param._grad_ivar() - if hasattr(grad_var, "_is_sparse") and grad_var._is_sparse( - ) and self.regularization is not None: - raise RuntimeError( - "AdamW don't support weight_decay with sparse parameters, please set it to None." - ) - params_grads.append((param, grad_var)) - - optimize_ops = self._apply_optimize( - loss=None, startup_program=None, params_grads=params_grads)