mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-03 22:04:53 +00:00
* isort the code * format the code * Add yapf config file * Remove torch cuda memory profiler
350 lines
13 KiB
Python
350 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import argparse
|
|
import datetime
|
|
import os
|
|
import sys
|
|
|
|
import imageio
|
|
import numpy as np
|
|
import torch
|
|
|
|
import gradio as gr
|
|
|
|
sys.path.insert(
|
|
0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
|
|
import wan
|
|
from wan import WanVace, WanVaceMP
|
|
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
|
|
|
|
|
|
class FixedSizeQueue:
|
|
|
|
def __init__(self, max_size):
|
|
self.max_size = max_size
|
|
self.queue = []
|
|
|
|
def add(self, item):
|
|
self.queue.insert(0, item)
|
|
if len(self.queue) > self.max_size:
|
|
self.queue.pop()
|
|
|
|
def get(self):
|
|
return self.queue
|
|
|
|
def __repr__(self):
|
|
return str(self.queue)
|
|
|
|
|
|
class VACEInference:
|
|
|
|
def __init__(self,
|
|
cfg,
|
|
skip_load=False,
|
|
gallery_share=True,
|
|
gallery_share_limit=5):
|
|
self.cfg = cfg
|
|
self.save_dir = cfg.save_dir
|
|
self.gallery_share = gallery_share
|
|
self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
|
|
if not skip_load:
|
|
if not args.mp:
|
|
self.pipe = WanVace(
|
|
config=WAN_CONFIGS[cfg.model_name],
|
|
checkpoint_dir=cfg.ckpt_dir,
|
|
device_id=0,
|
|
rank=0,
|
|
t5_fsdp=False,
|
|
dit_fsdp=False,
|
|
use_usp=False,
|
|
)
|
|
else:
|
|
self.pipe = WanVaceMP(
|
|
config=WAN_CONFIGS[cfg.model_name],
|
|
checkpoint_dir=cfg.ckpt_dir,
|
|
use_usp=True,
|
|
ulysses_size=cfg.ulysses_size,
|
|
ring_size=cfg.ring_size)
|
|
|
|
def create_ui(self, *args, **kwargs):
|
|
gr.Markdown("""
|
|
<div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
|
|
<a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
|
|
</div>
|
|
""")
|
|
with gr.Row(variant='panel', equal_height=True):
|
|
with gr.Column(scale=1, min_width=0):
|
|
self.src_video = gr.Video(
|
|
label="src_video",
|
|
sources=['upload'],
|
|
value=None,
|
|
interactive=True)
|
|
with gr.Column(scale=1, min_width=0):
|
|
self.src_mask = gr.Video(
|
|
label="src_mask",
|
|
sources=['upload'],
|
|
value=None,
|
|
interactive=True)
|
|
#
|
|
with gr.Row(variant='panel', equal_height=True):
|
|
with gr.Column(scale=1, min_width=0):
|
|
with gr.Row(equal_height=True):
|
|
self.src_ref_image_1 = gr.Image(
|
|
label='src_ref_image_1',
|
|
height=200,
|
|
interactive=True,
|
|
type='filepath',
|
|
image_mode='RGB',
|
|
sources=['upload'],
|
|
elem_id="src_ref_image_1",
|
|
format='png')
|
|
self.src_ref_image_2 = gr.Image(
|
|
label='src_ref_image_2',
|
|
height=200,
|
|
interactive=True,
|
|
type='filepath',
|
|
image_mode='RGB',
|
|
sources=['upload'],
|
|
elem_id="src_ref_image_2",
|
|
format='png')
|
|
self.src_ref_image_3 = gr.Image(
|
|
label='src_ref_image_3',
|
|
height=200,
|
|
interactive=True,
|
|
type='filepath',
|
|
image_mode='RGB',
|
|
sources=['upload'],
|
|
elem_id="src_ref_image_3",
|
|
format='png')
|
|
with gr.Row(variant='panel', equal_height=True):
|
|
with gr.Column(scale=1):
|
|
self.prompt = gr.Textbox(
|
|
show_label=False,
|
|
placeholder="positive_prompt_input",
|
|
elem_id='positive_prompt',
|
|
container=True,
|
|
autofocus=True,
|
|
elem_classes='type_row',
|
|
visible=True,
|
|
lines=2)
|
|
self.negative_prompt = gr.Textbox(
|
|
show_label=False,
|
|
value=self.pipe.config.sample_neg_prompt,
|
|
placeholder="negative_prompt_input",
|
|
elem_id='negative_prompt',
|
|
container=True,
|
|
autofocus=False,
|
|
elem_classes='type_row',
|
|
visible=True,
|
|
interactive=True,
|
|
lines=1)
|
|
#
|
|
with gr.Row(variant='panel', equal_height=True):
|
|
with gr.Column(scale=1, min_width=0):
|
|
with gr.Row(equal_height=True):
|
|
self.shift_scale = gr.Slider(
|
|
label='shift_scale',
|
|
minimum=0.0,
|
|
maximum=100.0,
|
|
step=1.0,
|
|
value=16.0,
|
|
interactive=True)
|
|
self.sample_steps = gr.Slider(
|
|
label='sample_steps',
|
|
minimum=1,
|
|
maximum=100,
|
|
step=1,
|
|
value=25,
|
|
interactive=True)
|
|
self.context_scale = gr.Slider(
|
|
label='context_scale',
|
|
minimum=0.0,
|
|
maximum=2.0,
|
|
step=0.1,
|
|
value=1.0,
|
|
interactive=True)
|
|
self.guide_scale = gr.Slider(
|
|
label='guide_scale',
|
|
minimum=1,
|
|
maximum=10,
|
|
step=0.5,
|
|
value=5.0,
|
|
interactive=True)
|
|
self.infer_seed = gr.Slider(
|
|
minimum=-1, maximum=10000000, value=2025, label="Seed")
|
|
#
|
|
with gr.Accordion(label="Usable without source video", open=False):
|
|
with gr.Row(equal_height=True):
|
|
self.output_height = gr.Textbox(
|
|
label='resolutions_height',
|
|
# value=480,
|
|
value=720,
|
|
interactive=True)
|
|
self.output_width = gr.Textbox(
|
|
label='resolutions_width',
|
|
# value=832,
|
|
value=1280,
|
|
interactive=True)
|
|
self.frame_rate = gr.Textbox(
|
|
label='frame_rate', value=16, interactive=True)
|
|
self.num_frames = gr.Textbox(
|
|
label='num_frames', value=81, interactive=True)
|
|
#
|
|
with gr.Row(equal_height=True):
|
|
with gr.Column(scale=5):
|
|
self.generate_button = gr.Button(
|
|
value='Run',
|
|
elem_classes='type_row',
|
|
elem_id='generate_button',
|
|
visible=True)
|
|
with gr.Column(scale=1):
|
|
self.refresh_button = gr.Button(value='\U0001f504') # 🔄
|
|
#
|
|
self.output_gallery = gr.Gallery(
|
|
label="output_gallery",
|
|
value=[],
|
|
interactive=False,
|
|
allow_preview=True,
|
|
preview=True)
|
|
|
|
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
|
|
src_ref_image_2, src_ref_image_3, prompt, negative_prompt,
|
|
shift_scale, sample_steps, context_scale, guide_scale,
|
|
infer_seed, output_height, output_width, frame_rate,
|
|
num_frames):
|
|
output_height, output_width, frame_rate, num_frames = int(
|
|
output_height), int(output_width), int(frame_rate), int(num_frames)
|
|
src_ref_images = [
|
|
x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3]
|
|
if x is not None
|
|
]
|
|
src_video, src_mask, src_ref_images = self.pipe.prepare_source(
|
|
[src_video], [src_mask], [src_ref_images],
|
|
num_frames=num_frames,
|
|
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
|
|
device=self.pipe.device)
|
|
video = self.pipe.generate(
|
|
prompt,
|
|
src_video,
|
|
src_mask,
|
|
src_ref_images,
|
|
size=(output_width, output_height),
|
|
context_scale=context_scale,
|
|
shift=shift_scale,
|
|
sampling_steps=sample_steps,
|
|
guide_scale=guide_scale,
|
|
n_prompt=negative_prompt,
|
|
seed=infer_seed,
|
|
offload_model=True)
|
|
|
|
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
|
|
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
|
|
video_frames = (
|
|
torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) *
|
|
255).cpu().numpy().astype(np.uint8)
|
|
|
|
try:
|
|
writer = imageio.get_writer(
|
|
video_path,
|
|
fps=frame_rate,
|
|
codec='libx264',
|
|
quality=8,
|
|
macro_block_size=1)
|
|
for frame in video_frames:
|
|
writer.append_data(frame)
|
|
writer.close()
|
|
print(video_path)
|
|
except Exception as e:
|
|
raise gr.Error(f"Video save error: {e}")
|
|
|
|
if self.gallery_share:
|
|
self.gallery_share_data.add(video_path)
|
|
return self.gallery_share_data.get()
|
|
else:
|
|
return [video_path]
|
|
|
|
def set_callbacks(self, **kwargs):
|
|
self.gen_inputs = [
|
|
self.output_gallery, self.src_video, self.src_mask,
|
|
self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3,
|
|
self.prompt, self.negative_prompt, self.shift_scale,
|
|
self.sample_steps, self.context_scale, self.guide_scale,
|
|
self.infer_seed, self.output_height, self.output_width,
|
|
self.frame_rate, self.num_frames
|
|
]
|
|
self.gen_outputs = [self.output_gallery]
|
|
self.generate_button.click(
|
|
self.generate,
|
|
inputs=self.gen_inputs,
|
|
outputs=self.gen_outputs,
|
|
queue=True)
|
|
self.refresh_button.click(
|
|
lambda x: self.gallery_share_data.get()
|
|
if self.gallery_share else x,
|
|
inputs=[self.output_gallery],
|
|
outputs=[self.output_gallery])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(
|
|
description='Argparser for VACE-WAN Demo:\n')
|
|
parser.add_argument(
|
|
'--server_port', dest='server_port', help='', type=int, default=7860)
|
|
parser.add_argument(
|
|
'--server_name', dest='server_name', help='', default='0.0.0.0')
|
|
parser.add_argument('--root_path', dest='root_path', help='', default=None)
|
|
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
|
|
parser.add_argument(
|
|
"--mp",
|
|
action="store_true",
|
|
help="Use Multi-GPUs",
|
|
)
|
|
parser.add_argument(
|
|
"--model_name",
|
|
type=str,
|
|
default="vace-14B",
|
|
choices=list(WAN_CONFIGS.keys()),
|
|
help="The model name to run.")
|
|
parser.add_argument(
|
|
"--ulysses_size",
|
|
type=int,
|
|
default=1,
|
|
help="The size of the ulysses parallelism in DiT.")
|
|
parser.add_argument(
|
|
"--ring_size",
|
|
type=int,
|
|
default=1,
|
|
help="The size of the ring attention parallelism in DiT.")
|
|
parser.add_argument(
|
|
"--ckpt_dir",
|
|
type=str,
|
|
# default='models/VACE-Wan2.1-1.3B-Preview',
|
|
default='models/Wan2.1-VACE-14B/',
|
|
help="The path to the checkpoint directory.",
|
|
)
|
|
parser.add_argument(
|
|
"--offload_to_cpu",
|
|
action="store_true",
|
|
help="Offloading unnecessary computations to CPU.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.save_dir):
|
|
os.makedirs(args.save_dir, exist_ok=True)
|
|
|
|
with gr.Blocks() as demo:
|
|
infer_gr = VACEInference(
|
|
args, skip_load=False, gallery_share=True, gallery_share_limit=5)
|
|
infer_gr.create_ui()
|
|
infer_gr.set_callbacks()
|
|
allowed_paths = [args.save_dir]
|
|
demo.queue(status_update_rate=1).launch(
|
|
server_name=args.server_name,
|
|
server_port=args.server_port,
|
|
root_path=args.root_path,
|
|
allowed_paths=allowed_paths,
|
|
show_error=True,
|
|
debug=True)
|