diff --git a/README.md b/README.md index 3a17381..aa6b933 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,78 @@ # sd-webui-incantations -This extension implements multiple novel algorithms that enhance image quality, prompt following, and more. - -## COMPATIBILITY NOTICES: -#### Currently incompatible with stable-diffusion-webui-forge -Use this extension with Forge: https://github.com/pamparamm/sd-perturbed-attention +# Table of Contents +- [What is this?](#what-is-this) +- [Installation](#installation) +- [Compatibility Notice](#compatibility-notice) +- [News](#compatibility-notice) +- [Extension Features](#extension-features) + - [Semantic CFG](#semantic-cfg-s-cfg) + - [Perturbed Attention Guidance](#perturbed-attention-guidance) + - [CFG Scheduler](#cfg-interval--cfg-scheduler) + - [Multi-Concept T2I-Zero](#multi-concept-t2i-zero--attention-regulation) + - [Seek for Incantations](#seek-for-incantations) +- [Tutorial](#tutorial) +- [Other cool extensions](#also-check-out) +- [Credits](#credits) + +## What is this? +### This extension for [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) implements algorithms from state-of-the-art research to achieve **higher-quality** images with *more accurate* prompt adherence. + +All methods are **training-free** and rely only on modifying the text embeddings or attention maps. + + +## Installation +To install the `sd-webui-incantations` extension, follow these steps: + +0. **Ensure you have the latest Automatic1111 stable-diffusion-webui version ≥ 1.93 installed** + +1. **Open the "Extensions" tab and navigate to the "Install from URL" section**: + +2. **Paste the repository URL into the "URL for extension's git repository" field**: + ``` + https://github.com/v0xie/sd-webui-incantations.git + ``` + +3. **Press the Install button**: Wait a few seconds for the extension to finish installing. + +4. **Restart the Web UI**: + Completely restart your Stable Diffusion Web UI to load the new extension. + +## Compatibility Notice +* Incompatible with **stable-diffusion-webui-forge**: Use this extension with Forge: https://github.com/pamparamm/sd-perturbed-attention * Reported incompatible with Adetailer: https://github.com/v0xie/sd-webui-incantations/issues/21 +* Incompatible with some older webui versions: https://github.com/v0xie/sd-webui-incantations/issues/14 +* May conflict with other extensions which modify the CFGDenoiser + +## News +- 15-05-2024 🔥 - S-CFG, optimizations for PAG and T2I-Zero, and more! https://github.com/v0xie/sd-webui-incantations/pull/37 +- 29-04-2024 🔥 - The implementation of T2I-Zero is fixed and works much more stably now. + +# Extension Features + +--- +## Semantic CFG (S-CFG) +https://arxiv.org/abs/2404.05384 +Dynamically rescale CFG guidance per semantic region to a uniform level to improve image / text alignment. +**Very computationally expensive**: A batch size of 4 with 1024x1024 will max out a 24GB 4090. + +#### Controls +* **SCFG Scale**: Multiplies the correction by a constant factor. Default: 1.0. +* **SCFG R**: A hyperparameter controlling the factor of cross-attention map refinement. Higher values use more memory and computation time. Default: 4. +* **Rate Min**: The minimum rate that the CFG can be scaled by. Default: 0.8. +* **Rate Max**: The maximum rate that the CFG can be scaled by. Default: 3.0. +* **Clamp Rate**: Overrides Rate Max. Clamps the Max Rate to Clamp Rate / CFG. Default: 0.0. +* **Start Step**: Start S-CFG on this step. +* **End Step**: End S-CFG after this step. + +#### Results +Prompt: "A cute puppy on the moon", Min Rate: 0.5, Max Rate: 10.0 +- SD 1.5 +![image](./images/xyz_grid-0006-1-SCFG.jpg) -* May conflict with extensions that modify the CFGDenoiser +#### Also check out the paper authors' official project repository: +- https://github.com/SmilesDZgk/S-CFG +#### [Return to top](#sd-webui-incantations) --- ## Perturbed Attention Guidance @@ -30,7 +95,10 @@ Prompt: "a puppy and a kitten on the moon" #### Also check out the paper authors' official project page: - https://ku-cvlab.github.io/Perturbed-Attention-Guidance/ +#### [Return to top](#sd-webui-incantations) + --- + ## CFG Interval / CFG Scheduler https://arxiv.org/abs/2404.07724 and https://arxiv.org/abs/2404.13040 @@ -62,6 +130,8 @@ Prompt: "A pointillist painting of a raccoon looking at the sea." Prompt: "An epic lithograph of a handsome salaryman carefully pouring coffee from a cup into an overflowing carafe, 4K, directed by Wong Kar Wai" - SD XL ![image](./images/xyz_grid-3380-1-An%20epic%20lithograph%20of%20a%20handsome%20salaryman%20carefully%20pouring%20coffee%20from%20a%20cup%20into%20an%20overflowing%20carafe,%204K,%20directed%20by%20Wong.jpg) + +#### [Return to top](#sd-webui-incantations) --- ## Multi-Concept T2I-Zero / Attention Regulation @@ -98,6 +168,7 @@ SD XL - https://multi-concept-t2i-zero.github.io/ - https://github.com/YaNgZhAnG-V5/attention_regulation +#### [Return to top](#sd-webui-incantations) --- ### Seek for Incantations An incomplete implementation of a "prompt-upsampling" method from https://arxiv.org/abs/2401.06345 @@ -121,6 +192,7 @@ SD XL * Modified Prompt: cinematic 4K photo of a dog riding a bus and eating cake and wearing headphones BREAK - - - - - dog - - bus - - - - - - ![image](./images/xyz_grid-2652-1419902843-cinematic%204K%20photo%20of%20a%20dog%20riding%20a%20bus%20and%20eating%20cake%20and%20wearing%20headphones.png) +#### [Return to top](#sd-webui-incantations) --- ### Issues / Pull Requests are welcome! @@ -132,6 +204,8 @@ SD XL [![image](https://cdn-uploads.huggingface.co/production/uploads/6345bd89fe134dfd7a0dba40/TzuZWTiHAc3wTxh3PwGL5.png)](https://youtu.be/lMQ7DIPmrfI) +#### [Return to top](#sd-webui-incantations) + ## Also check out: * **Characteristic Guidance**: Awesome enhancements for sampling at high CFG levels [https://github.com/scraed/CharacteristicGuidanceWebUI](https://github.com/scraed/CharacteristicGuidanceWebUI) @@ -144,6 +218,7 @@ SD XL * **Agent Attention**: Faster image generation and improved image quality with Agent Attention [https://github.com/v0xie/sd-webui-agentattention](https://github.com/v0xie/sd-webui-agentattention) +#### [Return to top](#sd-webui-incantations) --- ### Credits @@ -203,9 +278,19 @@ SD XL primaryClass={cs.CV} } + @misc{shen2024rethinking, + title={Rethinking the Spatial Inconsistency in Classifier-Free Diffusion Guidance}, + author={Dazhong Shen and Guanglu Song and Zeyue Xue and Fu-Yun Wang and Yu Liu}, + year={2024}, + eprint={2404.05384}, + archivePrefix={arXiv}, + primaryClass={cs.CV} + } + -- Hard Prompts Made Easy (https://github.com/YuxinWenRick/hard-prompts-made-easy) +- [Hard Prompts Made Easy](https://github.com/YuxinWenRick/hard-prompts-made-easy) +- [@udon-universe's extension templates](https://github.com/udon-universe/stable-diffusion-webui-extension-templates) -- @udon-universe's extension templates (https://github.com/udon-universe/stable-diffusion-webui-extension-templates) +#### [Return to top](#sd-webui-incantations) --- diff --git a/images/xyz_grid-0006-1-SCFG.jpg b/images/xyz_grid-0006-1-SCFG.jpg new file mode 100644 index 0000000..c59b58a Binary files /dev/null and b/images/xyz_grid-0006-1-SCFG.jpg differ diff --git a/scripts/cfg_combiner.py b/scripts/cfg_combiner.py new file mode 100644 index 0000000..12dbf6a --- /dev/null +++ b/scripts/cfg_combiner.py @@ -0,0 +1,237 @@ +import gradio as gr +import logging +import torch +from modules import shared, scripts, devices, patches, script_callbacks +from modules.script_callbacks import CFGDenoiserParams +from modules.processing import StableDiffusionProcessing +from scripts.incantation_base import UIWrapper +from scripts.scfg import scfg_combine_denoised + +logger = logging.getLogger(__name__) + +class CFGCombinerScript(UIWrapper): + """ Some scripts modify the CFGs in ways that are not compatible with each other. + This script will patch the CFG denoiser function to apply CFG in an ordered way. + This script adds a dict named 'incant_cfg_params' to the processing object. + This dict contains the following: + 'denoiser': the denoiser object + 'pag_params': list of PAG parameters + 'scfg_params': the S-CFG parameters + ... + """ + def __init__(self): + pass + + # Extension title in menu UI + def title(self): + return "CFG Combiner" + + # Decide to show menu in txt2img or img2img + def show(self, is_img2img): + return scripts.AlwaysVisible + + # Setup menu ui detail + def setup_ui(self, is_img2img): + self.infotext_fields = [] + self.paste_field_names = [] + return [] + + def before_process(self, p: StableDiffusionProcessing, *args, **kwargs): + logger.debug("CFGCombinerScript before_process") + cfg_dict = { + "denoiser": None, + "pag_params": None, + "scfg_params": None + } + setattr(p, 'incant_cfg_params', cfg_dict) + + def process(self, p: StableDiffusionProcessing, *args, **kwargs): + pass + + def before_process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): + pass + + def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): + """ Process the batch and hook the CFG denoiser if PAG or S-CFG is active """ + logger.debug("CFGCombinerScript process_batch") + pag_active = p.extra_generation_params.get('PAG Active', False) + cfg_active = p.extra_generation_params.get('CFG Interval Enable', False) + scfg_active = p.extra_generation_params.get('SCFG Active', False) + + if not any([ + pag_active, + cfg_active, + scfg_active + ]): + return + + #logger.debug("CFGCombinerScript process_batch: pag_active or scfg_active") + + cfg_denoise_lambda = lambda params: self.on_cfg_denoiser_callback(params, p.incant_cfg_params) + unhook_lambda = lambda: self.unhook_callbacks() + + script_callbacks.on_cfg_denoiser(cfg_denoise_lambda) + script_callbacks.on_script_unloaded(unhook_lambda) + logger.debug('Hooked callbacks') + + def postprocess_batch(self, p: StableDiffusionProcessing, *args, **kwargs): + logger.debug("CFGCombinerScript postprocess_batch") + script_callbacks.remove_current_script_callbacks() + + def unhook_callbacks(self, cfg_dict = None): + if not cfg_dict: + return + self.unpatch_cfg_denoiser(cfg_dict) + + def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, cfg_dict: dict): + """ Callback for when the CFG denoiser is called + Patches the combine_denoised function with a custom one. + """ + if cfg_dict['denoiser'] is None: + cfg_dict['denoiser'] = params.denoiser + else: + self.unpatch_cfg_denoiser(cfg_dict) + self.patch_cfg_denoiser(params.denoiser, cfg_dict) + + def patch_cfg_denoiser(self, denoiser, cfg_dict: dict): + """ Patch the CFG Denoiser combine_denoised function """ + if not cfg_dict: + logger.error("Unable to patch CFG Denoiser, no dict passed as cfg_dict") + return + if not denoiser: + logger.error("Unable to patch CFG Denoiser, denoiser is None") + return + + if getattr(denoiser, 'combine_denoised_patched', False) is False: + try: + setattr(denoiser, 'combine_denoised_original', denoiser.combine_denoised) + # create patch that references the original function + pass_conds_func = lambda *args, **kwargs: combine_denoised_pass_conds_list( + *args, + **kwargs, + original_func = denoiser.combine_denoised_original, + pag_params = cfg_dict['pag_params'], + scfg_params = cfg_dict['scfg_params'] + ) + patched_combine_denoised = patches.patch(__name__, denoiser, "combine_denoised", pass_conds_func) + setattr(denoiser, 'combine_denoised_patched', True) + setattr(denoiser, 'combine_denoised_original', patches.original(__name__, denoiser, "combine_denoised")) + except KeyError: + logger.exception("KeyError patching combine_denoised") + pass + except RuntimeError: + logger.exception("RuntimeError patching combine_denoised") + pass + + def unpatch_cfg_denoiser(self, cfg_dict = None): + """ Unpatch the CFG Denoiser combine_denoised function """ + if cfg_dict is None: + return + denoiser = cfg_dict.get('denoiser', None) + if denoiser is None: + return + + setattr(denoiser, 'combine_denoised_patched', False) + try: + patches.undo(__name__, denoiser, "combine_denoised") + except KeyError: + logger.exception("KeyError unhooking combine_denoised") + pass + except RuntimeError: + logger.exception("RuntimeError unhooking combine_denoised") + pass + + cfg_dict['denoiser'] = None + + +def combine_denoised_pass_conds_list(*args, **kwargs): + """ Hijacked function for combine_denoised in CFGDenoiser + Currently relies on the original function not having any kwargs + If any of the params are not None, it will apply the corresponding guidance + The order of guidance is: + 1. CFG and S-CFG are combined multiplicatively + 2. PAG guidance is added to the result + 3. ... + ... + """ + original_func = kwargs.get('original_func', None) + pag_params = kwargs.get('pag_params', None) + scfg_params = kwargs.get('scfg_params', None) + + if pag_params is None and scfg_params is None: + logger.warning("No reason to hijack combine_denoised") + return original_func(*args) + + def new_combine_denoised(x_out, conds_list, uncond, cond_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + + ### Variables + # 0. Standard CFG Value + cfg_scale = cond_scale + + # 1. CFG Interval + # Overrides cfg_scale if pag_params is not None + if pag_params is not None: + if pag_params.cfg_interval_enable: + cfg_scale = pag_params.cfg_interval_scheduled_value + + # 2. PAG + pag_x_out = None + pag_scale = None + if pag_params is not None: + pag_active = pag_params.pag_active + pag_x_out = pag_params.pag_x_out + pag_scale = pag_params.pag_scale + + ### Combine Denoised + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + + model_delta = x_out[cond_index] - denoised_uncond[i] + + # S-CFG + rate = 1.0 + if scfg_params is not None: + rate = scfg_combine_denoised( + model_delta = model_delta, + cfg_scale = cfg_scale, + scfg_params = scfg_params, + ) + # If rate is not an int, convert to tensor + if rate is None: + logger.error("scfg_combine_denoised returned None, using default rate of 1.0") + rate = 1.0 + elif not isinstance(rate, int) and not isinstance(rate, float): + rate = rate.to(device=shared.device, dtype=model_delta.dtype) + else: + # rate is tensor, probably + pass + + # 1. Experimental formulation for S-CFG combined with CFG + denoised[i] += (model_delta) * rate * (weight * cfg_scale) + del rate + + # 2. PAG + # PAG is added like CFG + if pag_params is not None: + if not pag_active: + pass + # Not within step interval? + elif not pag_params.pag_start_step <= pag_params.step <= pag_params.pag_end_step: + pass + # Scale is zero? + elif pag_scale <= 0: + pass + # do pag + else: + try: + denoised[i] += (x_out[cond_index] - pag_x_out[i]) * (weight * pag_scale) + except Exception as e: + logger.exception("Exception in combine_denoised_pass_conds_list - %s", e) + + #torch.cuda.empty_cache() + devices.torch_gc() + + return denoised + return new_combine_denoised(*args) \ No newline at end of file diff --git a/scripts/incant_utils/module_hooks.py b/scripts/incant_utils/module_hooks.py new file mode 100644 index 0000000..a8506d8 --- /dev/null +++ b/scripts/incant_utils/module_hooks.py @@ -0,0 +1,162 @@ +from typing import Optional, Callable, Dict +from collections import OrderedDict +from warnings import warn +import logging +import torch + + +from modules import shared + + +logger = logging.getLogger(__name__) + + +def modules_add_field(modules, field, value=None): + """ Add a field to a module if it isn't already added. + Args: + modules (list): Module or list of modules to add the field to + field (str): Field name to add + value (any): Value to assign to the field + Returns: + None + + """ + if not isinstance(modules, list): + modules = [modules] + for module in modules: + if not hasattr(module, field): + setattr(module, field, value) + else: + logger.warning(f"Field {field} already exists in module {module}") + + +def modules_remove_field(modules, field): + """ Remove a field from a module if it exists. + Args: + modules (list): Module or list of modules to add the field to + field (str): Field name to add + value (any): Value to assign to the field + Returns: + None + + """ + if not isinstance(modules, list): + modules = [modules] + for module in modules: + if hasattr(module, field): + delattr(module, field) + else: + # logger.warning(f"Field {field} does not exist in module {module}") + pass + + +def get_modules(network_layer_name_filter: Optional[str] = None, module_name_filter: Optional[str] = None): + """ Get all modules from the shared.sd_model that match the filters provided. If no filters are provided, all modules are returned. + + Args: + network_layer_name_filter (Optional[str], optional): Filters the modules by network layer name. Defaults to None. Example: "attn1" will return all modules that have "attn1" in their network layer name. + module_name_filter (Optional[str], optional): Filters the modules by module class name. Defaults to None. Example: "CrossAttention" will return all modules that have "CrossAttention" in their class name. + + Returns: + list: List of modules that match the filters provided. + """ + try: + m = shared.sd_model + nlm = m.network_layer_mapping + sd_model_modules = nlm.values() + + # Apply filters if they are provided + if network_layer_name_filter is not None: + sd_model_modules = list(filter(lambda m: network_layer_name_filter in m.network_layer_name, sd_model_modules)) + if module_name_filter is not None: + sd_model_modules = list(filter(lambda m: module_name_filter in m.__class__.__name__, sd_model_modules)) + return sd_model_modules + except AttributeError: + logger.exception("AttributeError in get_modules", stack_info=True) + return [] + except Exception: + logger.exception("Exception in get_modules", stack_info=True) + return [] + + +# workaround for torch remove hooks issue +# thank you to @ProGamerGov for this https://github.com/pytorch/pytorch/issues/70455 +def remove_module_forward_hook( + module: torch.nn.Module, hook_fn_name: Optional[str] = None +) -> None: + """ + This function removes all forward hooks in the specified module, without requiring + any hook handles. This lets us clean up & remove any hooks that weren't property + deleted. + + Warning: Various PyTorch modules and systems make use of hooks, and thus extreme + caution should be exercised when removing all hooks. Users are recommended to give + their hook function a unique name that can be used to safely identify and remove + the target forward hooks. + + Args: + + module (nn.Module): The module instance to remove forward hooks from. + hook_fn_name (str, optional): Optionally only remove specific forward hooks + based on their function's __name__ attribute. + Default: None + """ + + if hook_fn_name is None: + warn("Removing all active hooks can break some PyTorch modules & systems.") + + def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: + if hasattr(module, "_forward_hooks"): + if m._forward_hooks != OrderedDict(): + if name is not None: + dict_items = list(m._forward_hooks.items()) + m._forward_hooks = OrderedDict( + [(i, fn) for i, fn in dict_items if fn.__name__ != name] + ) + else: + m._forward_hooks: Dict[int, Callable] = OrderedDict() + + def _remove_child_hooks( + target_module: torch.nn.Module, hook_name: Optional[str] = None + ) -> None: + for name, child in target_module._modules.items(): + if child is not None: + _remove_hooks(child, hook_name) + _remove_child_hooks(child, hook_name) + + # Remove hooks from target submodules + _remove_child_hooks(module, hook_fn_name) + + # Remove hooks from the target module + _remove_hooks(module, hook_fn_name) + + +def module_add_forward_hook(module, hook_fn, hook_type="forward", with_kwargs=False): + """ Adds a forward hook to a module. + + hook_fn should be a function that accepts the following arguments: + forward hook, no kwargs: hook(module, args, output) -> None or modified output + forward hook, with kwargs: hook(module, args, kwargs output) -> None or modified output + + Args: + module (torch.nn.Module): Module to hook + hook_fn (Callable): Function to call when the hook is triggered + hook_type (str, optional): Type of hook to create. Defaults to "forward". Can be "forward" or "pre_forward". + with_kwargs (bool, optional): Whether the hook function should accept keyword arguments. Defaults to False. + + Returns: + torch.utils.hooks.RemovableHandle: Handle for the hook + """ + if module is None: + raise ValueError("module must be provided") + if not callable(hook_fn): + raise ValueError("hook_fn must be a callable function") + + if hook_type == "forward": + handle = module.register_forward_hook(hook_fn, with_kwargs=with_kwargs) + elif hook_type == "pre_forward": + handle = module.register_forward_pre_hook(hook_fn, with_kwargs=with_kwargs) + else: + raise ValueError(f"Invalid hook type {hook_type}. Must be 'forward' or 'pre_forward'.") + + return handle \ No newline at end of file diff --git a/scripts/incant_utils/prompt_utils.py b/scripts/incant_utils/prompt_utils.py new file mode 100644 index 0000000..393deb7 --- /dev/null +++ b/scripts/incant_utils/prompt_utils.py @@ -0,0 +1,29 @@ +from functools import reduce +from modules import shared +from modules import extra_networks +from modules import prompt_parser +from modules import sd_hijack + +# taken from modules/ui.py +# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/ui.py +def get_token_count(text, steps, is_positive: bool = True): + """ Get token count and max length for a given prompt text. """ + try: + text, _ = extra_networks.parse_prompt(text) + + if is_positive: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + else: + prompt_flat_list = [text] + + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + token_count, max_length = max([sd_hijack.model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) + return token_count, max_length \ No newline at end of file diff --git a/scripts/incantation_base.py b/scripts/incantation_base.py index 62a0769..48c4fa1 100644 --- a/scripts/incantation_base.py +++ b/scripts/incantation_base.py @@ -10,7 +10,10 @@ from scripts.ui_wrapper import UIWrapper from scripts.incant import IncantExtensionScript from scripts.t2i_zero import T2I0ExtensionScript +from scripts.scfg import SCFGExtensionScript from scripts.pag import PAGExtensionScript +from scripts.save_attn_maps import SaveAttentionMapsScript +from scripts.cfg_combiner import CFGCombinerScript logger = logging.getLogger(__name__) logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO)) @@ -29,12 +32,25 @@ def __init__(self, module: UIWrapper, module_idx = 0, num_args = -1, arg_idx = - self.num_args: int = num_args # the length of arg list self.arg_idx: int = arg_idx # where the list of args starts +# main scripts submodules: list[SubmoduleInfo] = [ + SubmoduleInfo(module=SCFGExtensionScript()), SubmoduleInfo(module=PAGExtensionScript()), SubmoduleInfo(module=T2I0ExtensionScript()), SubmoduleInfo(module=IncantExtensionScript()), ] - +# debug scripts +if environ.get("INCANT_DEBUG", default=False) != False: + submodules.append(SubmoduleInfo(module=SaveAttentionMapsScript())) +else: + logger.info("Incantation: Debug scripts are disabled. Set INCANT_DEBUG environment variable to enable them.") +# run these after submodules +end_submodules: list[SubmoduleInfo] = [ + SubmoduleInfo(module=CFGCombinerScript()) +] +submodules = submodules + end_submodules + + class IncantBaseExtensionScript(scripts.Script): def __init__(self): pass @@ -115,7 +131,11 @@ def callback_before_ui(): try: for module_info in submodules: module = module_info.module - extra_axis_options = module.get_xyz_axis_options() + try: + extra_axis_options = module.get_xyz_axis_options() + except NotImplementedError: + logger.warning(f"Module {module.title()} does not implement get_xyz_axis_options") + extra_axis_options = {} make_axis_options(extra_axis_options) except: logger.exception("Incantation: Error while making axis options") diff --git a/scripts/pag.py b/scripts/pag.py index 1e96a26..20e067a 100644 --- a/scripts/pag.py +++ b/scripts/pag.py @@ -98,6 +98,7 @@ class PAGStateParams: def __init__(self): + self.pag_active: bool = False # PAG guidance scale self.pag_scale: int = -1 # PAG guidance scale self.pag_start_step: int = 0 self.pag_end_step: int = 150 @@ -105,9 +106,11 @@ def __init__(self): self.cfg_interval_schedule: str = 'Constant' self.cfg_interval_low: float = 0 self.cfg_interval_high: float = 50.0 + self.cfg_interval_scheduled_value: float = 7.0 self.step : int = 0 self.max_sampling_step : int = 1 self.guidance_scale: int = -1 # CFG + self.current_noise_level: float = 100.0 self.x_in = None self.text_cond = None self.image_cond = None @@ -147,17 +150,18 @@ def setup_ui(self, is_img2img) -> list: with gr.Row(): start_step = gr.Slider(value = 0, minimum = 0, maximum = 150, step = 1, label="Start Step", elem_id = 'pag_start_step', info="") end_step = gr.Slider(value = 150, minimum = 0, maximum = 150, step = 1, label="End Step", elem_id = 'pag_end_step', info="") + + with gr.Accordion('CFG Scheduler', open=False): + cfg_interval_enable = gr.Checkbox(value=False, default=False, label="Enable CFG Scheduler", elem_id='cfg_interval_enable', info="If enabled, applies CFG only within noise interval with the selected schedule type. PAG must be enabled (scale can be 0). SDXL recommend CFG=15; CFG interval (0.28, 5.42]") with gr.Row(): - cfg_interval_enable = gr.Checkbox(value=False, default=False, label="Enable CFG Scheduler", elem_id='cfg_interval_enable', info="If enabled, applies CFG only within noise interval with the selected schedule type. PAG must be enabled (scale can be 0). SDXL recommend CFG=15; CFG interval (0.28, 5.42]") cfg_schedule = gr.Dropdown( value='Constant', choices= SCHEDULES, label="CFG Schedule Type", elem_id='cfg_interval_schedule', ) - with gr.Row(): - cfg_interval_low = gr.Slider(value = 0, minimum = 0, maximum = 100, step = 0.01, label="CFG Noise Interval Low", elem_id = 'cfg_interval_low', info="") - cfg_interval_high = gr.Slider(value = 100, minimum = 0, maximum = 100, step = 0.01, label="CFG Noise Interval High", elem_id = 'cfg_interval_high', info="") + cfg_interval_low = gr.Slider(value = 0, minimum = 0, maximum = 100, step = 0.1, label="CFG Noise Interval Low", elem_id = 'cfg_interval_low', info="") + cfg_interval_high = gr.Slider(value = 100, minimum = 0, maximum = 100, step = 0.1, label="CFG Noise Interval High", elem_id = 'cfg_interval_high', info="") active.do_not_save_to_config = True pag_scale.do_not_save_to_config = True @@ -198,32 +202,43 @@ def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, sta self.remove_all_hooks() active = getattr(p, "pag_active", active) - if active is False: + cfg_interval_enable = getattr(p, "cfg_interval_enable", cfg_interval_enable) + if active is False and cfg_interval_enable is False: return pag_scale = getattr(p, "pag_scale", pag_scale) start_step = getattr(p, "pag_start_step", start_step) end_step = getattr(p, "pag_end_step", end_step) - cfg_interval_enable = getattr(p, "cfg_interval_enable", cfg_interval_enable) cfg_schedule = getattr(p, "cfg_interval_schedule", cfg_schedule) cfg_interval_low = getattr(p, "cfg_interval_low", cfg_interval_low) cfg_interval_high = getattr(p, "cfg_interval_high", cfg_interval_high) - p.extra_generation_params.update({ - "PAG Active": active, - "PAG Scale": pag_scale, - "PAG Start Step": start_step, - "PAG End Step": end_step, - "CFG Interval Enable": cfg_interval_enable, - "CFG Interval Schedule": cfg_schedule, - "CFG Interval Low": cfg_interval_low, - "CFG Interval High": cfg_interval_high - }) + if active: + p.extra_generation_params.update({ + "PAG Active": active, + "PAG Scale": pag_scale, + "PAG Start Step": start_step, + "PAG End Step": end_step, + }) + if cfg_interval_enable: + p.extra_generation_params.update({ + "CFG Interval Enable": cfg_interval_enable, + "CFG Interval Schedule": cfg_schedule, + "CFG Interval Low": cfg_interval_low, + "CFG Interval High": cfg_interval_high + }) self.create_hook(p, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high) def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, *args, **kwargs): # Create a list of parameters for each concept pag_params = PAGStateParams() + + # Add to p's incant_cfg_params + if not hasattr(p, 'incant_cfg_params'): + logger.error("No incant_cfg_params found in p") + p.incant_cfg_params['pag_params'] = pag_params + + pag_params.pag_active = active pag_params.pag_scale = pag_scale pag_params.pag_start_step = start_step pag_params.pag_end_step = end_step @@ -233,6 +248,7 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste pag_params.guidance_scale = p.cfg_scale pag_params.batch_size = p.batch_size pag_params.denoiser = None + pag_params.cfg_interval_scheduled_value = p.cfg_scale if pag_params.cfg_interval_enable: # Refer to 3.1 Practice in the paper @@ -256,7 +272,8 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste #after_cfg_lambda = lambda x: self.cfg_after_cfg_callback(x, params) unhook_lambda = lambda _: self.unhook_callbacks(pag_params) - self.ready_hijack_forward(pag_params.crossattn_modules, pag_scale) + if pag_params.pag_active: + self.ready_hijack_forward(pag_params.crossattn_modules, pag_scale) logger.debug('Hooked callbacks') script_callbacks.on_cfg_denoiser(cfg_denoise_lambda) @@ -264,6 +281,8 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste #script_callbacks.on_cfg_after_cfg(after_cfg_lambda) script_callbacks.on_script_unloaded(unhook_lambda) + + def postprocess_batch(self, p, *args, **kwargs): self.pag_postprocess_batch(p, *args, **kwargs) @@ -287,6 +306,7 @@ def remove_all_hooks(self): def unhook_callbacks(self, pag_params: PAGStateParams): global handles + return if pag_params is None: logger.error("PAG params is None") @@ -388,29 +408,28 @@ def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, pag_params: PAGSta pag_params.step = params.sampling_step - # patch combine_denoised - if pag_params.denoiser is None: - pag_params.denoiser = params.denoiser - if getattr(params.denoiser, 'combine_denoised_patched', False) is False: - try: - setattr(params.denoiser, 'combine_denoised_original', params.denoiser.combine_denoised) - # create patch that references the original function - pass_conds_func = lambda *args, **kwargs: combine_denoised_pass_conds_list( - *args, - **kwargs, - original_func = params.denoiser.combine_denoised_original, - pag_params = pag_params) - pag_params.patched_combine_denoised = patches.patch(__name__, params.denoiser, "combine_denoised", pass_conds_func) - setattr(params.denoiser, 'combine_denoised_patched', True) - setattr(params.denoiser, 'combine_denoised_original', patches.original(__name__, params.denoiser, "combine_denoised")) - except KeyError: - logger.exception("KeyError patching combine_denoised") - pass - except RuntimeError: - logger.exception("RuntimeError patching combine_denoised") - pass + # CFG Interval + # TODO: set rho based on sdxl or sd1.5 + pag_params.current_noise_level = calculate_noise_level( + i = pag_params.step, + N = pag_params.max_sampling_step, + ) - # Run only within interval + if pag_params.cfg_interval_enable: + if pag_params.cfg_interval_schedule != 'Constant': + # Calculate noise interval + start = pag_params.cfg_interval_low + end = pag_params.cfg_interval_high + begin_range = start if start <= end else end + end_range = end if start <= end else start + # Scheduled CFG Value + scheduled_cfg_scale = cfg_scheduler(pag_params.cfg_interval_schedule, pag_params.step, pag_params.max_sampling_step, pag_params.guidance_scale) + + pag_params.cfg_interval_scheduled_value = scheduled_cfg_scale if begin_range <= pag_params.current_noise_level <= end_range else 1.0 + + # Run PAG only if active and within interval + if not pag_params.pag_active or pag_params.pag_scale <= 0: + return if not pag_params.pag_start_step <= params.sampling_step <= pag_params.pag_end_step or pag_params.pag_scale <= 0: return @@ -439,6 +458,9 @@ def on_cfg_denoised_callback(self, params: CFGDenoisedParams, pag_params: PAGSta """ # Run only within interval + # Run PAG only if active and within interval + if not pag_params.pag_active or pag_params.pag_scale <= 0: + return if not pag_params.pag_start_step <= params.sampling_step <= pag_params.pag_end_step or pag_params.pag_scale <= 0: return @@ -506,6 +528,8 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale): # Calculate CFG Scale cfg_scale = cond_scale + new_params.cfg_interval_scheduled_value = cfg_scale + if new_params.cfg_interval_enable: if new_params.cfg_interval_schedule != 'Constant': # Calculate noise interval @@ -517,6 +541,15 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale): scheduled_cfg_scale = cfg_scheduler(new_params.cfg_interval_schedule, new_params.step, new_params.max_sampling_step, cond_scale) # Only apply CFG in the interval cfg_scale = scheduled_cfg_scale if begin_range <= noise_level <= end_range else 1.0 + new_params.cfg_interval_scheduled_value = scheduled_cfg_scale + + # This may be temporarily necessary for compatibility with scfg + # if not new_params.pag_start_step <= new_params.step <= new_params.pag_end_step: + # return original_func(*args) + + # This may be temporarily necessary for compatibility with scfg + # if not new_params.pag_start_step <= new_params.step <= new_params.pag_end_step: + # return original_func(*args) if incantations_debug: logger.debug(f"Schedule: {new_params.cfg_interval_schedule}, CFG Scale: {cfg_scale}, Noise_level: {round(noise_level,3)}") diff --git a/scripts/save_attn_maps.py b/scripts/save_attn_maps.py new file mode 100644 index 0000000..7f12c96 --- /dev/null +++ b/scripts/save_attn_maps.py @@ -0,0 +1,236 @@ +import os +import logging +import copy +import gradio as gr +import torch +from torchvision.transforms import GaussianBlur + + +from einops import rearrange +from modules import shared +from modules.processing import StableDiffusionProcessing +from scripts.ui_wrapper import UIWrapper, arg +from scripts.incant_utils import module_hooks, plot_tools, prompt_utils + +logger = logging.getLogger(__name__) + + +module_field_map = { + 'savemaps': True, + 'savemaps_batch': None, + 'savemaps_step': None, + 'savemaps_save_steps': None, +} + + +class SaveAttentionMapsScript(UIWrapper): + def __init__(self): + self.infotext_fields: list = [] + self.paste_field_names: list = [] + + def title(self) -> str: + return "Save Attention Maps" + + def setup_ui(self, is_img2img) -> list: + with gr.Accordion('Save Attention Maps', open = False): + active = gr.Checkbox(label = 'Active', default = False) + export_folder = gr.Textbox(label = 'Export Folder', value = 'attention_maps', info = 'Folder to save attention maps to as a subdirectory of the outputs.') + module_name_filter = gr.Textbox(label = 'Module Names', value = 'input_blocks_5_1_transformer_blocks_0_attn2', info = 'Module name to save attention maps for. If the substring is found in the module name, the attention maps will be saved for that module.') + class_name_filter = gr.Textbox(label = 'Class Name Filter', value = 'CrossAttention', info = 'Filters eligible modules by the class name.') + save_every_n_step = gr.Slider(label = 'Save Every N Step', value = 0, min = 0, max = 100, step = 1, info = 'Save attention maps every N steps. 0 to save last step.') + print_modules = gr.Button(value = 'Print Modules To Console') + print_modules.click(self.print_modules, inputs=[module_name_filter, class_name_filter]) + active.do_not_save_to_config = True + export_folder.do_not_save_to_config = True + module_name_filter.do_not_save_to_config = True + class_name_filter.do_not_save_to_config = True + save_every_n_step.do_not_save_to_config = True + self.infotext_fields = [] + self.paste_field_names = [] + return [active, module_name_filter, class_name_filter, save_every_n_step] + + def before_process_batch(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, *args, **kwargs): + # Always unhook the modules first + module_list = self.get_modules_by_filter(module_name_filter, class_name_filter) + self.unhook_modules(module_list, copy.deepcopy(module_field_map)) + + if not active: + return + + token_count, _ = prompt_utils.get_token_count(p.prompt, p.steps, True) + setattr(p, 'savemaps_token_count', token_count) + + # Make sure the output folder exists + outpath_samples = p.outpath_samples + # Move this to plot tools? + if not outpath_samples: + logger.warning("No output path found. Skipping saving attention maps.") + return + output_folder_path = os.path.join(outpath_samples, 'attention_maps') + if not os.path.exists(output_folder_path): + logger.info(f"Creating directory: {output_folder_path}") + os.makedirs(output_folder_path) + + latent_shape = [p.height // p.rng.shape[1], p.width // p.rng.shape[2]] # (height, width) + + save_steps = [] + min_step = max(save_every_n_step, 0) + max_step = max(p.steps+1, 0) + if save_every_n_step > 0: + save_steps = list(range(min_step, max_step, save_every_n_step)) + else: + save_steps = [p.steps] + + # Create fields in module + value_map = copy.deepcopy(module_field_map) + value_map['savemaps_save_steps'] = torch.tensor(save_steps).to(device=shared.device, dtype=torch.int32) + value_map['savemaps_step'] = torch.tensor([0]).to(device=shared.device, dtype=torch.int32) + #value_map['savemaps_shape'] = torch.tensor(latent_shape).to(device=shared.device, dtype=torch.int32) + self.hook_modules(module_list, value_map) + self.create_save_hook(module_list) + + def process(self, p, *args, **kwargs): + pass + + def before_process(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, *args, **kwargs): + module_list = self.get_modules_by_filter(module_name_filter, class_name_filter) + self.unhook_modules(module_list, copy.deepcopy(module_field_map)) + + def process_batch(self, p, *args, **kwargs): + pass + + def postprocess_batch(self, p: StableDiffusionProcessing, active, module_name_filter, class_name_filter, save_every_n_step, *args, **kwargs): + module_list = self.get_modules_by_filter(module_name_filter, class_name_filter) + + if getattr(p, 'savemaps_token_count', None) is None: + self.unhook_modules(module_list, copy.deepcopy(module_field_map)) + return + + save_image_path = os.path.join(p.outpath_samples, 'attention_maps') + + max_dims = p.height * p.width + token_count = p.savemaps_token_count + token_indices = [x+1 for x in range(token_count)] + + for module in module_list: + if not hasattr(module, 'savemaps_batch') or module.savemaps_batch is None: + logger.error(f"No attention maps found for module: {module.network_layer_name}") + continue + + attn_maps = module.savemaps_batch # (attn_map num, 2 * batch_num, height * width, sequence_len) + + attn_map_num, batch_num, hw, seq_len = attn_maps.shape + + downscale_ratio = max_dims / hw + downscale_h = round((hw * (p.height / p.width)) ** 0.5) + downscale_w = hw // downscale_h + + # if take_mean_of_all_dims: + # attn_maps = attn_maps.mean(dim=-1) # (attn_map num, batch_num, height * width) + gaussian_blur = GaussianBlur(kernel_size=3, sigma=1) + attn_maps = attn_maps.permute(0, 3, 1, 2) + attn_maps = gaussian_blur(attn_maps) # Applying Gaussian smoothing + attn_maps = attn_maps.permute(0, 2, 3, 1) + + attn_maps = attn_maps[:, :, :, token_indices] # (attn_map num, batch_num, height * width) + + attn_maps = rearrange(attn_maps, 'n (m b) (h w) t -> n m b t h w', m = 2, h = downscale_h).mean(dim=1) # (attn_map num, batch_num, token_idx, height, width) + attn_map_num, batch_num, token_num, height, width = attn_maps.shape + for attn_map_idx in range(attn_map_num): + for batch_idx in range(batch_num): + for token_idx in range(token_num): + fn_pad_zeroes = lambda num: f"{num:04}" + savestep_num = module.savemaps_save_steps[attn_map_idx] + attn_map = attn_maps[attn_map_idx, batch_idx, token_idx] + out_file_name = f'{module.network_layer_name}_token{token_idx+1:04}_step{savestep_num:04}_attnmap_{attn_map_idx:04}_batch{batch_idx:04}.png' + save_path = os.path.join(save_image_path, out_file_name) + plot_tools.plot_attention_map( + attention_map=attn_map, + title=f"{module.network_layer_name}\nToken {token_idx+1}, Step {savestep_num}", + save_path=save_path, + plot_type="default" + ) + + self.unhook_modules(module_list, copy.deepcopy(module_field_map)) + + def unhook_callbacks(self) -> None: + pass + + def get_xyz_axis_options(self) -> dict: + return {} + + def get_infotext_fields(self) -> list: + return self.infotext_fields + + def create_save_hook(self, module_list): + pass + + def hook_modules(self, module_list: list, value_map: dict): + def savemaps_hook(module, input, kwargs, output): + """ Hook to save attention maps every N steps, or the last step if N is 0. + Saves attention maps to a field named 'savemaps_batch' in the module. + with shape (attn_map, batch_num, height * width). + + """ + module.savemaps_step += 1 + + #parent_module = getattr(module, 'savemaps_parent_module', None) + #to_v_map = None + #if parent_module is not None: + to_v_map = getattr(module, 'savemaps_to_v_map', None) + + if (module.savemaps_step in module.savemaps_save_steps): + #context = kwargs.get('context', None) + attn_map = output.detach().clone() + + # multiply into text embeddings + if to_v_map is not None: + attn_map = (to_v_map @ output.transpose(1,2)).transpose(1,2) + + attn_map = attn_map.unsqueeze(0) + + #attn_map = attn_map.mean(dim=-1) + if module.savemaps_batch is None: + module.savemaps_batch = attn_map + else: + module.savemaps_batch = torch.cat([module.savemaps_batch, attn_map], dim=0) + + def savemaps_to_v_hook(module, input, kwargs, output): + module.savemaps_parent_module[0].savemaps_to_v_map = output + + #for module, kv in zip(module_list, value_map.items()): + for module in module_list: + for key_name, default_value in value_map.items(): + module_hooks.modules_add_field(module, key_name, default_value) + module_hooks.module_add_forward_hook(module, savemaps_hook, 'forward', with_kwargs=True) + if hasattr(module, 'to_v'): + module_hooks.modules_add_field(module.to_v, 'savemaps_parent_module', [module]) + module_hooks.module_add_forward_hook(module.to_v, savemaps_to_v_hook, 'forward', with_kwargs=True) + + def unhook_modules(self, module_list: list, value_map: dict): + for module in module_list: + for key_name, _ in value_map.items(): + module_hooks.modules_remove_field(module, key_name) + module_hooks.remove_module_forward_hook(module, 'savemaps_hook') + if hasattr(module, 'to_v'): + module_hooks.modules_remove_field(module.to_v, 'savemaps_parent_module') + module_hooks.remove_module_forward_hook(module.to_v, 'savemaps_to_v_hook') + + def print_modules(self, module_name_filter, class_name_filter): + logger.info("Module name filter: '%s', Class name filter: '%s'", module_name_filter, class_name_filter) + modules = self.get_modules_by_filter(module_name_filter, class_name_filter) + module_names = [""] + if len(modules) > 0: + module_names = "\n".join([f"{m.network_layer_name}: {m.__class__.__name__}" for m in modules]) + logger.info("Modules found:\n----------\n%s\n----------\n", module_names) + + def get_modules_by_filter(self, module_name_filter, class_name_filter): + if len(class_name_filter) == 0: + class_name_filter = None + if len(module_name_filter) == 0: + module_name_filter = None + found_modules = module_hooks.get_modules(module_name_filter, class_name_filter) + if len(found_modules) == 0: + logger.warning(f"No modules found with module name filter: {module_name_filter} and class name filter") + return found_modules + diff --git a/scripts/scfg.py b/scripts/scfg.py new file mode 100644 index 0000000..93d2470 --- /dev/null +++ b/scripts/scfg.py @@ -0,0 +1,830 @@ +import logging +from os import environ +import modules.scripts as scripts +import gradio as gr +import scipy.stats as stats + +from scripts.ui_wrapper import UIWrapper, arg +from modules import script_callbacks, patches +from modules.hypernetworks import hypernetwork +#import modules.sd_hijack_optimizations +from modules.script_callbacks import CFGDenoiserParams, CFGDenoisedParams, AfterCFGCallbackParams +from modules.prompt_parser import reconstruct_multicond_batch +from modules.processing import StableDiffusionProcessing +#from modules.shared import sd_model, opts +from modules.sd_samplers_cfg_denoiser import catenate_conds +from modules.sd_samplers_cfg_denoiser import CFGDenoiser +from modules import shared + +import math +import torch +from torch.nn import functional as F +from torchvision.transforms import GaussianBlur + +from warnings import warn +from typing import Callable, Dict, Optional +from collections import OrderedDict +import torch + +from scripts.incant_utils import module_hooks + +# from pytorch_memlab import LineProfiler, MemReporter +# reporter = MemReporter() + +logger = logging.getLogger(__name__) +logger.setLevel(environ.get("SD_WEBUI_LOG_LEVEL", logging.INFO)) + +incantations_debug = environ.get("INCANTAIONS_DEBUG", False) + + +""" +An unofficial implementation of "Rethinking the Spatial Inconsistency in Classifier-Free Diffusion Guidancee" for Automatic1111 WebUI. + +This builds upon the code provided in the official S-CFG repository: https://github.com/SmilesDZgk/S-CFG + + +@inproceedings{shen2024rethinking, + title={Rethinking the Spatial Inconsistency in Classifier-Free Diffusion Guidancee}, + author={Shen, Dazhong and Song, Guanglu and Xue, Zeyue and Wang, Fu-Yun and Liu, Yu}, + booktitle={Proceedings of The IEEE/CVF Computer Vision and Pattern Recognition Conference (CVPR)}, + year={2024} +} + +Parts of the code are based on Diffusers under the Apache License 2.0: +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +Author: v0xie +GitHub URL: https://github.com/v0xie/sd-webui-incantations + +""" + + +handles = [] +global_scale = 1 + +SCFG_MODULES = ['to_q', 'to_k'] + + +class SCFGStateParams: + def __init__(self): + self.scfg_scale:float = 0.8 + self.rate_min = 0.8 + self.rate_max = 3.0 + self.rate_clamp = 15.0 + self.R = 4 + self.start_step = 0 + self.end_step = 150 + self.gaussian_smoothing = None + + self.max_sampling_steps = -1 + self.current_step = 0 + self.height = -1 + self.width = -1 + + self.statistics = { + "min_rate": float('inf'), + "max_rate": float('-inf'), + } + + self.mask_t = None + self.mask_fore = None + self.denoiser = None + self.all_crossattn_modules = None + self.patched_combined_denoised = None + + +class SCFGExtensionScript(UIWrapper): + def __init__(self): + self.cached_c = [None, None] + self.handles = [] + + # Extension title in menu UI + def title(self) -> str: + return "S-CFG" + + # Decide to show menu in txt2img or img2img + def show(self, is_img2img): + return scripts.AlwaysVisible + + # Setup menu ui detail + def setup_ui(self, is_img2img) -> list: + with gr.Accordion('S-CFG', open=False): + active = gr.Checkbox(value=False, default=False, label="Active", elem_id='scfg_active', info="Computationally expensive. A batch size of 4 for 1024x1024 will max out a 24GB card!") + with gr.Row(): + scfg_scale = gr.Slider(value = 1.0, minimum = 0, maximum = 10.0, step = 0.1, label="SCFG Scale", elem_id = 'scfg_scale', info="") + scfg_r = gr.Slider(value = 4, minimum = 1, maximum = 16, step = 1, label="SCFG R", elem_id = 'scfg_r', info="Scale factor. Greater R uses more memory.") + with gr.Row(): + scfg_rate_min = gr.Slider(value = 0.8, minimum = 0, maximum = 30.0, step = 0.1, label="Min Rate", elem_id = 'scfg_rate_min', info="") + scfg_rate_max = gr.Slider(value = 3.0, minimum = 0, maximum = 30.0, step = 0.1, label="Max Rate", elem_id = 'scfg_rate_max', info="") + scfg_rate_clamp = gr.Slider(value = 0.0, minimum = 0, maximum = 30.0, step = 0.1, label="Clamp Rate", elem_id = 'scfg_rate_clamp', info="If > 0, clamp max rate to Clamp Rate / CFG Scale. Overrides max rate.") + with gr.Row(): + start_step = gr.Slider(value = 0, minimum = 0, maximum = 150, step = 1, label="Start Step", elem_id = 'scfg_start_step', info="") + end_step = gr.Slider(value = 150, minimum = 0, maximum = 150, step = 1, label="End Step", elem_id = 'scfg_end_step', info="") + + active.do_not_save_to_config = True + scfg_scale.do_not_save_to_config = True + scfg_rate_min.do_not_save_to_config = True + scfg_rate_max.do_not_save_to_config = True + scfg_rate_clamp.do_not_save_to_config = True + scfg_r.do_not_save_to_config = True + start_step.do_not_save_to_config = True + end_step.do_not_save_to_config = True + + self.infotext_fields = [ + (active, lambda d: gr.Checkbox.update(value='SCFG Active' in d)), + (scfg_scale, 'SCFG Scale'), + (scfg_rate_min, 'SCFG Rate Min'), + (scfg_rate_max, 'SCFG Rate Max'), + (scfg_rate_clamp, 'SCFG Rate Clamp'), + (start_step, 'SCFG Start Step'), + (end_step, 'SCFG End Step'), + (scfg_r, 'SCFG R'), + ] + self.paste_field_names = [ + 'scfg_active', + 'scfg_scale', + 'scfg_rate_min', + 'scfg_rate_max', + 'scfg_rate_clamp', + 'scfg_start_step', + 'scfg_end_step', + 'scfg_r', + ] + return [active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r] + + def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): + self.pag_process_batch(p, *args, **kwargs) + + def pag_process_batch(self, p: StableDiffusionProcessing, active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r, *args, **kwargs): + # cleanup previous hooks always + script_callbacks.remove_current_script_callbacks() + self.remove_all_hooks() + + active = getattr(p, "scfg_active", active) + if active is False: + return + scfg_scale = getattr(p, "scfg_scale", scfg_scale) + scfg_rate_min = getattr(p, "scfg_rate_min", scfg_rate_min) + scfg_rate_max = getattr(p, "scfg_rate_max", scfg_rate_max) + scfg_rate_clamp = getattr(p, "scfg_rate_clamp", scfg_rate_clamp) + start_step = getattr(p, "scfg_start_step", start_step) + end_step = getattr(p, "scfg_end_step", end_step) + scfg_r = getattr(p, "scfg_r", scfg_r) + + p.extra_generation_params.update({ + "SCFG Active": active, + "SCFG Scale": scfg_scale, + "SCFG Rate Min": scfg_rate_min, + "SCFG Rate Max": scfg_rate_max, + "SCFG Rate Clamp": scfg_rate_clamp, + "SCFG Start Step": start_step, + "SCFG End Step": end_step, + "SCFG R": scfg_r, + }) + self.create_hook(p, active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r) + + def create_hook(self, p: StableDiffusionProcessing, active, scfg_scale, scfg_rate_min, scfg_rate_max, scfg_rate_clamp, start_step, end_step, scfg_r): + # Create a list of parameters for each concept + scfg_params = SCFGStateParams() + + # Add to p + if not hasattr(p, 'incant_cfg_params'): + logger.error("No incant_cfg_params found in p") + p.incant_cfg_params['scfg_params'] = scfg_params + + scfg_params.denoiser = None + scfg_params.all_crossattn_modules = self.get_all_crossattn_modules() + scfg_params.max_sampling_steps = p.steps + scfg_params.scfg_scale = scfg_scale + scfg_params.rate_min = scfg_rate_min + scfg_params.rate_max = scfg_rate_max + scfg_params.rate_clamp = scfg_rate_clamp + scfg_params.start_step = start_step + scfg_params.end_step = end_step + scfg_params.R = scfg_r + scfg_params.height = p.height + scfg_params.width = p.width + kernel_size = 3 + sigma=0.5 + scfg_params.gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(shared.device) + + + # Use lambda to call the callback function with the parameters to avoid global variables + #cfg_denoise_lambda = lambda callback_params: self.on_cfg_denoiser_callback(callback_params, scfg_params) + cfg_denoised_lambda = lambda callback_params: self.on_cfg_denoised_callback(callback_params, scfg_params) + unhook_lambda = lambda _: self.unhook_callbacks(scfg_params) + + self.ready_hijack_forward(scfg_params.all_crossattn_modules) + + logger.debug('Hooked callbacks') + #script_callbacks.on_cfg_denoiser(cfg_denoise_lambda) + script_callbacks.on_cfg_denoised(cfg_denoised_lambda) + script_callbacks.on_script_unloaded(unhook_lambda) + + def postprocess_batch(self, p, *args, **kwargs): + self.scfg_postprocess_batch(p, *args, **kwargs) + + def scfg_postprocess_batch(self, p, active, *args, **kwargs): + script_callbacks.remove_current_script_callbacks() + + logger.debug('Removed script callbacks') + active = getattr(p, "scfg_active", active) + if active is False: + return + + if hasattr(p, 'incant_cfg_params') and 'scfg_params' in p.incant_cfg_params: + stats = p.incant_cfg_params['scfg_params'].statistics + logger.debug('SCFG Statistics: %s', stats) + + + self.remove_all_hooks() + + def remove_all_hooks(self): + all_crossattn_modules = self.get_all_crossattn_modules() + for module in all_crossattn_modules: + self.remove_field_cross_attn_modules(module, 'scfg_last_to_q_map') + self.remove_field_cross_attn_modules(module, 'scfg_last_to_k_map') + if hasattr(module, 'to_q'): + handle_scfg_to_q = _remove_all_forward_hooks(module.to_q, 'scfg_to_q_hook') + self.remove_field_cross_attn_modules(module.to_q, 'scfg_parent_module') + if hasattr(module, 'to_k'): + handle_scfg_to_k = _remove_all_forward_hooks(module.to_k, 'scfg_to_k_hook') + self.remove_field_cross_attn_modules(module.to_k, 'scfg_parent_module') + + def unhook_callbacks(self, scfg_params: SCFGStateParams): + pass + + def ready_hijack_forward(self, all_crossattn_modules): + """ Create hooks in the forward pass of the cross attention modules + Copies the output of the to_v module to the parent module + """ + + def scfg_self_attn_hook(module, input, kwargs, output): + # scfg_q_map = output.detach().clone() + scfg_q_map = prepare_attn_map(output, module.scfg_heads) + attn_scores = get_attention_scores(scfg_q_map, scfg_q_map) + setattr(module.scfg_parent_module[0], 'scfg_last_qv_map', attn_scores) + + def scfg_cross_attn_hook(module, input, kwargs, output): + scfg_q_map = prepare_attn_map(module.scfg_parent_module[0].scfg_last_to_q_map, module.scfg_heads) + scfg_k_map = prepare_attn_map(output, module.scfg_heads) + #scfg_k_map = output.detach().clone() + attn_scores = get_attention_scores(scfg_q_map, scfg_k_map) + setattr(module.scfg_parent_module[0], 'scfg_last_qv_map', attn_scores) + # del module.parent_module[0].scfg_last_to_q_map + + def scfg_to_q_hook(module, input, kwargs, output): + setattr(module.scfg_parent_module[0], 'scfg_last_to_q_map', output) + + def scfg_to_k_hook(module, input, kwargs, output): + setattr(module.scfg_parent_module[0], 'scfg_last_to_k_map', output) + + for module in all_crossattn_modules: + if not hasattr(module, 'to_q') or not hasattr(module, 'to_k'): + logger.error("CrossAttention module '%s' does not have to_q or to_k", module.network_layer_name) + continue + + # to_q + self.add_field_cross_attn_modules(module.to_q, 'scfg_parent_module', [module]) + self.add_field_cross_attn_modules(module, 'scfg_last_to_q_map', None) + handle_scfg_to_q = module_hooks.module_add_forward_hook( + module.to_q, + scfg_to_q_hook, + with_kwargs=True + ) + + # to_k + self.add_field_cross_attn_modules(module.to_k, 'scfg_parent_module', [module]) + if module.network_layer_name.endswith('attn2'): # cross attn + self.add_field_cross_attn_modules(module, 'scfg_last_to_k_map', None) + handle_scfg_to_k = module_hooks.module_add_forward_hook( + module.to_k, + scfg_to_k_hook, + with_kwargs=True + ) + + def get_all_crossattn_modules(self): + """ + Get ALL attention modules + """ + modules = module_hooks.get_modules( + module_name_filter='CrossAttention' + ) + return modules + + def add_field_cross_attn_modules(self, module, field, value): + """ Add a field to a module if it doesn't exist """ + module_hooks.modules_add_field(module, field, value) + + def remove_field_cross_attn_modules(self, module, field): + """ Remove a field from a module if it exists """ + module_hooks.modules_remove_field(module, field) + + def on_cfg_denoiser_callback(self, params: CFGDenoiserParams, scfg_params: SCFGStateParams): + # always unhook + self.unhook_callbacks(scfg_params) + + def on_cfg_denoised_callback(self, params: CFGDenoisedParams, scfg_params: SCFGStateParams): + """ Callback function for the CFGDenoisedParams + Refer to pg.22 A.2 of the PAG paper for how CFG and PAG combine + + """ + scfg_params.current_step = params.sampling_step + + # Run only within interval + if not scfg_params.start_step <= params.sampling_step <= scfg_params.end_step: + return + + if scfg_params.scfg_scale <= 0: + return + + # S-CFG + R = scfg_params.R + max_latent_size = [params.x.shape[-2] // R, params.x.shape[-1] // R] + + #with LineProfiler(get_mask) as lp: + ca_mask, fore_mask = get_mask(scfg_params.all_crossattn_modules, + scfg_params, + r = scfg_params.R, + latent_size = max_latent_size, + ) + #lp.print_stats() + + # todo parameterize this + mask_t = F.interpolate(ca_mask, scale_factor=R, mode='nearest') + mask_fore = F.interpolate(fore_mask, scale_factor=R, mode='nearest') + scfg_params.mask_t = mask_t + scfg_params.mask_fore = mask_fore + + + def get_xyz_axis_options(self) -> dict: + xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module + extra_axis_options = { + xyz_grid.AxisOption("[SCFG] Active", str, scfg_apply_override('scfg_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[SCFG] SCFG Scale", float, scfg_apply_field("scfg_scale")), + xyz_grid.AxisOption("[SCFG] SCFG Rate Min", float, scfg_apply_field("scfg_rate_min")), + xyz_grid.AxisOption("[SCFG] SCFG Rate Max", float, scfg_apply_field("scfg_rate_max")), + xyz_grid.AxisOption("[SCFG] SCFG Rate Clamp", float, scfg_apply_field("scfg_rate_clamp")), + xyz_grid.AxisOption("[SCFG] SCFG Start Step", int, scfg_apply_field("scfg_start_step")), + xyz_grid.AxisOption("[SCFG] SCFG End Step", int, scfg_apply_field("scfg_end_step")), + xyz_grid.AxisOption("[SCFG] SCFG R", int, scfg_apply_field("scfg_r")), + } + return extra_axis_options + + +def scfg_combine_denoised(model_delta, cfg_scale, scfg_params: SCFGStateParams): + """ The inner loop of the S-CFG denoiser + Arguments: + model_delta: torch.Tensor - defined by `x_out[cond_index] - denoised_uncond[i]` + cfg_scale: float - guidance scale + scfg_params: SCFGStateParams - the state parameters for the S-CFG denoiser + + Returns: + int or torch.Tensor - 1.0 if not within interval or scale is 0, else the rate map tensor + """ + + current_step = scfg_params.current_step + start_step = scfg_params.start_step + end_step = scfg_params.end_step + scfg_scale = scfg_params.scfg_scale + + if not start_step <= current_step <= end_step: + return 1.0 + + if scfg_scale <= 0: + return 1.0 + + mask_t = scfg_params.mask_t + mask_fore = scfg_params.mask_fore + min_rate = scfg_params.rate_min + max_rate = scfg_params.rate_max + rate_clamp = scfg_params.rate_clamp + + model_delta = model_delta.unsqueeze(0) + model_delta_norm = model_delta.norm(dim=1, keepdim=True) + + eps = lambda dtype: torch.finfo(dtype).eps + + # rescale map if necessary + if mask_t.shape[2:] != model_delta_norm.shape[2:]: + logger.debug('Rescaling mask_t from %s to %s', mask_t.shape[2:], model_delta_norm.shape[2:]) + mask_t = F.interpolate(mask_t, size=model_delta_norm.shape[2:], mode='bilinear') + if mask_fore.shape[-2] != model_delta_norm.shape[-2]: + logger.debug('Rescaling mask_fore from %s to %s', mask_fore.shape[2:], model_delta_norm.shape[2:]) + mask_fore = F.interpolate(mask_fore, size=model_delta_norm.shape[2:], mode='bilinear') + + delta_mask_norms = (model_delta_norm * mask_t).sum([2,3])/(mask_t.sum([2,3])+eps(mask_t.dtype)) + upnormmax = delta_mask_norms.max(dim=1)[0] + upnormmax = upnormmax.unsqueeze(-1) + + fore_norms = (model_delta_norm * mask_fore).sum([2,3])/(mask_fore.sum([2,3])+eps(mask_fore.dtype)) + + up = fore_norms + down = delta_mask_norms + + tmp_mask = (mask_t.sum([2,3])>0).float() + rate = up*(tmp_mask)/(down+eps(down.dtype)) # b 257 + rate = (rate.unsqueeze(-1).unsqueeze(-1)*mask_t).sum(dim=1, keepdim=True) # b 1, 64 64 + + del model_delta_norm, delta_mask_norms, upnormmax, fore_norms, up, down, tmp_mask + + # unscaled min/max rate + if rate.min().item() < scfg_params.statistics["min_rate"]: + scfg_params.statistics["min_rate"] = rate.min().item() + if rate.max().item() > scfg_params.statistics["max_rate"]: + scfg_params.statistics["max_rate"] = rate.max().item() + + # should this go before or after the gaussian blur, or before/after the rate + rate = rate * scfg_scale + + rate = torch.clamp(rate,min=min_rate, max=max_rate) + + if rate_clamp > 0: + rate = torch.clamp_max(rate, rate_clamp/cfg_scale) + + ###Gaussian Smoothing + #kernel_size = 3 + #sigma=0.5 + #smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(rate.device) + smoothing = scfg_params.gaussian_smoothing + rate = F.pad(rate, (1, 1, 1, 1), mode='reflect') + rate = smoothing(rate) + + return rate.squeeze(0) + + +# XYZ Plot +# Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py +def scfg_apply_override(field, boolean: bool = False): + def fun(p, x, xs): + if boolean: + x = True if x.lower() == "true" else False + setattr(p, field, x) + if not hasattr(p, "scfg_active"): + setattr(p, "scfg_active", True) + return fun + + +def scfg_apply_field(field): + def fun(p, x, xs): + if not hasattr(p, "scfg_active"): + setattr(p, "scfg_active", True) + setattr(p, field, x) + return fun + + +def _remove_all_forward_hooks( + module: torch.nn.Module, hook_fn_name: Optional[str] = None +) -> None: + module_hooks.remove_module_forward_hook(module, hook_fn_name) + + +""" +# below code modified from https://github.com/SmilesDZgk/S-CFG +@inproceedings{shen2024rethinking, + title={Rethinking the Spatial Inconsistency in Classifier-Free Diffusion Guidancee}, + author={Shen, Dazhong and Song, Guanglu and Xue, Zeyue and Wang, Fu-Yun and Liu, Yu}, + booktitle={Proceedings of The IEEE/CVF Computer Vision and Pattern Recognition Conference (CVPR)}, + year={2024} +} +""" + + +import math +import numbers +import torch +from torch import nn +from torch.nn import functional as F + + +class GaussianSmoothing(nn.Module): + """ + Apply gaussian smoothing on a + 1d, 2d or 3d tensor. Filtering is performed seperately for each channel + in the input using a depthwise convolution. + Arguments: + channels (int, sequence): Number of channels of the input tensors. Output will + have this number of channels as well. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + Default value is 2 (spatial). + """ + def __init__(self, channels, kernel_size, sigma, dim=2): + super(GaussianSmoothing, self).__init__() + if isinstance(kernel_size, numbers.Number): + kernel_size = [kernel_size] * dim + if isinstance(sigma, numbers.Number): + sigma = [sigma] * dim + + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ + torch.exp(-((mgrid - mean) / (2 * std)) ** 2) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + + if dim == 1: + self.conv = F.conv1d + elif dim == 2: + self.conv = F.conv2d + elif dim == 3: + self.conv = F.conv3d + else: + raise RuntimeError( + 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) + ) + + def forward(self, input): + """ + Apply gaussian filter to input. + Arguments: + input (torch.Tensor): Input to apply gaussian filter on. + Returns: + filtered (torch.Tensor): Filtered output. + """ + return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups) + +# based on diffusers/models/attention_processor.py Attention head_to_batch_dim +def head_to_batch_dim(x, heads, out_dim=3): + head_size = heads + if x.ndim == 3: + + batch_size, seq_len, dim = x.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + x = x.permute(0, 2, 1, 3) + if out_dim == 3: + x = x.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + return x + + +# based on diffusers/models/attention_processor.py Attention batch_to_head_dim +def batch_to_head_dim(x, heads): + head_size = heads + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // head_size, head_size, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return x + + +def average_over_head_dim(x, heads): + x = rearrange(x, '(b h) s t -> b h s t', h=heads).mean(1) + return x + + +import torch.nn.functional as F +from einops import rearrange +def get_mask(attn_modules, scfg_params: SCFGStateParams, r, latent_size): + """ Aggregates the attention across the different layers and heads at the specified resolution. + In the original paper, r is a hyper-parameter set to 4. + Arguments: + attn_modules: List of attention modules + scfg_params: SCFGStateParams + r: int - + latent_size: tuple + + """ + height = scfg_params.height + width = scfg_params.width + max_dims = height * width + latent_size = latent_size[-2:] + module_attn_sizes = set() + + key_corss = f"r{r}_cross" + key_self = f"r{r}_self" + + + # The maximum value of the sizes of attention map to aggregate + max_r = r + max_sizes = r + + # The current number of attention map resolutions aggregated + attnmap_r = 0 + + r_r = 1 + new_ca = 0 + new_fore=0 + a_n=0 + # corresponds to diffusers pipe.unet.config.sample_size + # sample_size = 64 + # get a layer wise mapping + attention_store_proxy = {"r2_cross": [], "r4_cross": [], "r8_cross": [], "r16_cross": [], + "r2_self": [], "r4_self": [], "r8_self": [], "r16_self": []} + for module in attn_modules: + module_type = 'cross' if 'attn2' in module.network_layer_name else 'self' + + to_q_map = getattr(module, 'scfg_last_to_q_map', None) + to_k_map = getattr(module, 'scfg_last_to_k_map', None) + # self-attn + if to_k_map is None: + to_k_map = to_q_map + + to_q_map = prepare_attn_map(to_q_map, module.heads) + to_k_map = prepare_attn_map(to_k_map, module.heads) + + module_attn_size = to_q_map.size(1) + module_attn_sizes.add(module_attn_size) + downscale_h = int((module_attn_size * (height / width)) ** 0.5) + downscale_w = module_attn_size // downscale_h + module_key = f"r{module_attn_size}_{module_type}" + + attn_probs = get_attention_scores(to_q_map, to_k_map, to_q_map.dtype) + + if module_type == 'self': + del module.scfg_last_to_q_map + else: + del module.scfg_last_to_q_map, module.scfg_last_to_k_map + + if module_key not in attention_store_proxy: + attention_store_proxy[module_key] = [] + try: + attention_store_proxy[module_key].append(attn_probs) + except KeyError: + continue + + module_attn_sizes = sorted(list(module_attn_sizes)) + attention_maps = attention_store_proxy + + curr_r = module_attn_sizes.pop(0) + while curr_r != None and attnmap_r < max_sizes: + key_corss = f"r{curr_r}_cross" + key_self = f"r{curr_r}_self" + + if key_self not in attention_maps.keys() or key_corss not in attention_maps.keys(): + next_r = module_attn_sizes.pop(0) + attnmap_r += 1 + curr_r = next_r + continue + if len(attention_maps[key_self]) == 0 or len(attention_maps[key_corss]) == 0: + curr_r = module_attn_sizes.pop(0) + attnmap_r += 1 + curr_r = next_r + continue + + sa = torch.stack(attention_maps[key_self], dim=1) + ca = torch.stack(attention_maps[key_corss], dim=1) + attn_num = sa.size(1) + sa = rearrange(sa, 'b n h w -> (b n) h w') + ca = rearrange(ca, 'b n h w -> (b n) h w') + + curr = 0 # b hw c=hw + curr +=sa + + # 4.1.2 Self-Attentiion + ssgc_sa = curr + ssgc_n = max_r + + # summation from r=2 to R, we set ssgc_sa to curr which would be sa^1 + # major memory hog + # active_bytes peak from 3.41G to 4.04G + # reserved_bytes peak from 3.70G to 4.64G + # optimization 1: active 4.03G -> 3.72G = 0.31G, reserved 4.64G -> 4.16G = 0.48G + for r_value in range(1, ssgc_n): + r_pow = r_value + 1 + curr @= sa # optimization 1 +# curr = torch.linalg.matrix_power(sa, r_pow) # sa^r + ssgc_sa += curr + + ssgc_sa/=ssgc_n + sa = ssgc_sa + + ########smoothing ca + ca = sa@ca # b hw c + + hw = ca.size(1) + + downscale_h = round((hw * (height / width)) ** 0.5) + + ca = rearrange(ca, 'b (h w) c -> b c h w', h=downscale_h ) + + # Scale the attention map to the expected size + max_size = latent_size + scale_factor = [ + max_size[0] / ca.shape[-2], + max_size[1] / ca.shape[-1] + ] + mode = 'bilinear' #'nearest' # + ca = F.interpolate(ca, scale_factor=scale_factor, mode=mode) # b 77 32 32 + + #####Gaussian Smoothing + #kernel_size = 3 + #sigma = 0.5 + #smoothing = GaussianSmoothing(channels=1, kernel_size=kernel_size, sigma=sigma, dim=2).to(ca.device) + smoothing = scfg_params.gaussian_smoothing + channel = ca.size(1) + ca= rearrange(ca, ' b c h w -> (b c) h w' ).unsqueeze(1) + ca = F.pad(ca, (1, 1, 1, 1), mode='reflect') + ca = smoothing(ca.float()).squeeze(1) + ca = rearrange(ca, ' (b c) h w -> b c h w' , c= channel) + + ca_norm = ca/(ca.mean(dim=[2,3], keepdim=True)+torch.finfo(ca.dtype).eps) ### spatial normlization + + new_ca+=rearrange(ca_norm, '(b n) c h w -> b n c h w', n=attn_num).sum(1) + + fore_ca = torch.stack([ca[:,0],ca[:,1:].sum(dim=1)], dim=1) + froe_ca_norm = fore_ca/fore_ca.mean(dim=[2,3], keepdim=True) ### spatial normlization + new_fore += rearrange(froe_ca_norm, '(b n) c h w -> b n c h w', n=attn_num).sum(1) + a_n+=attn_num + + if len(module_attn_sizes) > 0: + curr_r = module_attn_sizes.pop(0) + else: + curr_r = None + attnmap_r += 1 + # r_r *= 2 + + # optimization 2: memory savings: 3.09G - 2.47G = 0.62G + del ca_norm, froe_ca_norm, fore_ca + + # no memory savings + del attention_maps + del sa, ca, ssgc_sa, ssgc_n, curr + + # variables used from above: + # new_ca, new_fore, a_n + new_ca = new_ca/a_n + new_fore = new_fore/a_n + _,new_ca = new_ca.chunk(2, dim=0) #[1] + fore_ca, _ = new_fore.chunk(2, dim=0) + + max_ca, inds = torch.max(new_ca[:,:], dim=1) + max_ca = max_ca.unsqueeze(1) # + ca_mask = (new_ca==max_ca).float() # b 77/10 16 16 + + max_fore, inds = torch.max(fore_ca[:,:], dim=1) + max_fore = max_fore.unsqueeze(1) # + fore_mask = (fore_ca==max_fore).float() # b 77/10 16 16 + fore_mask = 1.0-fore_mask[:,:1] # b 1 16 16 + + # no memory savings + del new_ca, new_fore, a_n, max_ca, max_fore, inds + + return [ ca_mask, fore_mask] + + +def prepare_attn_map(to_k_map, heads): + to_k_map = head_to_batch_dim(to_k_map, heads) + to_k_map = average_over_head_dim(to_k_map, heads) + to_k_map = torch.stack([to_k_map[0], to_k_map[0]], dim=0) + return to_k_map + + +def get_attention_scores(to_q_map, to_k_map, dtype): + """ Calculate the attention scores for the given query and key maps + Arguments: + to_q_map: torch.Tensor - query map + to_k_map: torch.Tensor - key map + dtype: torch.dtype - data type of the tensor + Returns: + torch.Tensor - attention scores + """ + # based on diffusers models/attention.py "get_attention_scores" + # use in place operations vs. softmax to save memory: https://stackoverflow.com/questions/53732209/torch-in-place-operations-to-save-memory-softmax + # 512x: 2.65G -> 2.47G + # attn_probs = attn_scores.softmax(dim=-1).to(device=shared.device, dtype=to_q_map.dtype) + + attn_probs = to_q_map @ to_k_map.transpose(-1, -2) + + # avoid nan by converting to float32 and subtracting max + attn_probs = attn_probs.to(dtype=torch.float32) # + attn_probs -= torch.max(attn_probs) + + torch.exp(attn_probs, out = attn_probs) + summed = attn_probs.sum(dim=-1, keepdim=True, dtype=torch.float32) + attn_probs /= summed + + attn_probs = attn_probs.to(dtype=dtype) + + return attn_probs \ No newline at end of file