Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ class MultiHeadAttention(nn.Module):
max_seq_len (int): maximum sequence length supported by the model.
This is needed to compute the RoPE Cache. Default: 4096.
is_causal (bool): sets the default mask to causal when no mask is provided
attn_dropout (float): dropout value passed onto the
scaled_dot_product_attention function. This argument is ignored if the
self.training is False. Default value is 0.0.
attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function.
Default value is 0.0.

Raises:
ValueError: If ``num_heads % num_kv_heads != 0``
Expand Down