Skip to content

Commit dd63939

Browse files
authored
Enable DeepSpeed for image-to-text example (#1455)
1 parent 03fa6dd commit dd63939

3 files changed

Lines changed: 106 additions & 7 deletions

File tree

examples/image-to-text/README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,45 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
204204
--use_flash_attention \
205205
--flash_attention_recompute
206206
```
207+
208+
## Multi-HPU inference
209+
210+
To enable multi-card inference, you must set the environment variable `PT_HPU_ENABLE_LAZY_COLLECTIVES=true`,
211+
212+
### BF16 Inference with FusedSDPA on 8 HPUs
213+
214+
Use the following commands to run Llava-v1.6-mistral-7b BF16 inference with FusedSDPA on 8 HPUs:
215+
```bash
216+
PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
217+
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
218+
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
219+
--use_hpu_graphs \
220+
--bf16 \
221+
--use_flash_attention \
222+
--flash_attention_recompute
223+
```
224+
225+
### FP8 Inference with FusedSDPA on 8 HPUs
226+
227+
Use the following commands to run Llava-v1.6-mistral-7b FP8 inference with FusedSDPA on 8 HPUs.
228+
Here is an example of measuring the tensor quantization statistics on Llava-v1.6-mistral-7b on 8 HPUs:
229+
```bash
230+
QUANT_CONFIG=./quantization_config/maxabs_measure.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
231+
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
232+
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
233+
--use_hpu_graphs \
234+
--bf16 \
235+
--use_flash_attention \
236+
--flash_attention_recompute
237+
```
238+
239+
Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b on 8 HPUs:
240+
```bash
241+
QUANT_CONFIG=./quantization_config/maxabs_quant.json PT_HPU_ENABLE_LAZY_COLLECTIVES=true python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \
242+
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
243+
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
244+
--use_hpu_graphs \
245+
--bf16 \
246+
--use_flash_attention \
247+
--flash_attention_recompute
248+
```

examples/image-to-text/run_pipeline.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,53 @@
3636
logger = logging.getLogger(__name__)
3737

3838

39+
def override_print(enable):
40+
import builtins as __builtin__
41+
42+
builtin_print = __builtin__.print
43+
44+
def print(*args, **kwargs):
45+
force = kwargs.pop("force", False)
46+
if force or enable:
47+
builtin_print(*args, **kwargs)
48+
49+
__builtin__.print = print
50+
51+
52+
def override_logger(logger, enable):
53+
logger_info = logger.info
54+
55+
def info(*args, **kwargs):
56+
force = kwargs.pop("force", False)
57+
if force or enable:
58+
logger_info(*args, **kwargs)
59+
60+
logger.info = info
61+
62+
63+
def initialize_distributed_model(args, model, logger, model_dtype):
64+
override_print(args.global_rank == 0)
65+
override_logger(logger, args.global_rank == 0)
66+
67+
import deepspeed
68+
69+
logger.info(f"Initializing DeepSpeed with world size: {args.world_size}")
70+
deepspeed.init_distributed(
71+
dist_backend="hccl",
72+
verbose=args.global_rank == 0,
73+
)
74+
model.eval()
75+
76+
ds_inference_kwargs = {"dtype": model_dtype}
77+
ds_inference_kwargs["tensor_parallel"] = {"tp_size": args.world_size}
78+
ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs
79+
ds_inference_kwargs["injection_policy"] = {}
80+
81+
model = deepspeed.init_inference(model, **ds_inference_kwargs).module
82+
83+
return model
84+
85+
3986
def setup_quantization(model, args):
4087
from neural_compressor.torch.quantization import FP8Config, convert, prepare
4188

@@ -129,6 +176,11 @@ def main():
129176

130177
# set args.quant_config with env variable if it is set
131178
args.quant_config = os.getenv("QUANT_CONFIG", "")
179+
180+
args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
181+
args.world_size = int(os.getenv("WORLD_SIZE", "0"))
182+
args.global_rank = int(os.getenv("RANK", "0"))
183+
132184
os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
133185
adapt_transformers_to_gaudi()
134186

@@ -187,6 +239,16 @@ def main():
187239
torch_dtype=model_dtype,
188240
device="hpu",
189241
)
242+
243+
if args.world_size > 1:
244+
generator.model = initialize_distributed_model(args, generator.model, logger, model_dtype)
245+
246+
else:
247+
if args.use_hpu_graphs:
248+
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
249+
250+
generator.model = wrap_in_hpu_graph(generator.model)
251+
190252
generate_kwargs = {
191253
"lazy_mode": True,
192254
"hpu_graphs": args.use_hpu_graphs,
@@ -198,11 +260,6 @@ def main():
198260
if args.use_kv_cache:
199261
generate_kwargs["use_cache"] = args.use_kv_cache
200262

201-
if args.use_hpu_graphs:
202-
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
203-
204-
generator.model = wrap_in_hpu_graph(generator.model)
205-
206263
if args.quant_config:
207264
generator.model = setup_quantization(generator.model, args)
208265
htcore.hpu_initialize(generator.model)

optimum/habana/transformers/models/clip/modeling_clip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def forward(
8686
- add new args use_flash_attention to enable FusedSDPA
8787
- add new args flash_attention_recompute
8888
"""
89-
bsz, tgt_len, embed_dim = hidden_states.size()
89+
bsz, tgt_len, _ = hidden_states.size()
9090
attn_weights_reshaped = None
9191
# get query proj
9292
query_states = self.q_proj(hidden_states) * self.scale
@@ -156,7 +156,7 @@ def forward(
156156

157157
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
158158
attn_output = attn_output.transpose(1, 2)
159-
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
159+
attn_output = attn_output.reshape(bsz, tgt_len, -1)
160160

161161
attn_output = self.out_proj(attn_output)
162162

0 commit comments

Comments
 (0)