Skip to content
4 changes: 2 additions & 2 deletions python/paddle/distribution/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

import paddle
from paddle.base.data_feeder import check_variable_and_dtype
from paddle.base.framework import in_dynamic_or_pir_mode
from paddle.base.layer_helper import LayerHelper
from paddle.distribution import exponential_family
from paddle.framework import in_dynamic_mode


class Dirichlet(exponential_family.ExponentialFamily):
Expand Down Expand Up @@ -156,7 +156,7 @@ def _log_normalizer(self, x):


def _dirichlet(concentration, name=None):
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return paddle._C_ops.dirichlet(concentration)
else:
op_type = 'dirichlet'
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import paddle
import paddle.nn.functional as F
from paddle import _C_ops, in_dynamic_mode
from paddle.base.framework import in_dynamic_or_pir_mode
from paddle.base.layer_helper import LayerHelper
from paddle.base.wrapped_decorator import signature_safe_contextmanager

Expand Down Expand Up @@ -225,7 +226,7 @@ def flash_attention(
sdp_func_name = _select_sdp(head_dim)

if sdp_func_name == "flash_attn":
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
(result_attention, result_softmax, _, _) = _C_ops.flash_attn(
query,
key,
Expand Down
3 changes: 3 additions & 0 deletions test/distribution/test_dirichlet_op.py

This comment was marked as off-topic.

This comment was marked as resolved.

Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def setUp(self):
self.sample_shape = (10000, 2)
self.dtype = np.uint16
self.np_dtype = np.float32
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample(
self.sample_shape
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample(
self.sample_shape
)

先复原对 test/distribution/test_dirichlet_op.py 文件的改动,并在 pr 描述里说明一下未适配该单测~


self.inputs = {
'Alpha': np.broadcast_to(self.alpha, self.sample_shape).astype(
Expand Down
111 changes: 111 additions & 0 deletions test/distribution/test_distribution_dirichlet_static.py

This comment was marked as resolved.

Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import scipy.stats
from distribution.config import ATOL, DEVICES, RTOL
from parameterize import TEST_CASE_NAME, parameterize_cls, place
from paddle.pir_utils import test_with_pir_api

import paddle

Expand Down Expand Up @@ -120,3 +121,113 @@ def test_entropy(self):
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)


@place(DEVICES)
@parameterize_cls(
(TEST_CASE_NAME, 'concentration'),
[('test-one-dim', np.random.rand(89) + 5.0)],
)
class TestDirichletPir(unittest.TestCase):
def setUp(self):
with paddle.pir_utils.IrGuard():
self.program = paddle.static.Program()
self.executor = paddle.static.Executor()
with paddle.static.program_guard(self.program):
conc = paddle.static.data(
'conc', self.concentration.shape, self.concentration.dtype
)
self._paddle_diric = paddle.distribution.Dirichlet(conc)
self.feeds = {'conc': self.concentration}

@test_with_pir_api
def test_mean(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.mean],
)
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.mean(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@test_with_pir_api
def test_mean(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.mean],
)
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.mean(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)
def test_mean(self):
with paddle.pir_utils.IrGuard():
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.mean],
)
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.mean(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)

此处不宜使用 @test_with_pir_api,被 @test_with_pir_api 修饰的单测会在旧 ir 模式和 pir 模式下分别运行一次。因为该单测为 pir 单测,所以我们只需要在 Pir 模式下运行就好了。使用 paddle.pir_utils.IrGuard() 作为上下文管理器可以切换到 pir 模式进行组网


@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

def test_variance(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.variance],
)
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.var(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

def test_prob(self):
with paddle.static.program_guard(self.program):
random_number = np.random.rand(*self.concentration.shape)
random_number = random_number / random_number.sum()
feeds = dict(self.feeds, value=random_number)
value = paddle.static.data(
'value', random_number.shape, random_number.dtype
)
out = self._paddle_diric.prob(value)
[out] = self.executor.run(
self.program, feed=feeds, fetch_list=[out]
)
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.pdf(random_number, self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

def test_log_prob(self):
with paddle.static.program_guard(self.program):
random_number = np.random.rand(*self.concentration.shape)
random_number = random_number / random_number.sum()
feeds = dict(self.feeds, value=random_number)
value = paddle.static.data(
'value', random_number.shape, random_number.dtype
)
out = self._paddle_diric.log_prob(value)
[out] = self.executor.run(
self.program, feed=feeds, fetch_list=[out]
)
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.logpdf(random_number, self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

def test_entropy(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
self.program,
feed=self.feeds,
fetch_list=[self._paddle_diric.entropy()],
)
np.testing.assert_allclose(
out,
scipy.stats.dirichlet.entropy(self.concentration),
rtol=RTOL.get(str(self.concentration.dtype)),
atol=ATOL.get(str(self.concentration.dtype)),
)

def test_all(self):
self._test_mean(self.program)
self._test_variance(self.program)
self._test_prob(self.program)
self._test_log_prob(self.program)
self._test_entropy(self.program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_all(self):
self._test_mean(self.program)
self._test_variance(self.program)
self._test_prob(self.program)
self._test_log_prob(self.program)
self._test_entropy(self.program)

该单测可以删除了,因为当前不存在 _test_mean 等函数

7 changes: 4 additions & 3 deletions test/legacy_test/test_activation_op.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该文件下还有 TestELU 单测,以及被 create_test_act_bf16_class 和 create_test_act_fp16_class 创建的 bf16 和 fp16 的 TestELU 单测遗漏了~ 麻烦一起适配一下吧

Original file line number Diff line number Diff line change
Expand Up @@ -3187,7 +3187,7 @@ def init_shape(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)

def get_alpha(self):
return 1.0
Expand Down Expand Up @@ -3218,6 +3218,7 @@ def setUp(self):
def executed_api(self):
self.elu = F.elu

@test_with_pir_api
def test_static_api(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -4896,7 +4897,7 @@ def test_check_grad(self):
create_test_act_fp16_class(TestBRelu, check_pir=True)
create_test_act_fp16_class(TestRelu6)
create_test_act_fp16_class(TestSoftRelu, check_dygraph=False)
create_test_act_fp16_class(TestELU)
create_test_act_fp16_class(TestELU, check_pir=True)
create_test_act_fp16_class(TestCELU)
create_test_act_fp16_class(TestReciprocal, check_pir=True)
create_test_act_fp16_class(TestLog, check_prim=True, check_pir=True)
Expand Down Expand Up @@ -5055,7 +5056,7 @@ def test_check_grad(self):
create_test_act_bf16_class(TestBRelu, check_pir=True)
create_test_act_bf16_class(TestRelu6)
create_test_act_bf16_class(TestSoftRelu, check_dygraph=False)
create_test_act_bf16_class(TestELU)
create_test_act_bf16_class(TestELU, check_pir=True)
create_test_act_bf16_class(TestCELU)
create_test_act_bf16_class(TestReciprocal, check_pir=True)
create_test_act_bf16_class(TestLog, check_prim=True, check_pir=True)
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_eigvals_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

np.set_printoptions(threshold=np.inf)

Expand Down Expand Up @@ -271,6 +272,7 @@ def run_dygraph(self, place):
np_outs = np_eigvals(self.batch_input)
self.verify_output(paddle_outs, np_outs)

@test_with_pir_api
def run_static(self, place):
paddle.enable_static()
with paddle.static.program_guard(
Expand Down
34 changes: 26 additions & 8 deletions test/legacy_test/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
flash_attn_unpadded,
scaled_dot_product_attention,
)
from paddle.pir_utils import test_with_pir_api

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")

Expand Down Expand Up @@ -198,15 +199,10 @@ def test_unpadded(self):
fetch_list=[outs],
)

np.testing.assert_allclose(
fetches_result[0], out_, rtol=5e-03, atol=1e-03
)

def test_all(self):
def test_dynamic_all(self):
print(
f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}"
f"Test dynamic case shape {self.shape} dtype {self.dtype} causal {self.causal}"
)
# test dynamic
paddle.disable_static()

query = np.random.random(self.shape)
Expand Down Expand Up @@ -266,9 +262,31 @@ def test_all(self):
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03
)

# test static
@test_with_pir_api
def test_static_all(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要使用 @test_with_pir_api 装饰

print(
f"Test static case shape {self.shape} dtype {self.dtype} causal {self.causal}"
)
paddle.enable_static()

query = np.random.random(self.shape)
key = np.random.random(self.shape)
value = np.random.random(self.shape)

q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k_ = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v_ = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)

out_ = attention_naive(q_, k_, v_, self.causal)

out_.backward()

with paddle.static.program_guard(paddle.static.Program()):
qs = paddle.static.data(
name="q", shape=self.shape, dtype=self.dtype
Expand Down