Skip to content

Commit 521dd7e

Browse files
authored
Add error message when parameter is set to 0 (#33859)
1 parent a0a9079 commit 521dd7e

1 file changed

Lines changed: 32 additions & 0 deletions

File tree

python/paddle/nn/layer/transformer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,12 @@ def __init__(self,
161161
weight_attr=None,
162162
bias_attr=None):
163163
super(MultiHeadAttention, self).__init__()
164+
165+
assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
166+
"but recieved {}".format(embed_dim))
167+
assert num_heads > 0, ("Expected num_heads to be greater than 0, "
168+
"but recieved {}".format(num_heads))
169+
164170
self.embed_dim = embed_dim
165171
self.kdim = kdim if kdim is not None else embed_dim
166172
self.vdim = vdim if vdim is not None else embed_dim
@@ -501,6 +507,15 @@ def __init__(self,
501507
self._config.pop("__class__", None) # py3
502508

503509
super(TransformerEncoderLayer, self).__init__()
510+
511+
assert d_model > 0, ("Expected d_model to be greater than 0, "
512+
"but recieved {}".format(d_model))
513+
assert nhead > 0, ("Expected nhead to be greater than 0, "
514+
"but recieved {}".format(nhead))
515+
assert dim_feedforward > 0, (
516+
"Expected dim_feedforward to be greater than 0, "
517+
"but recieved {}".format(dim_feedforward))
518+
504519
attn_dropout = dropout if attn_dropout is None else attn_dropout
505520
act_dropout = dropout if act_dropout is None else act_dropout
506521
self.normalize_before = normalize_before
@@ -797,6 +812,15 @@ def __init__(self,
797812
self._config.pop("__class__", None) # py3
798813

799814
super(TransformerDecoderLayer, self).__init__()
815+
816+
assert d_model > 0, ("Expected d_model to be greater than 0, "
817+
"but recieved {}".format(d_model))
818+
assert nhead > 0, ("Expected nhead to be greater than 0, "
819+
"but recieved {}".format(nhead))
820+
assert dim_feedforward > 0, (
821+
"Expected dim_feedforward to be greater than 0, "
822+
"but recieved {}".format(dim_feedforward))
823+
800824
attn_dropout = dropout if attn_dropout is None else attn_dropout
801825
act_dropout = dropout if act_dropout is None else act_dropout
802826
self.normalize_before = normalize_before
@@ -1196,6 +1220,14 @@ def __init__(self,
11961220
custom_decoder=None):
11971221
super(Transformer, self).__init__()
11981222

1223+
assert d_model > 0, ("Expected d_model to be greater than 0, "
1224+
"but recieved {}".format(d_model))
1225+
assert nhead > 0, ("Expected nhead to be greater than 0, "
1226+
"but recieved {}".format(nhead))
1227+
assert dim_feedforward > 0, (
1228+
"Expected dim_feedforward to be greater than 0, "
1229+
"but recieved {}".format(dim_feedforward))
1230+
11991231
if isinstance(bias_attr, (list, tuple)):
12001232
if len(bias_attr) == 1:
12011233
encoder_bias_attr = [bias_attr[0]] * 2

0 commit comments

Comments
 (0)