-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Mac compatibility! #170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Mac compatibility! #170
Changes from 12 commits
c339bc0
205343c
5f06f73
ea1fad0
59972e9
39cf8c1
749071e
4285606
173e66e
c0df480
dd87aaa
763c9f5
2a5edc0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,10 @@ | ||
| from tkinter import W | ||
|
||
| from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| import einops | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| import einops | ||
| import numpy as np | ||
|
|
||
| from diffusers.loaders import FromOriginalModelMixin | ||
|
|
@@ -17,17 +19,17 @@ | |
| from diffusers_helper.dit_common import LayerNorm | ||
| from diffusers_helper.utils import zero_module | ||
|
|
||
|
|
||
| enabled_backends = [] | ||
|
|
||
| if torch.backends.cuda.flash_sdp_enabled(): | ||
| enabled_backends.append("flash") | ||
| if torch.backends.cuda.math_sdp_enabled(): | ||
| enabled_backends.append("math") | ||
| if torch.backends.cuda.mem_efficient_sdp_enabled(): | ||
| enabled_backends.append("mem_efficient") | ||
| if torch.backends.cuda.cudnn_sdp_enabled(): | ||
| enabled_backends.append("cudnn") | ||
| if torch.cuda.is_available(): | ||
| if torch.backends.cuda.flash_sdp_enabled(): | ||
| enabled_backends.append("flash") | ||
| if torch.backends.cuda.math_sdp_enabled(): | ||
| enabled_backends.append("math") | ||
| if torch.backends.cuda.mem_efficient_sdp_enabled(): | ||
| enabled_backends.append("mem_efficient") | ||
| if torch.backends.cuda.cudnn_sdp_enabled(): | ||
| enabled_backends.append("cudnn") | ||
|
|
||
| print("Currently enabled native sdp backends:", enabled_backends) | ||
|
|
||
|
|
@@ -60,6 +62,37 @@ | |
|
|
||
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
||
| class AvgPool3dMPS(nn.Module): | ||
| def __init__(self, kernel_size, stride=None, padding=0): | ||
| super().__init__() | ||
| if isinstance(kernel_size, int): | ||
| kernel_size = (kernel_size, kernel_size, kernel_size) | ||
| self.kernel_size = kernel_size | ||
| self.stride = stride or kernel_size | ||
| self.padding = padding | ||
|
|
||
| # register a buffer that holds the kernel shape only | ||
| self.register_buffer("ones_kernel", None, persistent=False) | ||
|
|
||
| def forward(self, x): | ||
| B, C, D, H, W = x.shape | ||
| kD, kH, kW = self.kernel_size | ||
| kernel_shape = (C, 1, kD, kH, kW) | ||
|
|
||
| # lazily initialize or resize if needed | ||
| if (self.ones_kernel is None or self.ones_kernel.shape != kernel_shape or self.ones_kernel.device != x.device): | ||
| kernel = torch.ones(kernel_shape, dtype=x.dtype, device=x.device) / (kD * kH * kW) | ||
| self.ones_kernel = kernel | ||
| else: | ||
| kernel = self.ones_kernel | ||
|
|
||
| return F.conv3d( | ||
| x, kernel, bias=None, | ||
| stride=self.stride, | ||
| padding=self.padding, | ||
| groups=C | ||
| ) | ||
|
|
||
|
|
||
| def pad_for_3d_conv(x, kernel_size): | ||
| b, c, t, h, w = x.shape | ||
|
|
@@ -76,7 +109,7 @@ def center_down_sample_3d(x, kernel_size): | |
| # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw) | ||
| # xc = xp[cp] | ||
| # return xc | ||
| return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) | ||
| return AvgPool3dMPS(kernel_size, stride=kernel_size)(x) if torch.backends.mps.is_available() else torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) | ||
|
|
||
|
|
||
| def get_cu_seqlens(text_mask, img_len): | ||
|
|
@@ -104,6 +137,30 @@ def apply_rotary_emb_transposed(x, freqs_cis): | |
| out = out.to(x) | ||
| return out | ||
|
|
||
| def chunked_attention_bfloat16(q, k, v, chunk_size=64): | ||
| B, H, T_q, D = q.shape | ||
| T_kv = k.shape[2] | ||
| output_chunks = [] | ||
|
|
||
| for start in range(0, T_q, chunk_size): | ||
| end = min(start + chunk_size, T_q) | ||
| q_chunk = q[:, :, start:end, :] | ||
|
|
||
| attn_scores = torch.matmul(q_chunk, k.transpose(-2, -1)) / (D ** 0.5) | ||
| attn_probs = torch.softmax(attn_scores.float(), dim=-1).to(torch.bfloat16) # force softmax to fp32, then back | ||
| attn_out = torch.matmul(attn_probs, v) | ||
|
|
||
| output_chunks.append(attn_out) | ||
|
|
||
| return torch.cat(output_chunks, dim=2) | ||
|
|
||
| def mps_attn_varlen_func(q, k, v, chunk_size=128): | ||
| return chunked_attention_bfloat16( | ||
| q.transpose(1, 2), | ||
| k.transpose(1, 2), | ||
| v.transpose(1, 2), | ||
| chunk_size=chunk_size | ||
| ).transpose(1, 2) | ||
|
|
||
| def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv): | ||
| if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None: | ||
|
|
@@ -119,7 +176,7 @@ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seq | |
| x = xformers_attn_func(q, k, v) | ||
| return x | ||
|
|
||
| x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2) | ||
| x = mps_attn_varlen_func(q, k, v, chunk_size=64) if torch.backends.mps.is_available() else torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2) | ||
| return x | ||
|
|
||
| batch_size = q.shape[0] | ||
|
|
@@ -169,7 +226,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, i | |
| key = torch.cat([key, encoder_key], dim=1) | ||
| value = torch.cat([value, encoder_value], dim=1) | ||
|
|
||
| hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) | ||
| hidden_states = mps_attn_varlen_func(query, key, value, chunk_size=64) if torch.backends.mps.is_available() else attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) | ||
| hidden_states = hidden_states.flatten(-2) | ||
|
|
||
| txt_length = encoder_hidden_states.shape[1] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.mps.empty_cache should be call when mps is available