diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 00e6c13849..e9fa5456cf 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -460,6 +460,13 @@ def generate(size=None, reduce_recompile=False): max_length=args.max_input_tokens, truncation=True, ) + + def compute_valid_sequence_lengths_tensor(input_tokens): + attn_mask = input_tokens["attention_mask"] + return torch.sum(attn_mask, dim=1) + + valid_sequence_lengths = compute_valid_sequence_lengths_tensor(input_tokens).to(args.device) + generation_config.valid_sequence_lengths = valid_sequence_lengths else: input_tokens = tokenizer.batch_encode_plus(input_sentences, return_tensors="pt", padding=True) encode_duration = time.perf_counter() - encode_t0 diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 61a8aa3338..df37b10a30 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -590,6 +590,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code + generation_config.valid_sequence_lengths = None return generation_config diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index ce38a07ed9..ec04f139c9 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -55,3 +55,4 @@ def __init__(self, **kwargs): self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) + self.valid_sequence_lengths = kwargs.get("valid_sequence_lengths", None) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index a76ea59e87..db989cfe8a 100644 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1216,6 +1216,8 @@ def generate( True if generation_config.flash_attention_fast_softmax else False ) model_kwargs["num_virtual_tokens"] = num_virtual_tokens + if generation_config.valid_sequence_lengths is not None: + model_kwargs["valid_sequence_lengths"] = generation_config.valid_sequence_lengths if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[1] + num_virtual_tokens diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 21bafda4b2..55da544464 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -359,8 +359,33 @@ def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_ self.enable_recompute = enable_recompute self.flash_attention_fp8 = flash_attention_fp8 - def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode): - return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode) + def forward( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side="left", + ): + return self._hpu_kernel_fsdpa.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) class Matmul(torch.nn.Module): @@ -506,6 +531,7 @@ def pre_attn_forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, **kwargs, @@ -636,30 +662,54 @@ def pre_attn_forward( past_key_value = None if use_flash_attention and FusedSDPA is not None: - import habana_frameworks.torch.hpu as ht - softmax_mode = "fast" if flash_attention_fast_softmax else "None" if q_len == 1: # next token use_recompute = True if os.getenv("QUANT_CONFIG", "") else False - with ht.sdp_kernel(enable_recompute=use_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, "None" - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + softmax_mode, + use_recompute, + None, + "None", + ) else: # first token if flash_attention_causal_mask: - # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, None, 0.0, True, None, softmax_mode - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + None, + 0.0, + True, + None, + softmax_mode, + flash_attention_recompute, + valid_sequence_lengths, + "left", + ) else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + softmax_mode, + flash_attention_recompute, + None, + "None", + ) else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( @@ -855,6 +905,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, **kwargs, @@ -888,6 +939,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, **kwargs, @@ -923,6 +975,7 @@ def pre_attn( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -943,6 +996,7 @@ def pre_attn( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) @@ -1036,6 +1090,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, @@ -1175,6 +1230,7 @@ def forward( flash_attention_recompute, flash_attention_causal_mask, flash_attention_fast_softmax, + valid_sequence_lengths, None, ) else: @@ -1194,6 +1250,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) @@ -1272,6 +1329,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, @@ -1304,6 +1362,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, lazy_mode=lazy_mode, num_virtual_tokens=num_virtual_tokens, @@ -1427,6 +1486,7 @@ def prepare_inputs_for_generation( "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"), + "valid_sequence_lengths": kwargs.get("valid_sequence_lengths"), "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), "num_virtual_tokens": kwargs.get("num_virtual_tokens"),