Skip to content
Merged

Dev #52

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,15 @@ Prompt: "A cute puppy on the moon", Min Rate: 0.5, Max Rate: 10.0
https://arxiv.org/abs/2403.17377
An alternative/complementary method to CFG (Classifier-Free Guidance) that increases sampling quality.

# Update: 20-05-2024
Implemented a new feature called "Saliency-Adaptive Noise Fusion" derived from ["High-fidelity Person-centric Subject-to-Image Synthesis"](https://arxiv.org/abs/2311.10329).

This feature combines the guidance from PAG and CFG in an adaptive way that improves image quality especially at higher guidance scales.

Check out the paper authors' project repository here: https://github.com/CodeGoat24/Face-diffuser

#### Controls
* **Use Saliency-Adaptive Noise Fusion**: Use improved method of combining CFG + PAG.
* **PAG Scale**: Controls the intensity of effect of PAG on the generated image.
* **PAG Start Step**: Step to start using PAG.
* **PAG End Step**: Step to stop using PAG.
Expand Down Expand Up @@ -286,6 +294,15 @@ SD XL
archivePrefix={arXiv},
primaryClass={cs.CV}
}

@misc{wang2024highfidelity,
title={High-fidelity Person-centric Subject-to-Image Synthesis},
author={Yibin Wang and Weizhong Zhang and Jianwei Zheng and Cheng Jin},
year={2024},
eprint={2311.10329},
archivePrefix={arXiv},
primaryClass={cs.CV}
}


- [Hard Prompts Made Easy](https://github.com/YuxinWenRick/hard-prompts-made-easy)
Expand Down
57 changes: 48 additions & 9 deletions scripts/cfg_combiner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gradio as gr
import logging
import torch
import torchvision.transforms as F
from modules import shared, scripts, devices, patches, script_callbacks
from modules.script_callbacks import CFGDenoiserParams
from modules.processing import StableDiffusionProcessing
Expand Down Expand Up @@ -179,11 +180,29 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
# 2. PAG
pag_x_out = None
pag_scale = None
run_pag = False
if pag_params is not None:
pag_active = pag_params.pag_active
pag_x_out = pag_params.pag_x_out
pag_scale = pag_params.pag_scale

if not pag_active:
pass
# Not within step interval?
elif not pag_params.pag_start_step <= pag_params.step <= pag_params.pag_end_step:
pass
# Scale is zero?
elif pag_scale <= 0:
pass
else:
run_pag = pag_active

# 3. Saliency Map
use_saliency_map = False
if pag_params is not None:
use_saliency_map = pag_params.pag_sanf


### Combine Denoised
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
Expand All @@ -209,24 +228,44 @@ def new_combine_denoised(x_out, conds_list, uncond, cond_scale):
pass

# 1. Experimental formulation for S-CFG combined with CFG
denoised[i] += (model_delta) * rate * (weight * cfg_scale)
cfg_x = (model_delta) * rate * (weight * cfg_scale)
if not use_saliency_map or not run_pag:
denoised[i] += cfg_x
del rate

# 2. PAG
# PAG is added like CFG
if pag_params is not None:
if not pag_active:
pass
# Not within step interval?
elif not pag_params.pag_start_step <= pag_params.step <= pag_params.pag_end_step:
pass
# Scale is zero?
elif pag_scale <= 0:
if not run_pag:
pass
# do pag
else:
try:
denoised[i] += (x_out[cond_index] - pag_x_out[i]) * (weight * pag_scale)
pag_delta = x_out[cond_index] - pag_x_out[i]
pag_x = pag_delta * (weight * pag_scale)

if not use_saliency_map:
denoised[i] += pag_x

# 3. Saliency Adaptive Noise Fusion arXiv.2311.10329v5
# Smooth the saliency maps
if use_saliency_map:
blur = F.GaussianBlur(kernel_size=3, sigma=1).to(device=shared.device)
omega_rt = blur(torch.abs(cfg_x))
omega_rs = blur(torch.abs(pag_x))
soft_rt = torch.softmax(omega_rt, dim=0)
soft_rs = torch.softmax(omega_rs, dim=0)

m = torch.stack([soft_rt, soft_rs], dim=0) # 2 c h w
_, argmax_indices = torch.max(m, dim=0)

# select from cfg_x or pag_x
m1 = torch.where(argmax_indices == 0, 1, 0)

# hadamard product
sal_cfg = cfg_x * m1 + pag_x * (1 - m1)

denoised[i] += sal_cfg
except Exception as e:
logger.exception("Exception in combine_denoised_pass_conds_list - %s", e)

Expand Down
29 changes: 24 additions & 5 deletions scripts/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@
primaryClass={cs.CV}
}

Saliency-adaptive noise fusion from arXiv:2311.10329 "High-fidelity Person-centric Subject-to-Image Synthesis"
@misc{wang2024highfidelity,
title={High-fidelity Person-centric Subject-to-Image Synthesis},
author={Yibin Wang and Weizhong Zhang and Jianwei Zheng and Cheng Jin},
year={2024},
eprint={2311.10329},
archivePrefix={arXiv},
primaryClass={cs.CV}
}

Author: v0xie
GitHub URL: https://github.com/v0xie/sd-webui-incantations

Expand Down Expand Up @@ -99,6 +109,7 @@
class PAGStateParams:
def __init__(self):
self.pag_active: bool = False # PAG guidance scale
self.pag_sanf: bool = False # saliency-adaptive noise fusion, handled in cfg_combiner
self.pag_scale: int = -1 # PAG guidance scale
self.pag_start_step: int = 0
self.pag_end_step: int = 150
Expand Down Expand Up @@ -145,6 +156,7 @@ def show(self, is_img2img):
def setup_ui(self, is_img2img) -> list:
with gr.Accordion('Perturbed Attention Guidance', open=False):
active = gr.Checkbox(value=False, default=False, label="Active", elem_id='pag_active')
pag_sanf = gr.Checkbox(value=False, default=False, label="Use Saliency-Adaptive Noise Fusion", elem_id='pag_sanf')
with gr.Row():
pag_scale = gr.Slider(value = 0, minimum = 0, maximum = 20.0, step = 0.5, label="PAG Scale", elem_id = 'pag_scale', info="")
with gr.Row():
Expand All @@ -164,6 +176,7 @@ def setup_ui(self, is_img2img) -> list:
cfg_interval_high = gr.Slider(value = 100, minimum = 0, maximum = 100, step = 0.1, label="CFG Noise Interval High", elem_id = 'cfg_interval_high', info="")

active.do_not_save_to_config = True
pag_sanf.do_not_save_to_config = True
pag_scale.do_not_save_to_config = True
start_step.do_not_save_to_config = True
end_step.do_not_save_to_config = True
Expand All @@ -173,6 +186,7 @@ def setup_ui(self, is_img2img) -> list:
cfg_interval_high.do_not_save_to_config = True
self.infotext_fields = [
(active, lambda d: gr.Checkbox.update(value='PAG Active' in d)),
(pag_sanf, lambda d: gr.Checkbox.update(value='PAG SANF' in d)),
(pag_scale, 'PAG Scale'),
(start_step, 'PAG Start Step'),
(end_step, 'PAG End Step'),
Expand All @@ -183,6 +197,7 @@ def setup_ui(self, is_img2img) -> list:
]
self.paste_field_names = [
'pag_active',
'pag_sanf',
'pag_scale',
'pag_start_step',
'pag_end_step',
Expand All @@ -191,17 +206,18 @@ def setup_ui(self, is_img2img) -> list:
'cfg_interval_low',
'cfg_interval_high',
]
return [active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high]
return [active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf]

def process_batch(self, p: StableDiffusionProcessing, *args, **kwargs):
self.pag_process_batch(p, *args, **kwargs)

def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, *args, **kwargs):
def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf, *args, **kwargs):
# cleanup previous hooks always
script_callbacks.remove_current_script_callbacks()
self.remove_all_hooks()

active = getattr(p, "pag_active", active)
pag_sanf = getattr(p, "pag_sanf", pag_sanf)
cfg_interval_enable = getattr(p, "cfg_interval_enable", cfg_interval_enable)
if active is False and cfg_interval_enable is False:
return
Expand All @@ -216,6 +232,7 @@ def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, sta
if active:
p.extra_generation_params.update({
"PAG Active": active,
"PAG SANF": pag_sanf,
"PAG Scale": pag_scale,
"PAG Start Step": start_step,
"PAG End Step": end_step,
Expand All @@ -227,9 +244,9 @@ def pag_process_batch(self, p: StableDiffusionProcessing, active, pag_scale, sta
"CFG Interval Low": cfg_interval_low,
"CFG Interval High": cfg_interval_high
})
self.create_hook(p, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high)
self.create_hook(p, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf)

def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, *args, **kwargs):
def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_step, end_step, cfg_interval_enable, cfg_schedule, cfg_interval_low, cfg_interval_high, pag_sanf, *args, **kwargs):
# Create a list of parameters for each concept
pag_params = PAGStateParams()

Expand All @@ -239,6 +256,7 @@ def create_hook(self, p: StableDiffusionProcessing, active, pag_scale, start_ste
p.incant_cfg_params['pag_params'] = pag_params

pag_params.pag_active = active
pag_params.pag_sanf = pag_sanf
pag_params.pag_scale = pag_scale
pag_params.pag_start_step = start_step
pag_params.pag_end_step = end_step
Expand Down Expand Up @@ -359,7 +377,7 @@ def pag_pre_hook(module, input, kwargs, output):
batch_size, seq_len, inner_dim = output.shape
identity = torch.eye(seq_len, dtype=last_to_v.dtype, device=shared.device).expand(batch_size, -1, -1)
if last_to_v is not None:
new_output = torch.einsum('bij,bjk->bik', identity, last_to_v)
new_output = torch.einsum('bij,bjk->bik', identity, last_to_v[:, :seq_len, :])
return new_output
else:
# this is bad
Expand Down Expand Up @@ -499,6 +517,7 @@ def get_xyz_axis_options(self) -> dict:
xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module
extra_axis_options = {
xyz_grid.AxisOption("[PAG] Active", str, pag_apply_override('pag_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[PAG] SANF", str, pag_apply_override('pag_sanf', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),
xyz_grid.AxisOption("[PAG] PAG Scale", float, pag_apply_field("pag_scale")),
xyz_grid.AxisOption("[PAG] PAG Start Step", int, pag_apply_field("pag_start_step")),
xyz_grid.AxisOption("[PAG] PAG End Step", int, pag_apply_field("pag_end_step")),
Expand Down