Skip to content

Commit 8e14302

Browse files
committed
add multinomial python api
1 parent c01c4e1 commit 8e14302

File tree

4 files changed

+64
-0
lines changed

4 files changed

+64
-0
lines changed

python/paddle/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@
201201
from .tensor.math import isinf #DEFINE_ALIAS
202202
from .tensor.math import isnan #DEFINE_ALIAS
203203
from .tensor.math import prod #DEFINE_ALIAS
204+
from .tensor.random import multinomial #DEFINE_ALIAS
204205
from .tensor.random import standard_normal
205206
from .tensor.random import normal
206207
from .tensor.random import uniform #DEFINE_ALIAS

python/paddle/fluid/tests/unittests/test_multinomial_op.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,5 +102,25 @@ def init_data(self):
102102
self.attrs = {"num_samples": 10, "replacement": False}
103103
"""
104104

105+
106+
class TestMultinomialApi(unittest.TestCase):
107+
def test_dygraph(self):
108+
paddle.disable_static()
109+
x = paddle.rand([4])
110+
out = paddle.multinomial(x, num_samples=100000, replacement=True)
111+
x_numpy = x.numpy()
112+
paddle.enable_static()
113+
114+
sample_prob = np.unique(
115+
out.numpy(), return_counts=True)[1].astype("float32")
116+
sample_prob /= sample_prob.sum()
117+
118+
prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True)
119+
self.assertTrue(
120+
np.allclose(
121+
sample_prob, prob, rtol=0, atol=0.01),
122+
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
123+
124+
105125
if __name__ == "__main__":
106126
unittest.main()

python/paddle/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@
166166
from .math import isinf #DEFINE_ALIAS
167167
from .math import isnan #DEFINE_ALIAS
168168
from .math import prod #DEFINE_ALIAS
169+
from .random import multinomial #DEFINE_ALIAS
169170
from .random import standard_normal
170171
from .random import normal
171172
from .random import uniform #DEFINE_ALIAS

python/paddle/tensor/random.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
__all__ = [
2525
'bernoulli',
26+
'multinomial',
2627
'standard_normal',
2728
'normal',
2829
'uniform',
@@ -85,6 +86,47 @@ def bernoulli(x, name=None):
8586
return out
8687

8788

89+
def multinomial(x, num_samples=1, replacement=False, name=None):
90+
"""
91+
92+
93+
Examples:
94+
.. code-block:: python
95+
96+
import paddle
97+
98+
paddle.disable_static()
99+
100+
x = paddle.rand([2, 3])
101+
print(x.numpy())
102+
# [[0.11272584 0.3890902 0.7730957 ]
103+
# [0.10351662 0.8510418 0.63806665]]
104+
105+
out = paddle.bernoulli(x)
106+
print(out.numpy())
107+
# [[0. 0. 1.]
108+
# [0. 0. 1.]]
109+
110+
"""
111+
112+
if in_dygraph_mode():
113+
return core.ops.multinomial(x, 'num_samples', num_samples,
114+
'replacement', replacement)
115+
116+
check_variable_and_dtype(x, "x", ["float32", "float64"], "multinomial")
117+
118+
helper = LayerHelper("multinomial", **locals())
119+
out = helper.create_variable_for_type_inference(
120+
dtype=convert_np_dtype_to_dtype_('int64'))
121+
helper.append_op(
122+
type='multinomial',
123+
inputs={"X": x},
124+
outputs={'Out': out},
125+
attrs={'num_samples': num_samples,
126+
'replacement': replacement})
127+
return out
128+
129+
88130
def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
89131
"""
90132
This OP returns a Tensor filled with random values sampled from a Gaussian

0 commit comments

Comments
 (0)