Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions scripts/pag.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,11 @@ def pag_pre_hook(module, input, kwargs, output):
# oops we forgot to unhook
return

batch_size, seq_len, inner_dim = output.shape
identity = torch.eye(seq_len).expand(batch_size, -1, -1).to(shared.device)

# get the last to_v output and save it
last_to_v = getattr(module, 'pag_last_to_v', None)

batch_size, seq_len, inner_dim = output.shape
identity = torch.eye(seq_len, dtype=last_to_v.dtype, device=shared.device).expand(batch_size, -1, -1)
if last_to_v is not None:
new_output = torch.einsum('bij,bjk->bik', identity, last_to_v)
return new_output
Expand Down Expand Up @@ -836,4 +836,4 @@ def _remove_child_hooks(
_remove_child_hooks(module, hook_fn_name)

# Remove hooks from the target module
_remove_hooks(module, hook_fn_name)
_remove_hooks(module, hook_fn_name)
19 changes: 10 additions & 9 deletions scripts/t2i_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,15 +407,16 @@ def ready_hijack_forward(self, alpha, width, height, ema_factor, step_start, ste
plot_num = 0
for module in cross_attn_modules:
self.add_field_cross_attn_modules(module, 't2i0_last_attn_map', None)
self.add_field_cross_attn_modules(module, 't2i0_step', torch.tensor([-1]).to(device=shared.device))
self.add_field_cross_attn_modules(module, 't2i0_step_start', torch.tensor([step_start]).to(device=shared.device))
self.add_field_cross_attn_modules(module, 't2i0_step_end', torch.tensor([step_end]).to(device=shared.device))
self.add_field_cross_attn_modules(module, 't2i0_step', int(-1))
self.add_field_cross_attn_modules(module, 't2i0_step_start', int(step_start))
self.add_field_cross_attn_modules(module, 't2i0_step_end', int(step_end))
self.add_field_cross_attn_modules(module, 't2i0_ema', None)
self.add_field_cross_attn_modules(module, 't2i0_ema_factor', torch.tensor([ema_factor]).to(device=shared.device, dtype=torch.float16))
self.add_field_cross_attn_modules(module, 'plot_num', torch.tensor([plot_num]).to(device=shared.device))
self.add_field_cross_attn_modules(module, 't2i0_ema_factor', float(ema_factor))
self.add_field_cross_attn_modules(module, 'plot_num', int(plot_num))
self.add_field_cross_attn_modules(module, 't2i0_to_v_map', None)
self.add_field_cross_attn_modules(module.to_v, 't2i0_parent_module', [module])
self.add_field_cross_attn_modules(module, 't2i0_token_count', torch.tensor(token_count).to(device=shared.device, dtype=torch.int64))
self.add_field_cross_attn_modules(module, 't2i0_token_count', int(token_count))
self.add_field_cross_attn_modules(module, 'gaussian_blur', GaussianBlur(kernel_size=3, sigma=1).to(device=shared.device))
if tokens is not None:
self.add_field_cross_attn_modules(module, 't2i0_tokens', torch.tensor(tokens).to(device=shared.device, dtype=torch.int64))
else:
Expand Down Expand Up @@ -476,9 +477,9 @@ def cross_token_non_maximum_suppression(module, input, kwargs, output):
attention_map = output.view(batch_size, downscale_height, downscale_width, inner_dim)

if token_indices is None:
selected_tokens = torch.tensor(list(range(1, token_count.item())))
selected_tokens = torch.arange(1, token_count, device=output.device)
elif len(token_indices) == 0:
selected_tokens = torch.tensor(list(range(1, token_count.item())))
selected_tokens = torch.arange(1, token_count, device=output.device)
else:
selected_tokens = module.t2i0_tokens

Expand All @@ -490,7 +491,7 @@ def cross_token_non_maximum_suppression(module, input, kwargs, output):

# Extract and process the selected attention maps
# GaussianBlur expects the input [..., C, H, W]
gaussian_blur = GaussianBlur(kernel_size=3, sigma=1)
gaussian_blur = module.gaussian_blur
AC = AC.permute(0, 3, 1, 2)
AC = gaussian_blur(AC) # Applying Gaussian smoothing
AC = AC.permute(0, 2, 3, 1)
Expand Down