Skip to content

Commit 34aaff2

Browse files
committed
add decomp guard
1 parent beba53f commit 34aaff2

3 files changed

Lines changed: 36 additions & 12 deletions

File tree

python/paddle/base/core.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,17 @@ def _get_batch_norm_none_var(op):
518518

519519
# This api is used for development for dynamic shape in prim, and will be removed in future.
520520
def _enable_prim_dynamic_shape():
521-
if os.getenv("FLAGS_prim_skip_dynamic") == "1":
521+
flag = os.getenv("FLAGS_prim_skip_dynamic")
522+
if flag and flag.lower() in ("1", "true"):
523+
return True
524+
else:
525+
return False
526+
527+
528+
# This api is used for development for sinking decomp in c++, and will be removed in future.
529+
def _enable_sink_decomp():
530+
flag = os.getenv("FLAGS_sink_decomp")
531+
if flag and flag.lower() in ("1", "true"):
522532
return True
523533
else:
524534
return False

python/paddle/decomposition/decomp.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16+
import os
1617
import typing
1718
import warnings
1819

@@ -24,11 +25,25 @@
2425
has_decomp,
2526
)
2627
from paddle.base.libpaddle.pir import Block, Operation, Program
28+
from paddle.base.wrapped_decorator import signature_safe_contextmanager
2729
from paddle.framework import core
2830

2931
from . import register
3032

3133

34+
# For sinking decomp in c++. In future, sinking decomp will be implemented in c++ by default and then this api will be removed.
35+
@signature_safe_contextmanager
36+
def sink_decomp_guard():
37+
sink_decomp = core._enable_sink_decomp()
38+
try:
39+
if not sink_decomp:
40+
os.environ['FLAGS_sink_decomp'] = 'true'
41+
yield
42+
finally:
43+
if not sink_decomp:
44+
os.environ['FLAGS_sink_decomp'] = 'false'
45+
46+
3247
def _build_tensor_tuple(xs):
3348
if isinstance(xs, pir.Value):
3449
return (xs,)
@@ -176,16 +191,6 @@ def decompose(
176191
src_vars,
177192
blacklist=frozenset(),
178193
whitelist=frozenset(),
179-
):
180-
blacklist = core.prim_config["forward_blacklist"] | blacklist
181-
return core.decomp_tmp(program, src_vars, blacklist, whitelist)
182-
183-
184-
def decompose_(
185-
program,
186-
src_vars,
187-
blacklist=frozenset(),
188-
whitelist=frozenset(),
189194
):
190195
"""
191196
Search nonbasic ops which have be registered composite rules and replace them with primitive ops.
@@ -208,6 +213,9 @@ def decompose_(
208213
Returns:
209214
dst_vars (list): A list contains all vars which replace origin ones in src_vars.
210215
"""
216+
if core._enable_sink_decomp():
217+
blacklist = core.prim_config["forward_blacklist"] | blacklist
218+
return core.decomp_tmp(program, src_vars, blacklist, whitelist)
211219
if not core._is_fwd_prim_enabled():
212220
return src_vars
213221
if not isinstance(program, Program):

test/prim/pir_prim/test_prim_dynamic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,18 @@ def base_net(self, flag=None):
6565
return outs
6666

6767
def test_prim_all_dynamic(self):
68-
# os.environ["FLAGS_prim_skip_dynamic"] = "1"
6968
res_ref = self.base_net()
7069
res = self.base_net("all")
7170
for ref, actual in zip(res_ref, res):
7271
np.testing.assert_allclose(ref, actual, rtol=1e-6)
7372

73+
def test_aprim_all_dynamic_sink(self):
74+
with decomp.sink_decomp_guard():
75+
res_ref = self.base_net()
76+
res = self.base_net("all")
77+
for ref, actual in zip(res_ref, res):
78+
np.testing.assert_allclose(ref, actual, rtol=1e-6)
79+
7480

7581
if __name__ == "__main__":
7682
unittest.main()

0 commit comments

Comments
 (0)