Skip to content

Commit 7d4385b

Browse files
astachowiczhabanaugolowic
authored andcommitted
Add option to use bf16 in PT sdp (#5) (huggingface#1514)
Co-authored-by: Urszula Golowicz <[email protected]>
1 parent 267ace3 commit 7d4385b

4 files changed

Lines changed: 22 additions & 0 deletions

File tree

examples/stable-diffusion/text_to_image_generation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ def main():
228228
),
229229
)
230230
parser.add_argument("--bf16", action="store_true", help="Whether to perform generation in bf16 precision.")
231+
parser.add_argument(
232+
"--sdp_on_bf16", action="store_true", help="Allow pyTorch to use reduced precision in the SDPA math backend"
233+
)
231234
parser.add_argument(
232235
"--ldm3d", action="store_true", help="Use LDM3D to generate an image and a depth map from a given text prompt."
233236
)
@@ -344,6 +347,7 @@ def main():
344347
"use_habana": args.use_habana,
345348
"use_hpu_graphs": args.use_hpu_graphs,
346349
"gaudi_config": args.gaudi_config_name,
350+
"sdp_on_bf16": args.sdp_on_bf16,
347351
}
348352

349353
if scheduler is not None:

optimum/habana/diffusers/pipelines/pipeline_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class GaudiDiffusionPipeline(DiffusionPipeline):
113113
bf16_full_eval (bool, defaults to `False`):
114114
Whether to use full bfloat16 evaluation instead of 32-bit.
115115
This will be faster and save memory compared to fp32/mixed precision but can harm generated images.
116+
sdp_on_bf16 (bool, defaults to `False`):
117+
Whether to allow PyTorch to use reduced precision in the SDPA math backend.
116118
"""
117119

118120
def __init__(
@@ -121,9 +123,13 @@ def __init__(
121123
use_hpu_graphs: bool = False,
122124
gaudi_config: Union[str, GaudiConfig] = None,
123125
bf16_full_eval: bool = False,
126+
sdp_on_bf16: bool = False,
124127
):
125128
DiffusionPipeline.__init__(self)
126129

130+
if sdp_on_bf16:
131+
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
132+
127133
self.use_habana = use_habana
128134
if self.use_habana:
129135
self.use_hpu_graphs = use_hpu_graphs

optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ class GaudiStableDiffusionPipeline(GaudiDiffusionPipeline, StableDiffusionPipeli
131131
bf16_full_eval (bool, defaults to `False`):
132132
Whether to use full bfloat16 evaluation instead of 32-bit.
133133
This will be faster and save memory compared to fp32/mixed precision but can harm generated images.
134+
sdp_on_bf16 (bool, defaults to `False`):
135+
Whether to allow PyTorch to use reduced precision in the SDPA math backend.
134136
"""
135137

136138
def __init__(
@@ -148,13 +150,15 @@ def __init__(
148150
use_hpu_graphs: bool = False,
149151
gaudi_config: Union[str, GaudiConfig] = None,
150152
bf16_full_eval: bool = False,
153+
sdp_on_bf16: bool = False,
151154
):
152155
GaudiDiffusionPipeline.__init__(
153156
self,
154157
use_habana,
155158
use_hpu_graphs,
156159
gaudi_config,
157160
bf16_full_eval,
161+
sdp_on_bf16,
158162
)
159163

160164
# Workaround for Synapse 1.11 for full bf16

optimum/habana/transformers/training_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,11 @@ class GaudiTrainingArguments(TrainingArguments):
305305
},
306306
)
307307

308+
sdp_on_bf16: bool = field(
309+
default=False,
310+
metadata={"help": "Allow pyTorch to use reduced precision in the SDPA math backend"},
311+
)
312+
308313
fp8: Optional[bool] = field(
309314
default=False,
310315
metadata={"help": "Whether to use fp8 for training."},
@@ -847,6 +852,9 @@ def _setup_devices(self) -> "torch.device":
847852
):
848853
gaudi_config.declare_autocast_bf16_fp32_ops()
849854

855+
if self.sdp_on_bf16:
856+
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
857+
850858
logger.info("PyTorch: setting up devices")
851859
if not is_accelerate_available():
852860
raise ImportError(

0 commit comments

Comments
 (0)