diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 73210ac9fbc56b..bd2471e0f7e1df 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -29,7 +29,7 @@ if(NOT DEFINED XPU_BASE_DATE) set(XPU_BASE_DATE "20240104") endif() if(NOT DEFINED XPU_XHPC_BASE_DATE) - set(XPU_XHPC_BASE_DATE "20240218") + set(XPU_XHPC_BASE_DATE "20240222") endif() set(XPU_XCCL_BASE_VERSION "1.1.8.1") if(NOT DEFINED XPU_XFT_BASE_VERSION) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 491b47442725ac..55aae9f24c1a61 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -36,7 +36,10 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"accuracy", XPUKernelSet({phi::DataType::FLOAT32})}, {"adadelta", XPUKernelSet({phi::DataType::FLOAT32})}, - {"adamw", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"adamw", + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"adam", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"adam_dense_param_sparse_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, @@ -723,7 +726,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32, phi::DataType::INT64, phi::DataType::FLOAT16})}, - {"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_mean_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"reduce_mean", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 1d3d6001bca9c9..39e79ba0c4934c 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -721,7 +721,8 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BFLOAT16, phi::DataType::INT32, phi::DataType::INT64})}, - {"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"reduce_mean_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"reduce_mean", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index 4df7ab633ab4e4..ca39a9932a609e 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -24,6 +24,8 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/common/amp_type_traits.h" + namespace phi { template @@ -44,6 +46,234 @@ float GetAbsMax(const Context& dev_ctx, return *std::max_element(buffer_cpu.begin(), buffer_cpu.end()); } +template +void AdamwDenseKernelKL3(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment1, + const DenseTensor& moment2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param, + const paddle::optional& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + float lr_ratio, + float coeff, + bool with_decay, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + DenseTensor* param_out, + DenseTensor* moment1_out, + DenseTensor* moment2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_outs) { + // TODO(houj04): + // 当KL3稳定以后,并且不需要支持KL1和KL2的时候,拿这里的AdamwDenseKernelKL3替换掉AdamwDenseKernel + using MPDType = typename phi::dtype::MPTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; + + const auto grad_type = grad.dtype(); + + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + + MPDType coeff_ = static_cast(coeff); + MPDType lr_ratio_ = static_cast(lr_ratio); + + bool skip_update_ = false; + if (skip_update.is_initialized()) { + PADDLE_ENFORCE_EQ( + skip_update->numel(), + 1, + errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", + skip_update->numel())); + std::vector skip_update_vec; + phi::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + skip_update_ = skip_update_vec[0]; + } + + // skip_update=true, just copy input to output + if (skip_update_) { + VLOG(4) << "Adamw skip update"; + phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); + phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); + phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); + if (!use_global_beta_pow) { + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + } + return; + } + + // if with_decay = false, coeff = 0 + if (!with_decay) { + coeff_ = static_cast(0.0); + } + + MPDType beta1_ = beta1.to(); + MPDType beta2_ = beta2.to(); + MPDType epsilon_ = epsilon.to(); + VLOG(3) << "beta1_pow.numel() : " << beta1_pow.numel() + << "beta2_pow.numel() : " << beta2_pow.numel(); + VLOG(3) << "param.numel(): " << param.numel(); + PADDLE_ENFORCE_EQ( + beta1_pow_out->numel(), + 1, + errors::InvalidArgument("beta1 pow output size should be 1, but received " + "value is:%d.", + beta1_pow_out->numel())); + + PADDLE_ENFORCE_EQ( + beta2_pow_out->numel(), + 1, + errors::InvalidArgument("beta2 pow output size should be 1, but received " + "value is:%d.", + beta2_pow_out->numel())); + + const MPDType* master_in_data = + multi_precision ? master_param->data() : nullptr; + MPDType* master_out_data = + multi_precision ? dev_ctx.template Alloc(master_param_outs) + : nullptr; + // template DLL_EXPORT int + // adamw_v2(Context* ctx, MT beta1, MT beta2, MT epsilon, MT coeff, MT + // lr_ratio, const MT* beta1_pow, MT* beta1_pow_out, const MT* beta2_pow, MT* + // beta2_pow_out, const MT* moment1, MT* moment1_out, const MT* moment2, MT* + // moment2_out, const MT* lr, const TG* grad, const T* param, T* param_out, + // const MT* master_param, MT* master_param_out, int64_t n); + + if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { + DenseTensor xpu_beta1_pow; + DenseTensor xpu_beta2_pow; + phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, &xpu_beta1_pow); + phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, &xpu_beta2_pow); + dev_ctx.Wait(); + const MPDType* beta1_pow_ptr = xpu_beta1_pow.data(); + const MPDType* beta2_pow_ptr = xpu_beta2_pow.data(); + + if (grad_type == phi::DataType::FLOAT32) { + int r = xpu::adamw_v2( + dev_ctx.x_context(), + beta1_, + beta2_, + epsilon_, + coeff_, + lr_ratio_, + beta1_pow_ptr, + nullptr, + beta2_pow_ptr, + nullptr, + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + grad.data(), + reinterpret_cast(param.data()), + reinterpret_cast(dev_ctx.template Alloc(param_out)), + master_in_data, + master_out_data, + param.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + } else { + int r = xpu::adamw_v2( + dev_ctx.x_context(), + beta1_, + beta2_, + epsilon_, + coeff_, + lr_ratio_, + beta1_pow_ptr, + nullptr, + beta2_pow_ptr, + nullptr, + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + reinterpret_cast(grad.data()), + reinterpret_cast(param.data()), + reinterpret_cast(dev_ctx.template Alloc(param_out)), + master_in_data, + master_out_data, + param.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + } + if (!use_global_beta_pow) { + // Cpu update + dev_ctx.template HostAlloc(beta1_pow_out)[0] = + beta1_ * beta1_pow.data()[0]; + dev_ctx.template HostAlloc(beta2_pow_out)[0] = + beta2_ * beta2_pow.data()[0]; + } + } else { + MPDType* beta1_pow_out_ptr = nullptr; + MPDType* beta2_pow_out_ptr = nullptr; + + if (!use_global_beta_pow) { + beta1_pow_out_ptr = dev_ctx.template Alloc(beta1_pow_out); + beta2_pow_out_ptr = dev_ctx.template Alloc(beta2_pow_out); + } + + if (grad_type == phi::DataType::FLOAT32) { + int r = xpu::adamw_v2( + dev_ctx.x_context(), + beta1_, + beta2_, + epsilon_, + coeff_, + lr_ratio_, + beta1_pow.data(), + beta1_pow_out_ptr, + beta2_pow.data(), + beta2_pow_out_ptr, + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + grad.data(), + reinterpret_cast(param.data()), + reinterpret_cast(dev_ctx.template Alloc(param_out)), + master_in_data, + master_out_data, + param.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + } else { + int r = xpu::adamw_v2( + dev_ctx.x_context(), + beta1_, + beta2_, + epsilon_, + coeff_, + lr_ratio_, + beta1_pow.data(), + beta1_pow_out_ptr, + beta2_pow.data(), + beta2_pow_out_ptr, + moment1.data(), + dev_ctx.template Alloc(moment1_out), + moment2.data(), + dev_ctx.template Alloc(moment2_out), + learning_rate.data(), + reinterpret_cast(grad.data()), + reinterpret_cast(param.data()), + reinterpret_cast(dev_ctx.template Alloc(param_out)), + master_in_data, + master_out_data, + param.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + } + } + return; +} + template void AdamwDenseKernel(const Context& dev_ctx, const DenseTensor& param, @@ -71,6 +301,38 @@ void AdamwDenseKernel(const Context& dev_ctx, DenseTensor* beta1_pow_out, DenseTensor* beta2_pow_out, DenseTensor* master_param_outs) { + auto dev_version = + phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId()); + if (dev_version == phi::backends::xpu::XPUVersion::XPU3) { + AdamwDenseKernelKL3(dev_ctx, + param, + grad, + learning_rate, + moment1, + moment2, + beta1_pow, + beta2_pow, + master_param, + skip_update, + beta1, + beta2, + epsilon, + lr_ratio, + coeff, + with_decay, + lazy_mode, + min_row_size_to_use_multithread, + multi_precision, + use_global_beta_pow, + param_out, + moment1_out, + moment2_out, + beta1_pow_out, + beta2_pow_out, + master_param_outs); + return; + } + // check moment_dtype auto moment1_dtype = moment1.dtype(); auto moment2_dtype = moment2.dtype(); @@ -228,30 +490,85 @@ void AdamwDenseKernel(const Context& dev_ctx, 0.0f); PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); - // int adamw(Context* ctx, const T* g, const float* mom1, const float* mom2, - // const T* param, const float* beta1_pow, const float* beta2_pow, const - // float* lr, float* moment1_out, float* moment2_out, T* param_out, float - // beta1, float beta2, float epsilon, float coeff, int64_t n); - r = xpu::adamw( - dev_ctx.x_context(), - reinterpret_cast(grad.template data()), - moment_in_fp16 ? moment1_input_for_xdnn : moment1.template data(), - moment_in_fp16 ? moment2_input_for_xdnn : moment2.template data(), - reinterpret_cast(param.template data()), - beta1_pow_ptr, - beta2_pow_ptr, - new_lr, - moment_in_fp16 ? moment1_output_for_xdnn - : dev_ctx.template Alloc(moment1_out), - moment_in_fp16 ? moment2_output_for_xdnn - : dev_ctx.template Alloc(moment2_out), - reinterpret_cast(dev_ctx.template Alloc(param_out)), - beta1_, - beta2_, - epsilon_, - coeff, - param.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + if (multi_precision) { + const float* master_param_in_data = master_param->data(); + float* master_param_out_data = + dev_ctx.template Alloc(master_param_outs); + // convert grad to float if necessary + float* grad_fp32 = nullptr; + const auto grad_type = grad.dtype(); + if (grad_type != phi::DataType::FLOAT32) { + grad_fp32 = RAII_GUARD.alloc_l3_or_gm(grad.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(grad_fp32); + // int cast(Context* ctx, const TX* x, TY* y, int64_t len); + int r = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast(grad.template data()), + grad_fp32, + grad.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + } + // int adamw(Context* ctx, const T* g, const float* mom1, const float* mom2, + // const T* param, const float* beta1_pow, const float* beta2_pow, const + // float* lr, float* moment1_out, float* moment2_out, T* param_out, float + // beta1, float beta2, float epsilon, float coeff, int64_t n); + r = xpu::adamw( + dev_ctx.x_context(), + (grad_type == phi::DataType::FLOAT32) ? grad.data() : grad_fp32, + moment_in_fp16 ? moment1_input_for_xdnn + : moment1.template data(), + moment_in_fp16 ? moment2_input_for_xdnn + : moment2.template data(), + master_param_in_data, + beta1_pow_ptr, + beta2_pow_ptr, + new_lr, + moment_in_fp16 ? moment1_output_for_xdnn + : dev_ctx.template Alloc(moment1_out), + moment_in_fp16 ? moment2_output_for_xdnn + : dev_ctx.template Alloc(moment2_out), + master_param_out_data, + beta1_, + beta2_, + epsilon_, + coeff, + param.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + // convert master_param_out(fp32) to param_out(T) + r = xpu::cast( + dev_ctx.x_context(), + master_param_out_data, + reinterpret_cast(dev_ctx.template Alloc(param_out)), + param_out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + } else { + // int adamw(Context* ctx, const T* g, const float* mom1, const float* mom2, + // const T* param, const float* beta1_pow, const float* beta2_pow, const + // float* lr, float* moment1_out, float* moment2_out, T* param_out, float + // beta1, float beta2, float epsilon, float coeff, int64_t n); + r = xpu::adamw( + dev_ctx.x_context(), + reinterpret_cast(grad.template data()), + moment_in_fp16 ? moment1_input_for_xdnn + : moment1.template data(), + moment_in_fp16 ? moment2_input_for_xdnn + : moment2.template data(), + reinterpret_cast(param.template data()), + beta1_pow_ptr, + beta2_pow_ptr, + new_lr, + moment_in_fp16 ? moment1_output_for_xdnn + : dev_ctx.template Alloc(moment1_out), + moment_in_fp16 ? moment2_output_for_xdnn + : dev_ctx.template Alloc(moment2_out), + reinterpret_cast(dev_ctx.template Alloc(param_out)), + beta1_, + beta2_, + epsilon_, + coeff, + param.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + } if (moment_in_fp16) { int r = 0; @@ -369,11 +686,15 @@ PD_REGISTER_KERNEL(adamw, kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); - // Skip beta1_pow, beta2_pow, skip_update data transform - kernel->OutputAt(3) - .SetBackend(phi::Backend::UNDEFINED) - .SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(4) - .SetBackend(phi::Backend::UNDEFINED) - .SetDataType(phi::DataType::FLOAT32); + + if (kernel_key.dtype() == phi::DataType::FLOAT16 || + kernel_key.dtype() == phi::DataType::BFLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + } + kernel->OutputAt(3).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); } diff --git a/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc index c5b0950552629d..37ace904b2b807 100644 --- a/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc @@ -84,5 +84,9 @@ void ReduceMeanGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - mean_grad, XPU, ALL_LAYOUT, phi::ReduceMeanGradKernel, float) {} +PD_REGISTER_KERNEL(mean_grad, + XPU, + ALL_LAYOUT, + phi::ReduceMeanGradKernel, + float, + phi::dtype::float16) {} diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index b14f8603be89e1..f3a23ce846bf13 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -177,9 +177,9 @@ def __init__( assert epsilon is not None if not isinstance(beta1, Value) and not 0 <= beta1 < 1: raise ValueError("Invalid value of beta1, expect beta1 in [0,1).") - if not isinstance(beta1, Value) and not 0 <= beta2 < 1: + if not isinstance(beta2, Value) and not 0 <= beta2 < 1: raise ValueError("Invalid value of beta2, expect beta2 in [0,1).") - if not isinstance(beta1, Value) and not 0 <= epsilon: + if not isinstance(epsilon, Value) and not 0 <= epsilon: raise ValueError("Invalid value of epsilon, expect epsilon >= 0.") if not isinstance(weight_decay, float) and not isinstance( weight_decay, (framework.Variable, Value) diff --git a/test/xpu/test_adamw_op_xpu.py b/test/xpu/test_adamw_op_xpu.py index b9120779c40f6b..f8e0b7cd545bfb 100644 --- a/test/xpu/test_adamw_op_xpu.py +++ b/test/xpu/test_adamw_op_xpu.py @@ -59,8 +59,8 @@ def adamw_step(inputs, attributes): moment1_out = beta1 * moment1 + (1 - beta1) * grad moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) - lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) - param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon)) + denom = (np.sqrt(moment2_out) / np.sqrt(1.0 - beta2_pow)) + epsilon + param_out = param + ((moment1_out / denom) * (-(lr / (1.0 - beta1_pow)))) return param_out, moment1_out, moment2_out @@ -650,6 +650,200 @@ def get_numpy_output(param, grad, moment1, moment2, lr_ratio, t): paddle.disable_static() +class TestAdamWOpMultiPrecisonWithMainGrad(unittest.TestCase): + def _test_adamw_op_dygraph_place_amp_with_maingrad( + self, place, shape, use_main_grad + ): + paddle.disable_static() + paddle.seed(10) + paddle.set_device(place) + + found_inf = None + + _weight_decay = 0.1 + with_decay = True + _lazy_mode = False + find_master = True + + _epsilon = 1e-8 + + _beta1 = 0.9 + _beta2 = 0.99 + lr_ratio_ = 1.0 + + lr_rate = 1e-8 + + param = paddle.randn(shape).astype(paddle.bfloat16) + master_weight = param.astype(paddle.float32) + grad = paddle.randn(shape).astype(paddle.bfloat16) + main_grad = grad.astype(paddle.float32) + moment1 = paddle.randn(shape).astype(paddle.float32) + moment2 = paddle.randn(shape).astype(paddle.float32).abs() + lr = paddle.zeros([1]).astype(paddle.float32) + lr[0] = lr_rate + beta1_pow_acc = paddle.ones([1]).astype(paddle.float32) + beta1_pow_acc[0] = _beta1**10 + beta2_pow_acc = paddle.ones([1]).astype(paddle.float32) + beta2_pow_acc[0] = _beta2**10 + + ref_param = param.astype(paddle.float32) + ref_beta1_pow_acc = beta1_pow_acc.astype(paddle.float32) + ref_beta2_pow_acc = beta2_pow_acc.astype(paddle.float32) + ref_moment_1 = moment1.astype(paddle.float32) + ref_moment_2 = moment2.astype(paddle.float32) + + # reference code + _, _, _, _, _, _ = paddle._C_ops.adamw_( + ref_param, + main_grad, + lr, + ref_moment_1, + ref_moment_2, + ref_beta1_pow_acc, + ref_beta2_pow_acc, + master_weight, + found_inf, + _beta1, + _beta2, + _epsilon, + lr_ratio_, + _weight_decay, + with_decay, + _lazy_mode, + 1000, + False, + False, + ) + + if use_main_grad: + _, _, _, _, _, _ = paddle._C_ops.adamw_( + param, + main_grad, + lr, + moment1, + moment2, + beta1_pow_acc, + beta2_pow_acc, + master_weight, + found_inf, + _beta1, + _beta2, + _epsilon, + lr_ratio_, + _weight_decay, + with_decay, + _lazy_mode, + 1000, + find_master, + False, + ) + np.testing.assert_allclose( + param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2 + ) + np.testing.assert_allclose( + master_weight.numpy(), ref_param.numpy(), rtol=1e-6 + ) + else: + _, _, _, _, _, _ = paddle._C_ops.adamw_( + param, + grad, + lr, + moment1, + moment2, + beta1_pow_acc, + beta2_pow_acc, + master_weight, + found_inf, + _beta1, + _beta2, + _epsilon, + lr_ratio_, + _weight_decay, + with_decay, + _lazy_mode, + 1000, + find_master, + False, + ) + np.testing.assert_allclose( + param.astype("float32").numpy(), ref_param.numpy(), rtol=1e-2 + ) + np.testing.assert_allclose( + master_weight.numpy(), ref_param.numpy(), rtol=1e-6 + ) + + def _get_places(self): + places = [] + if paddle.is_compiled_with_xpu(): + places.append('xpu') + return places + + def test_main(self): + for _ in range(1): + shape = paddle.randint(1, 1024, [2]) + for place in self._get_places(): + use_main_grad_list = [True, False] + for use_main_grad in use_main_grad_list: + self._test_adamw_op_dygraph_place_amp_with_maingrad( + place, shape, use_main_grad + ) + + +class TestAdamWOpMultiPrecison(unittest.TestCase): + def _test_adamw_op_dygraph_place_amp(self, place, use_amp=False): + paddle.disable_static() + paddle.seed(10) + paddle.set_device(place) + + input = paddle.randn((5, 5)) + + model = paddle.nn.Linear(5, 5) + + optimizer = paddle.optimizer.AdamW( + parameters=[ + { + 'params': model.parameters(), + 'weight_decay': 0.001, + 'beta1': 0.1, + 'beta2': 0.99, + } + ], + multi_precision=use_amp, + ) + + for idx in range(2): + if place == 'xpu' and use_amp: + model = paddle.amp.decorate(models=model, level='O2') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + + if place == 'xpu' and use_amp: + with paddle.amp.auto_cast(level='O2'): + output = model(input) + loss = paddle.mean(output) + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + optimizer.clear_grad() + else: + output = model(input) + loss = paddle.mean(output) + loss.backward() + optimizer.step() + optimizer.clear_grad() + + def _get_places(self): + places = ['cpu'] + if paddle.is_compiled_with_xpu(): + places.append('xpu') + return places + + def test_main(self): + for place in self._get_places(): + use_amp_list = [True, False] + for use_amp in use_amp_list: + self._test_adamw_op_dygraph_place_amp(place, use_amp) + + support_types = get_xpu_op_support_types('adamw') for stype in support_types: create_test_class(globals(), XPUTestAdamwOp1, stype) diff --git a/test/xpu/test_flash_attention_op_xpu.py b/test/xpu/test_flash_attention_op_xpu.py index 8aadadfc40ecc2..372a2ee91f1dd6 100644 --- a/test/xpu/test_flash_attention_op_xpu.py +++ b/test/xpu/test_flash_attention_op_xpu.py @@ -79,7 +79,7 @@ def setUp(self): def test_all(self): self.run_case(dtype="float32", tolerance=5e-4, tolerance_dv=5e-4) self.run_case(dtype="float16", tolerance=5e-4, tolerance_dv=1e-3) - self.run_case(dtype="bfloat16", tolerance=5e-3, tolerance_dv=1e-2) + self.run_case(dtype="bfloat16", tolerance=6e-3, tolerance_dv=1e-2) def run_case(self, dtype, tolerance, tolerance_dv): # TODO(houj04) remove debug codes after correctness check