diff --git a/README.md b/README.md index aa6b933..2f200a8 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,15 @@ Prompt: "A cute puppy on the moon", Min Rate: 0.5, Max Rate: 10.0 https://arxiv.org/abs/2403.17377 An alternative/complementary method to CFG (Classifier-Free Guidance) that increases sampling quality. +# Update: 20-05-2024 +Implemented a new feature called "Saliency-Adaptive Noise Fusion" derived from ["High-fidelity Person-centric Subject-to-Image Synthesis"](https://arxiv.org/abs/2311.10329). + +This feature combines the guidance from PAG and CFG in an adaptive way that improves image quality especially at higher guidance scales. + +Check out the paper authors' project repository here: https://github.com/CodeGoat24/Face-diffuser + #### Controls +* **Use Saliency-Adaptive Noise Fusion**: Use improved method of combining CFG + PAG. * **PAG Scale**: Controls the intensity of effect of PAG on the generated image. * **PAG Start Step**: Step to start using PAG. * **PAG End Step**: Step to stop using PAG. @@ -286,6 +294,15 @@ SD XL archivePrefix={arXiv}, primaryClass={cs.CV} } + + @misc{wang2024highfidelity, + title={High-fidelity Person-centric Subject-to-Image Synthesis}, + author={Yibin Wang and Weizhong Zhang and Jianwei Zheng and Cheng Jin}, + year={2024}, + eprint={2311.10329}, + archivePrefix={arXiv}, + primaryClass={cs.CV} + } - [Hard Prompts Made Easy](https://github.com/YuxinWenRick/hard-prompts-made-easy) diff --git a/scripts/cfg_combiner.py b/scripts/cfg_combiner.py index 12dbf6a..bbf3d4f 100644 --- a/scripts/cfg_combiner.py +++ b/scripts/cfg_combiner.py @@ -1,6 +1,7 @@ import gradio as gr import logging import torch +import torchvision.transforms as F from modules import shared, scripts, devices, patches, script_callbacks from modules.script_callbacks import CFGDenoiserParams from modules.processing import StableDiffusionProcessing @@ -179,11 +180,29 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale): # 2. PAG pag_x_out = None pag_scale = None + run_pag = False if pag_params is not None: pag_active = pag_params.pag_active pag_x_out = pag_params.pag_x_out pag_scale = pag_params.pag_scale + if not pag_active: + pass + # Not within step interval? + elif not pag_params.pag_start_step <= pag_params.step <= pag_params.pag_end_step: + pass + # Scale is zero? + elif pag_scale <= 0: + pass + else: + run_pag = pag_active + + # 3. Saliency Map + use_saliency_map = False + if pag_params is not None: + use_saliency_map = pag_params.pag_sanf + + ### Combine Denoised for i, conds in enumerate(conds_list): for cond_index, weight in conds: @@ -209,24 +228,44 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale): pass # 1. Experimental formulation for S-CFG combined with CFG - denoised[i] += (model_delta) * rate * (weight * cfg_scale) + cfg_x = (model_delta) * rate * (weight * cfg_scale) + if not use_saliency_map or not run_pag: + denoised[i] += cfg_x del rate # 2. PAG # PAG is added like CFG if pag_params is not None: - if not pag_active: - pass - # Not within step interval? - elif not pag_params.pag_start_step <= pag_params.step <= pag_params.pag_end_step: - pass - # Scale is zero? - elif pag_scale <= 0: + if not run_pag: pass # do pag else: try: - denoised[i] += (x_out[cond_index] - pag_x_out[i]) * (weight * pag_scale) + pag_delta = x_out[cond_index] - pag_x_out[i] + pag_x = pag_delta * (weight * pag_scale) + + if not use_saliency_map: + denoised[i] += pag_x + + # 3. Saliency Adaptive Noise Fusion arXiv.2311.10329v5 + # Smooth the saliency maps + if use_saliency_map: + blur = F.GaussianBlur(kernel_size=3, sigma=1).to(device=shared.device) + omega_rt = blur(torch.abs(cfg_x)) + omega_rs = blur(torch.abs(pag_x)) + soft_rt = torch.softmax(omega_rt, dim=0) + soft_rs = torch.softmax(omega_rs, dim=0) + + m = torch.stack([soft_rt, soft_rs], dim=0) # 2 c h w + _, argmax_indices = torch.max(m, dim=0) + + # select from cfg_x or pag_x + m1 = torch.where(argmax_indices == 0, 1, 0) + + # hadamard product + sal_cfg = cfg_x * m1 + pag_x * (1 - m1) + + denoised[i] += sal_cfg except Exception as e: logger.exception("Exception in combine_denoised_pass_conds_list - %s", e) diff --git a/scripts/pag.py b/scripts/pag.py index 20e067a..c34ee7d 100644 --- a/scripts/pag.py +++ b/scripts/pag.py @@ -66,6 +66,16 @@ primaryClass={cs.CV} } +Saliency-adaptive noise fusion from arXiv:2311.10329 "High-fidelity Person-centric Subject-to-Image Synthesis" +@misc{wang2024highfidelity, + title={High-fidelity Person-centric Subject-to-Image Synthesis}, + author={Yibin Wang and Weizhong Zhang and Jianwei Zheng and Cheng Jin}, + year={2024}, + eprint={2311.10329}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} + Author: v0xie GitHub URL: https://github.com/v0xie/sd-webui-incantations @@ -99,6 +109,7 @@ class PAGStateParams: def __init__(self): self.pag_active: bool = False # PAG guidance scale + self.pag_sanf: bool = False # saliency-adaptive noise fusion, handled in cfg_combiner self.pag_scale: int = -1 # PAG guidance scale self.pag_start_step: int = 0 self.pag_end_step: int = 150 @@ -145,6 +156,7 @@ def show(self, is_img2img): def setup_ui(self, is_img2img) -> list: with gr.Accordion('Perturbed Attention Guidance', open=False): active = gr.Checkbox(value=False, default=False, label="Active", elem_id='pag_active') + pag_sanf = gr.Checkbox(value=False, default=False, label="Use Saliency-Adaptive Noise Fusion", elem_id='pag_sanf') with gr.Row(): pag_scale = gr.Slider(value = 0, minimum = 0, maximum = 20.0, step = 0.5, label="PAG Scale", elem_id = 'pag_scale', info="") with gr.Row(): @@ -164,6 +176,7 @@ def setup_ui(self, is_img2img) -> list: cfg_interval_high = gr.Slider(value = 100, minimum = 0, maximum = 100, step = 0.1, label="CFG Noise Interval High", elem_id = 'cfg_interval_high', info="") active.do_not_save_to_config = True + pag_sanf.do_not_save_to_config = True pag_scale.do_not_save_to_config = True start_step.do_not_save_to_config = True end_step.do_not_save_to_config = True @@ -173,6 +186,7 @@ def setup_ui(self, is_img2img) -> list: cfg_interval_high.do_not_save_to_config = True self.infotext_fields = [ (active, lambda d: gr.Checkbox.update(value='PAG Active' in d)), + (pag_sanf, lambda d: gr.Checkbox.update(value='PAG SANF' in d)), (pag_scale, 'PAG Scale'), (start_step, 'PAG Start Step'), (end_step, 'PAG End Step'), @@ -183,6 +197,7 @@ def setup_ui(self, is_img2img) -> list: ] self.paste_field_names = [ 'pag_active', + 'pag_sanf', 'pag_scale', 'pag_start_step', 'pag_end_step', @@ -191,17 +206,18 @@ def setup_ui(self, is_img2img) -> list: 'cfg_interval_low', 'cfg_interval_high', ] - return [active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high] + return [active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf] def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs): self.pag_process_batch(p, *args, **kwargs) - def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, *args, **kwargs): + def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf, *args, **kwargs): # cleanup previous hooks always script_callbacks.remove_current_script_callbacks() self.remove_all_hooks() active = getattr(p, "pag_active", active) + pag_sanf = getattr(p, "pag_sanf", pag_sanf) cfg_interval_enable = getattr(p, "cfg_interval_enable", cfg_interval_enable) if active is False and cfg_interval_enable is False: return @@ -216,6 +232,7 @@ def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, sta if active: p.extra_generation_params.update({ "PAG Active": active, + "PAG SANF": pag_sanf, "PAG Scale": pag_scale, "PAG Start Step": start_step, "PAG End Step": end_step, @@ -227,9 +244,9 @@ def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, sta "CFG Interval Low": cfg_interval_low, "CFG Interval High": cfg_interval_high }) - self.create_hook(p, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high) + self.create_hook(p, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf) - def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, *args, **kwargs): + def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf, *args, **kwargs): # Create a list of parameters for each concept pag_params = PAGStateParams() @@ -239,6 +256,7 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste p.incant_cfg_params['pag_params'] = pag_params pag_params.pag_active = active + pag_params.pag_sanf = pag_sanf pag_params.pag_scale = pag_scale pag_params.pag_start_step = start_step pag_params.pag_end_step = end_step @@ -359,7 +377,7 @@ def pag_pre_hook(module, input, kwargs, output): batch_size, seq_len, inner_dim = output.shape identity = torch.eye(seq_len, dtype=last_to_v.dtype, device=shared.device).expand(batch_size, -1, -1) if last_to_v is not None: - new_output = torch.einsum('bij,bjk->bik', identity, last_to_v) + new_output = torch.einsum('bij,bjk->bik', identity, last_to_v[:, :seq_len, :]) return new_output else: # this is bad @@ -499,6 +517,7 @@ def get_xyz_axis_options(self) -> dict: xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module extra_axis_options = { xyz_grid.AxisOption("[PAG] Active", str, pag_apply_override('pag_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[PAG] SANF", str, pag_apply_override('pag_sanf', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), xyz_grid.AxisOption("[PAG] PAG Scale", float, pag_apply_field("pag_scale")), xyz_grid.AxisOption("[PAG] PAG Start Step", int, pag_apply_field("pag_start_step")), xyz_grid.AxisOption("[PAG] PAG End Step", int, pag_apply_field("pag_end_step")),