# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import gc import os.path as osp import os import sys import warnings import gradio as gr warnings.filterwarnings('ignore') # Model sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) import wan from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_video # Global Var prompt_expander = None wan_i2v_480P = None wan_i2v_720P = None # Button Func def load_i2v_model(value): global wan_i2v_480P, wan_i2v_720P from mmgp import offload if value == '------': print("No model loaded") return '------' if value == '720P': if args.ckpt_dir_720p is None: print("Please specify the checkpoint directory for 720P model") return '------' if wan_i2v_720P is not None: pass else: del wan_i2v_480P gc.collect() wan_i2v_480P = None print("load 14B-720P i2v model...", end='', flush=True) cfg = WAN_CONFIGS['i2v-14B'] wan_i2v_720P = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir_720p, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, i2v720p= True ) print("done", flush=True) pipe = {"transformer": wan_i2v_720P.model, "text_encoder" : wan_i2v_720P.text_encoder.model, "text_encoder_2": wan_i2v_720P.clip.model, "vae": wan_i2v_720P.vae.model } # offload.profile(pipe, profile_no=4, budgets = {"transformer":100, "*":3000}, verboseLevel=2, compile="transformer", quantizeTransformer = False, pinnedMemory = False) return '720P' if value == '480P': if args.ckpt_dir_480p is None: print("Please specify the checkpoint directory for 480P model") return '------' if wan_i2v_480P is not None: pass else: del wan_i2v_720P gc.collect() wan_i2v_720P = None print("load 14B-480P i2v model...", end='', flush=True) cfg = WAN_CONFIGS['i2v-14B'] wan_i2v_480P = wan.WanI2V( config=cfg, checkpoint_dir=args.ckpt_dir_480p, device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, i2v720p= False ) print("done", flush=True) pipe = {"transformer": wan_i2v_480P.model, "text_encoder" : wan_i2v_480P.text_encoder.model, "text_encoder_2": wan_i2v_480P.clip.model, "vae": wan_i2v_480P.vae.model } # offload.profile(pipe, profile_no=4, budgets = {"model":100, "*":3000}, verboseLevel=2, compile="transformer" ) return '480P' def prompt_enc(prompt, img, tar_lang): print('prompt extend...') if img is None: print('Please upload an image') return prompt global prompt_expander prompt_output = prompt_expander( prompt, image=img, tar_lang=tar_lang.lower()) if prompt_output.status == False: return prompt else: return prompt_output.prompt def i2v_generation(img2vid_prompt, img2vid_image, res, sd_steps, guide_scale, shift_scale, seed, n_prompt): # print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") global resolution from PIL import Image img2vid_image = Image.open("d:\mammoth2.jpg") if resolution == '------': print( 'Please specify at least one resolution ckpt dir or specify the resolution' ) return None else: if resolution == '720P': global wan_i2v_720P video = wan_i2v_720P.generate( img2vid_prompt, img2vid_image, max_area=MAX_AREA_CONFIGS['720*1280'], shift=shift_scale, sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, offload_model=False) else: global wan_i2v_480P video = wan_i2v_480P.generate( img2vid_prompt, img2vid_image, max_area=MAX_AREA_CONFIGS['480*832'], shift=3.0, #shift_scale sampling_steps=sd_steps, guide_scale=guide_scale, n_prompt=n_prompt, seed=seed, offload_model=False) cache_video( tensor=video[None], save_file="example.mp4", fps=16, nrow=1, normalize=True, value_range=(-1, 1)) return "example.mp4" # Interface def gradio_interface(): with gr.Blocks() as demo: gr.Markdown("""
Wan2.1 (I2V-14B)
Wan: Open and Advanced Large-Scale Video Generative Models.
""") with gr.Row(): with gr.Column(): resolution = gr.Dropdown( label='Resolution', choices=['------', '720P', '480P'], value='------') img2vid_image = gr.Image( type="pil", label="Upload Input Image", elem_id="image_upload", ) img2vid_prompt = gr.Textbox( label="Prompt", value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", placeholder="Describe the video you want to generate", ) tar_lang = gr.Radio( choices=["CH", "EN"], label="Target language of prompt enhance", value="CH") run_p_button = gr.Button(value="Prompt Enhance") with gr.Accordion("Advanced Options", open=True): with gr.Row(): sd_steps = gr.Slider( label="Diffusion steps", minimum=1, maximum=1000, value=50, step=1) guide_scale = gr.Slider( label="Guide scale", minimum=0, maximum=20, value=5.0, step=1) with gr.Row(): shift_scale = gr.Slider( label="Shift scale", minimum=0, maximum=10, value=5.0, step=1) seed = gr.Slider( label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1) n_prompt = gr.Textbox( label="Negative Prompt", placeholder="Describe the negative prompt you want to add" ) run_i2v_button = gr.Button("Generate Video") with gr.Column(): result_gallery = gr.Video( label='Generated Video', interactive=False, height=600) resolution.input( fn=load_model, inputs=[resolution], outputs=[resolution]) run_p_button.click( fn=prompt_enc, inputs=[img2vid_prompt, img2vid_image, tar_lang], outputs=[img2vid_prompt]) run_i2v_button.click( fn=i2v_generation, inputs=[ img2vid_prompt, img2vid_image, resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt ], outputs=[result_gallery], ) return demo # Main def _parse_args(): parser = argparse.ArgumentParser( description="Generate a video from a text prompt or image using Gradio") parser.add_argument( "--ckpt_dir_720p", type=str, default=None, help="The path to the checkpoint directory.") parser.add_argument( "--ckpt_dir_480p", type=str, default=None, help="The path to the checkpoint directory.") parser.add_argument( "--prompt_extend_method", type=str, default="local_qwen", choices=["dashscope", "local_qwen"], help="The prompt extend method to use.") parser.add_argument( "--prompt_extend_model", type=str, default=None, help="The prompt extend model to use.") args = parser.parse_args() args.ckpt_dir_720p = "../ckpts" # os.path.join("ckpt") args.ckpt_dir_480p = "../ckpts" # os.path.join("ckpt") assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory." return args if __name__ == '__main__': args = _parse_args() global resolution # load_model('720P') # resolution = '720P' resolution = '480P' load_model(resolution) print("Step1: Init prompt_expander...", end='', flush=True) if args.prompt_extend_method == "dashscope": prompt_expander = DashScopePromptExpander( model_name=args.prompt_extend_model, is_vl=True) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, is_vl=True, device=0) else: raise NotImplementedError( f"Unsupport prompt_extend_method: {args.prompt_extend_method}") print("done", flush=True) demo = gradio_interface() demo.launch(server_name="0.0.0.0", share=False, server_port=7860)