@@ -161,6 +161,12 @@ def __init__(self,
161161 weight_attr = None ,
162162 bias_attr = None ):
163163 super (MultiHeadAttention , self ).__init__ ()
164+
165+ assert embed_dim > 0 , ("Expected embed_dim to be greater than 0, "
166+ "but recieved {}" .format (embed_dim ))
167+ assert num_heads > 0 , ("Expected num_heads to be greater than 0, "
168+ "but recieved {}" .format (num_heads ))
169+
164170 self .embed_dim = embed_dim
165171 self .kdim = kdim if kdim is not None else embed_dim
166172 self .vdim = vdim if vdim is not None else embed_dim
@@ -501,6 +507,15 @@ def __init__(self,
501507 self ._config .pop ("__class__" , None ) # py3
502508
503509 super (TransformerEncoderLayer , self ).__init__ ()
510+
511+ assert d_model > 0 , ("Expected d_model to be greater than 0, "
512+ "but recieved {}" .format (d_model ))
513+ assert nhead > 0 , ("Expected nhead to be greater than 0, "
514+ "but recieved {}" .format (nhead ))
515+ assert dim_feedforward > 0 , (
516+ "Expected dim_feedforward to be greater than 0, "
517+ "but recieved {}" .format (dim_feedforward ))
518+
504519 attn_dropout = dropout if attn_dropout is None else attn_dropout
505520 act_dropout = dropout if act_dropout is None else act_dropout
506521 self .normalize_before = normalize_before
@@ -797,6 +812,15 @@ def __init__(self,
797812 self ._config .pop ("__class__" , None ) # py3
798813
799814 super (TransformerDecoderLayer , self ).__init__ ()
815+
816+ assert d_model > 0 , ("Expected d_model to be greater than 0, "
817+ "but recieved {}" .format (d_model ))
818+ assert nhead > 0 , ("Expected nhead to be greater than 0, "
819+ "but recieved {}" .format (nhead ))
820+ assert dim_feedforward > 0 , (
821+ "Expected dim_feedforward to be greater than 0, "
822+ "but recieved {}" .format (dim_feedforward ))
823+
800824 attn_dropout = dropout if attn_dropout is None else attn_dropout
801825 act_dropout = dropout if act_dropout is None else act_dropout
802826 self .normalize_before = normalize_before
@@ -1196,6 +1220,14 @@ def __init__(self,
11961220 custom_decoder = None ):
11971221 super (Transformer , self ).__init__ ()
11981222
1223+ assert d_model > 0 , ("Expected d_model to be greater than 0, "
1224+ "but recieved {}" .format (d_model ))
1225+ assert nhead > 0 , ("Expected nhead to be greater than 0, "
1226+ "but recieved {}" .format (nhead ))
1227+ assert dim_feedforward > 0 , (
1228+ "Expected dim_feedforward to be greater than 0, "
1229+ "but recieved {}" .format (dim_feedforward ))
1230+
11991231 if isinstance (bias_attr , (list , tuple )):
12001232 if len (bias_attr ) == 1 :
12011233 encoder_bias_attr = [bias_attr [0 ]] * 2
0 commit comments