Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/image-to-text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,45 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
--use_flash_attention \
--flash_attention_recompute
```

## Multi-HPU inference

To enable multi-card inference, you must set the environment variable `PT_HPU_ENABLE_LAZY_COLLECTIVES=true`,

### BF16 Inference with FusedSDPA on 8 HPUs

Use the following commands to run Llava-v1.6-mistral-7b BF16 inference with FusedSDPA on 8 HPUs:
```bash
PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```

### FP8 Inference with FusedSDPA on 8 HPUs

Use the following commands to run Llava-v1.6-mistral-7b FP8 inference with FusedSDPA on 8 HPUs.
Here is an example of measuring the tensor quantization statistics on Llava-v1.6-mistral-7b on 8 HPUs:
```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```

Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b on 8 HPUs:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```
67 changes: 62 additions & 5 deletions examples/image-to-text/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,53 @@
logger = logging.getLogger(__name__)


def override_print(enable):
import builtins as __builtin__

builtin_print = __builtin__.print

def print(*args, **kwargs):
force = kwargs.pop("force", False)
if force or enable:
builtin_print(*args, **kwargs)

__builtin__.print = print


def override_logger(logger, enable):
logger_info = logger.info

def info(*args, **kwargs):
force = kwargs.pop("force", False)
if force or enable:
logger_info(*args, **kwargs)

logger.info = info


def initialize_distributed_model(args, model, logger, model_dtype):
override_print(args.global_rank == 0)
override_logger(logger, args.global_rank == 0)

import deepspeed

logger.info(f"Initializing DeepSpeed with world size: {args.world_size}")
deepspeed.init_distributed(
dist_backend="hccl",
verbose=args.global_rank == 0,
)
model.eval()

ds_inference_kwargs = {"dtype": model_dtype}
ds_inference_kwargs["tensor_parallel"] = {"tp_size": args.world_size}
ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs
ds_inference_kwargs["injection_policy"] = {}

model = deepspeed.init_inference(model, **ds_inference_kwargs).module

return model


def setup_quantization(model, args):
from neural_compressor.torch.quantization import FP8Config, convert, prepare

Expand Down Expand Up @@ -129,6 +176,11 @@ def main():

# set args.quant_config with env variable if it is set
args.quant_config = os.getenv("QUANT_CONFIG", "")

args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
args.world_size = int(os.getenv("WORLD_SIZE", "0"))
args.global_rank = int(os.getenv("RANK", "0"))

os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
adapt_transformers_to_gaudi()

Expand Down Expand Up @@ -187,6 +239,16 @@ def main():
torch_dtype=model_dtype,
device="hpu",
)

if args.world_size > 1:
generator.model = initialize_distributed_model(args, generator.model, logger, model_dtype)

else:
if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

generator.model = wrap_in_hpu_graph(generator.model)

generate_kwargs = {
"lazy_mode": True,
"hpu_graphs": args.use_hpu_graphs,
Expand All @@ -198,11 +260,6 @@ def main():
if args.use_kv_cache:
generate_kwargs["use_cache"] = args.use_kv_cache

if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

generator.model = wrap_in_hpu_graph(generator.model)

if args.quant_config:
generator.model = setup_quantization(generator.model, args)
htcore.hpu_initialize(generator.model)
Expand Down
4 changes: 2 additions & 2 deletions optimum/habana/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def forward(
- add new args use_flash_attention to enable FusedSDPA
- add new args flash_attention_recompute
"""
bsz, tgt_len, embed_dim = hidden_states.size()
bsz, tgt_len, _ = hidden_states.size()
attn_weights_reshaped = None
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
Expand Down Expand Up @@ -156,7 +156,7 @@ def forward(

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = attn_output.reshape(bsz, tgt_len, -1)

attn_output = self.out_proj(attn_output)

Expand Down