Skip to content

Commit a0c7bd0

Browse files
Merge pull request #4 from dtan3847/dynthresh1
Try dynamic thresholding 3962
2 parents 32246ca + b7826ac commit a0c7bd0

File tree

7 files changed

+107
-25
lines changed

7 files changed

+107
-25
lines changed

modules/img2img.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
5959
processed_image.save(os.path.join(output_dir, filename))
6060

6161

62-
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
62+
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, mimic_scale: float, threshold_enable: bool, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
6363
is_inpaint = mode == 1
6464
is_batch = mode == 2
6565

@@ -117,6 +117,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
117117
n_iter=n_iter,
118118
steps=steps,
119119
cfg_scale=cfg_scale,
120+
mimic_scale=mimic_scale,
121+
threshold_enable=threshold_enable,
120122
width=width,
121123
height=height,
122124
restore_faces=restore_faces,

modules/processing.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class StableDiffusionProcessing():
8080
"""
8181
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
8282
"""
83-
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
83+
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, mimic_scale: float = 7.5, threshold_enable: bool = False, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, sampler_index: int = None):
8484
if sampler_index is not None:
8585
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
8686

@@ -101,6 +101,8 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
101101
self.n_iter: int = n_iter
102102
self.steps: int = steps
103103
self.cfg_scale: float = cfg_scale
104+
self.mimic_scale: float = mimic_scale
105+
self.threshold_enable: float = threshold_enable
104106
self.width: int = width
105107
self.height: int = height
106108
self.restore_faces: bool = restore_faces
@@ -129,6 +131,9 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
129131
self.seed_resize_from_h = 0
130132
self.seed_resize_from_w = 0
131133

134+
if not threshold_enable:
135+
self.mimic_scale = 0
136+
132137
self.scripts = None
133138
self.script_args = None
134139
self.all_prompts = None
@@ -250,6 +255,7 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="",
250255
self.height = p.height
251256
self.sampler_name = p.sampler_name
252257
self.cfg_scale = p.cfg_scale
258+
self.mimic_scale = p.mimic_scale
253259
self.steps = p.steps
254260
self.batch_size = p.batch_size
255261
self.restore_faces = p.restore_faces
@@ -298,6 +304,7 @@ def js(self):
298304
"height": self.height,
299305
"sampler_name": self.sampler_name,
300306
"cfg_scale": self.cfg_scale,
307+
"mimic_scale": self.mimic_scale,
301308
"steps": self.steps,
302309
"batch_size": self.batch_size,
303310
"restore_faces": self.restore_faces,
@@ -443,6 +450,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
443450
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
444451
"Clip skip": None if clip_skip <= 1 else clip_skip,
445452
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
453+
"Mimic CFG scale": None if not p.threshold_enable else p.mimic_scale,
454+
"Threshold percentile": None if not p.threshold_enable else opts.dynamic_threshold_percentile,
446455
}
447456

448457
generation_params.update(p.extra_generation_params)
@@ -699,11 +708,11 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
699708

700709
if not self.enable_hr:
701710
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
702-
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
711+
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x), mimic_scale=self.mimic_scale, threshold_enable=self.threshold_enable)
703712
return samples
704713

705714
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
706-
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
715+
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height), mimic_scale=self.mimic_scale, threshold_enable=self.threshold_enable)
707716

708717
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
709718

@@ -905,7 +914,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
905914
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
906915
x *= self.initial_noise_multiplier
907916

908-
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
917+
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning, mimic_scale=self.mimic_scale, threshold_enable=self.threshold_enable)
909918

910919
if self.mask is not None:
911920
samples = samples * self.nmask + self.init_latent * self.mask

modules/sd_samplers.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def adjust_steps_if_invalid(self, p, num_steps):
237237

238238
return num_steps
239239

240-
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
240+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None, mimic_scale=None, threshold_enable=False):
241241
steps, t_enc = setup_img2img_steps(p, steps)
242242
steps = self.adjust_steps_if_invalid(p, steps)
243243
self.initialize(p)
@@ -259,7 +259,7 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
259259

260260
return samples
261261

262-
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
262+
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None, mimic_scale=None, threshold_enable=False):
263263
self.initialize(p)
264264

265265
self.init_latent = None
@@ -288,7 +288,39 @@ def __init__(self, model):
288288
self.init_latent = None
289289
self.step = 0
290290

291-
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
291+
def _dynthresh(self, cond, uncond, cond_scale, conds_list, mimic_scale):
292+
# uncond shape is (batch, 4, height, width)
293+
conds_per_batch = cond.shape[0] / uncond.shape[0]
294+
assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches"
295+
cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:])
296+
diff = cond_stacked - uncond.unsqueeze(1)
297+
# conds_list shape is (batch, cond, 2)
298+
weights = torch.tensor(conds_list).select(2, 1)
299+
weights = weights.reshape(*weights.shape, 1, 1, 1).to(diff.device)
300+
diff_weighted = (diff * weights).sum(1)
301+
dynthresh_target = uncond + diff_weighted * mimic_scale
302+
303+
dt_flattened = dynthresh_target.flatten(2)
304+
dt_means = dt_flattened.mean(dim=2).unsqueeze(2)
305+
dt_recentered = dt_flattened - dt_means
306+
dt_max = dt_recentered.abs().max(dim=2).values.unsqueeze(2)
307+
308+
ut = uncond + diff_weighted * cond_scale
309+
ut_flattened = ut.flatten(2)
310+
ut_means = ut_flattened.mean(dim=2).unsqueeze(2)
311+
ut_centered = ut_flattened - ut_means
312+
313+
ut_q = torch.quantile(ut_centered.abs(), opts.dynamic_threshold_percentile, dim=2).unsqueeze(2)
314+
s = torch.maximum(ut_q, dt_max)
315+
t_clamped = ut_centered.clamp(-s, s)
316+
t_normalized = t_clamped / s
317+
t_renormalized = t_normalized * dt_max
318+
319+
uncentered = t_renormalized + ut_means
320+
unflattened = uncentered.unflatten(2, dynthresh_target.shape[2:])
321+
return unflattened
322+
323+
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, mimic_scale, threshold_enable):
292324
if state.interrupted or state.skipped:
293325
raise InterruptedException
294326

@@ -330,11 +362,14 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
330362
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
331363

332364
denoised_uncond = x_out[-uncond.shape[0]:]
333-
denoised = torch.clone(denoised_uncond)
365+
if threshold_enable:
366+
denoised = self._dynthresh(x_out[:-uncond.shape[0]], denoised_uncond, cond_scale, conds_list, mimic_scale)
367+
else:
368+
denoised = torch.clone(denoised_uncond)
334369

335-
for i, conds in enumerate(conds_list):
336-
for cond_index, weight in conds:
337-
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
370+
for i, conds in enumerate(conds_list):
371+
for cond_index, weight in conds:
372+
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
338373

339374
if self.mask is not None:
340375
denoised = self.init_latent * self.mask + self.nmask * denoised
@@ -444,7 +479,7 @@ def initialize(self, p):
444479

445480
return extra_params_kwargs
446481

447-
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
482+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None, mimic_scale=None, threshold_enable=False):
448483
steps, t_enc = setup_img2img_steps(p, steps)
449484

450485
if p.sampler_noise_scheduler_override:
@@ -477,12 +512,14 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
477512
'cond': conditioning,
478513
'image_cond': image_conditioning,
479514
'uncond': unconditional_conditioning,
480-
'cond_scale': p.cfg_scale
515+
'cond_scale': p.cfg_scale,
516+
'mimic_scale': mimic_scale,
517+
'threshold_enable': threshold_enable,
481518
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
482519

483520
return samples
484521

485-
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
522+
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None, mimic_scale=None, threshold_enable=False):
486523
steps = steps or p.steps
487524

488525
if p.sampler_noise_scheduler_override:
@@ -508,7 +545,9 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
508545
'cond': conditioning,
509546
'image_cond': image_conditioning,
510547
'uncond': unconditional_conditioning,
511-
'cond_scale': p.cfg_scale
548+
'cond_scale': p.cfg_scale,
549+
'mimic_scale': mimic_scale,
550+
'threshold_enable': threshold_enable,
512551
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
513552

514553
return samples

modules/shared.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ def list_samplers():
369369
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
370370
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
371371
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
372+
"dynamic_threshold_percentile": OptionInfo(0.999, "For latent fix, the top percentile of latents to clamp (ex: 0.999 means the top 0.1% is clamped)", gr.Slider, {"minimum": 0.9, "maximum": 1.0, "step": 0.0005})
372373
}))
373374

374375
options_templates.update(options_section(('interrogate', "Interrogate Options"), {

modules/txt2img.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from modules.ui import plaintext_to_html
99

1010

11-
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
11+
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, mimic_scale: float, threshold_enable: bool, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
1212
p = StableDiffusionProcessingTxt2Img(
1313
sd_model=shared.sd_model,
1414
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -27,6 +27,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
2727
n_iter=n_iter,
2828
steps=steps,
2929
cfg_scale=cfg_scale,
30+
mimic_scale=mimic_scale,
31+
threshold_enable=threshold_enable,
3032
width=width,
3133
height=height,
3234
restore_faces=restore_faces,

0 commit comments

Comments
 (0)