Skip to content

Commit 82c9264

Browse files
dsoceksplotnikv
andcommitted
Add fused RoPE with selection
Signed-off-by: Daniel Socek <[email protected]> Co-authored-by: Sergey Plotnikov <[email protected]>
1 parent 2e896a2 commit 82c9264

1 file changed

Lines changed: 52 additions & 3 deletions

File tree

optimum/habana/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16+
import os
1617
import time
1718
from dataclasses import dataclass
1819
from typing import Any, Callable, Dict, List, Optional, Union
@@ -26,6 +27,7 @@
2627
from diffusers.models.transformers import FluxTransformer2DModel
2728
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps
2829
from diffusers.utils import BaseOutput, replace_example_docstring
30+
from habana_frameworks.torch.hpex.kernels import apply_rotary_pos_emb as FusedRoPE
2931
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
3032

3133
from optimum.utils import logging
@@ -76,14 +78,55 @@ class GaudiFluxPipelineOutput(BaseOutput):
7678
"""
7779

7880

79-
def apply_rope(xq, xk, freqs_cis):
81+
def apply_orig_rope(xq, xk, freqs_cis):
8082
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
8183
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
8284
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
8385
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
8486
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
8587

8688

89+
def apply_fused_rope(xq, xk, freqs_cis):
90+
cos = freqs_cis[..., 0, 0].type_as(xq)
91+
sin = freqs_cis[..., 1, 0].type_as(xq)
92+
cos_ = torch.cat((cos, cos), dim=-1)
93+
sin_ = torch.cat((sin, sin), dim=-1)
94+
95+
xq0 = xq[..., 0::2]
96+
xq1 = xq[..., 1::2]
97+
xq_ = torch.cat((xq0, xq1), dim=-1)
98+
99+
xk0 = xk[..., 0::2]
100+
xk1 = xk[..., 1::2]
101+
xk_ = torch.cat((xk0, xk1), dim=-1)
102+
103+
xq_out_ = FusedRoPE(xq_, cos_, sin_)
104+
xk_out_ = FusedRoPE(xk_, cos_, sin_)
105+
106+
sh = xq_out_.shape
107+
xq_out = xq_out_.view(*sh[:-1], 2, sh[-1] // 2)
108+
dims = list(range(xq_out.ndimension()))
109+
dims[-1], dims[-2] = dims[-2], dims[-1]
110+
xq_out = xq_out.permute(*dims).contiguous().view(sh)
111+
112+
sh = xk_out_.shape
113+
xk_out = xk_out_.view(*sh[:-1], 2, sh[-1] // 2)
114+
dims = list(range(xk_out.ndimension()))
115+
dims[-1], dims[-2] = dims[-2], dims[-1]
116+
xk_out = xk_out.permute(*dims).contiguous().view(sh)
117+
118+
return xq_out, xk_out
119+
120+
121+
def apply_rope(xq, xk, freqs_cis):
122+
rope_opt = os.getenv("GAUDI_FLUX_FUSED_ROPE")
123+
124+
if rope_opt in ["0", "False"]:
125+
return apply_orig_rope(xq, xk, freqs_cis)
126+
else:
127+
return apply_fused_rope(xq, xk, freqs_cis)
128+
129+
87130
class GaudiFluxSingleAttnProcessor2_0:
88131
r"""
89132
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -512,17 +555,23 @@ def __call__(
512555
import habana_frameworks.torch as ht
513556
import habana_frameworks.torch.core as htcore
514557

515-
quant_mode = kwargs["quant_mode"]
558+
quant_mode = kwargs.get("quant_mode", None)
516559

517560
if quant_mode == "quantize-mixed":
518561
import copy
519562

520563
transformer_bf16 = copy.deepcopy(self.transformer).to(self._execution_device)
521564

522-
if quant_mode == "measure" or quant_mode.startswith("quantize"):
565+
if quant_mode in ("measure", "quantize", "quantize-mixed"):
523566
import os
524567

525568
quant_config_path = os.getenv("QUANT_CONFIG")
569+
if not quant_config_path:
570+
raise ImportError(
571+
"Error: QUANT_CONFIG path is not defined. Please define path to quantization configuration JSON file."
572+
)
573+
elif not os.path.isfile(quant_config_path):
574+
raise ImportError(f"Error: QUANT_CONFIG path '{quant_config_path}' is not valid")
526575

527576
htcore.hpu_set_env()
528577

0 commit comments

Comments
 (0)