1313// limitations under the License.
1414
1515#include " paddle/fluid/prim/utils/utils.h"
16+ #include < sstream>
1617#include " paddle/fluid/platform/flags.h"
1718#include " paddle/fluid/prim/utils/static/static_global_utils.h"
1819
1920PADDLE_DEFINE_EXPORTED_bool (prim_enabled, false , " enable_prim or not" );
20- PADDLE_DEFINE_EXPORTED_string (prim_blacklist, " " , " prim ops blacklist" );
21+ PADDLE_DEFINE_EXPORTED_bool (prim_all, false , " enable prim_all or not" );
22+ PADDLE_DEFINE_EXPORTED_bool (prim_forward, false , " enable prim_forward or not" );
23+ PADDLE_DEFINE_EXPORTED_bool (prim_backward, false , " enable prim_backward not" );
2124
2225namespace paddle {
2326namespace prim {
24-
2527bool PrimCommonUtils::IsBwdPrimEnabled () {
26- return StaticCompositeContext::Instance ().IsBwdPrimEnabled ();
28+ bool res = StaticCompositeContext::Instance ().IsBwdPrimEnabled ();
29+ return res || FLAGS_prim_all || FLAGS_prim_backward;
2730}
2831
2932void PrimCommonUtils::SetBwdPrimEnabled (bool enable_prim) {
@@ -39,16 +42,15 @@ void PrimCommonUtils::SetEagerPrimEnabled(bool enable_prim) {
3942}
4043
4144bool PrimCommonUtils::IsFwdPrimEnabled () {
42- return StaticCompositeContext::Instance ().IsFwdPrimEnabled ();
45+ bool res = StaticCompositeContext::Instance ().IsFwdPrimEnabled ();
46+ return res || FLAGS_prim_all || FLAGS_prim_forward;
4347}
4448
4549void PrimCommonUtils::SetFwdPrimEnabled (bool enable_prim) {
46- VLOG (0 ) << " FLAGS_prim_enabled ====================== " << FLAGS_prim_enabled;
4750 StaticCompositeContext::Instance ().SetFwdPrimEnabled (enable_prim);
4851}
4952
5053void PrimCommonUtils::SetAllPrimEnabled (bool enable_prim) {
51- VLOG (0 ) << " FLAGS_prim_enabled ====================== " << FLAGS_prim_enabled;
5254 StaticCompositeContext::Instance ().SetAllPrimEnabled (enable_prim);
5355}
5456
0 commit comments