Skip to content

Commit 975bd88

Browse files
authored
Fix error message of multinomial op (#27946)
* fix multinomial doc * fix multinomial error message * little doc change * fix Categorical class doc * optimize format of error message * fix CPU Kernel error message format * fix isinf and isnan error in WindowsOPENBLAS CI * delete inf and nan * add manual_seed in sample code * little error message change * change error message to InvalidArgument * add full point for error message and add manual_seed in CPU environment
1 parent b6eff44 commit 975bd88

File tree

6 files changed

+213
-138
lines changed

6 files changed

+213
-138
lines changed

paddle/fluid/operators/multinomial_op.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,27 @@ class MultinomialOp : public framework::OperatorWithKernel {
5353

5454
auto x_dim = ctx->GetInputDim("X");
5555
int64_t x_rank = x_dim.size();
56+
PADDLE_ENFORCE_GT(x_rank, 0,
57+
platform::errors::InvalidArgument(
58+
"The number of dimensions of the input probability "
59+
"distribution should be > 0, but got %d.",
60+
x_rank));
61+
PADDLE_ENFORCE_LE(x_rank, 2,
62+
platform::errors::InvalidArgument(
63+
"The number of dimensions of the input probability "
64+
"distribution should be <= 2, but got %d.",
65+
x_rank));
66+
5667
std::vector<int64_t> out_dims(x_rank);
5768
for (int64_t i = 0; i < x_rank - 1; i++) {
5869
out_dims[i] = x_dim[i];
5970
}
6071

6172
int64_t num_samples = ctx->Attrs().Get<int>("num_samples");
73+
PADDLE_ENFORCE_GT(
74+
num_samples, 0,
75+
platform::errors::InvalidArgument(
76+
"The number of samples should be > 0, but got %d.", num_samples));
6277
out_dims[x_rank - 1] = num_samples;
6378

6479
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));

paddle/fluid/operators/multinomial_op.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include "paddle/fluid/framework/op_registry.h"
2222
#include "paddle/fluid/framework/operator.h"
2323
#include "paddle/fluid/operators/multinomial_op.h"
24+
#include "paddle/fluid/platform/enforce.h"
2425
#include "paddle/fluid/platform/transform.h"
2526

2627
namespace paddle {
@@ -31,6 +32,14 @@ __global__ void NormalizeProbability(T* norm_probs, const T* in_data,
3132
T* sum_rows) {
3233
int id = threadIdx.x + blockIdx.x * blockDim.x +
3334
blockIdx.y * gridDim.x * blockDim.x;
35+
PADDLE_ENFORCE(
36+
in_data[id] >= 0.0,
37+
"The input of multinomial distribution should be >= 0, but got %f.",
38+
in_data[id]);
39+
PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0,
40+
"The sum of one multinomial distribution probability should "
41+
"be > 0, but got %f.",
42+
sum_rows[blockIdx.y]);
3443
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
3544
}
3645

paddle/fluid/operators/multinomial_op.h

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,28 +44,29 @@ void MultinomialFunctor(int64_t* out_data, const T* in_data,
4444
int64_t num_zeros = 0;
4545
for (int64_t j = 0; j < num_categories; j++) {
4646
prob_value = in_data[i * num_categories + j];
47-
PADDLE_ENFORCE_GE(
48-
prob_value, 0.0,
49-
platform::errors::OutOfRange(
50-
"The input of multinomial distribution should be >= 0"));
51-
PADDLE_ENFORCE_EQ((std::isinf(static_cast<double>(prob_value)) ||
52-
std::isnan(static_cast<double>(prob_value))),
53-
false, platform::errors::OutOfRange(
54-
"The input of multinomial distribution "
55-
"shoud not be infinity or NaN"));
47+
PADDLE_ENFORCE_GE(prob_value, 0.0,
48+
platform::errors::InvalidArgument(
49+
"The input of multinomial distribution "
50+
"should be >= 0, but got %f.",
51+
prob_value));
52+
5653
probs_sum += prob_value;
5754
if (prob_value == 0) {
5855
num_zeros += 1;
5956
}
6057
cumulative_probs[j] = probs_sum;
6158
}
62-
PADDLE_ENFORCE_GT(probs_sum, 0.0, platform::errors::OutOfRange(
63-
"The sum of input should not be 0"));
59+
PADDLE_ENFORCE_GT(probs_sum, 0.0,
60+
platform::errors::InvalidArgument(
61+
"The sum of one multinomial distribution "
62+
"probability should be > 0, but got %f.",
63+
probs_sum));
6464
PADDLE_ENFORCE_EQ(
6565
(replacement || (num_categories - num_zeros >= num_samples)), true,
66-
platform::errors::OutOfRange("When replacement is False, number of "
67-
"samples should be less than non-zero "
68-
"categories"));
66+
platform::errors::InvalidArgument(
67+
"When replacement is False, number of "
68+
"samples should be less than non-zero "
69+
"categories."));
6970

7071
for (int64_t j = 0; j < num_categories; j++) {
7172
cumulative_probs[j] /= probs_sum;

python/paddle/distribution.py

Lines changed: 105 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -662,48 +662,54 @@ class Categorical(Distribution):
662662
663663
Args:
664664
logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64.
665+
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
665666
666667
Examples:
667668
.. code-block:: python
668669
669-
import paddle
670-
from paddle.distribution import Categorical
670+
import paddle
671+
from paddle.distribution import Categorical
671672
672-
x = paddle.rand([6])
673-
print(x.numpy())
674-
# [0.32564053, 0.99334985, 0.99034804,
675-
# 0.09053693, 0.30820143, 0.19095989]
676-
y = paddle.rand([6])
677-
print(y.numpy())
678-
# [0.6365463 , 0.7278677 , 0.90260243,
679-
# 0.5226815 , 0.35837543, 0.13981032]
673+
paddle.manual_seed(100) # on CPU device
674+
x = paddle.rand([6])
675+
print(x.numpy())
676+
# [0.5535528 0.20714243 0.01162981
677+
# 0.51577556 0.36369765 0.2609165 ]
680678
681-
cat = Categorical(x)
682-
cat2 = Categorical(y)
679+
paddle.manual_seed(200) # on CPU device
680+
y = paddle.rand([6])
681+
print(y.numpy())
682+
# [0.77663314 0.90824795 0.15685187
683+
# 0.04279523 0.34468332 0.7955718 ]
683684
684-
cat.sample([2,3])
685-
# [[5, 1, 1],
686-
# [0, 1, 2]]
685+
cat = Categorical(x)
686+
cat2 = Categorical(y)
687687
688-
cat.entropy()
689-
# [1.71887]
688+
paddle.manual_seed(1000) # on CPU device
689+
cat.sample([2,3])
690+
# [[0, 0, 5],
691+
# [3, 4, 5]]
690692
691-
cat.kl_divergence(cat2)
692-
# [0.0278455]
693+
cat.entropy()
694+
# [1.77528]
693695
694-
value = paddle.to_tensor([2,1,3])
695-
cat.probs(value)
696-
# [0.341613 0.342648 0.03123]
696+
cat.kl_divergence(cat2)
697+
# [0.071952]
697698
698-
cat.log_prob(value)
699-
# [-1.07408 -1.07105 -3.46638]
699+
value = paddle.to_tensor([2,1,3])
700+
cat.probs(value)
701+
# [0.00608027 0.108298 0.269656]
702+
703+
cat.log_prob(value)
704+
# [-5.10271 -2.22287 -1.31061]
700705
701706
"""
702707

703708
def __init__(self, logits, name=None):
704709
"""
705710
Args:
706-
logits(list|numpy.ndarray|Variable): The logits input of categorical distribution. The data type is float32 or float64.
711+
logits(list|numpy.ndarray|Tensor): The logits input of categorical distribution. The data type is float32 or float64.
712+
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
707713
"""
708714
if not in_dygraph_mode():
709715
check_type(logits, 'logits', (np.ndarray, tensor.Variable, list),
@@ -727,27 +733,29 @@ def sample(self, shape):
727733
"""Generate samples of the specified shape.
728734
729735
Args:
730-
shape (list): Shape of the generated samples.
736+
shape (list): Shape of the generated samples.
731737
732738
Returns:
733-
Tensor: A tensor with prepended dimensions shape.
739+
Tensor: A tensor with prepended dimensions shape.
734740
735741
Examples:
736-
.. code-block:: python
742+
.. code-block:: python
737743
738-
import paddle
739-
from paddle.distribution import Categorical
744+
import paddle
745+
from paddle.distribution import Categorical
740746
741-
x = paddle.rand([6])
742-
print(x.numpy())
743-
# [0.32564053, 0.99334985, 0.99034804,
744-
# 0.09053693, 0.30820143, 0.19095989]
747+
paddle.manual_seed(100) # on CPU device
748+
x = paddle.rand([6])
749+
print(x.numpy())
750+
# [0.5535528 0.20714243 0.01162981
751+
# 0.51577556 0.36369765 0.2609165 ]
745752
746-
cat = Categorical(x)
753+
cat = Categorical(x)
747754
748-
cat.sample([2,3])
749-
# [[5, 1, 1],
750-
# [0, 1, 2]]
755+
paddle.manual_seed(1000) # on CPU device
756+
cat.sample([2,3])
757+
# [[0, 0, 5],
758+
# [3, 4, 5]]
751759
752760
"""
753761
name = self.name + '_sample'
@@ -775,28 +783,31 @@ def kl_divergence(self, other):
775783
other (Categorical): instance of Categorical. The data type is float32.
776784
777785
Returns:
778-
Variable: kl-divergence between two Categorical distributions.
786+
Tensor: kl-divergence between two Categorical distributions.
779787
780788
Examples:
781-
.. code-block:: python
789+
.. code-block:: python
782790
783-
import paddle
784-
from paddle.distribution import Categorical
791+
import paddle
792+
from paddle.distribution import Categorical
793+
794+
paddle.manual_seed(100) # on CPU device
795+
x = paddle.rand([6])
796+
print(x.numpy())
797+
# [0.5535528 0.20714243 0.01162981
798+
# 0.51577556 0.36369765 0.2609165 ]
785799
786-
x = paddle.rand([6])
787-
print(x.numpy())
788-
# [0.32564053, 0.99334985, 0.99034804,
789-
# 0.09053693, 0.30820143, 0.19095989]
790-
y = paddle.rand([6])
791-
print(y.numpy())
792-
# [0.6365463 , 0.7278677 , 0.90260243,
793-
# 0.5226815 , 0.35837543, 0.13981032]
800+
paddle.manual_seed(200) # on CPU device
801+
y = paddle.rand([6])
802+
print(y.numpy())
803+
# [0.77663314 0.90824795 0.15685187
804+
# 0.04279523 0.34468332 0.7955718 ]
794805
795-
cat = Categorical(x)
796-
cat2 = Categorical(y)
806+
cat = Categorical(x)
807+
cat2 = Categorical(y)
797808
798-
cat.kl_divergence(cat2)
799-
# [0.0278455]
809+
cat.kl_divergence(cat2)
810+
# [0.071952]
800811
801812
"""
802813
name = self.name + '_kl_divergence'
@@ -823,23 +834,24 @@ def entropy(self):
823834
"""Shannon entropy in nats.
824835
825836
Returns:
826-
Variable: Shannon entropy of Categorical distribution. The data type is float32.
837+
Tensor: Shannon entropy of Categorical distribution. The data type is float32.
827838
828839
Examples:
829-
.. code-block:: python
840+
.. code-block:: python
830841
831-
import paddle
832-
from paddle.distribution import Categorical
842+
import paddle
843+
from paddle.distribution import Categorical
833844
834-
x = paddle.rand([6])
835-
print(x.numpy())
836-
# [0.32564053, 0.99334985, 0.99034804,
837-
# 0.09053693, 0.30820143, 0.19095989]
845+
paddle.manual_seed(100) # on CPU device
846+
x = paddle.rand([6])
847+
print(x.numpy())
848+
# [0.5535528 0.20714243 0.01162981
849+
# 0.51577556 0.36369765 0.2609165 ]
838850
839-
cat = Categorical(x)
851+
cat = Categorical(x)
840852
841-
cat.entropy()
842-
# [1.71887]
853+
cat.entropy()
854+
# [1.77528]
843855
844856
"""
845857
name = self.name + '_entropy'
@@ -864,27 +876,28 @@ def probs(self, value):
864876
with ``logits. That is, ``value[:-1] = logits[:-1]``.
865877
866878
Args:
867-
value (Tensor): The input tensor represents the selected category index.
879+
value (Tensor): The input tensor represents the selected category index.
868880
869881
Returns:
870-
Tensor: probability according to the category index.
882+
Tensor: probability according to the category index.
871883
872884
Examples:
873-
.. code-block:: python
885+
.. code-block:: python
874886
875-
import paddle
876-
from paddle.distribution import Categorical
887+
import paddle
888+
from paddle.distribution import Categorical
877889
878-
x = paddle.rand([6])
879-
print(x.numpy())
880-
# [0.32564053, 0.99334985, 0.99034804,
881-
# 0.09053693, 0.30820143, 0.19095989]
890+
paddle.manual_seed(100) # on CPU device
891+
x = paddle.rand([6])
892+
print(x.numpy())
893+
# [0.5535528 0.20714243 0.01162981
894+
# 0.51577556 0.36369765 0.2609165 ]
882895
883-
cat = Categorical(x)
896+
cat = Categorical(x)
884897
885-
value = paddle.to_tensor([2,1,3])
886-
cat.probs(value)
887-
# [0.341613 0.342648 0.03123]
898+
value = paddle.to_tensor([2,1,3])
899+
cat.probs(value)
900+
# [0.00608027 0.108298 0.269656]
888901
889902
"""
890903
name = self.name + '_probs'
@@ -929,28 +942,28 @@ def log_prob(self, value):
929942
"""Log probabilities of the given category. Refer to ``probs`` method.
930943
931944
Args:
932-
value (Tensor): The input tensor represents the selected category index.
945+
value (Tensor): The input tensor represents the selected category index.
933946
934947
Returns:
935-
Tensor: Log probability.
948+
Tensor: Log probability.
936949
937950
Examples:
938-
.. code-block:: python
939-
940-
import paddle
941-
from paddle.distribution import Categorical
951+
.. code-block:: python
942952
943-
x = paddle.rand([6])
944-
print(x.numpy())
945-
# [0.32564053, 0.99334985, 0.99034804,
946-
# 0.09053693, 0.30820143, 0.19095989]
953+
import paddle
954+
from paddle.distribution import Categorical
947955
948-
cat = Categorical(x)
956+
paddle.manual_seed(100) # on CPU device
957+
x = paddle.rand([6])
958+
print(x.numpy())
959+
# [0.5535528 0.20714243 0.01162981
960+
# 0.51577556 0.36369765 0.2609165 ]
949961
950-
value = paddle.to_tensor([2,1,3])
962+
cat = Categorical(x)
951963
952-
cat.log_prob(value)
953-
# [-1.07408 -1.07105 -3.46638]
964+
value = paddle.to_tensor([2,1,3])
965+
cat.log_prob(value)
966+
# [-5.10271 -2.22287 -1.31061]
954967
955968
"""
956969
name = self.name + '_log_prob'

0 commit comments

Comments
 (0)