Skip to content

Commit 8fca8f9

Browse files
committed
fix bug
1 parent 737a127 commit 8fca8f9

File tree

2 files changed

+79
-86
lines changed

2 files changed

+79
-86
lines changed

test/distribution/test_dirichlet_op.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,6 @@ def setUp(self):
103103
self.sample_shape = (10000, 2)
104104
self.dtype = np.uint16
105105
self.np_dtype = np.float32
106-
self.python_api = paddle.distribution.Dirichlet(self.alpha).sample(
107-
self.sample_shape
108-
)
109106

110107
self.inputs = {
111108
'Alpha': np.broadcast_to(self.alpha, self.sample_shape).astype(

test/distribution/test_distribution_dirichlet_static.py

Lines changed: 79 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import scipy.stats
1919
from distribution.config import ATOL, DEVICES, RTOL
2020
from parameterize import TEST_CASE_NAME, parameterize_cls, place
21-
from paddle.pir_utils import test_with_pir_api
2221

2322
import paddle
2423

@@ -140,94 +139,91 @@ def setUp(self):
140139
self._paddle_diric = paddle.distribution.Dirichlet(conc)
141140
self.feeds = {'conc': self.concentration}
142141

143-
@test_with_pir_api
144142
def test_mean(self):
145-
with paddle.static.program_guard(self.program):
146-
[out] = self.executor.run(
147-
self.program,
148-
feed=self.feeds,
149-
fetch_list=[self._paddle_diric.mean],
150-
)
151-
np.testing.assert_allclose(
152-
out,
153-
scipy.stats.dirichlet.mean(self.concentration),
154-
rtol=RTOL.get(str(self.concentration.dtype)),
155-
atol=ATOL.get(str(self.concentration.dtype)),
156-
)
143+
with paddle.pir_utils.IrGuard():
144+
with paddle.static.program_guard(self.program):
145+
[out] = self.executor.run(
146+
self.program,
147+
feed=self.feeds,
148+
fetch_list=[self._paddle_diric.mean],
149+
)
150+
np.testing.assert_allclose(
151+
out,
152+
scipy.stats.dirichlet.mean(self.concentration),
153+
rtol=RTOL.get(str(self.concentration.dtype)),
154+
atol=ATOL.get(str(self.concentration.dtype)),
155+
)
157156

158-
@test_with_pir_api
159157
def test_variance(self):
160-
with paddle.static.program_guard(self.program):
161-
[out] = self.executor.run(
162-
self.program,
163-
feed=self.feeds,
164-
fetch_list=[self._paddle_diric.variance],
165-
)
166-
np.testing.assert_allclose(
167-
out,
168-
scipy.stats.dirichlet.var(self.concentration),
169-
rtol=RTOL.get(str(self.concentration.dtype)),
170-
atol=ATOL.get(str(self.concentration.dtype)),
171-
)
158+
with paddle.pir_utils.IrGuard():
159+
with paddle.static.program_guard(self.program):
160+
[out] = self.executor.run(
161+
self.program,
162+
feed=self.feeds,
163+
fetch_list=[self._paddle_diric.variance],
164+
)
165+
np.testing.assert_allclose(
166+
out,
167+
scipy.stats.dirichlet.var(self.concentration),
168+
rtol=RTOL.get(str(self.concentration.dtype)),
169+
atol=ATOL.get(str(self.concentration.dtype)),
170+
)
172171

173-
@test_with_pir_api
174172
def test_prob(self):
175-
with paddle.static.program_guard(self.program):
176-
random_number = np.random.rand(*self.concentration.shape)
177-
random_number = random_number / random_number.sum()
178-
feeds = dict(self.feeds, value=random_number)
179-
value = paddle.static.data(
180-
'value', random_number.shape, random_number.dtype
181-
)
182-
out = self._paddle_diric.prob(value)
183-
[out] = self.executor.run(
184-
self.program, feed=feeds, fetch_list=[out]
185-
)
186-
np.testing.assert_allclose(
187-
out,
188-
scipy.stats.dirichlet.pdf(random_number, self.concentration),
189-
rtol=RTOL.get(str(self.concentration.dtype)),
190-
atol=ATOL.get(str(self.concentration.dtype)),
191-
)
173+
with paddle.pir_utils.IrGuard():
174+
with paddle.static.program_guard(self.program):
175+
random_number = np.random.rand(*self.concentration.shape)
176+
random_number = random_number / random_number.sum()
177+
feeds = dict(self.feeds, value=random_number)
178+
value = paddle.static.data(
179+
'value', random_number.shape, random_number.dtype
180+
)
181+
out = self._paddle_diric.prob(value)
182+
[out] = self.executor.run(
183+
self.program, feed=feeds, fetch_list=[out]
184+
)
185+
np.testing.assert_allclose(
186+
out,
187+
scipy.stats.dirichlet.pdf(
188+
random_number, self.concentration
189+
),
190+
rtol=RTOL.get(str(self.concentration.dtype)),
191+
atol=ATOL.get(str(self.concentration.dtype)),
192+
)
192193

193-
@test_with_pir_api
194194
def test_log_prob(self):
195-
with paddle.static.program_guard(self.program):
196-
random_number = np.random.rand(*self.concentration.shape)
197-
random_number = random_number / random_number.sum()
198-
feeds = dict(self.feeds, value=random_number)
199-
value = paddle.static.data(
200-
'value', random_number.shape, random_number.dtype
201-
)
202-
out = self._paddle_diric.log_prob(value)
203-
[out] = self.executor.run(
204-
self.program, feed=feeds, fetch_list=[out]
205-
)
206-
np.testing.assert_allclose(
207-
out,
208-
scipy.stats.dirichlet.logpdf(random_number, self.concentration),
209-
rtol=RTOL.get(str(self.concentration.dtype)),
210-
atol=ATOL.get(str(self.concentration.dtype)),
211-
)
195+
with paddle.pir_utils.IrGuard():
196+
with paddle.static.program_guard(self.program):
197+
random_number = np.random.rand(*self.concentration.shape)
198+
random_number = random_number / random_number.sum()
199+
feeds = dict(self.feeds, value=random_number)
200+
value = paddle.static.data(
201+
'value', random_number.shape, random_number.dtype
202+
)
203+
out = self._paddle_diric.log_prob(value)
204+
[out] = self.executor.run(
205+
self.program, feed=feeds, fetch_list=[out]
206+
)
207+
np.testing.assert_allclose(
208+
out,
209+
scipy.stats.dirichlet.logpdf(
210+
random_number, self.concentration
211+
),
212+
rtol=RTOL.get(str(self.concentration.dtype)),
213+
atol=ATOL.get(str(self.concentration.dtype)),
214+
)
212215

213-
@test_with_pir_api
214216
def test_entropy(self):
215-
with paddle.static.program_guard(self.program):
216-
[out] = self.executor.run(
217-
self.program,
218-
feed=self.feeds,
219-
fetch_list=[self._paddle_diric.entropy()],
220-
)
221-
np.testing.assert_allclose(
222-
out,
223-
scipy.stats.dirichlet.entropy(self.concentration),
224-
rtol=RTOL.get(str(self.concentration.dtype)),
225-
atol=ATOL.get(str(self.concentration.dtype)),
226-
)
227-
228-
def test_all(self):
229-
self._test_mean(self.program)
230-
self._test_variance(self.program)
231-
self._test_prob(self.program)
232-
self._test_log_prob(self.program)
233-
self._test_entropy(self.program)
217+
with paddle.pir_utils.IrGuard():
218+
with paddle.static.program_guard(self.program):
219+
[out] = self.executor.run(
220+
self.program,
221+
feed=self.feeds,
222+
fetch_list=[self._paddle_diric.entropy()],
223+
)
224+
np.testing.assert_allclose(
225+
out,
226+
scipy.stats.dirichlet.entropy(self.concentration),
227+
rtol=RTOL.get(str(self.concentration.dtype)),
228+
atol=ATOL.get(str(self.concentration.dtype)),
229+
)

0 commit comments

Comments
 (0)