Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ Here are a few settings you may be interested in:
- `--attn_softmax_bf16` to run attention softmax layer in bfloat16 precision provided that the model (such as Llama) supports it
- `--trim_logits` to calculate logits only for the last token in the first time step provided that the model (such as Llama) supports it
- `--attn_batch_split` specifies the number of smaller batches into which attention and MLP processing are split to improve parallelization. By default, no splitting is performed (value is 1). Splitting is enabled only for prompt processing. This configuration is most effective for batch sizes (BS) > 125 and tensor parallelism (TP) >= 2, with a recommended value of '3' splits. This feature is thoroughly tested with Llama 2 70B but may be useful for other models as well.
- `--decode_attn_batch_split` specifies the number of smaller batches to split the attention and MLP processing into for better parallelization.By default, no splitting is performed (value is 1). Splitting is enabled only for decode.
- `--dynamo_specialize_float` enables specialization for float inputs by setting `specialize_float=True` in the `torch._dynamo` configuration. This option is applicable only when using `torch.compile` and can enhance performance, particularly in models utilizing FP8 quantization.

For example, you can reproduce the results presented in [this blog post](https://huggingface.co/blog/habana-gaudi-2-bloom) with the following command:
Expand Down
6 changes: 6 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,12 @@ def setup_parser(parser):
type=int,
help="Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for prompt.",
)
parser.add_argument(
"--decode_attn_batch_split",
default=1,
type=int,
help="Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for decode.",
)
parser.add_argument(
"--regional_compile",
action="store_true",
Expand Down
6 changes: 5 additions & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.trust_remote_code = args.trust_remote_code
generation_config.valid_sequence_lengths = None
generation_config.attn_batch_split = args.attn_batch_split
generation_config.decode_attn_batch_split = args.decode_attn_batch_split

return generation_config

Expand All @@ -771,8 +772,11 @@ def exclude_hpu_graph_configs(args):
def initialize_model(args, logger):
setup_distributed(args)
if not args.world_size > 0 and args.attn_batch_split > 1:
logger.warning("Disabling attention batch splitting as it's unnecessary for single-card execution")
logger.warning("Disabling attention batch splitting for prompt as it's unnecessary for single-card execution")
args.attn_batch_split = 1
if not args.world_size > 0 and args.decode_attn_batch_split > 1:
logger.warning("Disabling attention batch splitting for decode as it's unnecessary for single-card execution")
args.decode_attn_batch_split = 1
if exclude_hpu_graph_configs(args):
args.limit_hpu_graphs = False
override_prints(args.global_rank == 0 or args.verbose_workers, logger)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class GaudiGenerationConfig(GenerationConfig):
Whether to use fast softmax with reduced precision if use Habana flash attention.
attn_batch_split (`int`, *optional*):
Specify the batch size split for attention and mlp layers. 1 for no split. This is enabled only for prompt.
decode_attn_batch_split (`int`, *optional*):
Specify the batch size split for attention and mlp layers for decode. 1 for no split.
logits_bf16 (`bool`, *optional*):
Keep logits in bf16.
"""
Expand All @@ -65,4 +67,5 @@ def __init__(self, **kwargs):
self.use_fused_rope = kwargs.get("use_fused_rope", None)
self.valid_sequence_lengths = kwargs.get("valid_sequence_lengths", None)
self.attn_batch_split = kwargs.get("attn_batch_split", 1)
self.decode_attn_batch_split = kwargs.get("decode_attn_batch_split", 1)
self.logits_bf16 = kwargs.get("logits_bf16", None)
22 changes: 20 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,17 @@ def _pad_past_key_values(self, model_kwargs):
# Mark step if lazy mode is enabled
if lazy_mode:
self.htcore_generation.mark_step()
# For Non-MQA models with decode_attn_batch_split > 1, past_key_values is a list of list of list (k and v)
elif not is_mqa_model and model_kwargs.get("decode_attn_batch_split", 1) > 1:
for i, layer in enumerate(past_key_values): # Iterate over layers
for j, split_kv_caches in enumerate(layer): # Iterate over splitted kv_cahe
for k, k_or_v in enumerate(split_kv_caches): # Iterate over k and v
if torch.is_tensor(k_or_v) and k_or_v.shape[-2] == kv_cache_len_pad_amount:
# tensor(batch_size/num_splits, n_heads, kv_cache_len, head_dim)
past_key_values[i][j][k] = torch.nn.functional.pad(k_or_v, (0, 0, 0, pad_amount))
# Mark step if lazy mode is enabled
if lazy_mode:
self.htcore_generation.mark_step()
# For Non-MQA models, the past_key_values is a list of lists (k and v)
else:
for i, layer in enumerate(past_key_values): # Iterate over layers
Expand Down Expand Up @@ -1606,9 +1617,12 @@ def generate(
# prepare for allocate kv cache
model_kwargs["reuse_cache"] = generation_config.reuse_cache

# prepare for attention batch splitting
# prepare for attention batch splitting for prompt
model_kwargs["attn_batch_split"] = generation_config.attn_batch_split

# prepare for attention batch splitting for decode
model_kwargs["decode_attn_batch_split"] = generation_config.decode_attn_batch_split

# Keep logits in bf16
model_kwargs["logits_bf16"] = kwargs.get("logits_bf16")

Expand Down Expand Up @@ -2941,9 +2955,13 @@ def _sample(
if "inputs_embeds" in model_inputs
else None
)
if model_kwargs["decode_attn_batch_split"] > 1 :
output_past_key_values_shape = outputs.past_key_values[0][0][0].shape
else:
output_past_key_values_shape = outputs.past_key_values[0][0].shape
do_padding = (
key_to_check is not None
and outputs.past_key_values[0][0].shape[2] == model_inputs[key_to_check].shape[1]
and output_past_key_values_shape[2] == model_inputs[key_to_check].shape[1]
and generation_config.max_new_tokens > 1
)

Expand Down
51 changes: 39 additions & 12 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,7 @@ def forward(
cache_idx: int = None,
num_virtual_tokens: int = None,
attn_batch_split: int = 1,
decode_attn_batch_split: int = 1,
prev_layer_residual: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -960,8 +961,11 @@ def forward(
- add new arg flash_attention_causal_mask
- add new arg flash_attention_fast_softmax
"""
if attn_batch_split > 1 and past_key_value is None:
if (attn_batch_split > 1 and past_key_value is None) or (decode_attn_batch_split > 1 and past_key_value is not None):
# Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split
if past_key_value is not None:
attn_batch_split = decode_attn_batch_split

batch_size = attention_mask.size(0)
base_split_size = batch_size // attn_batch_split
remainder = batch_size % attn_batch_split
Expand All @@ -984,11 +988,11 @@ def forward(
split_hidden_states[i] = self.post_mlp(hidden_states[i], prev_layer_residual[i])

residual[i] = split_hidden_states[i]
split_hidden_states[i], self_attn_weights, present_key_value = self.pre_attn(
split_hidden_states[i], self_attn_weights, inter_present_key_value = self.pre_attn(
hidden_states=split_hidden_states[i],
attention_mask=sub_attention_mask[i],
position_ids=sub_position_ids[i],
past_key_value=past_key_value,
past_key_value=past_key_value[i] if past_key_value is not None else past_key_value,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
Expand All @@ -1006,10 +1010,17 @@ def forward(
)
self.self_attn.attention_all_reduce(split_hidden_states[i])
if use_cache:
split_present_key_values.append(present_key_value)
split_present_key_values.append(inter_present_key_value)

self_attn_weights = torch.cat(split_attn_weights, dim=0) if split_attn_weights else None
present_key_value = [torch.cat(tensors, dim=0) for tensors in zip(*split_present_key_values)]
if decode_attn_batch_split > 1:
# Instead of concatenating, keep them as a list of lists
# [[k1, v1], [k2, v2]]
present_key_value = split_present_key_values
else:
# Concatenate along the batch dimension to form the final present_key_value
# [k, v] where k and v have batch dimension = sum of all splits
present_key_value = [torch.cat(tensors, dim=0) for tensors in zip(*split_present_key_values)]

int_residual_splits = []
for i in range(attn_batch_split):
Expand Down Expand Up @@ -1054,7 +1065,7 @@ def forward(
if use_cache:
outputs += (present_key_value,)
# Store the residual splits to add them in the beginning of the next layer
if attn_batch_split > 1 and past_key_value is None:
if (attn_batch_split > 1 and past_key_value is None) or (decode_attn_batch_split > 1 and past_key_value is not None):
outputs += (int_residual_splits,)

return outputs
Expand Down Expand Up @@ -1133,6 +1144,7 @@ def post_mlp(self, hidden_states, residual):
return hidden_states



class GaudiLlamaModel(LlamaModel):
"""
Copied from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L909
Expand Down Expand Up @@ -1197,6 +1209,7 @@ def forward(
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
attn_batch_split: int = 1,
decode_attn_batch_split: int = 1,
**kwargs,
) -> BaseModelOutputWithPast:
"""
Expand Down Expand Up @@ -1246,8 +1259,10 @@ def forward(
past_seen_tokens = past_key_values[0][0][2]
else:
# HPU uses legacy cache path (use_new_cache = False)
past_seen_tokens = past_key_values[0][0].shape[2]

if decode_attn_batch_split > 1:
past_seen_tokens = past_key_values[0][0][0].shape[2]
else:
past_seen_tokens = past_key_values[0][0].shape[2]
if ignore_cache_position is False:
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down Expand Up @@ -1284,8 +1299,9 @@ def forward(
htcore.mark_step()

split_prompt = False
prev_layer_residual = None
if attn_batch_split > 1 and past_key_values is None:
if (attn_batch_split > 1 and past_key_values is None) or (decode_attn_batch_split > 1 and past_key_values is not None):
if past_key_values is not None:
attn_batch_split = decode_attn_batch_split
# Calculate split sizes to handle cases where batch size is not divisible by attn_batch_split
batch_size = hidden_states.size(0)
base_split_size = batch_size // attn_batch_split
Expand All @@ -1295,6 +1311,8 @@ def forward(
hidden_states_split = torch.split(hidden_states, split_sizes, dim=0)
split_prompt = True

prev_layer_residual = None

for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if (
lazy_mode
Expand All @@ -1306,10 +1324,10 @@ def forward(
# Calling the layer with positional arguments
# This is a workaround for an issue with DeepSpeed where
# it cannot handle keyword arguments and throws a RuntimError
use_prev_layer_residual = attn_batch_split > 1 and past_key_values is None
past_key_value = None if past_key_values is None else past_key_values[layer_idx]
use_prev_layer_residual = (attn_batch_split > 1 and past_key_value is None) or (decode_attn_batch_split > 1 and past_key_value is not None)
layer_prev_layer_residual = prev_layer_residual if use_prev_layer_residual else None
layer_hidden_states = hidden_states_split if split_prompt else hidden_states
past_key_value = None if past_key_values is None else past_key_values[layer_idx]
layer_outputs = decoder_layer(
layer_hidden_states,
causal_mask,
Expand All @@ -1330,6 +1348,7 @@ def forward(
cache_idx,
num_virtual_tokens,
attn_batch_split,
decode_attn_batch_split,
layer_prev_layer_residual,
)
if use_prev_layer_residual:
Expand All @@ -1345,6 +1364,11 @@ def forward(

hidden_states = self.norm(hidden_states)

if lazy_mode and decode_attn_batch_split > 1 and torch.distributed.get_world_size() > 1 :
# In order to reduce NIC bombardment, put a barrier here so that all processes
# finish computation before moving to the next step.
# Recommended to use for llama 405B model during decoding with batch split
torch.distributed.barrier()
next_cache = next_decoder_cache if use_cache else None
if not use_new_cache and isinstance(next_cache, Cache):
next_cache = next_cache.to_legacy_cache()
Expand Down Expand Up @@ -1409,6 +1433,7 @@ def forward(
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
attn_batch_split: int = 1,
decode_attn_batch_split: int = 1,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
if self.generation_config.use_fused_rope is False:
Expand Down Expand Up @@ -1436,6 +1461,7 @@ def forward(
lazy_mode=lazy_mode,
num_virtual_tokens=num_virtual_tokens,
attn_batch_split=attn_batch_split,
decode_attn_batch_split=decode_attn_batch_split,
**kwargs,
)

Expand Down Expand Up @@ -1563,6 +1589,7 @@ def prepare_inputs_for_generation(
"lazy_mode": kwargs.get("lazy_mode"),
"num_virtual_tokens": kwargs.get("num_virtual_tokens"),
"attn_batch_split": kwargs.get("attn_batch_split"),
"decode_attn_batch_split": kwargs.get("decode_attn_batch_split"),
}
)
return model_inputs
Expand Down