diff --git a/scripts/pag.py b/scripts/pag.py index bacffe9..1e96a26 100644 --- a/scripts/pag.py +++ b/scripts/pag.py @@ -333,11 +333,11 @@ def pag_pre_hook(module, input, kwargs, output): # oops we forgot to unhook return - batch_size, seq_len, inner_dim = output.shape - identity = torch.eye(seq_len).expand(batch_size, -1, -1).to(shared.device) - # get the last to_v output and save it last_to_v = getattr(module, 'pag_last_to_v', None) + + batch_size, seq_len, inner_dim = output.shape + identity = torch.eye(seq_len, dtype=last_to_v.dtype, device=shared.device).expand(batch_size, -1, -1) if last_to_v is not None: new_output = torch.einsum('bij,bjk->bik', identity, last_to_v) return new_output @@ -836,4 +836,4 @@ def _remove_child_hooks( _remove_child_hooks(module, hook_fn_name) # Remove hooks from the target module - _remove_hooks(module, hook_fn_name) \ No newline at end of file + _remove_hooks(module, hook_fn_name) diff --git a/scripts/t2i_zero.py b/scripts/t2i_zero.py index d768916..35ed728 100644 --- a/scripts/t2i_zero.py +++ b/scripts/t2i_zero.py @@ -407,15 +407,16 @@ def ready_hijack_forward(self, alpha, width, height, ema_factor, step_start, ste plot_num = 0 for module in cross_attn_modules: self.add_field_cross_attn_modules(module, 't2i0_last_attn_map', None) - self.add_field_cross_attn_modules(module, 't2i0_step', torch.tensor([-1]).to(device=shared.device)) - self.add_field_cross_attn_modules(module, 't2i0_step_start', torch.tensor([step_start]).to(device=shared.device)) - self.add_field_cross_attn_modules(module, 't2i0_step_end', torch.tensor([step_end]).to(device=shared.device)) + self.add_field_cross_attn_modules(module, 't2i0_step', int(-1)) + self.add_field_cross_attn_modules(module, 't2i0_step_start', int(step_start)) + self.add_field_cross_attn_modules(module, 't2i0_step_end', int(step_end)) self.add_field_cross_attn_modules(module, 't2i0_ema', None) - self.add_field_cross_attn_modules(module, 't2i0_ema_factor', torch.tensor([ema_factor]).to(device=shared.device, dtype=torch.float16)) - self.add_field_cross_attn_modules(module, 'plot_num', torch.tensor([plot_num]).to(device=shared.device)) + self.add_field_cross_attn_modules(module, 't2i0_ema_factor', float(ema_factor)) + self.add_field_cross_attn_modules(module, 'plot_num', int(plot_num)) self.add_field_cross_attn_modules(module, 't2i0_to_v_map', None) self.add_field_cross_attn_modules(module.to_v, 't2i0_parent_module', [module]) - self.add_field_cross_attn_modules(module, 't2i0_token_count', torch.tensor(token_count).to(device=shared.device, dtype=torch.int64)) + self.add_field_cross_attn_modules(module, 't2i0_token_count', int(token_count)) + self.add_field_cross_attn_modules(module, 'gaussian_blur', GaussianBlur(kernel_size=3, sigma=1).to(device=shared.device)) if tokens is not None: self.add_field_cross_attn_modules(module, 't2i0_tokens', torch.tensor(tokens).to(device=shared.device, dtype=torch.int64)) else: @@ -476,9 +477,9 @@ def cross_token_non_maximum_suppression(module, input, kwargs, output): attention_map = output.view(batch_size, downscale_height, downscale_width, inner_dim) if token_indices is None: - selected_tokens = torch.tensor(list(range(1, token_count.item()))) + selected_tokens = torch.arange(1, token_count, device=output.device) elif len(token_indices) == 0: - selected_tokens = torch.tensor(list(range(1, token_count.item()))) + selected_tokens = torch.arange(1, token_count, device=output.device) else: selected_tokens = module.t2i0_tokens @@ -490,7 +491,7 @@ def cross_token_non_maximum_suppression(module, input, kwargs, output): # Extract and process the selected attention maps # GaussianBlur expects the input [..., C, H, W] - gaussian_blur = GaussianBlur(kernel_size=3, sigma=1) + gaussian_blur = module.gaussian_blur AC = AC.permute(0, 3, 1, 2) AC = gaussian_blur(AC) # Applying Gaussian smoothing AC = AC.permute(0, 2, 3, 1)