diff --git a/python/paddle/distribution/dirichlet.py b/python/paddle/distribution/dirichlet.py index cf578c9d0dd5c9..e3a2ec562a2b1c 100644 --- a/python/paddle/distribution/dirichlet.py +++ b/python/paddle/distribution/dirichlet.py @@ -14,9 +14,9 @@ import paddle from paddle.base.data_feeder import check_variable_and_dtype +from paddle.base.framework import in_dynamic_or_pir_mode from paddle.base.layer_helper import LayerHelper from paddle.distribution import exponential_family -from paddle.framework import in_dynamic_mode class Dirichlet(exponential_family.ExponentialFamily): @@ -156,7 +156,7 @@ def _log_normalizer(self, x): def _dirichlet(concentration, name=None): - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return paddle._C_ops.dirichlet(concentration) else: op_type = 'dirichlet' diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 67407fb8cd380e..8cd3f46e7124fb 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -15,6 +15,7 @@ import paddle import paddle.nn.functional as F from paddle import _C_ops, in_dynamic_mode +from paddle.base.framework import in_dynamic_or_pir_mode from paddle.base.layer_helper import LayerHelper from paddle.base.wrapped_decorator import signature_safe_contextmanager @@ -225,7 +226,7 @@ def flash_attention( sdp_func_name = _select_sdp(head_dim) if sdp_func_name == "flash_attn": - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): (result_attention, result_softmax, _, _) = _C_ops.flash_attn( query, key, diff --git a/test/distribution/test_distribution_dirichlet_static.py b/test/distribution/test_distribution_dirichlet_static.py index a4ba2249ee707a..a66d7f8afaaf16 100644 --- a/test/distribution/test_distribution_dirichlet_static.py +++ b/test/distribution/test_distribution_dirichlet_static.py @@ -120,3 +120,110 @@ def test_entropy(self): rtol=RTOL.get(str(self.concentration.dtype)), atol=ATOL.get(str(self.concentration.dtype)), ) + + +@place(DEVICES) +@parameterize_cls( + (TEST_CASE_NAME, 'concentration'), + [('test-one-dim', np.random.rand(89) + 5.0)], +) +class TestDirichletPir(unittest.TestCase): + def setUp(self): + with paddle.pir_utils.IrGuard(): + self.program = paddle.static.Program() + self.executor = paddle.static.Executor() + with paddle.static.program_guard(self.program): + conc = paddle.static.data( + 'conc', self.concentration.shape, self.concentration.dtype + ) + self._paddle_diric = paddle.distribution.Dirichlet(conc) + self.feeds = {'conc': self.concentration} + + def test_mean(self): + with paddle.pir_utils.IrGuard(): + with paddle.static.program_guard(self.program): + [out] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=[self._paddle_diric.mean], + ) + np.testing.assert_allclose( + out, + scipy.stats.dirichlet.mean(self.concentration), + rtol=RTOL.get(str(self.concentration.dtype)), + atol=ATOL.get(str(self.concentration.dtype)), + ) + + def test_variance(self): + with paddle.pir_utils.IrGuard(): + with paddle.static.program_guard(self.program): + [out] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=[self._paddle_diric.variance], + ) + np.testing.assert_allclose( + out, + scipy.stats.dirichlet.var(self.concentration), + rtol=RTOL.get(str(self.concentration.dtype)), + atol=ATOL.get(str(self.concentration.dtype)), + ) + + def test_prob(self): + with paddle.pir_utils.IrGuard(): + with paddle.static.program_guard(self.program): + random_number = np.random.rand(*self.concentration.shape) + random_number = random_number / random_number.sum() + feeds = dict(self.feeds, value=random_number) + value = paddle.static.data( + 'value', random_number.shape, random_number.dtype + ) + out = self._paddle_diric.prob(value) + [out] = self.executor.run( + self.program, feed=feeds, fetch_list=[out] + ) + np.testing.assert_allclose( + out, + scipy.stats.dirichlet.pdf( + random_number, self.concentration + ), + rtol=RTOL.get(str(self.concentration.dtype)), + atol=ATOL.get(str(self.concentration.dtype)), + ) + + def test_log_prob(self): + with paddle.pir_utils.IrGuard(): + with paddle.static.program_guard(self.program): + random_number = np.random.rand(*self.concentration.shape) + random_number = random_number / random_number.sum() + feeds = dict(self.feeds, value=random_number) + value = paddle.static.data( + 'value', random_number.shape, random_number.dtype + ) + out = self._paddle_diric.log_prob(value) + [out] = self.executor.run( + self.program, feed=feeds, fetch_list=[out] + ) + np.testing.assert_allclose( + out, + scipy.stats.dirichlet.logpdf( + random_number, self.concentration + ), + rtol=RTOL.get(str(self.concentration.dtype)), + atol=ATOL.get(str(self.concentration.dtype)), + ) + + def test_entropy(self): + with paddle.pir_utils.IrGuard(): + with paddle.static.program_guard(self.program): + [out] = self.executor.run( + self.program, + feed=self.feeds, + fetch_list=[self._paddle_diric.entropy()], + ) + np.testing.assert_allclose( + out, + scipy.stats.dirichlet.entropy(self.concentration), + rtol=RTOL.get(str(self.concentration.dtype)), + atol=ATOL.get(str(self.concentration.dtype)), + ) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 6b5c24b7101aed..d06bb133437509 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -3209,7 +3209,7 @@ def init_shape(self): def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_pir=True) def get_alpha(self): return 1.0 @@ -3240,6 +3240,7 @@ def setUp(self): def executed_api(self): self.elu = F.elu + @test_with_pir_api def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): @@ -4975,7 +4976,7 @@ def test_check_grad(self): create_test_act_fp16_class(TestBRelu, check_pir=True) create_test_act_fp16_class(TestRelu6) create_test_act_fp16_class(TestSoftRelu, check_dygraph=False) -create_test_act_fp16_class(TestELU) +create_test_act_fp16_class(TestELU, check_pir=True) create_test_act_fp16_class(TestCELU, check_pir=True) create_test_act_fp16_class(TestReciprocal, check_pir=True) create_test_act_fp16_class(TestLog, check_prim=True, check_pir=True) @@ -5144,7 +5145,7 @@ def test_check_grad(self): create_test_act_bf16_class(TestBRelu, check_pir=True) create_test_act_bf16_class(TestRelu6) create_test_act_bf16_class(TestSoftRelu, check_dygraph=False) -create_test_act_bf16_class(TestELU) +create_test_act_bf16_class(TestELU, check_pir=True) create_test_act_bf16_class(TestCELU, check_pir=True) create_test_act_bf16_class(TestReciprocal, check_pir=True) create_test_act_bf16_class(TestLog, check_prim=True, check_pir=True) diff --git a/test/legacy_test/test_eigvals_op.py b/test/legacy_test/test_eigvals_op.py index c54a4070be3a44..6e28cc736d418f 100644 --- a/test/legacy_test/test_eigvals_op.py +++ b/test/legacy_test/test_eigvals_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api np.set_printoptions(threshold=np.inf) @@ -271,6 +272,7 @@ def run_dygraph(self, place): np_outs = np_eigvals(self.batch_input) self.verify_output(paddle_outs, np_outs) + @test_with_pir_api def run_static(self, place): paddle.enable_static() with paddle.static.program_guard( diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index b74d4a87af20e2..f9ce2c3fb8213b 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -28,6 +28,7 @@ flash_attn_unpadded, scaled_dot_product_attention, ) +from paddle.pir_utils import test_with_pir_api logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") @@ -198,15 +199,10 @@ def test_unpadded(self): fetch_list=[outs], ) - np.testing.assert_allclose( - fetches_result[0], out_, rtol=5e-03, atol=1e-03 - ) - - def test_all(self): + def test_dynamic_all(self): print( - f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}" + f"Test dynamic case shape {self.shape} dtype {self.dtype} causal {self.causal}" ) - # test dynamic paddle.disable_static() query = np.random.random(self.shape) @@ -266,9 +262,31 @@ def test_all(self): q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 ) - # test static + @test_with_pir_api + def test_static_all(self): + print( + f"Test static case shape {self.shape} dtype {self.dtype} causal {self.causal}" + ) paddle.enable_static() + query = np.random.random(self.shape) + key = np.random.random(self.shape) + value = np.random.random(self.shape) + + q_ = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + k_ = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + v_ = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + out_ = attention_naive(q_, k_, v_, self.causal) + + out_.backward() + with paddle.static.program_guard(paddle.static.Program()): qs = paddle.static.data( name="q", shape=self.shape, dtype=self.dtype