diff --git a/paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc b/paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc index 2842e54a6f1c65..add0359ccf25d3 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/forwards/add_n_fwd_func.cc @@ -42,7 +42,7 @@ paddle::Tensor add_n_ad_func(const std::vector& x) { { paddle::imperative::AutoCastGuard guard( - egr::Controller::Instance().GetCurrentAMPState(), + egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return add_n_ad_func(NEW_x); } diff --git a/paddle/fluid/eager/api/manual/eager_manual/forwards/conv2d_fwd_function.cc b/paddle/fluid/eager/api/manual/eager_manual/forwards/conv2d_fwd_function.cc index fef6b57659c2c9..33e9393a615bc2 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/forwards/conv2d_fwd_function.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/forwards/conv2d_fwd_function.cc @@ -53,7 +53,7 @@ paddle::Tensor conv2d_ad_func(const paddle::Tensor& input, { paddle::imperative::AutoCastGuard guard( - egr::Controller::Instance().GetCurrentAMPState(), + egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return conv2d_ad_func(new_input, new_filter, diff --git a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc index f6902fd014cd54..18e36264ebe6bb 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc @@ -52,7 +52,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, { paddle::imperative::AutoCastGuard guard( - egr::Controller::Instance().GetCurrentAMPState(), + egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return multiply_ad_func(new_x, new_y); } @@ -399,7 +399,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, { paddle::imperative::AutoCastGuard guard( - egr::Controller::Instance().GetCurrentAMPState(), + egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return multiply_ad_func(new_x, new_y); } diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc index 23f63df2c846ad..f3612c2830dd07 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_attention_fwd_func.cc @@ -128,7 +128,7 @@ fused_attention_dygraph_function( { paddle::imperative::AutoCastGuard guard( - egr::Controller::Instance().GetCurrentAMPState(), + egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return fused_attention_dygraph_function(NEW_X, NEW_LnScale, diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_bias_dropout_residual_layer_norm_fwd_func.cc b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_bias_dropout_residual_layer_norm_fwd_func.cc index a337bc7b80030b..c76073ba0b5745 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_bias_dropout_residual_layer_norm_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_bias_dropout_residual_layer_norm_fwd_func.cc @@ -83,7 +83,7 @@ fused_bias_dropout_residual_layer_norm_dygraph_function( { paddle::imperative::AutoCastGuard guard( - egr::Controller::Instance().GetCurrentAMPState(), + egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return fused_bias_dropout_residual_layer_norm_dygraph_function( NEW_X, NEW_Residual, NEW_Bias, NEW_LnScale, NEW_LnBias, attr_map); diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_feedforward_fwd_func.cc b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_feedforward_fwd_func.cc index 11e0df8d20e970..b2f5238c5be322 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_feedforward_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_feedforward_fwd_func.cc @@ -122,7 +122,7 @@ fused_feedforward_dygraph_function( { paddle::imperative::AutoCastGuard guard( - egr::Controller::Instance().GetCurrentAMPState(), + egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return fused_feedforward_dygraph_function(NEW_X, NEW_Dropout1Seed, diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc index 16583c4dcf287c..c42a099cef4b07 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gate_attention_fwd_func.cc @@ -127,7 +127,7 @@ fused_gate_attention_dygraph_function( { paddle::imperative::AutoCastGuard guard( - egr::Controller::Instance().GetCurrentAMPState(), + egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return fused_gate_attention_dygraph_function(NEW_Query, NEW_Key, diff --git a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gemm_epilogue_fwd_func.cc b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gemm_epilogue_fwd_func.cc index f35770a2b8a22c..c4ae0840c294f4 100644 --- a/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gemm_epilogue_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/fluid_manual/forwards/fused_gemm_epilogue_fwd_func.cc @@ -49,7 +49,7 @@ paddle::Tensor fused_gemm_epilogue_dygraph_function( { paddle::imperative::AutoCastGuard guard( - egr::Controller::Instance().GetCurrentAMPState(), + egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return fused_gemm_epilogue_dygraph_function( NEW_X, NEW_Y, NEW_Bias, attr_map); diff --git a/paddle/fluid/eager/api/utils/global_utils.h b/paddle/fluid/eager/api/utils/global_utils.h index e048e4ed684989..facfbde6970751 100644 --- a/paddle/fluid/eager/api/utils/global_utils.h +++ b/paddle/fluid/eager/api/utils/global_utils.h @@ -91,8 +91,8 @@ class Controller { VLOG(6) << "Set current tracer for Controller: " << tracer_; } - const std::shared_ptr& GetCurrentAMPState() { - return paddle::imperative::GetCurrentAMPState(); + const std::shared_ptr& GetCurrentAmpAttrs() { + return paddle::imperative::GetCurrentAmpAttrs(); } const std::unordered_map>& diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index be4b32910e36d3..b9e04b3e318ac9 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -1763,7 +1763,7 @@ static std::pair GenerateForwardFunctionContents( const char* CALL_BACK_TEMPLATE = " {\n" " paddle::imperative::AutoCastGuard " - "guard(egr::Controller::Instance().GetCurrentAMPState(), " + "guard(egr::Controller::Instance().GetCurrentAmpAttrs(), " "paddle::imperative::AmpLevel::O0);\n" " return %s_dygraph_function(%s);\n" " }"; diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 17c36a1681bb36..e1ad1a0dc81b20 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -520,7 +520,7 @@ class {} : public egr::GradNodeBase {{ {} {} {{ - paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentAMPState(), paddle::imperative::AmpLevel::O0); + paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); {} }} }} diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 6637ddbe9ee670..0dd5bc5779d21c 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -119,7 +119,7 @@ OpSupportedInfos(const std::string& place, std::move(all_ops), std::move(supported_ops), std::move(unsupported_ops)); } -AutoCastGuard::AutoCastGuard(std::shared_ptr state, AmpLevel level) +AutoCastGuard::AutoCastGuard(std::shared_ptr state, AmpLevel level) : state_(state) { pre_amp_level_ = state_->GetAmpLevel(); @@ -231,25 +231,25 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops) { return os; } -thread_local bool AMPState::use_promote_ = false; +thread_local bool AmpAttrs::use_promote_ = false; -thread_local AmpLevel AMPState::amp_level_ = AmpLevel::O0; +thread_local AmpLevel AmpAttrs::amp_level_ = AmpLevel::O0; -thread_local phi::DataType AMPState::amp_dtype_ = phi::DataType::FLOAT32; +thread_local phi::DataType AmpAttrs::amp_dtype_ = phi::DataType::FLOAT32; -AMPState::AMPState() {} +AmpAttrs::AmpAttrs() {} -AMPState::~AMPState() = default; +AmpAttrs::~AmpAttrs() = default; -bool AMPState::GetUsePromote() const { return use_promote_; } +bool AmpAttrs::GetUsePromote() const { return use_promote_; } -void AMPState::SetUsePromote(bool use_promote) { use_promote_ = use_promote; } +void AmpAttrs::SetUsePromote(bool use_promote) { use_promote_ = use_promote; } -AmpLevel AMPState::GetAmpLevel() const { return amp_level_; } +AmpLevel AmpAttrs::GetAmpLevel() const { return amp_level_; } -void AMPState::SetAmpLevel(AmpLevel level) { amp_level_ = level; } +void AmpAttrs::SetAmpLevel(AmpLevel level) { amp_level_ = level; } -std::string AMPState::GetAmpDtype() const { +std::string AmpAttrs::GetAmpDtype() const { if (amp_dtype_ == phi::DataType::FLOAT16) { return std::string("float16"); } else if (amp_dtype_ == phi::DataType::BFLOAT16) { @@ -259,7 +259,7 @@ std::string AMPState::GetAmpDtype() const { } } -void AMPState::SetAmpDtype(std::string amp_dtype) { +void AmpAttrs::SetAmpDtype(std::string amp_dtype) { if (amp_dtype == "float16") { amp_dtype_ = phi::DataType::FLOAT16; } else if (amp_dtype == "bfloat16") { @@ -269,7 +269,7 @@ void AMPState::SetAmpDtype(std::string amp_dtype) { } } -phi::DataType AMPState::GetAmpPhiDtype() const { return amp_dtype_; } +phi::DataType AmpAttrs::GetAmpPhiDtype() const { return amp_dtype_; } template inline std::string GetDtypeStr(const std::shared_ptr& var) { @@ -308,7 +308,7 @@ static inline std::shared_ptr CastToType( imperative::NameVarMap outs = {{"Out", {out}}}; { - AutoCastGuard guard(imperative::GetCurrentAMPState(), AmpLevel::O0); + AutoCastGuard guard(imperative::GetCurrentAmpAttrs(), AmpLevel::O0); tracer->TraceOp("cast", ins, outs, std::move(attrs)); } diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index 35969fb5188180..1864f990576b12 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -84,10 +84,10 @@ class AmpOperators { std::ostream& operator<<(std::ostream& os, AmpOperators& ops); -class AMPState { +class AmpAttrs { public: - AMPState(); - ~AMPState(); + AmpAttrs(); + ~AmpAttrs(); bool GetUsePromote() const; void SetUsePromote(bool use_promote); AmpLevel GetAmpLevel() const; @@ -105,7 +105,7 @@ class AMPState { // NOTE(zhiqiu): AutoCastGuard is used for RAII. class AutoCastGuard { public: - AutoCastGuard(std::shared_ptr state, AmpLevel guard_level); + AutoCastGuard(std::shared_ptr state, AmpLevel guard_level); ~AutoCastGuard(); @@ -114,7 +114,7 @@ class AutoCastGuard { AutoCastGuard& operator=(const AutoCastGuard& guard) = delete; private: - std::shared_ptr state_; + std::shared_ptr state_; AmpLevel pre_amp_level_; }; diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index decb93a4435690..48b51265421c5a 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -52,8 +52,8 @@ thread_local bool Tracer::use_layout_autotune_ = false; static thread_local std::shared_ptr g_current_tracer(nullptr); -static thread_local std::shared_ptr g_current_amp_state = - std::make_shared(); +static thread_local std::shared_ptr g_current_amp_attrs = + std::make_shared(); TEST_API void Tracer::DisableLayoutAutoTune() { use_layout_autotune_ = false; } TEST_API void Tracer::EnableLayoutAutoTune() { @@ -89,8 +89,8 @@ TEST_API void SetCurrentTracer(const std::shared_ptr& tracer) { VLOG(6) << "Set current tracer: " << g_current_tracer; } -const std::shared_ptr& GetCurrentAMPState() { - return g_current_amp_state; +const std::shared_ptr& GetCurrentAmpAttrs() { + return g_current_amp_attrs; } void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) { @@ -276,23 +276,23 @@ void Tracer::TraceOpImpl(const std::string& type, : attr_checker->GetDefaultAttrMap(); std::unique_ptr> ins_amp = nullptr; - if (GetCurrentAMPState()->GetAmpLevel() == AmpLevel::O1) { - if (GetCurrentAMPState()->GetAmpPhiDtype() == phi::DataType::FLOAT16) { + if (GetCurrentAmpAttrs()->GetAmpLevel() == AmpLevel::O1) { + if (GetCurrentAmpAttrs()->GetAmpPhiDtype() == phi::DataType::FLOAT16) { VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type; ins_amp = std::make_unique>( AutoCastInputs(type, ins)); - } else if (GetCurrentAMPState()->GetAmpPhiDtype() == + } else if (GetCurrentAmpAttrs()->GetAmpPhiDtype() == phi::DataType::BFLOAT16) { VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type; ins_amp = std::make_unique>( AutoCastBF16Inputs(type, ins)); } - } else if (GetCurrentAMPState()->GetAmpLevel() == AmpLevel::O2) { - if (GetCurrentAMPState()->GetAmpPhiDtype() == phi::DataType::FLOAT16) { + } else if (GetCurrentAmpAttrs()->GetAmpLevel() == AmpLevel::O2) { + if (GetCurrentAmpAttrs()->GetAmpPhiDtype() == phi::DataType::FLOAT16) { VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type; ins_amp = std::make_unique>( CastPureFp16Inputs(type, ins)); - } else if (GetCurrentAMPState()->GetAmpPhiDtype() == + } else if (GetCurrentAmpAttrs()->GetAmpPhiDtype() == phi::DataType::BFLOAT16) { VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type; ins_amp = std::make_unique>( @@ -560,20 +560,20 @@ TEST_API void Tracer::SetHasGrad(bool has_grad) { has_grad_ = has_grad; } TEST_API void Tracer::SetUsePromote(bool use_promote) { VLOG(4) << "set use_promote to " << use_promote; - g_current_amp_state->SetUsePromote(use_promote); + g_current_amp_attrs->SetUsePromote(use_promote); } TEST_API bool Tracer::GetUsePromote() const { - return g_current_amp_state->GetUsePromote(); + return g_current_amp_attrs->GetUsePromote(); } TEST_API void Tracer::SetAmpLevel(AmpLevel level) { VLOG(4) << "set amp_level to " << static_cast(level); - g_current_amp_state->SetAmpLevel(level); + g_current_amp_attrs->SetAmpLevel(level); } TEST_API AmpLevel Tracer::GetAmpLevel() const { - return g_current_amp_state->GetAmpLevel(); + return g_current_amp_attrs->GetAmpLevel(); } bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, @@ -604,15 +604,15 @@ bool Tracer::IsProgramDescTracingEnabled() const { void Tracer::SetAmpDtype(std::string amp_dtype) { VLOG(4) << "set amp_dtype to " << amp_dtype; - g_current_amp_state->SetAmpDtype(amp_dtype); + g_current_amp_attrs->SetAmpDtype(amp_dtype); } std::string Tracer::GetAmpDtype() const { - return g_current_amp_state->GetAmpDtype(); + return g_current_amp_attrs->GetAmpDtype(); } phi::DataType Tracer::GetAmpPhiDtype() const { - return g_current_amp_state->GetAmpPhiDtype(); + return g_current_amp_attrs->GetAmpPhiDtype(); } bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins, diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 9e0033c4269bd2..b6f61c36f670bc 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -200,7 +200,7 @@ class Tracer { // To access static variable current_tracer const std::shared_ptr& GetCurrentTracer(); TEST_API void SetCurrentTracer(const std::shared_ptr& tracer_); -const std::shared_ptr& GetCurrentAMPState(); +const std::shared_ptr& GetCurrentAmpAttrs(); void IncreaseVarbaseReferenceCountUntilCopyComplete( const std::shared_ptr& var, const platform::Place& place); diff --git a/paddle/fluid/pir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py index 3f805b9c29b331..1d62d14213ddad 100644 --- a/paddle/fluid/pir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -99,7 +99,7 @@ AMP_LOGIC_TEMPLATE = """ - if (egr::Controller::Instance().GetCurrentAMPState()->GetAmpLevel() != paddle::imperative::AmpLevel::O0){{ + if (egr::Controller::Instance().GetCurrentAmpAttrs()->GetAmpLevel() != paddle::imperative::AmpLevel::O0){{ VLOG(5) << "Check and Prepare For AMP"; auto op_name = phi::TransToFluidOpName("{op_name}"); std::vector> amp_values_vector = {{ {no_optional_inputs} }}; @@ -107,7 +107,7 @@ auto amp_dst_dtype = paddle::dialect::GetAmpDestDtype("{op_name}", amp_values_vector); {new_inputs} {{ - paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentAMPState(), paddle::imperative::AmpLevel::O0); + paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return paddle::dialect::{op_name}({args}); }} }} diff --git a/paddle/fluid/pir/dialect/operator/utils/amp_utils.cc b/paddle/fluid/pir/dialect/operator/utils/amp_utils.cc index 4d224befc3514f..ac631baf66ee61 100644 --- a/paddle/fluid/pir/dialect/operator/utils/amp_utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/amp_utils.cc @@ -36,7 +36,7 @@ phi::DataType GetPromoteType( return dst_type; } - if (egr::Controller::Instance().GetCurrentAMPState()->GetAmpDtype() == + if (egr::Controller::Instance().GetCurrentAmpAttrs()->GetAmpDtype() == "float16") { if (op_name == "fused_attention") { for (size_t i = 0; i < amp_values_vector.size(); i++) { @@ -75,7 +75,7 @@ phi::DataType GetPromoteType( pir::Value Cast(const pir::Value& input, const phi::DataType& dst_dtype) { paddle::imperative::AutoCastGuard guard( - egr::Controller::Instance().GetCurrentAMPState(), + egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); return paddle::dialect::cast(input, dst_dtype); } @@ -168,13 +168,13 @@ phi::DataType GetAmpDestDtype( const std::vector>& amp_values_vector) { auto amp_level = egr::Controller::Instance().GetAMPLevel(); auto amp_setting_dtype = - egr::Controller::Instance().GetCurrentAMPState()->GetAmpPhiDtype(); + egr::Controller::Instance().GetCurrentAmpAttrs()->GetAmpPhiDtype(); auto dst_type = amp_setting_dtype; bool use_promote = true; if (amp_level == paddle::imperative::AmpLevel::O2) { use_promote = - egr::Controller::Instance().GetCurrentAMPState()->GetUsePromote(); + egr::Controller::Instance().GetCurrentAmpAttrs()->GetUsePromote(); } if (use_promote) { diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 82437281af0e3e..c540fe0687d883 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -634,8 +634,8 @@ void BindImperative(py::module *m_ptr) { egr::Controller::Instance().SetCurrentTracer(tracer); imperative::SetCurrentTracer(tracer); }); - m.def("_get_amp_state", - []() { return egr::Controller::Instance().GetCurrentAMPState(); }); + m.def("_get_amp_attrs", + []() { return egr::Controller::Instance().GetCurrentAmpAttrs(); }); m.def("_set_amp_op_list", [](std::unordered_set &allow_ops, std::unordered_set &block_ops) { @@ -664,17 +664,17 @@ void BindImperative(py::module *m_ptr) { .value("O3", paddle::imperative::AmpLevel::O3) .export_values(); - py::class_>( - m, "AMPState", R"DOC()DOC") + py::class_>( + m, "AmpAttrs", R"DOC()DOC") .def_property("_use_promote", - &imperative::AMPState::GetUsePromote, - &imperative::AMPState::SetUsePromote) + &imperative::AmpAttrs::GetUsePromote, + &imperative::AmpAttrs::SetUsePromote) .def_property("_amp_level", - &imperative::AMPState::GetAmpLevel, - &imperative::AMPState::SetAmpLevel) + &imperative::AmpAttrs::GetAmpLevel, + &imperative::AmpAttrs::SetAmpLevel) .def_property("_amp_dtype", - &imperative::AMPState::GetAmpDtype, - &imperative::AMPState::SetAmpDtype); + &imperative::AmpAttrs::GetAmpDtype, + &imperative::AmpAttrs::SetAmpDtype); py::class_>( m, "Tracer", R"DOC()DOC") diff --git a/python/paddle/base/core.py b/python/paddle/base/core.py index 6b97c60c4e061f..fcb1f5605e8236 100644 --- a/python/paddle/base/core.py +++ b/python/paddle/base/core.py @@ -291,8 +291,8 @@ def to_list(s): _device_synchronize, _dygraph_debug_level, _get_all_register_op_kernels, + _get_amp_attrs, _get_amp_op_list, - _get_amp_state, _get_current_stream, _get_eager_deletion_vars, _get_phi_kernel_name, diff --git a/test/amp/test_pir_amp.py b/test/amp/test_pir_amp.py index bb435a7209f916..6b4cd5e13c60df 100644 --- a/test/amp/test_pir_amp.py +++ b/test/amp/test_pir_amp.py @@ -21,21 +21,21 @@ from paddle.base import core -class TestAMPState(unittest.TestCase): - def test_pir_amp_state(self): +class TestAmpAttrs(unittest.TestCase): + def test_pir_amp_attrs(self): with paddle.pir_utils.IrGuard(): - amp_state = core._get_amp_state() - amp_state._use_promote = True - amp_state._amp_level = core.AmpLevel.O2 - amp_state._amp_dtype = 'float16' - np.testing.assert_equal(core._get_amp_state()._use_promote, True) + amp_attrs = core._get_amp_attrs() + amp_attrs._use_promote = True + amp_attrs._amp_level = core.AmpLevel.O2 + amp_attrs._amp_dtype = 'float16' + np.testing.assert_equal(core._get_amp_attrs()._use_promote, True) np.testing.assert_equal( - core._get_amp_state()._amp_level, core.AmpLevel.O2 + core._get_amp_attrs()._amp_level, core.AmpLevel.O2 ) - np.testing.assert_equal(core._get_amp_state()._amp_dtype, 'float16') - amp_state._use_promote = False - amp_state._amp_level = core.AmpLevel.O0 - amp_state._amp_dtype = 'float32' + np.testing.assert_equal(core._get_amp_attrs()._amp_dtype, 'float16') + amp_attrs._use_promote = False + amp_attrs._amp_level = core.AmpLevel.O0 + amp_attrs._amp_dtype = 'float32' class TestPirAMPProgram(unittest.TestCase): @@ -49,10 +49,10 @@ def test_linear_amp_o1(self): x = paddle.static.data('x', [3, 4], 'float32') linear = paddle.nn.Linear(4, 5) - amp_state = core._get_amp_state() - amp_state._use_promote = True - amp_state._amp_level = core.AmpLevel.O1 - amp_state._amp_dtype = 'float16' + amp_attrs = core._get_amp_attrs() + amp_attrs._use_promote = True + amp_attrs._amp_level = core.AmpLevel.O1 + amp_attrs._amp_dtype = 'float16' ( original_white_list, original_black_list, @@ -73,9 +73,9 @@ def test_linear_amp_o1(self): np.testing.assert_equal(out2.dtype, core.DataType.FLOAT32) np.testing.assert_equal(cast_op_count, 3) - amp_state._use_promote = False - amp_state._amp_level = core.AmpLevel.O0 - amp_state._amp_dtype = 'float32' + amp_attrs._use_promote = False + amp_attrs._amp_level = core.AmpLevel.O0 + amp_attrs._amp_dtype = 'float32' core._set_amp_op_list(original_white_list, original_black_list) _white_list, _black_list = core._get_amp_op_list() np.testing.assert_equal(len(_white_list), 0)