Skip to content
3 changes: 2 additions & 1 deletion python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import paddle
import paddle.nn.functional as F
from paddle import _C_ops, in_dynamic_mode
from paddle.base.framework import in_dynamic_or_pir_mode
from paddle.base.layer_helper import LayerHelper
from paddle.base.wrapped_decorator import signature_safe_contextmanager

Expand Down Expand Up @@ -221,7 +222,7 @@ def flash_attention(
sdp_func_name = _select_sdp(head_dim)

if sdp_func_name == "flash_attn":
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
(
result_attention,
result_softmax,
Expand Down
15 changes: 12 additions & 3 deletions test/distribution/test_dirichlet_op.py

This comment was marked as off-topic.

This comment was marked as resolved.

Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ def setUp(self):
self.op_type = "dirichlet"
self.alpha = np.array((1.0, 2.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python_api 需要是一个 callable 对象,其调用的结果是返回结果。比方说这里的 self.python_api = paddle.distribution.Dirichlet,如果调用只是实例化了一个 Dirichlet 对象;而 self.python_api = paddle.distribution.Dirichlet(...).sample,这个 self.python_api 被调用才是返回一个结果。这里单测的结果需要和 self._hypothesis_testing 作比较,所以 self.python_api 需要根据 self._hypothesis_testing 相应的计算类型进行对应的设置

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测的适配先跳过吧,可以在 pr 描述里补充说明一下~ 我暂时还没有想到好的办法去添加这个 self.python_api

self.sample_shape = (100000, 2)
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample(
self.sample_shape
)

self.inputs = {'Alpha': np.broadcast_to(self.alpha, self.sample_shape)}
self.attrs = {}
self.outputs = {'Out': np.zeros(self.sample_shape)}

def test_check_output(self):
self.check_output_customized(self._hypothesis_testing)
self.check_output_customized(self._hypothesis_testing, check_pir=True)

def _hypothesis_testing(self, outs):
self.assertEqual(outs[0].shape, self.sample_shape)
Expand All @@ -63,6 +66,9 @@ def setUp(self):
self.alpha = np.array((1.0, 2.0))
self.sample_shape = (100000, 2)
self.dtype = np.float16
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample(
self.sample_shape
)

self.inputs = {
'Alpha': np.broadcast_to(self.alpha, self.sample_shape).astype(
Expand All @@ -73,7 +79,7 @@ def setUp(self):
self.outputs = {'Out': np.zeros(self.sample_shape).astype(self.dtype)}

def test_check_output(self):
self.check_output_customized(self._hypothesis_testing)
self.check_output_customized(self._hypothesis_testing, check_pir=True)

def _hypothesis_testing(self, outs):
self.assertEqual(outs[0].shape, self.sample_shape)
Expand Down Expand Up @@ -103,6 +109,9 @@ def setUp(self):
self.sample_shape = (10000, 2)
self.dtype = np.uint16
self.np_dtype = np.float32
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample(
self.sample_shape
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample(
self.sample_shape
)

先复原对 test/distribution/test_dirichlet_op.py 文件的改动,并在 pr 描述里说明一下未适配该单测~


self.inputs = {
'Alpha': np.broadcast_to(self.alpha, self.sample_shape).astype(
Expand All @@ -119,7 +128,7 @@ def setUp(self):

def test_check_output(self):
self.check_output_with_place_customized(
self._hypothesis_testing, place=core.CUDAPlace(0)
self._hypothesis_testing, place=core.CUDAPlace(0), check_pir=True
)

def _hypothesis_testing(self, outs):
Expand Down
6 changes: 6 additions & 0 deletions test/distribution/test_distribution_dirichlet_static.py

This comment was marked as resolved.

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from parameterize import TEST_CASE_NAME, parameterize_cls, place

import paddle
from paddle.pir_utils import test_with_pir_api

np.random.seed(2022)
paddle.enable_static()
Expand All @@ -41,6 +42,7 @@ def setUp(self):
self._paddle_diric = paddle.distribution.Dirichlet(conc)
self.feeds = {'conc': self.concentration}

@test_with_pir_api
def test_mean(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
Expand All @@ -55,6 +57,7 @@ def test_mean(self):
atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

def test_variance(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
Expand All @@ -69,6 +72,7 @@ def test_variance(self):
atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

def test_prob(self):
with paddle.static.program_guard(self.program):
random_number = np.random.rand(*self.concentration.shape)
Expand All @@ -88,6 +92,7 @@ def test_prob(self):
atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

def test_log_prob(self):
with paddle.static.program_guard(self.program):
random_number = np.random.rand(*self.concentration.shape)
Expand All @@ -107,6 +112,7 @@ def test_log_prob(self):
atol=ATOL.get(str(self.concentration.dtype)),
)

@test_with_pir_api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,需要修改

def test_entropy(self):
with paddle.static.program_guard(self.program):
[out] = self.executor.run(
Expand Down
7 changes: 4 additions & 3 deletions test/legacy_test/test_activation_op.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该文件下还有 TestELU 单测,以及被 create_test_act_bf16_class 和 create_test_act_fp16_class 创建的 bf16 和 fp16 的 TestELU 单测遗漏了~ 麻烦一起适配一下吧

Original file line number Diff line number Diff line change
Expand Up @@ -3170,7 +3170,7 @@ def init_shape(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)

def get_alpha(self):
return 1.0
Expand Down Expand Up @@ -3201,6 +3201,7 @@ def setUp(self):
def executed_api(self):
self.elu = F.elu

@test_with_pir_api
def test_static_api(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -4876,7 +4877,7 @@ def test_check_grad(self):
create_test_act_fp16_class(TestBRelu, check_pir=True)
create_test_act_fp16_class(TestRelu6)
create_test_act_fp16_class(TestSoftRelu, check_dygraph=False)
create_test_act_fp16_class(TestELU)
create_test_act_fp16_class(TestELU, check_pir=True)
create_test_act_fp16_class(TestCELU)
create_test_act_fp16_class(TestReciprocal, check_pir=True)
create_test_act_fp16_class(TestLog, check_prim=True, check_pir=True)
Expand Down Expand Up @@ -5031,7 +5032,7 @@ def test_check_grad(self):
create_test_act_bf16_class(TestBRelu, check_pir=True)
create_test_act_bf16_class(TestRelu6)
create_test_act_bf16_class(TestSoftRelu, check_dygraph=False)
create_test_act_bf16_class(TestELU)
create_test_act_bf16_class(TestELU, check_pir=True)
create_test_act_bf16_class(TestCELU)
create_test_act_bf16_class(TestReciprocal, check_pir=True)
create_test_act_bf16_class(TestLog, check_prim=True, check_pir=True)
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_eigvals_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api

np.set_printoptions(threshold=np.inf)

Expand Down Expand Up @@ -271,6 +272,7 @@ def run_dygraph(self, place):
np_outs = np_eigvals(self.batch_input)
self.verify_output(paddle_outs, np_outs)

@test_with_pir_api
def run_static(self, place):
paddle.enable_static()
with paddle.static.program_guard(
Expand Down
34 changes: 26 additions & 8 deletions test/legacy_test/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
flash_attn_unpadded,
scaled_dot_product_attention,
)
from paddle.pir_utils import test_with_pir_api


def get_cuda_version():
Expand Down Expand Up @@ -187,15 +188,10 @@ def test_unpadded(self):
fetch_list=[outs],
)

np.testing.assert_allclose(
fetches_result[0], out_, rtol=5e-03, atol=1e-03
)

def test_all(self):
def test_dynamic_all(self):
print(
f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}"
f"Test dynamic case shape {self.shape} dtype {self.dtype} causal {self.causal}"
)
# test dynamic
paddle.disable_static()

query = np.random.random(self.shape)
Expand Down Expand Up @@ -255,9 +251,31 @@ def test_all(self):
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03
)

# test static
@test_with_pir_api
def test_static_all(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要使用 @test_with_pir_api 装饰

print(
f"Test static case shape {self.shape} dtype {self.dtype} causal {self.causal}"
)
paddle.enable_static()

query = np.random.random(self.shape)
key = np.random.random(self.shape)
value = np.random.random(self.shape)

q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
k_ = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
v_ = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)

out_ = attention_naive(q_, k_, v_, self.causal)

out_.backward()

with paddle.static.program_guard(paddle.static.Program()):
qs = paddle.static.data(
name="q", shape=self.shape, dtype=self.dtype
Expand Down