Skip to content
Merged
10 changes: 8 additions & 2 deletions paddle/fluid/prim/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
#include "paddle/fluid/prim/utils/static/static_global_utils.h"

PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not");
PADDLE_DEFINE_EXPORTED_bool(prim_all, false, "enable prim_all or not");
PADDLE_DEFINE_EXPORTED_bool(prim_forward, false, "enable prim_forward or not");
PADDLE_DEFINE_EXPORTED_bool(prim_backward, false, "enable prim_backward not");

namespace paddle {
namespace prim {
bool PrimCommonUtils::IsBwdPrimEnabled() {
return StaticCompositeContext::Instance().IsBwdPrimEnabled();
bool res = StaticCompositeContext::Instance().IsBwdPrimEnabled();
return res || FLAGS_prim_all || FLAGS_prim_backward;
}

void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
Expand All @@ -36,7 +41,8 @@ void PrimCommonUtils::SetEagerPrimEnabled(bool enable_prim) {
}

bool PrimCommonUtils::IsFwdPrimEnabled() {
return StaticCompositeContext::Instance().IsFwdPrimEnabled();
bool res = StaticCompositeContext::Instance().IsFwdPrimEnabled();
return res || FLAGS_prim_all || FLAGS_prim_forward;
}

void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
Expand Down