Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/fluid/operators/multinomial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,18 @@ class MultinomialOp : public framework::OperatorWithKernel {

auto x_dim = ctx->GetInputDim("X");
int64_t x_rank = x_dim.size();
PADDLE_ENFORCE_EQ(
x_rank > 0 && x_rank <= 2, true,
Copy link
Contributor

@zhiqiu zhiqiu Oct 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use PADDLE_ENFORCE_GT and PADDLE_ENFORCE_LE instead, do not combine two checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

platform::errors::PreconditionNotMet(
"Input probability distribution should be 1 or 2 dimension"));
std::vector<int64_t> out_dims(x_rank);
for (int64_t i = 0; i < x_rank - 1; i++) {
out_dims[i] = x_dim[i];
}

int64_t num_samples = ctx->Attrs().Get<int>("num_samples");
PADDLE_ENFORCE_GT(num_samples, 0, platform::errors::OutOfRange(
"Number of samples should be > 0"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Number of samples should be > 0"));
"The number of samples should be > 0, but got %d.", num_samples ));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

out_dims[x_rank - 1] = num_samples;

ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/operators/multinomial_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/multinomial_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/transform.h"

namespace paddle {
Expand All @@ -31,6 +32,14 @@ __global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) {
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
PADDLE_ENFORCE(in_data[id] >= 0.0,
"The input of multinomial distribution should be >= 0");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above, print the actual data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

PADDLE_ENFORCE(
!std::isinf(static_cast<double>(in_data[id])) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not combine several logical expressions in one ENFORCE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

!std::isnan(static_cast<double>(in_data[id])),
"The input of multinomial distribution shoud not be infinity or NaN");
PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0,
"The sum of input should not be 0");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean >0 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's > 0, and >0 has the same meaning with not be 0 here. Because <0 has been forbidden before.
I have change the description from not be 0 to >0.

norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
}

Expand Down
36 changes: 36 additions & 0 deletions python/paddle/fluid/tests/unittests/test_multinomial_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import unittest
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from op_test import OpTest
import numpy as np


class TestMultinomialOp(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "multinomial"
self.init_data()
self.inputs = {"X": self.input_np}
Expand Down Expand Up @@ -175,5 +177,39 @@ def test_alias(self):
paddle.tensor.random.multinomial(x, num_samples=10, replacement=True)


class TestMultinomialError(unittest.TestCase):
def setUp(self):
paddle.disable_static()

def test_num_sample(self):
def test_num_sample_less_than_0():
x = paddle.rand([4])
paddle.multinomial(x, num_samples=-2)

self.assertRaises(core.EnforceNotMet, test_num_sample_less_than_0)

def test_replacement_False(self):
def test_samples_larger_than_categories():
x = paddle.rand([4])
paddle.multinomial(x, num_samples=5, replacement=False)

self.assertRaises(core.EnforceNotMet,
test_samples_larger_than_categories)

def test_input_probs_dim(self):
def test_dim_larger_than_2():
x = paddle.rand([2, 3, 3])
paddle.multinomial(x)

self.assertRaises(core.EnforceNotMet, test_dim_larger_than_2)

def test_dim_less_than_1():
x_np = np.random.random([])
x = paddle.to_tensor(x_np)
paddle.multinomial(x)

self.assertRaises(core.EnforceNotMet, test_dim_less_than_1)


if __name__ == "__main__":
unittest.main()
54 changes: 25 additions & 29 deletions python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,17 @@ def bernoulli(x, name=None):
Examples:
.. code-block:: python

import paddle

paddle.disable_static()
import paddle

x = paddle.rand([2, 3])
print(x.numpy())
# [[0.11272584 0.3890902 0.7730957 ]
# [0.10351662 0.8510418 0.63806665]]
x = paddle.rand([2, 3])
print(x.numpy())
# [[0.11272584 0.3890902 0.7730957 ]
# [0.10351662 0.8510418 0.63806665]]

out = paddle.bernoulli(x)
print(out.numpy())
# [[0. 0. 1.]
# [0. 0. 1.]]
out = paddle.bernoulli(x)
print(out.numpy())
# [[0. 0. 1.]
# [0. 0. 1.]]

"""

Expand Down Expand Up @@ -108,28 +106,26 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
Examples:
.. code-block:: python

import paddle

paddle.disable_static()
import paddle

x = paddle.rand([2,4])
print(x.numpy())
# [[0.7713825 0.4055941 0.433339 0.70706886]
# [0.9223313 0.8519825 0.04574518 0.16560672]]
x = paddle.rand([2,4])
print(x.numpy())
# [[0.7713825 0.4055941 0.433339 0.70706886]
# [0.9223313 0.8519825 0.04574518 0.16560672]]

out1 = paddle.multinomial(x, num_samples=5, replacement=True)
print(out1.numpy())
# [[3 3 1 1 0]
# [0 0 0 0 1]]
out1 = paddle.multinomial(x, num_samples=5, replacement=True)
print(out1.numpy())
# [[3, 3, 1, 1, 0]
# [0, 0, 0, 0, 1]]

# out2 = paddle.multinomial(x, num_samples=5)
# OutOfRangeError: When replacement is False, number of samples
# should be less than non-zero categories
# out2 = paddle.multinomial(x, num_samples=5)
# OutOfRangeError: When replacement is False, number of samples
# should be less than non-zero categories

out3 = paddle.multinomial(x, num_samples=3)
print(out3.numpy())
# [[0 2 3]
# [0 1 3]]
out3 = paddle.multinomial(x, num_samples=3)
print(out3.numpy())
# [[0, 2, 3]
# [0, 1, 3]]

"""

Expand Down