-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【PIR API adaptor No.58, 62, 64, 70】Migrate some ops into pir #59230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
111150c
e8cbd34
2b54a58
473af57
bf4f8e2
d4919d4
658fe16
ea2667e
534c928
737a127
8fca8f9
33253f2
9711cc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -33,13 +33,16 @@ def setUp(self): | |||||||
| self.op_type = "dirichlet" | ||||||||
| self.alpha = np.array((1.0, 2.0)) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. python_api 需要是一个 callable 对象,其调用的结果是返回结果。比方说这里的
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个单测的适配先跳过吧,可以在 pr 描述里补充说明一下~ 我暂时还没有想到好的办法去添加这个 self.python_api |
||||||||
| self.sample_shape = (100000, 2) | ||||||||
| self.python_api = paddle.distribution.Dirichlet(self.alpha).sample( | ||||||||
| self.sample_shape | ||||||||
| ) | ||||||||
|
|
||||||||
| self.inputs = {'Alpha': np.broadcast_to(self.alpha, self.sample_shape)} | ||||||||
| self.attrs = {} | ||||||||
| self.outputs = {'Out': np.zeros(self.sample_shape)} | ||||||||
|
|
||||||||
| def test_check_output(self): | ||||||||
| self.check_output_customized(self._hypothesis_testing) | ||||||||
| self.check_output_customized(self._hypothesis_testing, check_pir=True) | ||||||||
|
|
||||||||
| def _hypothesis_testing(self, outs): | ||||||||
| self.assertEqual(outs[0].shape, self.sample_shape) | ||||||||
|
|
@@ -63,6 +66,9 @@ def setUp(self): | |||||||
| self.alpha = np.array((1.0, 2.0)) | ||||||||
| self.sample_shape = (100000, 2) | ||||||||
| self.dtype = np.float16 | ||||||||
| self.python_api = paddle.distribution.Dirichlet(self.alpha).sample( | ||||||||
| self.sample_shape | ||||||||
| ) | ||||||||
|
|
||||||||
| self.inputs = { | ||||||||
| 'Alpha': np.broadcast_to(self.alpha, self.sample_shape).astype( | ||||||||
|
|
@@ -73,7 +79,7 @@ def setUp(self): | |||||||
| self.outputs = {'Out': np.zeros(self.sample_shape).astype(self.dtype)} | ||||||||
|
|
||||||||
| def test_check_output(self): | ||||||||
| self.check_output_customized(self._hypothesis_testing) | ||||||||
| self.check_output_customized(self._hypothesis_testing, check_pir=True) | ||||||||
|
|
||||||||
| def _hypothesis_testing(self, outs): | ||||||||
| self.assertEqual(outs[0].shape, self.sample_shape) | ||||||||
|
|
@@ -103,6 +109,9 @@ def setUp(self): | |||||||
| self.sample_shape = (10000, 2) | ||||||||
| self.dtype = np.uint16 | ||||||||
| self.np_dtype = np.float32 | ||||||||
| self.python_api = paddle.distribution.Dirichlet(self.alpha).sample( | ||||||||
| self.sample_shape | ||||||||
| ) | ||||||||
|
||||||||
| self.python_api = paddle.distribution.Dirichlet(self.alpha).sample( | |
| self.sample_shape | |
| ) |
先复原对 test/distribution/test_dirichlet_op.py 文件的改动,并在 pr 描述里说明一下未适配该单测~
This comment was marked as resolved.
Sorry, something went wrong. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| from parameterize import TEST_CASE_NAME, parameterize_cls, place | ||
|
|
||
| import paddle | ||
| from paddle.pir_utils import test_with_pir_api | ||
|
|
||
| np.random.seed(2022) | ||
| paddle.enable_static() | ||
|
|
@@ -41,6 +42,7 @@ def setUp(self): | |
| self._paddle_diric = paddle.distribution.Dirichlet(conc) | ||
| self.feeds = {'conc': self.concentration} | ||
|
|
||
| @test_with_pir_api | ||
| def test_mean(self): | ||
| with paddle.static.program_guard(self.program): | ||
| [out] = self.executor.run( | ||
|
|
@@ -55,6 +57,7 @@ def test_mean(self): | |
| atol=ATOL.get(str(self.concentration.dtype)), | ||
| ) | ||
|
|
||
| @test_with_pir_api | ||
|
||
| def test_variance(self): | ||
| with paddle.static.program_guard(self.program): | ||
| [out] = self.executor.run( | ||
|
|
@@ -69,6 +72,7 @@ def test_variance(self): | |
| atol=ATOL.get(str(self.concentration.dtype)), | ||
| ) | ||
|
|
||
| @test_with_pir_api | ||
|
||
| def test_prob(self): | ||
| with paddle.static.program_guard(self.program): | ||
| random_number = np.random.rand(*self.concentration.shape) | ||
|
|
@@ -88,6 +92,7 @@ def test_prob(self): | |
| atol=ATOL.get(str(self.concentration.dtype)), | ||
| ) | ||
|
|
||
| @test_with_pir_api | ||
|
||
| def test_log_prob(self): | ||
| with paddle.static.program_guard(self.program): | ||
| random_number = np.random.rand(*self.concentration.shape) | ||
|
|
@@ -107,6 +112,7 @@ def test_log_prob(self): | |
| atol=ATOL.get(str(self.concentration.dtype)), | ||
| ) | ||
|
|
||
| @test_with_pir_api | ||
|
||
| def test_entropy(self): | ||
| with paddle.static.program_guard(self.program): | ||
| [out] = self.executor.run( | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 该文件下还有 TestELU 单测,以及被 create_test_act_bf16_class 和 create_test_act_fp16_class 创建的 bf16 和 fp16 的 TestELU 单测遗漏了~ 麻烦一起适配一下吧 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| flash_attn_unpadded, | ||
| scaled_dot_product_attention, | ||
| ) | ||
| from paddle.pir_utils import test_with_pir_api | ||
|
|
||
|
|
||
| def get_cuda_version(): | ||
|
|
@@ -187,15 +188,10 @@ def test_unpadded(self): | |
| fetch_list=[outs], | ||
| ) | ||
|
|
||
| np.testing.assert_allclose( | ||
| fetches_result[0], out_, rtol=5e-03, atol=1e-03 | ||
| ) | ||
|
|
||
| def test_all(self): | ||
| def test_dynamic_all(self): | ||
| print( | ||
| f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}" | ||
| f"Test dynamic case shape {self.shape} dtype {self.dtype} causal {self.causal}" | ||
| ) | ||
| # test dynamic | ||
| paddle.disable_static() | ||
|
|
||
| query = np.random.random(self.shape) | ||
|
|
@@ -255,9 +251,31 @@ def test_all(self): | |
| q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 | ||
| ) | ||
|
|
||
| # test static | ||
| @test_with_pir_api | ||
| def test_static_all(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要使用 |
||
| print( | ||
| f"Test static case shape {self.shape} dtype {self.dtype} causal {self.causal}" | ||
| ) | ||
| paddle.enable_static() | ||
|
|
||
| query = np.random.random(self.shape) | ||
| key = np.random.random(self.shape) | ||
| value = np.random.random(self.shape) | ||
|
|
||
| q_ = paddle.to_tensor( | ||
| query, place=self.place, dtype=self.dtype, stop_gradient=False | ||
| ) | ||
| k_ = paddle.to_tensor( | ||
| key, place=self.place, dtype=self.dtype, stop_gradient=False | ||
| ) | ||
| v_ = paddle.to_tensor( | ||
| value, place=self.place, dtype=self.dtype, stop_gradient=False | ||
| ) | ||
|
|
||
| out_ = attention_naive(q_, k_, v_, self.causal) | ||
|
|
||
| out_.backward() | ||
|
|
||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| qs = paddle.static.data( | ||
| name="q", shape=self.shape, dtype=self.dtype | ||
|
|
||
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.
This comment was marked as resolved.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.