-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【PIR API adaptor No.58, 62, 64, 70】Migrate some ops into pir #59230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
111150c
e8cbd34
2b54a58
473af57
bf4f8e2
d4919d4
658fe16
ea2667e
534c928
737a127
8fca8f9
33253f2
9711cc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -103,6 +103,9 @@ def setUp(self): | |||||||
| self.sample_shape = (10000, 2) | ||||||||
| self.dtype = np.uint16 | ||||||||
| self.np_dtype = np.float32 | ||||||||
| self.python_api = paddle.distribution.Dirichlet(self.alpha).sample( | ||||||||
| self.sample_shape | ||||||||
| ) | ||||||||
|
||||||||
| self.python_api = paddle.distribution.Dirichlet(self.alpha).sample( | |
| self.sample_shape | |
| ) |
先复原对 test/distribution/test_dirichlet_op.py 文件的改动,并在 pr 描述里说明一下未适配该单测~
This comment was marked as resolved.
Sorry, something went wrong. |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import scipy.stats | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from distribution.config import ATOL, DEVICES, RTOL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from parameterize import TEST_CASE_NAME, parameterize_cls, place | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from paddle.pir_utils import test_with_pir_api | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import paddle | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -120,3 +121,113 @@ def test_entropy(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rtol=RTOL.get(str(self.concentration.dtype)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| atol=ATOL.get(str(self.concentration.dtype)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @place(DEVICES) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @parameterize_cls( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (TEST_CASE_NAME, 'concentration'), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [('test-one-dim', np.random.rand(89) + 5.0)], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class TestDirichletPir(unittest.TestCase): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def setUp(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with paddle.pir_utils.IrGuard(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.program = paddle.static.Program() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.executor = paddle.static.Executor() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with paddle.static.program_guard(self.program): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| conc = paddle.static.data( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 'conc', self.concentration.shape, self.concentration.dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._paddle_diric = paddle.distribution.Dirichlet(conc) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.feeds = {'conc': self.concentration} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @test_with_pir_api | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_mean(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with paddle.static.program_guard(self.program): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [out] = self.executor.run( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.program, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| feed=self.feeds, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fetch_list=[self._paddle_diric.mean], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.testing.assert_allclose( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scipy.stats.dirichlet.mean(self.concentration), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rtol=RTOL.get(str(self.concentration.dtype)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| atol=ATOL.get(str(self.concentration.dtype)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @test_with_pir_api | |
| def test_mean(self): | |
| with paddle.static.program_guard(self.program): | |
| [out] = self.executor.run( | |
| self.program, | |
| feed=self.feeds, | |
| fetch_list=[self._paddle_diric.mean], | |
| ) | |
| np.testing.assert_allclose( | |
| out, | |
| scipy.stats.dirichlet.mean(self.concentration), | |
| rtol=RTOL.get(str(self.concentration.dtype)), | |
| atol=ATOL.get(str(self.concentration.dtype)), | |
| ) | |
| def test_mean(self): | |
| with paddle.pir_utils.IrGuard(): | |
| with paddle.static.program_guard(self.program): | |
| [out] = self.executor.run( | |
| self.program, | |
| feed=self.feeds, | |
| fetch_list=[self._paddle_diric.mean], | |
| ) | |
| np.testing.assert_allclose( | |
| out, | |
| scipy.stats.dirichlet.mean(self.concentration), | |
| rtol=RTOL.get(str(self.concentration.dtype)), | |
| atol=ATOL.get(str(self.concentration.dtype)), | |
| ) |
此处不宜使用 @test_with_pir_api,被 @test_with_pir_api 修饰的单测会在旧 ir 模式和 pir 模式下分别运行一次。因为该单测为 pir 单测,所以我们只需要在 Pir 模式下运行就好了。使用 paddle.pir_utils.IrGuard() 作为上下文管理器可以切换到 pir 模式进行组网
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,需要修改
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,需要修改
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,需要修改
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,需要修改
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def test_all(self): | |
| self._test_mean(self.program) | |
| self._test_variance(self.program) | |
| self._test_prob(self.program) | |
| self._test_log_prob(self.program) | |
| self._test_entropy(self.program) |
该单测可以删除了,因为当前不存在 _test_mean 等函数
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 该文件下还有 TestELU 单测,以及被 create_test_act_bf16_class 和 create_test_act_fp16_class 创建的 bf16 和 fp16 的 TestELU 单测遗漏了~ 麻烦一起适配一下吧 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,6 +28,7 @@ | |
| flash_attn_unpadded, | ||
| scaled_dot_product_attention, | ||
| ) | ||
| from paddle.pir_utils import test_with_pir_api | ||
|
|
||
| logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") | ||
|
|
||
|
|
@@ -198,15 +199,10 @@ def test_unpadded(self): | |
| fetch_list=[outs], | ||
| ) | ||
|
|
||
| np.testing.assert_allclose( | ||
| fetches_result[0], out_, rtol=5e-03, atol=1e-03 | ||
| ) | ||
|
|
||
| def test_all(self): | ||
| def test_dynamic_all(self): | ||
| print( | ||
| f"Test case shape {self.shape} dtype {self.dtype} causal {self.causal}" | ||
| f"Test dynamic case shape {self.shape} dtype {self.dtype} causal {self.causal}" | ||
| ) | ||
| # test dynamic | ||
| paddle.disable_static() | ||
|
|
||
| query = np.random.random(self.shape) | ||
|
|
@@ -266,9 +262,31 @@ def test_all(self): | |
| q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03 | ||
| ) | ||
|
|
||
| # test static | ||
| @test_with_pir_api | ||
| def test_static_all(self): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要使用 |
||
| print( | ||
| f"Test static case shape {self.shape} dtype {self.dtype} causal {self.causal}" | ||
| ) | ||
| paddle.enable_static() | ||
|
|
||
| query = np.random.random(self.shape) | ||
| key = np.random.random(self.shape) | ||
| value = np.random.random(self.shape) | ||
|
|
||
| q_ = paddle.to_tensor( | ||
| query, place=self.place, dtype=self.dtype, stop_gradient=False | ||
| ) | ||
| k_ = paddle.to_tensor( | ||
| key, place=self.place, dtype=self.dtype, stop_gradient=False | ||
| ) | ||
| v_ = paddle.to_tensor( | ||
| value, place=self.place, dtype=self.dtype, stop_gradient=False | ||
| ) | ||
|
|
||
| out_ = attention_naive(q_, k_, v_, self.causal) | ||
|
|
||
| out_.backward() | ||
|
|
||
| with paddle.static.program_guard(paddle.static.Program()): | ||
| qs = paddle.static.data( | ||
| name="q", shape=self.shape, dtype=self.dtype | ||
|
|
||
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.
This comment was marked as resolved.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.