Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,13 @@ def generate(size=None, reduce_recompile=False):
max_length=args.max_input_tokens,
truncation=True,
)

def compute_valid_sequence_lengths_tensor(input_tokens):
attn_mask = input_tokens["attention_mask"]
return torch.sum(attn_mask, dim=1)

valid_sequence_lengths = compute_valid_sequence_lengths_tensor(input_tokens).to(args.device)
generation_config.valid_sequence_lengths = valid_sequence_lengths
else:
input_tokens = tokenizer.batch_encode_plus(input_sentences, return_tensors="pt", padding=True)
encode_duration = time.perf_counter() - encode_t0
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
generation_config.trust_remote_code = args.trust_remote_code
generation_config.valid_sequence_lengths = None

return generation_config

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ def __init__(self, **kwargs):
self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None)
self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
self.valid_sequence_lengths = kwargs.get("valid_sequence_lengths", None)
2 changes: 2 additions & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,8 @@ def generate(
True if generation_config.flash_attention_fast_softmax else False
)
model_kwargs["num_virtual_tokens"] = num_virtual_tokens
if generation_config.valid_sequence_lengths is not None:
model_kwargs["valid_sequence_lengths"] = generation_config.valid_sequence_lengths

if not self.config.is_encoder_decoder:
calculated_max_length = input_ids.shape[1] + num_virtual_tokens
Expand Down
94 changes: 77 additions & 17 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,33 @@ def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_
self.enable_recompute = enable_recompute
self.flash_attention_fp8 = flash_attention_fp8

def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode)
def forward(
self,
query,
key,
value,
attn_mask,
dropout_p,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side="left",
):
return self._hpu_kernel_fsdpa.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)


class Matmul(torch.nn.Module):
Expand Down Expand Up @@ -506,6 +531,7 @@ def pre_attn_forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
Expand Down Expand Up @@ -636,30 +662,54 @@ def pre_attn_forward(
past_key_value = None

if use_flash_attention and FusedSDPA is not None:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"

if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, "None"
)
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
softmax_mode,
use_recompute,
None,
"None",
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
None,
0.0,
True,
None,
softmax_mode,
flash_attention_recompute,
valid_sequence_lengths,
"left",
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
softmax_mode,
flash_attention_recompute,
None,
"None",
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
Expand Down Expand Up @@ -855,6 +905,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
Expand Down Expand Up @@ -888,6 +939,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
**kwargs,
Expand Down Expand Up @@ -923,6 +975,7 @@ def pre_attn(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -943,6 +996,7 @@ def pre_attn(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand Down Expand Up @@ -1036,6 +1090,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
Expand Down Expand Up @@ -1175,6 +1230,7 @@ def forward(
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
valid_sequence_lengths,
None,
)
else:
Expand All @@ -1194,6 +1250,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand Down Expand Up @@ -1272,6 +1329,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
Expand Down Expand Up @@ -1304,6 +1362,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
num_virtual_tokens=num_virtual_tokens,
Expand Down Expand Up @@ -1427,6 +1486,7 @@ def prepare_inputs_for_generation(
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"),
"valid_sequence_lengths": kwargs.get("valid_sequence_lengths"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
"num_virtual_tokens": kwargs.get("num_virtual_tokens"),
Expand Down