Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ paddle::Tensor add_n_ad_func(const std::vector<paddle::Tensor>& x) {

{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentAMPState(),
egr::Controller::Instance().GetCurrentAmpAttrs(),
paddle::imperative::AmpLevel::O0);
return add_n_ad_func(NEW_x);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ paddle::Tensor conv2d_ad_func(const paddle::Tensor& input,

{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentAMPState(),
egr::Controller::Instance().GetCurrentAmpAttrs(),
paddle::imperative::AmpLevel::O0);
return conv2d_ad_func(new_input,
new_filter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,

{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentAMPState(),
egr::Controller::Instance().GetCurrentAmpAttrs(),
paddle::imperative::AmpLevel::O0);
return multiply_ad_func(new_x, new_y);
}
Expand Down Expand Up @@ -399,7 +399,7 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x,

{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentAMPState(),
egr::Controller::Instance().GetCurrentAmpAttrs(),
paddle::imperative::AmpLevel::O0);
return multiply_ad_func(new_x, new_y);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fused_attention_dygraph_function(

{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentAMPState(),
egr::Controller::Instance().GetCurrentAmpAttrs(),
paddle::imperative::AmpLevel::O0);
return fused_attention_dygraph_function(NEW_X,
NEW_LnScale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ fused_bias_dropout_residual_layer_norm_dygraph_function(

{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentAMPState(),
egr::Controller::Instance().GetCurrentAmpAttrs(),
paddle::imperative::AmpLevel::O0);
return fused_bias_dropout_residual_layer_norm_dygraph_function(
NEW_X, NEW_Residual, NEW_Bias, NEW_LnScale, NEW_LnBias, attr_map);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ fused_feedforward_dygraph_function(

{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentAMPState(),
egr::Controller::Instance().GetCurrentAmpAttrs(),
paddle::imperative::AmpLevel::O0);
return fused_feedforward_dygraph_function(NEW_X,
NEW_Dropout1Seed,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ fused_gate_attention_dygraph_function(

{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentAMPState(),
egr::Controller::Instance().GetCurrentAmpAttrs(),
paddle::imperative::AmpLevel::O0);
return fused_gate_attention_dygraph_function(NEW_Query,
NEW_Key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ paddle::Tensor fused_gemm_epilogue_dygraph_function(

{
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentAMPState(),
egr::Controller::Instance().GetCurrentAmpAttrs(),
paddle::imperative::AmpLevel::O0);
return fused_gemm_epilogue_dygraph_function(
NEW_X, NEW_Y, NEW_Bias, attr_map);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/eager/api/utils/global_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ class Controller {
VLOG(6) << "Set current tracer for Controller: " << tracer_;
}

const std::shared_ptr<paddle::imperative::AMPState>& GetCurrentAMPState() {
return paddle::imperative::GetCurrentAMPState();
const std::shared_ptr<paddle::imperative::AmpAttrs>& GetCurrentAmpAttrs() {
return paddle::imperative::GetCurrentAmpAttrs();
}

const std::unordered_map<std::string, std::vector<paddle::OpMetaInfo>>&
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1763,7 +1763,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const char* CALL_BACK_TEMPLATE =
" {\n"
" paddle::imperative::AutoCastGuard "
"guard(egr::Controller::Instance().GetCurrentAMPState(), "
"guard(egr::Controller::Instance().GetCurrentAmpAttrs(), "
"paddle::imperative::AmpLevel::O0);\n"
" return %s_dygraph_function(%s);\n"
" }";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ class {} : public egr::GradNodeBase {{
{}
{}
{{
paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentAMPState(), paddle::imperative::AmpLevel::O0);
paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0);
{}
}}
}}
Expand Down
28 changes: 14 additions & 14 deletions paddle/fluid/imperative/amp_auto_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ OpSupportedInfos(const std::string& place,
std::move(all_ops), std::move(supported_ops), std::move(unsupported_ops));
}

AutoCastGuard::AutoCastGuard(std::shared_ptr<AMPState> state, AmpLevel level)
AutoCastGuard::AutoCastGuard(std::shared_ptr<AmpAttrs> state, AmpLevel level)
: state_(state) {
pre_amp_level_ = state_->GetAmpLevel();

Expand Down Expand Up @@ -231,25 +231,25 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
return os;
}

thread_local bool AMPState::use_promote_ = false;
thread_local bool AmpAttrs::use_promote_ = false;

thread_local AmpLevel AMPState::amp_level_ = AmpLevel::O0;
thread_local AmpLevel AmpAttrs::amp_level_ = AmpLevel::O0;

thread_local phi::DataType AMPState::amp_dtype_ = phi::DataType::FLOAT32;
thread_local phi::DataType AmpAttrs::amp_dtype_ = phi::DataType::FLOAT32;

AMPState::AMPState() {}
AmpAttrs::AmpAttrs() {}

AMPState::~AMPState() = default;
AmpAttrs::~AmpAttrs() = default;

bool AMPState::GetUsePromote() const { return use_promote_; }
bool AmpAttrs::GetUsePromote() const { return use_promote_; }

void AMPState::SetUsePromote(bool use_promote) { use_promote_ = use_promote; }
void AmpAttrs::SetUsePromote(bool use_promote) { use_promote_ = use_promote; }

AmpLevel AMPState::GetAmpLevel() const { return amp_level_; }
AmpLevel AmpAttrs::GetAmpLevel() const { return amp_level_; }

void AMPState::SetAmpLevel(AmpLevel level) { amp_level_ = level; }
void AmpAttrs::SetAmpLevel(AmpLevel level) { amp_level_ = level; }

std::string AMPState::GetAmpDtype() const {
std::string AmpAttrs::GetAmpDtype() const {
if (amp_dtype_ == phi::DataType::FLOAT16) {
return std::string("float16");
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
Expand All @@ -259,7 +259,7 @@ std::string AMPState::GetAmpDtype() const {
}
}

void AMPState::SetAmpDtype(std::string amp_dtype) {
void AmpAttrs::SetAmpDtype(std::string amp_dtype) {
if (amp_dtype == "float16") {
amp_dtype_ = phi::DataType::FLOAT16;
} else if (amp_dtype == "bfloat16") {
Expand All @@ -269,7 +269,7 @@ void AMPState::SetAmpDtype(std::string amp_dtype) {
}
}

phi::DataType AMPState::GetAmpPhiDtype() const { return amp_dtype_; }
phi::DataType AmpAttrs::GetAmpPhiDtype() const { return amp_dtype_; }

template <typename VarType>
inline std::string GetDtypeStr(const std::shared_ptr<VarType>& var) {
Expand Down Expand Up @@ -308,7 +308,7 @@ static inline std::shared_ptr<VarType> CastToType(
imperative::NameVarMap<VarType> outs = {{"Out", {out}}};

{
AutoCastGuard guard(imperative::GetCurrentAMPState(), AmpLevel::O0);
AutoCastGuard guard(imperative::GetCurrentAmpAttrs(), AmpLevel::O0);
tracer->TraceOp("cast", ins, outs, std::move(attrs));
}

Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/imperative/amp_auto_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ class AmpOperators {

std::ostream& operator<<(std::ostream& os, AmpOperators& ops);

class AMPState {
class AmpAttrs {
public:
AMPState();
~AMPState();
AmpAttrs();
~AmpAttrs();
bool GetUsePromote() const;
void SetUsePromote(bool use_promote);
AmpLevel GetAmpLevel() const;
Expand All @@ -105,7 +105,7 @@ class AMPState {
// NOTE(zhiqiu): AutoCastGuard is used for RAII.
class AutoCastGuard {
public:
AutoCastGuard(std::shared_ptr<AMPState> state, AmpLevel guard_level);
AutoCastGuard(std::shared_ptr<AmpAttrs> state, AmpLevel guard_level);

~AutoCastGuard();

Expand All @@ -114,7 +114,7 @@ class AutoCastGuard {
AutoCastGuard& operator=(const AutoCastGuard& guard) = delete;

private:
std::shared_ptr<AMPState> state_;
std::shared_ptr<AmpAttrs> state_;
AmpLevel pre_amp_level_;
};

Expand Down
34 changes: 17 additions & 17 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ thread_local bool Tracer::use_layout_autotune_ = false;

static thread_local std::shared_ptr<Tracer> g_current_tracer(nullptr);

static thread_local std::shared_ptr<AMPState> g_current_amp_state =
std::make_shared<AMPState>();
static thread_local std::shared_ptr<AmpAttrs> g_current_amp_attrs =
std::make_shared<AmpAttrs>();

TEST_API void Tracer::DisableLayoutAutoTune() { use_layout_autotune_ = false; }
TEST_API void Tracer::EnableLayoutAutoTune() {
Expand Down Expand Up @@ -89,8 +89,8 @@ TEST_API void SetCurrentTracer(const std::shared_ptr<Tracer>& tracer) {
VLOG(6) << "Set current tracer: " << g_current_tracer;
}

const std::shared_ptr<AMPState>& GetCurrentAMPState() {
return g_current_amp_state;
const std::shared_ptr<AmpAttrs>& GetCurrentAmpAttrs() {
return g_current_amp_attrs;
}

void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) {
Expand Down Expand Up @@ -276,23 +276,23 @@ void Tracer::TraceOpImpl(const std::string& type,
: attr_checker->GetDefaultAttrMap();

std::unique_ptr<NameVarMap<VarType>> ins_amp = nullptr;
if (GetCurrentAMPState()->GetAmpLevel() == AmpLevel::O1) {
if (GetCurrentAMPState()->GetAmpPhiDtype() == phi::DataType::FLOAT16) {
if (GetCurrentAmpAttrs()->GetAmpLevel() == AmpLevel::O1) {
if (GetCurrentAmpAttrs()->GetAmpPhiDtype() == phi::DataType::FLOAT16) {
VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastInputs<VarType>(type, ins));
} else if (GetCurrentAMPState()->GetAmpPhiDtype() ==
} else if (GetCurrentAmpAttrs()->GetAmpPhiDtype() ==
phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
ins_amp = std::make_unique<NameVarMap<VarType>>(
AutoCastBF16Inputs<VarType>(type, ins));
}
} else if (GetCurrentAMPState()->GetAmpLevel() == AmpLevel::O2) {
if (GetCurrentAMPState()->GetAmpPhiDtype() == phi::DataType::FLOAT16) {
} else if (GetCurrentAmpAttrs()->GetAmpLevel() == AmpLevel::O2) {
if (GetCurrentAmpAttrs()->GetAmpPhiDtype() == phi::DataType::FLOAT16) {
VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
ins_amp = std::make_unique<NameVarMap<VarType>>(
CastPureFp16Inputs<VarType>(type, ins));
} else if (GetCurrentAMPState()->GetAmpPhiDtype() ==
} else if (GetCurrentAmpAttrs()->GetAmpPhiDtype() ==
phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
ins_amp = std::make_unique<NameVarMap<VarType>>(
Expand Down Expand Up @@ -560,20 +560,20 @@ TEST_API void Tracer::SetHasGrad(bool has_grad) { has_grad_ = has_grad; }

TEST_API void Tracer::SetUsePromote(bool use_promote) {
VLOG(4) << "set use_promote to " << use_promote;
g_current_amp_state->SetUsePromote(use_promote);
g_current_amp_attrs->SetUsePromote(use_promote);
}

TEST_API bool Tracer::GetUsePromote() const {
return g_current_amp_state->GetUsePromote();
return g_current_amp_attrs->GetUsePromote();
}

TEST_API void Tracer::SetAmpLevel(AmpLevel level) {
VLOG(4) << "set amp_level to " << static_cast<unsigned int>(level);
g_current_amp_state->SetAmpLevel(level);
g_current_amp_attrs->SetAmpLevel(level);
}

TEST_API AmpLevel Tracer::GetAmpLevel() const {
return g_current_amp_state->GetAmpLevel();
return g_current_amp_attrs->GetAmpLevel();
}

bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
Expand Down Expand Up @@ -604,15 +604,15 @@ bool Tracer::IsProgramDescTracingEnabled() const {

void Tracer::SetAmpDtype(std::string amp_dtype) {
VLOG(4) << "set amp_dtype to " << amp_dtype;
g_current_amp_state->SetAmpDtype(amp_dtype);
g_current_amp_attrs->SetAmpDtype(amp_dtype);
}

std::string Tracer::GetAmpDtype() const {
return g_current_amp_state->GetAmpDtype();
return g_current_amp_attrs->GetAmpDtype();
}

phi::DataType Tracer::GetAmpPhiDtype() const {
return g_current_amp_state->GetAmpPhiDtype();
return g_current_amp_attrs->GetAmpPhiDtype();
}

bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class Tracer {
// To access static variable current_tracer
const std::shared_ptr<Tracer>& GetCurrentTracer();
TEST_API void SetCurrentTracer(const std::shared_ptr<Tracer>& tracer_);
const std::shared_ptr<AMPState>& GetCurrentAMPState();
const std::shared_ptr<AmpAttrs>& GetCurrentAmpAttrs();
void IncreaseVarbaseReferenceCountUntilCopyComplete(
const std::shared_ptr<imperative::VarBase>& var,
const platform::Place& place);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@


AMP_LOGIC_TEMPLATE = """
if (egr::Controller::Instance().GetCurrentAMPState()->GetAmpLevel() != paddle::imperative::AmpLevel::O0){{
if (egr::Controller::Instance().GetCurrentAmpAttrs()->GetAmpLevel() != paddle::imperative::AmpLevel::O0){{
VLOG(5) << "Check and Prepare For AMP";
auto op_name = phi::TransToFluidOpName("{op_name}");
std::vector<std::vector<pir::Value>> amp_values_vector = {{ {no_optional_inputs} }};
{optional_inputs}
auto amp_dst_dtype = paddle::dialect::GetAmpDestDtype("{op_name}", amp_values_vector);
{new_inputs}
{{
paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentAMPState(), paddle::imperative::AmpLevel::O0);
paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0);
return paddle::dialect::{op_name}({args});
}}
}}
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pir/dialect/operator/utils/amp_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ phi::DataType GetPromoteType(
return dst_type;
}

if (egr::Controller::Instance().GetCurrentAMPState()->GetAmpDtype() ==
if (egr::Controller::Instance().GetCurrentAmpAttrs()->GetAmpDtype() ==
"float16") {
if (op_name == "fused_attention") {
for (size_t i = 0; i < amp_values_vector.size(); i++) {
Expand Down Expand Up @@ -75,7 +75,7 @@ phi::DataType GetPromoteType(

pir::Value Cast(const pir::Value& input, const phi::DataType& dst_dtype) {
paddle::imperative::AutoCastGuard guard(
egr::Controller::Instance().GetCurrentAMPState(),
egr::Controller::Instance().GetCurrentAmpAttrs(),
paddle::imperative::AmpLevel::O0);
return paddle::dialect::cast(input, dst_dtype);
}
Expand Down Expand Up @@ -168,13 +168,13 @@ phi::DataType GetAmpDestDtype(
const std::vector<std::vector<pir::Value>>& amp_values_vector) {
auto amp_level = egr::Controller::Instance().GetAMPLevel();
auto amp_setting_dtype =
egr::Controller::Instance().GetCurrentAMPState()->GetAmpPhiDtype();
egr::Controller::Instance().GetCurrentAmpAttrs()->GetAmpPhiDtype();
auto dst_type = amp_setting_dtype;

bool use_promote = true;
if (amp_level == paddle::imperative::AmpLevel::O2) {
use_promote =
egr::Controller::Instance().GetCurrentAMPState()->GetUsePromote();
egr::Controller::Instance().GetCurrentAmpAttrs()->GetUsePromote();
}

if (use_promote) {
Expand Down
Loading