Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import os

# 设置MPS回退环境变量,以处理未实现的操作
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download')))

import gradio as gr
Expand All @@ -20,7 +23,7 @@
from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, state_dict_weighted_merge, state_dict_offset_merge, generate_timestamp
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
from diffusers_helper.memory import cpu, gpu, mps, get_cuda_free_memory_gb, get_mps_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
from diffusers_helper.thread_utils import AsyncStream, async_run
from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
from transformers import SiglipImageProcessor, SiglipVisionModel
Expand All @@ -40,7 +43,7 @@

print(args)

free_mem_gb = get_cuda_free_memory_gb(gpu)
free_mem_gb = get_mps_free_memory_gb(mps) if torch.backends.mps.is_available() else get_cuda_free_memory_gb(gpu)
high_vram = free_mem_gb > 60

print(f'Free VRAM {free_mem_gb} GB')
Expand Down Expand Up @@ -84,14 +87,14 @@

if not high_vram:
# DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
DynamicSwapInstaller.install_model(transformer, device=gpu)
DynamicSwapInstaller.install_model(text_encoder, device=gpu)
DynamicSwapInstaller.install_model(transformer, device=mps if torch.backends.mps.is_available() else gpu)
DynamicSwapInstaller.install_model(text_encoder, device=mps if torch.backends.mps.is_available() else gpu)
else:
text_encoder.to(gpu)
text_encoder_2.to(gpu)
image_encoder.to(gpu)
vae.to(gpu)
transformer.to(gpu)
text_encoder.to(mps if torch.backends.mps.is_available() else gpu)
text_encoder_2.to(mps if torch.backends.mps.is_available() else gpu)
image_encoder.to(mps if torch.backends.mps.is_available() else gpu)
vae.to(mps if torch.backends.mps.is_available() else gpu)
transformer.to(mps if torch.backends.mps.is_available() else gpu)

stream = AsyncStream()

Expand Down Expand Up @@ -120,8 +123,8 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))

if not high_vram:
fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
load_model_as_complete(text_encoder_2, target_device=gpu)
fake_diffusers_current_device(text_encoder, mps if torch.backends.mps.is_available() else gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
load_model_as_complete(text_encoder_2, target_device=mps if torch.backends.mps.is_available() else gpu)

llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)

Expand Down Expand Up @@ -151,7 +154,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))

if not high_vram:
load_model_as_complete(vae, target_device=gpu)
load_model_as_complete(vae, target_device=mps if torch.backends.mps.is_available() else gpu)

start_latent = vae_encode(input_image_pt, vae)

Expand All @@ -160,7 +163,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))

if not high_vram:
load_model_as_complete(image_encoder, target_device=gpu)
load_model_as_complete(image_encoder, target_device=mps if torch.backends.mps.is_available() else gpu)

image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
Expand Down Expand Up @@ -213,7 +216,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind

if not high_vram:
unload_complete_models()
move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
move_model_to_device_with_memory_preservation(transformer, target_device=mps if torch.backends.mps.is_available() else gpu, preserved_memory_gb=gpu_memory_preservation)

if use_teacache:
transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
Expand Down Expand Up @@ -256,7 +259,7 @@ def callback(d):
negative_prompt_embeds=llama_vec_n,
negative_prompt_embeds_mask=llama_attention_mask_n,
negative_prompt_poolers=clip_l_pooler_n,
device=gpu,
device=mps if torch.backends.mps.is_available() else gpu,
dtype=torch.bfloat16,
image_embeddings=image_encoder_last_hidden_state,
latent_indices=latent_indices,
Expand All @@ -276,8 +279,8 @@ def callback(d):
history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)

if not high_vram:
offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
load_model_as_complete(vae, target_device=gpu)
offload_model_from_device_for_memory_preservation(transformer, target_device=mps if torch.backends.mps.is_available() else gpu, preserved_memory_gb=8)
load_model_as_complete(vae, target_device=mps if torch.backends.mps.is_available() else gpu)

real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]

Expand Down Expand Up @@ -395,7 +398,6 @@ def end_process():
start_button.click(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button])
end_button.click(fn=end_process)


block.launch(
server_name=args.server,
server_port=args.port,
Expand Down
46 changes: 38 additions & 8 deletions diffusers_helper/memory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# By lllyasviel


import torch


# Device management
cpu = torch.device('cpu')
gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
mps = torch.device('mps') if torch.backends.mps.is_available() else None
gpu = torch.device(f'cuda:{torch.cuda.current_device()}') if torch.cuda.is_available() else None
gpu_complete_modules = []


Expand Down Expand Up @@ -81,35 +82,61 @@ def get_cuda_free_memory_gb(device=None):
return bytes_total_available / (1024 ** 3)


def get_mps_free_memory_gb(device=None):
if device is None:
device = mps

return (torch.mps.recommended_max_memory() - torch.mps.driver_allocated_memory()) / (1024 ** 3)


def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')

for m in model.modules():
if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
if target_device == gpu and get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
torch.cuda.empty_cache()
return


if target_device == mps and get_mps_free_memory_gb(target_device) <= preserved_memory_gb:
torch.mps.empty_cache()
return
if hasattr(m, 'weight'):
m.to(device=target_device)

model.to(device=target_device)
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.synchronize()
if hasattr(torch, 'mps') and torch.mps.is_available():
torch.mps.synchronize()


model.to(device=target_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.mps.empty_cache()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.mps.empty_cache should be call when mps is available

return


def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')

for m in model.modules():
if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
if target_device == gpu and get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
torch.cuda.empty_cache()
return

if target_device == mps and get_mps_free_memory_gb(target_device) >= preserved_memory_gb:
torch.mps.empty_cache()
return

if hasattr(m, 'weight'):
m.to(device=cpu)

model.to(device=cpu)
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
return


Expand All @@ -119,7 +146,10 @@ def unload_complete_models(*args):
print(f'Unloaded {m.__class__.__name__} as complete.')

gpu_complete_modules.clear()
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
return


Expand Down
83 changes: 70 additions & 13 deletions diffusers_helper/models/hunyuan_video_packed.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from tkinter import W

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks compatibility with my Python installed using homebrew. Also, I don't think the import is being utilized as I didn't run into any issues when I removed the line.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! My bad - I can remove this

from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import einops
import torch.nn as nn
import torch.nn.functional as F
import einops
import numpy as np

from diffusers.loaders import FromOriginalModelMixin
Expand All @@ -17,17 +19,17 @@
from diffusers_helper.dit_common import LayerNorm
from diffusers_helper.utils import zero_module


enabled_backends = []

if torch.backends.cuda.flash_sdp_enabled():
enabled_backends.append("flash")
if torch.backends.cuda.math_sdp_enabled():
enabled_backends.append("math")
if torch.backends.cuda.mem_efficient_sdp_enabled():
enabled_backends.append("mem_efficient")
if torch.backends.cuda.cudnn_sdp_enabled():
enabled_backends.append("cudnn")
if torch.cuda.is_available():
if torch.backends.cuda.flash_sdp_enabled():
enabled_backends.append("flash")
if torch.backends.cuda.math_sdp_enabled():
enabled_backends.append("math")
if torch.backends.cuda.mem_efficient_sdp_enabled():
enabled_backends.append("mem_efficient")
if torch.backends.cuda.cudnn_sdp_enabled():
enabled_backends.append("cudnn")

print("Currently enabled native sdp backends:", enabled_backends)

Expand Down Expand Up @@ -60,6 +62,37 @@

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

class AvgPool3dMPS(nn.Module):
def __init__(self, kernel_size, stride=None, padding=0):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
self.kernel_size = kernel_size
self.stride = stride or kernel_size
self.padding = padding

# register a buffer that holds the kernel shape only
self.register_buffer("ones_kernel", None, persistent=False)

def forward(self, x):
B, C, D, H, W = x.shape
kD, kH, kW = self.kernel_size
kernel_shape = (C, 1, kD, kH, kW)

# lazily initialize or resize if needed
if (self.ones_kernel is None or self.ones_kernel.shape != kernel_shape or self.ones_kernel.device != x.device):
kernel = torch.ones(kernel_shape, dtype=x.dtype, device=x.device) / (kD * kH * kW)
self.ones_kernel = kernel
else:
kernel = self.ones_kernel

return F.conv3d(
x, kernel, bias=None,
stride=self.stride,
padding=self.padding,
groups=C
)


def pad_for_3d_conv(x, kernel_size):
b, c, t, h, w = x.shape
Expand All @@ -76,7 +109,7 @@ def center_down_sample_3d(x, kernel_size):
# xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
# xc = xp[cp]
# return xc
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
return AvgPool3dMPS(kernel_size, stride=kernel_size)(x) if torch.backends.mps.is_available() else torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)


def get_cu_seqlens(text_mask, img_len):
Expand Down Expand Up @@ -104,6 +137,30 @@ def apply_rotary_emb_transposed(x, freqs_cis):
out = out.to(x)
return out

def chunked_attention_bfloat16(q, k, v, chunk_size=64):
B, H, T_q, D = q.shape
T_kv = k.shape[2]
output_chunks = []

for start in range(0, T_q, chunk_size):
end = min(start + chunk_size, T_q)
q_chunk = q[:, :, start:end, :]

attn_scores = torch.matmul(q_chunk, k.transpose(-2, -1)) / (D ** 0.5)
attn_probs = torch.softmax(attn_scores.float(), dim=-1).to(torch.bfloat16) # force softmax to fp32, then back
attn_out = torch.matmul(attn_probs, v)

output_chunks.append(attn_out)

return torch.cat(output_chunks, dim=2)

def mps_attn_varlen_func(q, k, v, chunk_size=128):
return chunked_attention_bfloat16(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
chunk_size=chunk_size
).transpose(1, 2)

def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
Expand All @@ -119,7 +176,7 @@ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seq
x = xformers_attn_func(q, k, v)
return x

x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
x = mps_attn_varlen_func(q, k, v, chunk_size=64) if torch.backends.mps.is_available() else torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
return x

batch_size = q.shape[0]
Expand Down Expand Up @@ -169,7 +226,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, i
key = torch.cat([key, encoder_key], dim=1)
value = torch.cat([value, encoder_value], dim=1)

hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
hidden_states = mps_attn_varlen_func(query, key, value, chunk_size=64) if torch.backends.mps.is_available() else attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
hidden_states = hidden_states.flatten(-2)

txt_length = encoder_hidden_states.shape[1]
Expand Down
20 changes: 14 additions & 6 deletions diffusers_helper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,20 @@ def add_tensors_with_padding(tensor1, tensor2):


def print_free_mem():
torch.cuda.empty_cache()
free_mem, total_mem = torch.cuda.mem_get_info(0)
free_mem_mb = free_mem / (1024 ** 2)
total_mem_mb = total_mem / (1024 ** 2)
print(f"Free memory: {free_mem_mb:.2f} MB")
print(f"Total memory: {total_mem_mb:.2f} MB")
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_mem, total_mem = torch.cuda.mem_get_info(0)
free_mem_mb = free_mem / (1024 ** 2)
total_mem_mb = total_mem / (1024 ** 2)
print(f"Free memory: {free_mem_mb:.2f} MB")
print(f"Total memory: {total_mem_mb:.2f} MB")

if torch.backends.mps.is_available():
torch.mps.empty_cache()
free_mem_mb = (torch.mps.recommended_max_memory() - torch.mps.driver_allocated_memory()) / (1024 ** 2)
total_mem_mb = torch.mps.recommended_max_memory() / (1024 ** 2)
print(f"Free memory: {free_mem_mb:.2f} MB")
print(f"Total memory: {total_mem_mb:.2f} MB")
return


Expand Down
Loading