diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 2dfeaddc9a..d239334cf8 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -69,9 +69,8 @@ class MultiHeadAttention(nn.Module): max_seq_len (int): maximum sequence length supported by the model. This is needed to compute the RoPE Cache. Default: 4096. is_causal (bool): sets the default mask to causal when no mask is provided - attn_dropout (float): dropout value passed onto the - scaled_dot_product_attention function. This argument is ignored if the - self.training is False. Default value is 0.0. + attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. + Default value is 0.0. Raises: ValueError: If ``num_heads % num_kv_heads != 0``