Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 50 additions & 24 deletions demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, state_dict_weighted_merge, state_dict_offset_merge, generate_timestamp
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
import diffusers_helper.memory as memory_helper
from diffusers_helper.memory import cpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete, get_available_gpus, set_gpu_device
from diffusers_helper.thread_utils import AsyncStream, async_run
from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
from transformers import SiglipImageProcessor, SiglipVisionModel
Expand All @@ -40,10 +41,39 @@

print(args)

free_mem_gb = get_cuda_free_memory_gb(gpu)
available_gpus = get_available_gpus()

if not available_gpus:
print("Error: No CUDA-enabled GPUs found. Exiting.")
exit()

print("Available GPUs:")
for gpu_info in available_gpus:
print(f" [{gpu_info['index']}] {gpu_info['name']} - Free: {gpu_info['free_memory']:.2f} GB / Total: {gpu_info['total_memory']:.2f} GB")

selected_gpu_index = -1
while True:
try:
user_input = input(f"Select GPU index [0-{len(available_gpus) - 1}]: ")
selected_gpu_index = int(user_input)
if 0 <= selected_gpu_index < len(available_gpus):
break
else:
print(f"Invalid index. Please enter a number between 0 and {len(available_gpus) - 1}.")
except ValueError:
print("Invalid input. Please enter a number.")
except EOFError:
print("\nNo GPU selected. Exiting.")
exit()

set_gpu_device(selected_gpu_index)
print(f"Using GPU: [{selected_gpu_index}] {available_gpus[selected_gpu_index]['name']}")


free_mem_gb = get_cuda_free_memory_gb(memory_helper.gpu)
high_vram = free_mem_gb > 60

print(f'Free VRAM {free_mem_gb} GB')
print(f'Free VRAM {free_mem_gb:.2f} GB on {memory_helper.gpu}')
print(f'High-VRAM Mode: {high_vram}')

text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
Expand Down Expand Up @@ -84,14 +114,14 @@

if not high_vram:
# DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
DynamicSwapInstaller.install_model(transformer, device=gpu)
DynamicSwapInstaller.install_model(text_encoder, device=gpu)
DynamicSwapInstaller.install_model(transformer, device=memory_helper.gpu)
DynamicSwapInstaller.install_model(text_encoder, device=memory_helper.gpu)
else:
text_encoder.to(gpu)
text_encoder_2.to(gpu)
image_encoder.to(gpu)
vae.to(gpu)
transformer.to(gpu)
text_encoder.to(memory_helper.gpu)
text_encoder_2.to(memory_helper.gpu)
image_encoder.to(memory_helper.gpu)
vae.to(memory_helper.gpu)
transformer.to(memory_helper.gpu)

stream = AsyncStream()

Expand Down Expand Up @@ -120,8 +150,8 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))

if not high_vram:
fake_diffusers_current_device(text_encoder, gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
load_model_as_complete(text_encoder_2, target_device=gpu)
fake_diffusers_current_device(text_encoder, memory_helper.gpu) # since we only encode one text - that is one model move and one encode, offload is same time consumption since it is also one load and one encode.
load_model_as_complete(text_encoder_2, target_device=memory_helper.gpu)

llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)

Expand Down Expand Up @@ -151,7 +181,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))

if not high_vram:
load_model_as_complete(vae, target_device=gpu)
load_model_as_complete(vae, target_device=memory_helper.gpu)

start_latent = vae_encode(input_image_pt, vae)

Expand All @@ -160,7 +190,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))

if not high_vram:
load_model_as_complete(image_encoder, target_device=gpu)
load_model_as_complete(image_encoder, target_device=memory_helper.gpu)

image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
Expand Down Expand Up @@ -213,7 +243,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind

if not high_vram:
unload_complete_models()
move_model_to_device_with_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
move_model_to_device_with_memory_preservation(transformer, target_device=memory_helper.gpu, preserved_memory_gb=gpu_memory_preservation)

if use_teacache:
transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
Expand Down Expand Up @@ -256,7 +286,7 @@ def callback(d):
negative_prompt_embeds=llama_vec_n,
negative_prompt_embeds_mask=llama_attention_mask_n,
negative_prompt_poolers=clip_l_pooler_n,
device=gpu,
device=memory_helper.gpu,
dtype=torch.bfloat16,
image_embeddings=image_encoder_last_hidden_state,
latent_indices=latent_indices,
Expand All @@ -276,8 +306,8 @@ def callback(d):
history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)

if not high_vram:
offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
load_model_as_complete(vae, target_device=gpu)
offload_model_from_device_for_memory_preservation(transformer, target_device=memory_helper.gpu, preserved_memory_gb=8)
load_model_as_complete(vae, target_device=memory_helper.gpu)

real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]

Expand Down Expand Up @@ -305,11 +335,7 @@ def callback(d):
break
except:
traceback.print_exc()

if not high_vram:
unload_complete_models(
text_encoder, text_encoder_2, image_encoder, vae, transformer
)
pass

stream.output_queue.push(('end', None))
return
Expand Down Expand Up @@ -366,7 +392,7 @@ def end_process():
example_quick_prompts.click(lambda x: x[0], inputs=[example_quick_prompts], outputs=prompt, show_progress=False, queue=False)

with gr.Row():
start_button = gr.Button(value="Start Generation")
start_button = gr.Button(value="Start Generation", interactive=True)
end_button = gr.Button(value="End Generation", interactive=False)

with gr.Group():
Expand Down
38 changes: 37 additions & 1 deletion diffusers_helper/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,46 @@


cpu = torch.device('cpu')
gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
gpu = None
gpu_complete_modules = []


def get_available_gpus():
"""Returns a list of available CUDA GPUs with their info."""
gpus = []
if not torch.cuda.is_available():
return gpus

num_gpus = torch.cuda.device_count()
for i in range(num_gpus):
device = torch.device(f'cuda:{i}')
props = torch.cuda.get_device_properties(device)
total_memory_gb = props.total_memory / (1024 ** 3)
try:
free_memory_gb = get_cuda_free_memory_gb(device)
except RuntimeError:
free_memory_gb = 0

gpus.append({
'index': i,
'name': props.name,
'total_memory': round(total_memory_gb, 2),
'free_memory': round(free_memory_gb, 2)
})
return gpus

def set_gpu_device(index: int):
"""Sets the global GPU device based on the selected index."""
global gpu
if torch.cuda.is_available() and 0 <= index < torch.cuda.device_count():
gpu = torch.device(f'cuda:{index}')
print(f"Global GPU device set to: {gpu}")
else:
print(f"Error: Invalid GPU index {index} or CUDA not available.")
gpu = cpu
print("Falling back to CPU.")


class DynamicSwapInstaller:
@staticmethod
def _install_module(module: torch.nn.Module, **kwargs):
Expand Down