Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions demo_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

# for win desktop probably use --server 127.0.0.1 --inbrowser
# For linux server probably use --server 127.0.0.1 or do not use any cmd flags

print(args)

free_mem_gb = get_cuda_free_memory_gb(gpu)
Expand Down Expand Up @@ -100,7 +99,7 @@


@torch.no_grad()
def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf):
def worker(input_image, end_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf):
total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
total_latent_sections = int(max(round(total_latent_sections), 1))

Expand Down Expand Up @@ -133,48 +132,65 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)

# Processing input image

stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
# Processing input image (start frame)
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Processing start frame ...'))))

H, W, C = input_image.shape
height, width = find_nearest_bucket(H, W, resolution=640)
input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)

Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}.png'))
Image.fromarray(input_image_np).save(os.path.join(outputs_folder, f'{job_id}_start.png'))

input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]

# Processing end image (if provided)
has_end_image = end_image is not None
if has_end_image:
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Processing end frame ...'))))

H_end, W_end, C_end = end_image.shape
end_image_np = resize_and_center_crop(end_image, target_width=width, target_height=height)

Image.fromarray(end_image_np).save(os.path.join(outputs_folder, f'{job_id}_end.png'))

end_image_pt = torch.from_numpy(end_image_np).float() / 127.5 - 1
end_image_pt = end_image_pt.permute(2, 0, 1)[None, :, None]

# VAE encoding

stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))

if not high_vram:
load_model_as_complete(vae, target_device=gpu)

start_latent = vae_encode(input_image_pt, vae)

if has_end_image:
end_latent = vae_encode(end_image_pt, vae)

# CLIP Vision

stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))

if not high_vram:
load_model_as_complete(image_encoder, target_device=gpu)

image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state

if has_end_image:
end_image_encoder_output = hf_clip_vision_encode(end_image_np, feature_extractor, image_encoder)
end_image_encoder_last_hidden_state = end_image_encoder_output.last_hidden_state
# Combine both image embeddings or use a weighted approach
image_encoder_last_hidden_state = (image_encoder_last_hidden_state + end_image_encoder_last_hidden_state) / 2

# Dtype

llama_vec = llama_vec.to(transformer.dtype)
llama_vec_n = llama_vec_n.to(transformer.dtype)
clip_l_pooler = clip_l_pooler.to(transformer.dtype)
clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)

# Sampling

stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))

rnd = torch.Generator("cpu").manual_seed(seed)
Expand All @@ -184,7 +200,8 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
history_pixels = None
total_generated_latent_frames = 0

latent_paddings = reversed(range(total_latent_sections))
# 将迭代器转换为列表
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

English comment would be great

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中文评论才是精华

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

佩服😁

latent_paddings = list(reversed(range(total_latent_sections)))

if total_latent_sections > 4:
# In theory the latent_paddings should follow the above sequence, but it seems that duplicating some
Expand All @@ -195,13 +212,14 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind

for latent_padding in latent_paddings:
is_last_section = latent_padding == 0
is_first_section = latent_padding == latent_paddings[0]
latent_padding_size = latent_padding * latent_window_size

if stream.input_queue.top() == 'end':
stream.output_queue.push(('end', None))
return

print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}')
print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}, is_first_section = {is_first_section}')

indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
Expand All @@ -210,6 +228,11 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
clean_latents_pre = start_latent.to(history_latents)
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)

# Use end image latent for the first section if provided
if has_end_image and is_first_section:
clean_latents_post = end_latent.to(history_latents)
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)

if not high_vram:
unload_complete_models()
Expand Down Expand Up @@ -315,15 +338,15 @@ def callback(d):
return


def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf):
def process(input_image, end_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf):
global stream
assert input_image is not None, 'No input image!'

yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)

stream = AsyncStream()

async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf)
async_run(worker, input_image, end_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf)

output_filename = None

Expand Down Expand Up @@ -360,7 +383,12 @@ def end_process():
gr.Markdown('# FramePack')
with gr.Row():
with gr.Column():
input_image = gr.Image(sources='upload', type="numpy", label="Image", height=320)
with gr.Row():
with gr.Column():
input_image = gr.Image(sources='upload', type="numpy", label="Start Frame", height=320)
with gr.Column():
end_image = gr.Image(sources='upload', type="numpy", label="End Frame (Optional)", height=320)

prompt = gr.Textbox(label="Prompt", value='')
example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Quick List', samples_per_page=1000, components=[prompt])
example_quick_prompts.click(lambda x: x[0], inputs=[example_quick_prompts], outputs=prompt, show_progress=False, queue=False)
Expand All @@ -384,19 +412,19 @@ def end_process():
rs = gr.Slider(label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01, visible=False) # Should not change

gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=6, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.")

mp4_crf = gr.Slider(label="MP4 Compression", minimum=0, maximum=100, value=16, step=1, info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs. ")

with gr.Column():
preview_image = gr.Image(label="Next Latents", height=200, visible=False)
result_video = gr.Video(label="Finished Frames", autoplay=True, show_share_button=False, height=512, loop=True)
gr.Markdown('Note that the ending actions will be generated before the starting actions due to the inverted sampling. If the starting action is not in the video, you just need to wait, and it will be generated later.')
gr.Markdown('When using only a start frame, the ending actions will be generated before the starting actions due to the inverted sampling. If using both start and end frames, the model will try to create a smooth transition between them.')
progress_desc = gr.Markdown('', elem_classes='no-generating-animation')
progress_bar = gr.HTML('', elem_classes='no-generating-animation')

gr.HTML('<div style="text-align:center; margin-top:20px;">Share your results and find ideas at the <a href="https://x.com/search?q=framepack&f=live" target="_blank">FramePack Twitter (X) thread</a></div>')

ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf]
ips = [input_image, end_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf]
start_button.click(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button])
end_button.click(fn=end_process)

Expand Down