@@ -237,7 +237,7 @@ def adjust_steps_if_invalid(self, p, num_steps):
237237
238238 return num_steps
239239
240- def sample_img2img (self , p , x , noise , conditioning , unconditional_conditioning , steps = None , image_conditioning = None ):
240+ def sample_img2img (self , p , x , noise , conditioning , unconditional_conditioning , steps = None , image_conditioning = None , mimic_scale = None , threshold_enable = False ):
241241 steps , t_enc = setup_img2img_steps (p , steps )
242242 steps = self .adjust_steps_if_invalid (p , steps )
243243 self .initialize (p )
@@ -259,7 +259,7 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
259259
260260 return samples
261261
262- def sample (self , p , x , conditioning , unconditional_conditioning , steps = None , image_conditioning = None ):
262+ def sample (self , p , x , conditioning , unconditional_conditioning , steps = None , image_conditioning = None , mimic_scale = None , threshold_enable = False ):
263263 self .initialize (p )
264264
265265 self .init_latent = None
@@ -288,7 +288,39 @@ def __init__(self, model):
288288 self .init_latent = None
289289 self .step = 0
290290
291- def forward (self , x , sigma , uncond , cond , cond_scale , image_cond ):
291+ def _dynthresh (self , cond , uncond , cond_scale , conds_list , mimic_scale ):
292+ # uncond shape is (batch, 4, height, width)
293+ conds_per_batch = cond .shape [0 ] / uncond .shape [0 ]
294+ assert conds_per_batch == int (conds_per_batch ), "Expected # of conds per batch to be constant across batches"
295+ cond_stacked = cond .reshape ((- 1 , int (conds_per_batch )) + uncond .shape [1 :])
296+ diff = cond_stacked - uncond .unsqueeze (1 )
297+ # conds_list shape is (batch, cond, 2)
298+ weights = torch .tensor (conds_list ).select (2 , 1 )
299+ weights = weights .reshape (* weights .shape , 1 , 1 , 1 ).to (diff .device )
300+ diff_weighted = (diff * weights ).sum (1 )
301+ dynthresh_target = uncond + diff_weighted * mimic_scale
302+
303+ dt_flattened = dynthresh_target .flatten (2 )
304+ dt_means = dt_flattened .mean (dim = 2 ).unsqueeze (2 )
305+ dt_recentered = dt_flattened - dt_means
306+ dt_max = dt_recentered .abs ().max (dim = 2 ).values .unsqueeze (2 )
307+
308+ ut = uncond + diff_weighted * cond_scale
309+ ut_flattened = ut .flatten (2 )
310+ ut_means = ut_flattened .mean (dim = 2 ).unsqueeze (2 )
311+ ut_centered = ut_flattened - ut_means
312+
313+ ut_q = torch .quantile (ut_centered .abs (), opts .dynamic_threshold_percentile , dim = 2 ).unsqueeze (2 )
314+ s = torch .maximum (ut_q , dt_max )
315+ t_clamped = ut_centered .clamp (- s , s )
316+ t_normalized = t_clamped / s
317+ t_renormalized = t_normalized * dt_max
318+
319+ uncentered = t_renormalized + ut_means
320+ unflattened = uncentered .unflatten (2 , dynthresh_target .shape [2 :])
321+ return unflattened
322+
323+ def forward (self , x , sigma , uncond , cond , cond_scale , image_cond , mimic_scale , threshold_enable ):
292324 if state .interrupted or state .skipped :
293325 raise InterruptedException
294326
@@ -330,11 +362,14 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
330362 x_out [- uncond .shape [0 ]:] = self .inner_model (x_in [- uncond .shape [0 ]:], sigma_in [- uncond .shape [0 ]:], cond = {"c_crossattn" : [uncond ], "c_concat" : [image_cond_in [- uncond .shape [0 ]:]]})
331363
332364 denoised_uncond = x_out [- uncond .shape [0 ]:]
333- denoised = torch .clone (denoised_uncond )
365+ if threshold_enable :
366+ denoised = self ._dynthresh (x_out [:- uncond .shape [0 ]], denoised_uncond , cond_scale , conds_list , mimic_scale )
367+ else :
368+ denoised = torch .clone (denoised_uncond )
334369
335- for i , conds in enumerate (conds_list ):
336- for cond_index , weight in conds :
337- denoised [i ] += (x_out [cond_index ] - denoised_uncond [i ]) * (weight * cond_scale )
370+ for i , conds in enumerate (conds_list ):
371+ for cond_index , weight in conds :
372+ denoised [i ] += (x_out [cond_index ] - denoised_uncond [i ]) * (weight * cond_scale )
338373
339374 if self .mask is not None :
340375 denoised = self .init_latent * self .mask + self .nmask * denoised
@@ -444,7 +479,7 @@ def initialize(self, p):
444479
445480 return extra_params_kwargs
446481
447- def sample_img2img (self , p , x , noise , conditioning , unconditional_conditioning , steps = None , image_conditioning = None ):
482+ def sample_img2img (self , p , x , noise , conditioning , unconditional_conditioning , steps = None , image_conditioning = None , mimic_scale = None , threshold_enable = False ):
448483 steps , t_enc = setup_img2img_steps (p , steps )
449484
450485 if p .sampler_noise_scheduler_override :
@@ -477,12 +512,14 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
477512 'cond' : conditioning ,
478513 'image_cond' : image_conditioning ,
479514 'uncond' : unconditional_conditioning ,
480- 'cond_scale' : p .cfg_scale
515+ 'cond_scale' : p .cfg_scale ,
516+ 'mimic_scale' : mimic_scale ,
517+ 'threshold_enable' : threshold_enable ,
481518 }, disable = False , callback = self .callback_state , ** extra_params_kwargs ))
482519
483520 return samples
484521
485- def sample (self , p , x , conditioning , unconditional_conditioning , steps = None , image_conditioning = None ):
522+ def sample (self , p , x , conditioning , unconditional_conditioning , steps = None , image_conditioning = None , mimic_scale = None , threshold_enable = False ):
486523 steps = steps or p .steps
487524
488525 if p .sampler_noise_scheduler_override :
@@ -508,7 +545,9 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
508545 'cond' : conditioning ,
509546 'image_cond' : image_conditioning ,
510547 'uncond' : unconditional_conditioning ,
511- 'cond_scale' : p .cfg_scale
548+ 'cond_scale' : p .cfg_scale ,
549+ 'mimic_scale' : mimic_scale ,
550+ 'threshold_enable' : threshold_enable ,
512551 }, disable = False , callback = self .callback_state , ** extra_params_kwargs ))
513552
514553 return samples
0 commit comments