Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions paddle/fluid/operators/multinomial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,25 @@ class MultinomialOp : public framework::OperatorWithKernel {

auto x_dim = ctx->GetInputDim("X");
int64_t x_rank = x_dim.size();
PADDLE_ENFORCE_GT(x_rank, 0, platform::errors::PreconditionNotMet(
"Input probability distribution should be "
"1 or 2 dimension, but got %d",
x_rank));
PADDLE_ENFORCE_LE(x_rank, 2, platform::errors::PreconditionNotMet(
"Input probability distribution should be "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Input probability distribution should be "
"The number of dimensions of the input probability distribution should be <= 2, but got %d."

Similar for the others.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"1 or 2 dimension, but got %d",
x_rank));

std::vector<int64_t> out_dims(x_rank);
for (int64_t i = 0; i < x_rank - 1; i++) {
out_dims[i] = x_dim[i];
}

int64_t num_samples = ctx->Attrs().Get<int>("num_samples");
PADDLE_ENFORCE_GT(
num_samples, 0,
platform::errors::OutOfRange(
"The number of samples should be > 0, but got %d", num_samples));
out_dims[x_rank - 1] = num_samples;

ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/multinomial_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/multinomial_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/transform.h"

namespace paddle {
Expand All @@ -31,6 +32,16 @@ __global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) {
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
PADDLE_ENFORCE(
in_data[id] >= 0.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议报错信息统一加句点,PR里有的加了,有的没加

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"The input of multinomial distribution should be >= 0, but got %f",
in_data[id]);
PADDLE_ENFORCE(in_data[id] != INFINITY,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any special reason that checking INF/NaN is added here? Otherwise, I think it is not really necessary. Because the property that a number is not NAN or INF should be satisfied almost everywhere, and if we check it everywhere, it may slow down the system.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have removed checking INF/NaN

"The input of multinomial distribution shoud not be infinity");
PADDLE_ENFORCE(in_data[id] != NAN,
"The input of multinomial distribution shoud not be NaN");
PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0,
"The sum of input should not be 0");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean >0 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's > 0, and >0 has the same meaning with not be 0 here. Because <0 has been forbidden before.
I have change the description from not be 0 to >0.

norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
}

Expand Down
17 changes: 11 additions & 6 deletions paddle/fluid/operators/multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ void MultinomialFunctor(int64_t* out_data, const T* in_data,
prob_value = in_data[i * num_categories + j];
PADDLE_ENFORCE_GE(
prob_value, 0.0,
platform::errors::OutOfRange("The input of multinomial distribution "
"should be >= 0, but got %f",
prob_value));
PADDLE_ENFORCE_EQ(
std::isinf(static_cast<double>(prob_value)), false,
platform::errors::OutOfRange(
"The input of multinomial distribution should be >= 0"));
PADDLE_ENFORCE_EQ((std::isinf(static_cast<double>(prob_value)) ||
std::isnan(static_cast<double>(prob_value))),
false, platform::errors::OutOfRange(
"The input of multinomial distribution "
"shoud not be infinity or NaN"));
"The input of multinomial distribution shoud not be infinity"));
PADDLE_ENFORCE_EQ(
std::isnan(static_cast<double>(prob_value)), false,
platform::errors::OutOfRange(
"The input of multinomial distribution shoud not be NaN"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


probs_sum += prob_value;
if (prob_value == 0) {
num_zeros += 1;
Expand Down
184 changes: 93 additions & 91 deletions python/paddle/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,48 +662,50 @@ class Categorical(Distribution):

Args:
logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Examples:
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
y = paddle.rand([6])
print(y.numpy())
# [0.6365463 , 0.7278677 , 0.90260243,
# 0.5226815 , 0.35837543, 0.13981032]
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
y = paddle.rand([6])
print(y.numpy())
# [0.6365463 , 0.7278677 , 0.90260243,
# 0.5226815 , 0.35837543, 0.13981032]

cat = Categorical(x)
cat2 = Categorical(y)
cat = Categorical(x)
cat2 = Categorical(y)

cat.sample([2,3])
# [[5, 1, 1],
# [0, 1, 2]]
cat.sample([2,3])
# [[5, 1, 1],
# [0, 1, 2]]

cat.entropy()
# [1.71887]
cat.entropy()
# [1.71887]

cat.kl_divergence(cat2)
# [0.0278455]
cat.kl_divergence(cat2)
# [0.0278455]

value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.341613 0.342648 0.03123]
value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.341613 0.342648 0.03123]

cat.log_prob(value)
# [-1.07408 -1.07105 -3.46638]
cat.log_prob(value)
# [-1.07408 -1.07105 -3.46638]

"""

def __init__(self, logits, name=None):
"""
Args:
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32 or float64.
logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
"""
if not in_dygraph_mode():
check_type(logits, 'logits', (np.ndarray, tensor.Variable, list),
Expand All @@ -727,27 +729,27 @@ def sample(self, shape):
"""Generate samples of the specified shape.

Args:
shape (list): Shape of the generated samples.
shape (list): Shape of the generated samples.

Returns:
Tensor: A tensor with prepended dimensions shape.
Tensor: A tensor with prepended dimensions shape.

Examples:
.. code-block:: python
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]

cat = Categorical(x)
cat = Categorical(x)

cat.sample([2,3])
# [[5, 1, 1],
# [0, 1, 2]]
cat.sample([2,3])
# [[5, 1, 1],
# [0, 1, 2]]

"""
name = self.name + '_sample'
Expand Down Expand Up @@ -775,28 +777,28 @@ def kl_divergence(self, other):
other (Categorical): instance of Categorical. The data type is float32.

Returns:
Variable: kl-divergence between two Categorical distributions.
Tensor: kl-divergence between two Categorical distributions.

Examples:
.. code-block:: python
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
y = paddle.rand([6])
print(y.numpy())
# [0.6365463 , 0.7278677 , 0.90260243,
# 0.5226815 , 0.35837543, 0.13981032]
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
y = paddle.rand([6])
print(y.numpy())
# [0.6365463 , 0.7278677 , 0.90260243,
# 0.5226815 , 0.35837543, 0.13981032]

cat = Categorical(x)
cat2 = Categorical(y)
cat = Categorical(x)
cat2 = Categorical(y)

cat.kl_divergence(cat2)
# [0.0278455]
cat.kl_divergence(cat2)
# [0.0278455]

"""
name = self.name + '_kl_divergence'
Expand All @@ -823,23 +825,23 @@ def entropy(self):
"""Shannon entropy in nats.

Returns:
Variable: Shannon entropy of Categorical distribution. The data type is float32.
Tensor: Shannon entropy of Categorical distribution. The data type is float32.

Examples:
.. code-block:: python
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]

cat = Categorical(x)
cat = Categorical(x)

cat.entropy()
# [1.71887]
cat.entropy()
# [1.71887]

"""
name = self.name + '_entropy'
Expand All @@ -864,27 +866,27 @@ def probs(self, value):
with ``logits. That is, ``value[:-1] = logits[:-1]``.

Args:
value (Tensor): The input tensor represents the selected category index.
value (Tensor): The input tensor represents the selected category index.

Returns:
Tensor: probability according to the category index.
Tensor: probability according to the category index.

Examples:
.. code-block:: python
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]

cat = Categorical(x)
cat = Categorical(x)

value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.341613 0.342648 0.03123]
value = paddle.to_tensor([2,1,3])
cat.probs(value)
# [0.341613 0.342648 0.03123]

"""
name = self.name + '_probs'
Expand Down Expand Up @@ -929,28 +931,28 @@ def log_prob(self, value):
"""Log probabilities of the given category. Refer to ``probs`` method.

Args:
value (Tensor): The input tensor represents the selected category index.
value (Tensor): The input tensor represents the selected category index.

Returns:
Tensor: Log probability.
Tensor: Log probability.

Examples:
.. code-block:: python
.. code-block:: python

import paddle
from paddle.distribution import Categorical
import paddle
from paddle.distribution import Categorical

x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
# 0.09053693, 0.30820143, 0.19095989]
x = paddle.rand([6])
print(x.numpy())
# [0.32564053, 0.99334985, 0.99034804,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better add paddle.manual_seed(xx) here, otherwise, users cannot get the same random output as your sample code.
Same for all the other examples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# 0.09053693, 0.30820143, 0.19095989]

cat = Categorical(x)
cat = Categorical(x)

value = paddle.to_tensor([2,1,3])
value = paddle.to_tensor([2,1,3])

cat.log_prob(value)
# [-1.07408 -1.07105 -3.46638]
cat.log_prob(value)
# [-1.07408 -1.07105 -3.46638]

"""
name = self.name + '_log_prob'
Expand Down
Loading