Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from pathlib import Path

import torch
from utils import adjust_batch, count_hpu_graphs, finalize_quantization, initialize_model

from optimum.habana.utils import get_hpu_memory_stats
from utils import adjust_batch, count_hpu_graphs, finalize_quantization, initialize_model


logging.basicConfig(
Expand Down Expand Up @@ -417,6 +417,13 @@ def generate(size=None, reduce_recompile=False):
max_length=args.max_input_tokens,
truncation=True,
)

def compute_valid_sequence_lengths_tensor(input_tokens):
attn_mask = input_tokens["attention_mask"]
return torch.sum(attn_mask, dim=1)

valid_sequence_lengths = compute_valid_sequence_lengths_tensor(input_tokens).to(args.device)
generation_config.valid_sequence_lengths = valid_sequence_lengths
else:
input_tokens = tokenizer.batch_encode_plus(input_sentences, return_tensors="pt", padding=True)
encode_duration = time.perf_counter() - encode_t0
Expand Down
2 changes: 1 addition & 1 deletion examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@

# Local imports
from run_generation import setup_parser
from utils import finalize_quantization, initialize_model

from optimum.habana.utils import get_hpu_memory_stats
from utils import finalize_quantization, initialize_model


os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
generation_config.trust_remote_code = args.trust_remote_code
setattr(generation_config, 'valid_sequence_lengths', None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hsubramony @ssarkar2 Isn't this where the default value is getting set? why do we get no attribute error?

Copy link
Copy Markdown
Collaborator

@regisss regisss Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

valid_sequence_lengths should be declared here: https://github.com/huggingface/optimum-habana/blob/main/optimum/habana/transformers/generation/configuration_utils.py
And then:

generation_config.valid_sequence_lengths = None


return generation_config

Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,8 @@ def generate(
True if generation_config.flash_attention_fast_softmax else False
)
model_kwargs["num_virtual_tokens"] = num_virtual_tokens
if hasattr(generation_config, "valid_sequence_lengths"):
Comment thread
libinta marked this conversation as resolved.
Outdated
model_kwargs["valid_sequence_lengths"] = generation_config.valid_sequence_lengths

if not self.config.is_encoder_decoder:
calculated_max_length = input_ids.shape[-1] + num_virtual_tokens
Expand Down
93 changes: 77 additions & 16 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,33 @@ def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode)
def forward(
self,
query,
key,
value,
attn_mask,
dropout_p,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side="left",
):
return self._hpu_kernel_fsdpa.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)


class Matmul(torch.nn.Module):
Expand Down Expand Up @@ -488,6 +513,7 @@ def pre_attn_forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
Expand Down Expand Up @@ -618,30 +644,55 @@ def pre_attn_forward(
past_key_value = None

if use_flash_attention and FusedSDPA is not None:
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"

if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, "None"
)
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
softmax_mode,
use_recompute,
None,
"None",
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
None,
0.0,
True,
None,
softmax_mode,
flash_attention_recompute,
valid_sequence_lengths,
"left",
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
softmax_mode,
flash_attention_recompute,
None,
"None",
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
Expand Down Expand Up @@ -837,6 +888,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
Expand Down Expand Up @@ -870,6 +922,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
**kwargs,
Expand Down Expand Up @@ -905,6 +958,7 @@ def pre_attn(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -925,6 +979,7 @@ def pre_attn(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand Down Expand Up @@ -1018,6 +1073,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
Expand Down Expand Up @@ -1157,6 +1213,7 @@ def forward(
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
valid_sequence_lengths,
None,
)
else:
Expand All @@ -1176,6 +1233,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand Down Expand Up @@ -1253,6 +1311,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
Expand Down Expand Up @@ -1285,6 +1344,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
num_virtual_tokens=num_virtual_tokens,
Expand Down Expand Up @@ -1394,6 +1454,7 @@ def prepare_inputs_for_generation(
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"),
"valid_sequence_lengths": kwargs.get("valid_sequence_lengths"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
"num_virtual_tokens": kwargs.get("num_virtual_tokens"),
Expand Down