|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import math |
| 16 | +import os |
16 | 17 | import time |
17 | 18 | from dataclasses import dataclass |
18 | 19 | from typing import Any, Callable, Dict, List, Optional, Union |
|
26 | 27 | from diffusers.models.transformers import FluxTransformer2DModel |
27 | 28 | from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, calculate_shift, retrieve_timesteps |
28 | 29 | from diffusers.utils import BaseOutput, replace_example_docstring |
| 30 | +from habana_frameworks.torch.hpex.kernels import apply_rotary_pos_emb as FusedRoPE |
29 | 31 | from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast |
30 | 32 |
|
31 | 33 | from optimum.utils import logging |
@@ -76,14 +78,55 @@ class GaudiFluxPipelineOutput(BaseOutput): |
76 | 78 | """ |
77 | 79 |
|
78 | 80 |
|
79 | | -def apply_rope(xq, xk, freqs_cis): |
| 81 | +def apply_orig_rope(xq, xk, freqs_cis): |
80 | 82 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) |
81 | 83 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) |
82 | 84 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] |
83 | 85 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] |
84 | 86 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) |
85 | 87 |
|
86 | 88 |
|
| 89 | +def apply_fused_rope(xq, xk, freqs_cis): |
| 90 | + cos = freqs_cis[..., 0, 0].type_as(xq) |
| 91 | + sin = freqs_cis[..., 1, 0].type_as(xq) |
| 92 | + cos_ = torch.cat((cos, cos), dim=-1) |
| 93 | + sin_ = torch.cat((sin, sin), dim=-1) |
| 94 | + |
| 95 | + xq0 = xq[..., 0::2] |
| 96 | + xq1 = xq[..., 1::2] |
| 97 | + xq_ = torch.cat((xq0, xq1), dim=-1) |
| 98 | + |
| 99 | + xk0 = xk[..., 0::2] |
| 100 | + xk1 = xk[..., 1::2] |
| 101 | + xk_ = torch.cat((xk0, xk1), dim=-1) |
| 102 | + |
| 103 | + xq_out_ = FusedRoPE(xq_, cos_, sin_) |
| 104 | + xk_out_ = FusedRoPE(xk_, cos_, sin_) |
| 105 | + |
| 106 | + sh = xq_out_.shape |
| 107 | + xq_out = xq_out_.view(*sh[:-1], 2, sh[-1] // 2) |
| 108 | + dims = list(range(xq_out.ndimension())) |
| 109 | + dims[-1], dims[-2] = dims[-2], dims[-1] |
| 110 | + xq_out = xq_out.permute(*dims).contiguous().view(sh) |
| 111 | + |
| 112 | + sh = xk_out_.shape |
| 113 | + xk_out = xk_out_.view(*sh[:-1], 2, sh[-1] // 2) |
| 114 | + dims = list(range(xk_out.ndimension())) |
| 115 | + dims[-1], dims[-2] = dims[-2], dims[-1] |
| 116 | + xk_out = xk_out.permute(*dims).contiguous().view(sh) |
| 117 | + |
| 118 | + return xq_out, xk_out |
| 119 | + |
| 120 | + |
| 121 | +def apply_rope(xq, xk, freqs_cis): |
| 122 | + rope_opt = os.getenv("GAUDI_FLUX_FUSED_ROPE") |
| 123 | + |
| 124 | + if rope_opt in ["0", "False"]: |
| 125 | + return apply_orig_rope(xq, xk, freqs_cis) |
| 126 | + else: |
| 127 | + return apply_fused_rope(xq, xk, freqs_cis) |
| 128 | + |
| 129 | + |
87 | 130 | class GaudiFluxSingleAttnProcessor2_0: |
88 | 131 | r""" |
89 | 132 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
@@ -512,17 +555,23 @@ def __call__( |
512 | 555 | import habana_frameworks.torch as ht |
513 | 556 | import habana_frameworks.torch.core as htcore |
514 | 557 |
|
515 | | - quant_mode = kwargs["quant_mode"] |
| 558 | + quant_mode = kwargs.get("quant_mode", None) |
516 | 559 |
|
517 | 560 | if quant_mode == "quantize-mixed": |
518 | 561 | import copy |
519 | 562 |
|
520 | 563 | transformer_bf16 = copy.deepcopy(self.transformer).to(self._execution_device) |
521 | 564 |
|
522 | | - if quant_mode == "measure" or quant_mode.startswith("quantize"): |
| 565 | + if quant_mode in ("measure", "quantize", "quantize-mixed"): |
523 | 566 | import os |
524 | 567 |
|
525 | 568 | quant_config_path = os.getenv("QUANT_CONFIG") |
| 569 | + if not quant_config_path: |
| 570 | + raise ImportError( |
| 571 | + "Error: QUANT_CONFIG path is not defined. Please define path to quantization configuration JSON file." |
| 572 | + ) |
| 573 | + elif not os.path.isfile(quant_config_path): |
| 574 | + raise ImportError(f"Error: QUANT_CONFIG path '{quant_config_path}' is not valid") |
526 | 575 |
|
527 | 576 | htcore.hpu_set_env() |
528 | 577 |
|
|
0 commit comments