mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-03 22:04:21 +00:00 
			
		
		
		
	* isort the code * format the code * Add yapf config file * Remove torch cuda memory profiler
		
			
				
	
	
		
			350 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			350 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# -*- coding: utf-8 -*-
 | 
						|
# Copyright (c) Alibaba, Inc. and its affiliates.
 | 
						|
 | 
						|
import argparse
 | 
						|
import datetime
 | 
						|
import os
 | 
						|
import sys
 | 
						|
 | 
						|
import imageio
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
 | 
						|
import gradio as gr
 | 
						|
 | 
						|
sys.path.insert(
 | 
						|
    0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
 | 
						|
import wan
 | 
						|
from wan import WanVace, WanVaceMP
 | 
						|
from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
 | 
						|
 | 
						|
 | 
						|
class FixedSizeQueue:
 | 
						|
 | 
						|
    def __init__(self, max_size):
 | 
						|
        self.max_size = max_size
 | 
						|
        self.queue = []
 | 
						|
 | 
						|
    def add(self, item):
 | 
						|
        self.queue.insert(0, item)
 | 
						|
        if len(self.queue) > self.max_size:
 | 
						|
            self.queue.pop()
 | 
						|
 | 
						|
    def get(self):
 | 
						|
        return self.queue
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return str(self.queue)
 | 
						|
 | 
						|
 | 
						|
class VACEInference:
 | 
						|
 | 
						|
    def __init__(self,
 | 
						|
                 cfg,
 | 
						|
                 skip_load=False,
 | 
						|
                 gallery_share=True,
 | 
						|
                 gallery_share_limit=5):
 | 
						|
        self.cfg = cfg
 | 
						|
        self.save_dir = cfg.save_dir
 | 
						|
        self.gallery_share = gallery_share
 | 
						|
        self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
 | 
						|
        if not skip_load:
 | 
						|
            if not args.mp:
 | 
						|
                self.pipe = WanVace(
 | 
						|
                    config=WAN_CONFIGS[cfg.model_name],
 | 
						|
                    checkpoint_dir=cfg.ckpt_dir,
 | 
						|
                    device_id=0,
 | 
						|
                    rank=0,
 | 
						|
                    t5_fsdp=False,
 | 
						|
                    dit_fsdp=False,
 | 
						|
                    use_usp=False,
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                self.pipe = WanVaceMP(
 | 
						|
                    config=WAN_CONFIGS[cfg.model_name],
 | 
						|
                    checkpoint_dir=cfg.ckpt_dir,
 | 
						|
                    use_usp=True,
 | 
						|
                    ulysses_size=cfg.ulysses_size,
 | 
						|
                    ring_size=cfg.ring_size)
 | 
						|
 | 
						|
    def create_ui(self, *args, **kwargs):
 | 
						|
        gr.Markdown("""
 | 
						|
                    <div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
 | 
						|
                        <a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
 | 
						|
                    </div>
 | 
						|
                    """)
 | 
						|
        with gr.Row(variant='panel', equal_height=True):
 | 
						|
            with gr.Column(scale=1, min_width=0):
 | 
						|
                self.src_video = gr.Video(
 | 
						|
                    label="src_video",
 | 
						|
                    sources=['upload'],
 | 
						|
                    value=None,
 | 
						|
                    interactive=True)
 | 
						|
            with gr.Column(scale=1, min_width=0):
 | 
						|
                self.src_mask = gr.Video(
 | 
						|
                    label="src_mask",
 | 
						|
                    sources=['upload'],
 | 
						|
                    value=None,
 | 
						|
                    interactive=True)
 | 
						|
        #
 | 
						|
        with gr.Row(variant='panel', equal_height=True):
 | 
						|
            with gr.Column(scale=1, min_width=0):
 | 
						|
                with gr.Row(equal_height=True):
 | 
						|
                    self.src_ref_image_1 = gr.Image(
 | 
						|
                        label='src_ref_image_1',
 | 
						|
                        height=200,
 | 
						|
                        interactive=True,
 | 
						|
                        type='filepath',
 | 
						|
                        image_mode='RGB',
 | 
						|
                        sources=['upload'],
 | 
						|
                        elem_id="src_ref_image_1",
 | 
						|
                        format='png')
 | 
						|
                    self.src_ref_image_2 = gr.Image(
 | 
						|
                        label='src_ref_image_2',
 | 
						|
                        height=200,
 | 
						|
                        interactive=True,
 | 
						|
                        type='filepath',
 | 
						|
                        image_mode='RGB',
 | 
						|
                        sources=['upload'],
 | 
						|
                        elem_id="src_ref_image_2",
 | 
						|
                        format='png')
 | 
						|
                    self.src_ref_image_3 = gr.Image(
 | 
						|
                        label='src_ref_image_3',
 | 
						|
                        height=200,
 | 
						|
                        interactive=True,
 | 
						|
                        type='filepath',
 | 
						|
                        image_mode='RGB',
 | 
						|
                        sources=['upload'],
 | 
						|
                        elem_id="src_ref_image_3",
 | 
						|
                        format='png')
 | 
						|
        with gr.Row(variant='panel', equal_height=True):
 | 
						|
            with gr.Column(scale=1):
 | 
						|
                self.prompt = gr.Textbox(
 | 
						|
                    show_label=False,
 | 
						|
                    placeholder="positive_prompt_input",
 | 
						|
                    elem_id='positive_prompt',
 | 
						|
                    container=True,
 | 
						|
                    autofocus=True,
 | 
						|
                    elem_classes='type_row',
 | 
						|
                    visible=True,
 | 
						|
                    lines=2)
 | 
						|
                self.negative_prompt = gr.Textbox(
 | 
						|
                    show_label=False,
 | 
						|
                    value=self.pipe.config.sample_neg_prompt,
 | 
						|
                    placeholder="negative_prompt_input",
 | 
						|
                    elem_id='negative_prompt',
 | 
						|
                    container=True,
 | 
						|
                    autofocus=False,
 | 
						|
                    elem_classes='type_row',
 | 
						|
                    visible=True,
 | 
						|
                    interactive=True,
 | 
						|
                    lines=1)
 | 
						|
        #
 | 
						|
        with gr.Row(variant='panel', equal_height=True):
 | 
						|
            with gr.Column(scale=1, min_width=0):
 | 
						|
                with gr.Row(equal_height=True):
 | 
						|
                    self.shift_scale = gr.Slider(
 | 
						|
                        label='shift_scale',
 | 
						|
                        minimum=0.0,
 | 
						|
                        maximum=100.0,
 | 
						|
                        step=1.0,
 | 
						|
                        value=16.0,
 | 
						|
                        interactive=True)
 | 
						|
                    self.sample_steps = gr.Slider(
 | 
						|
                        label='sample_steps',
 | 
						|
                        minimum=1,
 | 
						|
                        maximum=100,
 | 
						|
                        step=1,
 | 
						|
                        value=25,
 | 
						|
                        interactive=True)
 | 
						|
                    self.context_scale = gr.Slider(
 | 
						|
                        label='context_scale',
 | 
						|
                        minimum=0.0,
 | 
						|
                        maximum=2.0,
 | 
						|
                        step=0.1,
 | 
						|
                        value=1.0,
 | 
						|
                        interactive=True)
 | 
						|
                    self.guide_scale = gr.Slider(
 | 
						|
                        label='guide_scale',
 | 
						|
                        minimum=1,
 | 
						|
                        maximum=10,
 | 
						|
                        step=0.5,
 | 
						|
                        value=5.0,
 | 
						|
                        interactive=True)
 | 
						|
                    self.infer_seed = gr.Slider(
 | 
						|
                        minimum=-1, maximum=10000000, value=2025, label="Seed")
 | 
						|
        #
 | 
						|
        with gr.Accordion(label="Usable without source video", open=False):
 | 
						|
            with gr.Row(equal_height=True):
 | 
						|
                self.output_height = gr.Textbox(
 | 
						|
                    label='resolutions_height',
 | 
						|
                    # value=480,
 | 
						|
                    value=720,
 | 
						|
                    interactive=True)
 | 
						|
                self.output_width = gr.Textbox(
 | 
						|
                    label='resolutions_width',
 | 
						|
                    # value=832,
 | 
						|
                    value=1280,
 | 
						|
                    interactive=True)
 | 
						|
                self.frame_rate = gr.Textbox(
 | 
						|
                    label='frame_rate', value=16, interactive=True)
 | 
						|
                self.num_frames = gr.Textbox(
 | 
						|
                    label='num_frames', value=81, interactive=True)
 | 
						|
        #
 | 
						|
        with gr.Row(equal_height=True):
 | 
						|
            with gr.Column(scale=5):
 | 
						|
                self.generate_button = gr.Button(
 | 
						|
                    value='Run',
 | 
						|
                    elem_classes='type_row',
 | 
						|
                    elem_id='generate_button',
 | 
						|
                    visible=True)
 | 
						|
            with gr.Column(scale=1):
 | 
						|
                self.refresh_button = gr.Button(value='\U0001f504')  # 🔄
 | 
						|
        #
 | 
						|
        self.output_gallery = gr.Gallery(
 | 
						|
            label="output_gallery",
 | 
						|
            value=[],
 | 
						|
            interactive=False,
 | 
						|
            allow_preview=True,
 | 
						|
            preview=True)
 | 
						|
 | 
						|
    def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
 | 
						|
                 src_ref_image_2, src_ref_image_3, prompt, negative_prompt,
 | 
						|
                 shift_scale, sample_steps, context_scale, guide_scale,
 | 
						|
                 infer_seed, output_height, output_width, frame_rate,
 | 
						|
                 num_frames):
 | 
						|
        output_height, output_width, frame_rate, num_frames = int(
 | 
						|
            output_height), int(output_width), int(frame_rate), int(num_frames)
 | 
						|
        src_ref_images = [
 | 
						|
            x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3]
 | 
						|
            if x is not None
 | 
						|
        ]
 | 
						|
        src_video, src_mask, src_ref_images = self.pipe.prepare_source(
 | 
						|
            [src_video], [src_mask], [src_ref_images],
 | 
						|
            num_frames=num_frames,
 | 
						|
            image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
 | 
						|
            device=self.pipe.device)
 | 
						|
        video = self.pipe.generate(
 | 
						|
            prompt,
 | 
						|
            src_video,
 | 
						|
            src_mask,
 | 
						|
            src_ref_images,
 | 
						|
            size=(output_width, output_height),
 | 
						|
            context_scale=context_scale,
 | 
						|
            shift=shift_scale,
 | 
						|
            sampling_steps=sample_steps,
 | 
						|
            guide_scale=guide_scale,
 | 
						|
            n_prompt=negative_prompt,
 | 
						|
            seed=infer_seed,
 | 
						|
            offload_model=True)
 | 
						|
 | 
						|
        name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
 | 
						|
        video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
 | 
						|
        video_frames = (
 | 
						|
            torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) *
 | 
						|
            255).cpu().numpy().astype(np.uint8)
 | 
						|
 | 
						|
        try:
 | 
						|
            writer = imageio.get_writer(
 | 
						|
                video_path,
 | 
						|
                fps=frame_rate,
 | 
						|
                codec='libx264',
 | 
						|
                quality=8,
 | 
						|
                macro_block_size=1)
 | 
						|
            for frame in video_frames:
 | 
						|
                writer.append_data(frame)
 | 
						|
            writer.close()
 | 
						|
            print(video_path)
 | 
						|
        except Exception as e:
 | 
						|
            raise gr.Error(f"Video save error: {e}")
 | 
						|
 | 
						|
        if self.gallery_share:
 | 
						|
            self.gallery_share_data.add(video_path)
 | 
						|
            return self.gallery_share_data.get()
 | 
						|
        else:
 | 
						|
            return [video_path]
 | 
						|
 | 
						|
    def set_callbacks(self, **kwargs):
 | 
						|
        self.gen_inputs = [
 | 
						|
            self.output_gallery, self.src_video, self.src_mask,
 | 
						|
            self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3,
 | 
						|
            self.prompt, self.negative_prompt, self.shift_scale,
 | 
						|
            self.sample_steps, self.context_scale, self.guide_scale,
 | 
						|
            self.infer_seed, self.output_height, self.output_width,
 | 
						|
            self.frame_rate, self.num_frames
 | 
						|
        ]
 | 
						|
        self.gen_outputs = [self.output_gallery]
 | 
						|
        self.generate_button.click(
 | 
						|
            self.generate,
 | 
						|
            inputs=self.gen_inputs,
 | 
						|
            outputs=self.gen_outputs,
 | 
						|
            queue=True)
 | 
						|
        self.refresh_button.click(
 | 
						|
            lambda x: self.gallery_share_data.get()
 | 
						|
            if self.gallery_share else x,
 | 
						|
            inputs=[self.output_gallery],
 | 
						|
            outputs=[self.output_gallery])
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    parser = argparse.ArgumentParser(
 | 
						|
        description='Argparser for VACE-WAN Demo:\n')
 | 
						|
    parser.add_argument(
 | 
						|
        '--server_port', dest='server_port', help='', type=int, default=7860)
 | 
						|
    parser.add_argument(
 | 
						|
        '--server_name', dest='server_name', help='', default='0.0.0.0')
 | 
						|
    parser.add_argument('--root_path', dest='root_path', help='', default=None)
 | 
						|
    parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
 | 
						|
    parser.add_argument(
 | 
						|
        "--mp",
 | 
						|
        action="store_true",
 | 
						|
        help="Use Multi-GPUs",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--model_name",
 | 
						|
        type=str,
 | 
						|
        default="vace-14B",
 | 
						|
        choices=list(WAN_CONFIGS.keys()),
 | 
						|
        help="The model name to run.")
 | 
						|
    parser.add_argument(
 | 
						|
        "--ulysses_size",
 | 
						|
        type=int,
 | 
						|
        default=1,
 | 
						|
        help="The size of the ulysses parallelism in DiT.")
 | 
						|
    parser.add_argument(
 | 
						|
        "--ring_size",
 | 
						|
        type=int,
 | 
						|
        default=1,
 | 
						|
        help="The size of the ring attention parallelism in DiT.")
 | 
						|
    parser.add_argument(
 | 
						|
        "--ckpt_dir",
 | 
						|
        type=str,
 | 
						|
        # default='models/VACE-Wan2.1-1.3B-Preview',
 | 
						|
        default='models/Wan2.1-VACE-14B/',
 | 
						|
        help="The path to the checkpoint directory.",
 | 
						|
    )
 | 
						|
    parser.add_argument(
 | 
						|
        "--offload_to_cpu",
 | 
						|
        action="store_true",
 | 
						|
        help="Offloading unnecessary computations to CPU.",
 | 
						|
    )
 | 
						|
 | 
						|
    args = parser.parse_args()
 | 
						|
 | 
						|
    if not os.path.exists(args.save_dir):
 | 
						|
        os.makedirs(args.save_dir, exist_ok=True)
 | 
						|
 | 
						|
    with gr.Blocks() as demo:
 | 
						|
        infer_gr = VACEInference(
 | 
						|
            args, skip_load=False, gallery_share=True, gallery_share_limit=5)
 | 
						|
        infer_gr.create_ui()
 | 
						|
        infer_gr.set_callbacks()
 | 
						|
        allowed_paths = [args.save_dir]
 | 
						|
        demo.queue(status_update_rate=1).launch(
 | 
						|
            server_name=args.server_name,
 | 
						|
            server_port=args.server_port,
 | 
						|
            root_path=args.root_path,
 | 
						|
            allowed_paths=allowed_paths,
 | 
						|
            show_error=True,
 | 
						|
            debug=True)
 |