diff --git a/README.md b/README.md index c832fa3..4b87673 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid ## 🔥 Latest News!! +* Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights! * Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback! * Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try! * Feb 27, 2025: 👋 **Wan2.1** has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy! @@ -54,6 +55,13 @@ If your work has improved **Wan2.1** and you would like more people to see it, p - [x] ComfyUI integration - [x] Diffusers integration - [ ] Diffusers + Multi-GPU Inference +- Wan2.1 First-Last-Frame-to-Video + - [x] Multi-GPU Inference code of the 14B model + - [x] Checkpoints of the 14B model + - [x] Gradio demo + - [ ] ComfyUI integration + - [ ] Diffusers integration + - [ ] Diffusers + Multi-GPU Inference ## Quickstart @@ -74,14 +82,17 @@ pip install -r requirements.txt #### Model Download -| Models | Download Link | Notes | -| --------------|-------------------------------------------------------------------------------|-------------------------------| -| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P -| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P -| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P -| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P +| Models | Download Link | Notes | +|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------| +| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P +| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P +| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P +| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P +| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P -> 💡Note: The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution. +> 💡Note: +> * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution. +> * For the first-last frame to video generation, we train our model primarily on Chinese text-video pairs. Therefore, we recommend using Chinese prompt to achieve better results. Download models using huggingface-cli: @@ -185,7 +196,7 @@ DASH_API_KEY=your_key python generate.py --task t2v-14B --size 1280*720 --ckpt_ - By default, the Qwen model on HuggingFace is used for this extension. Users can choose Qwen models or other models based on the available GPU memory size. - For text-to-video tasks, you can use models like `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-7B-Instruct` and `Qwen/Qwen2.5-3B-Instruct`. - - For image-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`. + - For image-to-video or first-last-frame-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`. - Larger models generally provide better extension results but require more GPU memory. - You can modify the model used for extension with the parameter `--prompt_extend_model` , allowing you to specify either a local model path or a Hugging Face model. For example: @@ -367,6 +378,74 @@ DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashsc ``` +#### Run First-Last-Frame-to-Video Generation + +First-Last-Frame-to-Video is also divided into processes with and without the prompt extension step. Currently, only 720P is supported. The specific parameters and corresponding settings are as follows: + + + + + + + + + + + + + + + + + + + + +
TaskResolutionModel
480P720P
flf2v-14B✔️Wan2.1-FLF2V-14B-720P
+ + +##### (1) Without Prompt Extension + +- Single-GPU inference +```sh +python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." +``` + +> 💡Similar to Image-to-Video, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image. + + +- Multi-GPU inference using FSDP + xDiT USP + +```sh +pip install "xfuser>=0.4.1" +torchrun --nproc_per_node=8 generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." +``` + +##### (2) Using Prompt Extension + + +The process of prompt extension can be referenced [here](#2-using-prompt-extention). + +Run with local prompt extension using `Qwen/Qwen2.5-VL-7B-Instruct`: +``` +python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." +``` + +Run with remote prompt extension using `dashscope`: +``` +DASH_API_KEY=your_key python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --use_prompt_extend --prompt_extend_method 'dashscope' --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." +``` + + +##### (3) Running local gradio + +```sh +cd gradio +# use 720P model in gradio +DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_720p ./Wan2.1-FLF2V-14B-720P +``` + + #### Run Text-to-Image Generation Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows: diff --git a/examples/flf2v_input_first_frame.png b/examples/flf2v_input_first_frame.png new file mode 100644 index 0000000..032cd5c Binary files /dev/null and b/examples/flf2v_input_first_frame.png differ diff --git a/examples/flf2v_input_last_frame.png b/examples/flf2v_input_last_frame.png new file mode 100644 index 0000000..83ac8c5 Binary files /dev/null and b/examples/flf2v_input_last_frame.png differ diff --git a/generate.py b/generate.py index 1b1a9d7..73f273e 100644 --- a/generate.py +++ b/generate.py @@ -33,6 +33,14 @@ EXAMPLE_PROMPT = { "image": "examples/i2v_input.JPG", }, + "flf2v-14B": { + "prompt": + "CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。", + "first_frame": + "examples/flf2v_input_first_frame.png", + "last_frame": + "examples/flf2v_input_last_frame.png", + }, } @@ -50,6 +58,8 @@ def _validate_args(args): args.sample_shift = 5.0 if "i2v" in args.task and args.size in ["832*480", "480*832"]: args.sample_shift = 3.0 + if "flf2v" in args.task: + args.sample_shift = 16 # The default number of frames are 1 for text-to-image tasks and 81 for other tasks. if args.frame_num is None: @@ -167,7 +177,17 @@ def _parse_args(): "--image", type=str, default=None, - help="The image to generate the video from.") + help="[image to video] The image to generate the video from.") + parser.add_argument( + "--first_frame", + type=str, + default=None, + help="[first-last frame to video] The image (first frame) to generate the video from.") + parser.add_argument( + "--last_frame", + type=str, + default=None, + help="[first-last frame to video] The image (last frame) to generate the video from.") parser.add_argument( "--sample_solver", type=str, @@ -248,7 +268,7 @@ def generate(args): if args.use_prompt_extend: if args.prompt_extend_method == "dashscope": prompt_expander = DashScopePromptExpander( - model_name=args.prompt_extend_model, is_vl="i2v" in args.task) + model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task) elif args.prompt_extend_method == "local_qwen": prompt_expander = QwenPromptExpander( model_name=args.prompt_extend_model, @@ -321,7 +341,7 @@ def generate(args): seed=args.base_seed, offload_model=args.offload_model) - else: + elif "i2v" in args.task: if args.prompt is None: args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] if args.image is None: @@ -377,6 +397,66 @@ def generate(args): guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) + else: + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + if args.first_frame is None or args.last_frame is None: + args.first_frame = EXAMPLE_PROMPT[args.task]["first_frame"] + args.last_frame = EXAMPLE_PROMPT[args.task]["last_frame"] + logging.info(f"Input prompt: {args.prompt}") + logging.info(f"Input first frame: {args.first_frame}") + logging.info(f"Input last frame: {args.last_frame}") + first_frame = Image.open(args.first_frame).convert("RGB") + last_frame = Image.open(args.last_frame).convert("RGB") + if args.use_prompt_extend: + logging.info("Extending prompt ...") + if rank == 0: + prompt_output = prompt_expander( + args.prompt, + tar_lang=args.prompt_extend_target_lang, + image=[first_frame, last_frame], + seed=args.base_seed) + if prompt_output.status == False: + logging.info( + f"Extending prompt failed: {prompt_output.message}") + logging.info("Falling back to original prompt.") + input_prompt = args.prompt + else: + input_prompt = prompt_output.prompt + input_prompt = [input_prompt] + else: + input_prompt = [None] + if dist.is_initialized(): + dist.broadcast_object_list(input_prompt, src=0) + args.prompt = input_prompt[0] + logging.info(f"Extended prompt: {args.prompt}") + + logging.info("Creating WanFLF2V pipeline.") + wan_flf2v = wan.WanFLF2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) + + logging.info("Generating video ...") + video = wan_flf2v.generate( + args.prompt, + first_frame, + last_frame, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model + ) if rank == 0: if args.save_file is None: diff --git a/gradio/fl2v_14B_singleGPU.py b/gradio/fl2v_14B_singleGPU.py new file mode 100644 index 0000000..476a136 --- /dev/null +++ b/gradio/fl2v_14B_singleGPU.py @@ -0,0 +1,252 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import gc +import os.path as osp +import os +import sys +import warnings + +import gradio as gr + +warnings.filterwarnings('ignore') + +# Model +sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) +import wan +from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS +from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander +from wan.utils.utils import cache_video + +# Global Var +prompt_expander = None +wan_flf2v_720P = None + + +# Button Func +def load_model(value): + global wan_flf2v_720P + + if value == '------': + print("No model loaded") + return '------' + + if value == '720P': + if args.ckpt_dir_720p is None: + print("Please specify the checkpoint directory for 720P model") + return '------' + if wan_flf2v_720P is not None: + pass + else: + gc.collect() + + print("load 14B-720P flf2v model...", end='', flush=True) + cfg = WAN_CONFIGS['flf2v-14B'] + wan_flf2v_720P = wan.WanFLF2V( + config=cfg, + checkpoint_dir=args.ckpt_dir_720p, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_usp=False, + ) + print("done", flush=True) + return '720P' + return value + + +def prompt_enc(prompt, img_first, img_last, tar_lang): + print('prompt extend...') + if img_first is None or img_last is None: + print('Please upload the first and last frames') + return prompt + global prompt_expander + prompt_output = prompt_expander( + prompt, image=[img_first, img_last], tar_lang=tar_lang.lower()) + if prompt_output.status == False: + return prompt + else: + return prompt_output.prompt + + +def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, + guide_scale, shift_scale, seed, n_prompt): + + if resolution == '------': + print( + 'Please specify the resolution ckpt dir or specify the resolution' + ) + return None + + else: + if resolution == '720P': + global wan_flf2v_720P + video = wan_flf2v_720P.generate( + flf2vid_prompt, + flf2vid_image_first, + flf2vid_image_last, + max_area=MAX_AREA_CONFIGS['720*1280'], + shift=shift_scale, + sampling_steps=sd_steps, + guide_scale=guide_scale, + n_prompt=n_prompt, + seed=seed, + offload_model=True) + pass + else: + print( + 'Sorry, currently only 720P is supported.' + ) + return None + + cache_video( + tensor=video[None], + save_file="example.mp4", + fps=16, + nrow=1, + normalize=True, + value_range=(-1, 1)) + + return "example.mp4" + + +# Interface +def gradio_interface(): + with gr.Blocks() as demo: + gr.Markdown(""" +
+ Wan2.1 (FLF2V-14B) +
+
+ Wan: Open and Advanced Large-Scale Video Generative Models. +
+ """) + + with gr.Row(): + with gr.Column(): + resolution = gr.Dropdown( + label='Resolution', + choices=['------', '720P'], + value='------') + flf2vid_image_first = gr.Image( + type="pil", + label="Upload First Frame", + elem_id="image_upload", + ) + flf2vid_image_last = gr.Image( + type="pil", + label="Upload Last Frame", + elem_id="image_upload", + ) + flf2vid_prompt = gr.Textbox( + label="Prompt", + placeholder="Describe the video you want to generate", + ) + tar_lang = gr.Radio( + choices=["ZH", "EN"], + label="Target language of prompt enhance", + value="ZH") + run_p_button = gr.Button(value="Prompt Enhance") + + with gr.Accordion("Advanced Options", open=True): + with gr.Row(): + sd_steps = gr.Slider( + label="Diffusion steps", + minimum=1, + maximum=1000, + value=50, + step=1) + guide_scale = gr.Slider( + label="Guide scale", + minimum=0, + maximum=20, + value=5.0, + step=1) + with gr.Row(): + shift_scale = gr.Slider( + label="Shift scale", + minimum=0, + maximum=20, + value=5.0, + step=1) + seed = gr.Slider( + label="Seed", + minimum=-1, + maximum=2147483647, + step=1, + value=-1) + n_prompt = gr.Textbox( + label="Negative Prompt", + placeholder="Describe the negative prompt you want to add" + ) + + run_flf2v_button = gr.Button("Generate Video") + + with gr.Column(): + result_gallery = gr.Video( + label='Generated Video', interactive=False, height=600) + + resolution.input( + fn=load_model, inputs=[resolution], outputs=[resolution]) + + run_p_button.click( + fn=prompt_enc, + inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang], + outputs=[flf2vid_prompt]) + + run_flf2v_button.click( + fn=flf2v_generation, + inputs=[ + flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps, + guide_scale, shift_scale, seed, n_prompt + ], + outputs=[result_gallery], + ) + + return demo + + +# Main +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a video from a text prompt or image using Gradio") + parser.add_argument( + "--ckpt_dir_720p", + type=str, + default=None, + help="The path to the checkpoint directory.") + parser.add_argument( + "--prompt_extend_method", + type=str, + default="local_qwen", + choices=["dashscope", "local_qwen"], + help="The prompt extend method to use.") + parser.add_argument( + "--prompt_extend_model", + type=str, + default=None, + help="The prompt extend model to use.") + + args = parser.parse_args() + assert args.ckpt_dir_720p is not None, "Please specify the checkpoint directory." + + return args + + +if __name__ == '__main__': + args = _parse_args() + + print("Step1: Init prompt_expander...", end='', flush=True) + if args.prompt_extend_method == "dashscope": + prompt_expander = DashScopePromptExpander( + model_name=args.prompt_extend_model, is_vl=True) + elif args.prompt_extend_method == "local_qwen": + prompt_expander = QwenPromptExpander( + model_name=args.prompt_extend_model, is_vl=True, device=0) + else: + raise NotImplementedError( + f"Unsupport prompt_extend_method: {args.prompt_extend_method}") + print("done", flush=True) + + demo = gradio_interface() + demo.launch(server_name="0.0.0.0", share=False, server_port=7860) diff --git a/wan/__init__.py b/wan/__init__.py index df36ebe..d6c25f4 100644 --- a/wan/__init__.py +++ b/wan/__init__.py @@ -1,3 +1,4 @@ from . import configs, distributed, modules from .image2video import WanI2V from .text2video import WanT2V +from .first_last_frame2video import WanFLF2V diff --git a/wan/configs/__init__.py b/wan/configs/__init__.py index c72d2d0..cccda2f 100644 --- a/wan/configs/__init__.py +++ b/wan/configs/__init__.py @@ -12,11 +12,17 @@ from .wan_t2v_14B import t2v_14B t2i_14B = copy.deepcopy(t2v_14B) t2i_14B.__name__ = 'Config: Wan T2I 14B' +# the config of flf2v_14B is the same as i2v_14B +flf2v_14B = copy.deepcopy(i2v_14B) +flf2v_14B.__name__ = 'Config: Wan FLF2V 14B' +flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt + WAN_CONFIGS = { 't2v-14B': t2v_14B, 't2v-1.3B': t2v_1_3B, 'i2v-14B': i2v_14B, 't2i-14B': t2i_14B, + 'flf2v-14B': flf2v_14B } SIZE_CONFIGS = { @@ -38,5 +44,6 @@ SUPPORTED_SIZES = { 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 't2v-1.3B': ('480*832', '832*480'), 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 't2i-14B': tuple(SIZE_CONFIGS.keys()), } diff --git a/wan/configs/wan_i2v_14B.py b/wan/configs/wan_i2v_14B.py index 12e8e20..53bf221 100644 --- a/wan/configs/wan_i2v_14B.py +++ b/wan/configs/wan_i2v_14B.py @@ -8,6 +8,7 @@ from .shared_config import wan_shared_cfg i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') i2v_14B.update(wan_shared_cfg) +i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' i2v_14B.t5_tokenizer = 'google/umt5-xxl' diff --git a/wan/first_last_frame2video.py b/wan/first_last_frame2video.py new file mode 100644 index 0000000..4f300ca --- /dev/null +++ b/wan/first_last_frame2video.py @@ -0,0 +1,370 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +import torchvision.transforms.functional as TF +from tqdm import tqdm + +from .distributed.fsdp import shard_model +from .modules.clip import CLIPModel +from .modules.model import WanModel +from .modules.t5 import T5EncoderModel +from .modules.vae import WanVAE +from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) +from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + + +class WanFLF2V: + + def __init__( + self, + config, + checkpoint_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_usp=False, + t5_cpu=False, + init_on_cpu=True, + ): + r""" + Initializes the image-to-video generation model components. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for T5 model + dit_fsdp (`bool`, *optional*, defaults to False): + Enable FSDP sharding for DiT model + use_usp (`bool`, *optional*, defaults to False): + Enable distribution strategy of USP. + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + init_on_cpu (`bool`, *optional*, defaults to True): + Enable initializing Transformer Model on CPU. Only works without FSDP or USP. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.use_usp = use_usp + self.t5_cpu = t5_cpu + + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + shard_fn = partial(shard_model, device_id=device_id) + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), + shard_fn=shard_fn if t5_fsdp else None, + ) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE( + vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), + device=self.device) + + self.clip = CLIPModel( + dtype=config.clip_dtype, + device=self.device, + checkpoint_path=os.path.join(checkpoint_dir, + config.clip_checkpoint), + tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) + + logging.info(f"Creating WanModel from {checkpoint_dir}") + self.model = WanModel.from_pretrained(checkpoint_dir) + self.model.eval().requires_grad_(False) + + if t5_fsdp or dit_fsdp or use_usp: + init_on_cpu = False + + if use_usp: + from xfuser.core.distributed import \ + get_sequence_parallel_world_size + + from .distributed.xdit_context_parallel import (usp_attn_forward, + usp_dit_forward) + for block in self.model.blocks: + block.self_attn.forward = types.MethodType( + usp_attn_forward, block.self_attn) + self.model.forward = types.MethodType(usp_dit_forward, self.model) + self.sp_size = get_sequence_parallel_world_size() + else: + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + if dit_fsdp: + self.model = shard_fn(self.model) + else: + if not init_on_cpu: + self.model.to(self.device) + + self.sample_neg_prompt = config.sample_neg_prompt + + def generate(self, + input_prompt, + first_frame, + last_frame, + max_area=720 * 1280, + frame_num=81, + shift=16, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.5, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from input first-last frame and text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation. + first_frame (PIL.Image.Image): + Input image tensor. Shape: [3, H, W] + last_frame (PIL.Image.Image): + Input image tensor. Shape: [3, H, W] + [NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized + to match first_frame. + max_area (`int`, *optional*, defaults to 720*1280): + Maximum pixel area for latent space calculation. Controls video resolution scaling + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from max_area) + - W: Frame width from max_area) + """ + first_frame_size = first_frame.size + last_frame_size = last_frame.size + first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(self.device) + last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(self.device) + + F = frame_num + first_frame_h, first_frame_w = first_frame.shape[1:] + aspect_ratio = first_frame_h / first_frame_w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // + self.patch_size[1] * self.patch_size[1]) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // + self.patch_size[2] * self.patch_size[2]) + first_frame_h = lat_h * self.vae_stride[1] + first_frame_w = lat_w * self.vae_stride[2] + if first_frame_size != last_frame_size: + # 1. resize + last_frame_resize_ratio = max( + first_frame_size[0] / last_frame_size[0], + first_frame_size[1] / last_frame_size[1] + ) + last_frame_size = [ + round(last_frame_size[0] * last_frame_resize_ratio), + round(last_frame_size[1] * last_frame_resize_ratio), + ] + # 2. center crop + last_frame = TF.center_crop(last_frame, last_frame_size) + + max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( + self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + noise = torch.randn( + 16, + (F - 1) // 4 + 1, + lat_h, + lat_w, + dtype=torch.float32, + generator=seed_g, + device=self.device) + + msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) + msk[:, 1: -1] = 0 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + # preprocess + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + self.clip.model.to(self.device) + clip_context = self.clip.visual([first_frame[:, None, :, :], last_frame[:, None, :, :]]) + if offload_model: + self.clip.model.cpu() + + y = self.vae.encode([ + torch.concat([ + torch.nn.functional.interpolate( + first_frame[None].cpu(), + size=(first_frame_h, first_frame_w), + mode='bicubic' + ).transpose(0, 1), + torch.zeros(3, F - 2, first_frame_h, first_frame_w), + torch.nn.functional.interpolate( + last_frame[None].cpu(), + size=(first_frame_h, first_frame_w), + mode='bicubic' + ).transpose(0, 1), + ], dim=1).to(self.device) + ])[0] + y = torch.concat([msk, y]) + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latent = noise + + arg_c = { + 'context': [context[0]], + 'clip_fea': clip_context, + 'seq_len': max_seq_len, + 'y': [y], + } + + arg_null = { + 'context': context_null, + 'clip_fea': clip_context, + 'seq_len': max_seq_len, + 'y': [y], + } + + if offload_model: + torch.cuda.empty_cache() + + self.model.to(self.device) + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = [latent.to(self.device)] + timestep = [t] + + timestep = torch.stack(timestep).to(self.device) + + noise_pred_cond = self.model( + latent_model_input, t=timestep, **arg_c)[0].to( + torch.device('cpu') if offload_model else self.device) + if offload_model: + torch.cuda.empty_cache() + noise_pred_uncond = self.model( + latent_model_input, t=timestep, **arg_null)[0].to( + torch.device('cpu') if offload_model else self.device) + if offload_model: + torch.cuda.empty_cache() + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + latent = latent.to( + torch.device('cpu') if offload_model else self.device) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latent.unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latent = temp_x0.squeeze(0) + + x0 = [latent.to(self.device)] + del latent_model_input, timestep + + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + + if self.rank == 0: + videos = self.vae.decode(x0) + + del noise, latent + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None diff --git a/wan/image2video.py b/wan/image2video.py index 468f17c..5004f46 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -197,7 +197,7 @@ class WanI2V: seed_g.manual_seed(seed) noise = torch.randn( 16, - 21, + (F - 1) // 4 + 1, lat_h, lat_w, dtype=torch.float32, @@ -239,7 +239,7 @@ class WanI2V: torch.nn.functional.interpolate( img[None].cpu(), size=(h, w), mode='bicubic').transpose( 0, 1), - torch.zeros(3, 80, h, w) + torch.zeros(3, F - 1, h, w) ], dim=1).to(self.device) ])[0] diff --git a/wan/modules/model.py b/wan/modules/model.py index b65021c..7c6bddb 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -11,6 +11,9 @@ from .attention import flash_attention __all__ = ['WanModel'] +T5_CONTEXT_TOKEN_NUMBER = 512 +FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 + def sinusoidal_embedding_1d(dim, position): # preprocess @@ -203,8 +206,9 @@ class WanI2VCrossAttention(WanSelfAttention): context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B] """ - context_img = context[:, :257] - context = context[:, 257:] + image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER + context_img = context[:, :image_context_length] + context = context[:, image_context_length:] b, n, d = x.size(0), self.num_heads, self.head_dim # compute query, key, value @@ -269,7 +273,7 @@ class WanAttentionBlock(nn.Module): nn.Linear(ffn_dim, dim)) # modulation - self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5) def forward( self, @@ -328,7 +332,7 @@ class Head(nn.Module): self.head = nn.Linear(dim, out_dim) # modulation - self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5) def forward(self, x, e): r""" @@ -345,15 +349,21 @@ class Head(nn.Module): class MLPProj(torch.nn.Module): - def __init__(self, in_dim, out_dim): + def __init__(self, in_dim, out_dim, flf_pos_emb=False): super().__init__() self.proj = torch.nn.Sequential( torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim), torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim), torch.nn.LayerNorm(out_dim)) + if flf_pos_emb: # NOTE: we only use this for `flf2v` + self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280)) def forward(self, image_embeds): + if hasattr(self, 'emb_pos'): + bs, n, d = image_embeds.shape + image_embeds = image_embeds.view(-1, 2 * n, d) + image_embeds = image_embeds + self.emb_pos clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens @@ -390,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin): Args: model_type (`str`, *optional*, defaults to 't2v'): - Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) text_len (`int`, *optional*, defaults to 512): @@ -423,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin): super().__init__() - assert model_type in ['t2v', 'i2v'] + assert model_type in ['t2v', 'i2v', 'flf2v'] self.model_type = model_type self.patch_size = patch_size @@ -473,8 +483,8 @@ class WanModel(ModelMixin, ConfigMixin): ], dim=1) - if model_type == 'i2v': - self.img_emb = MLPProj(1280, dim) + if model_type == 'i2v' or model_type == 'flf2v': + self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v') # initialize weights self.init_weights() @@ -501,7 +511,7 @@ class WanModel(ModelMixin, ConfigMixin): seq_len (`int`): Maximum sequence length for positional encoding clip_fea (Tensor, *optional*): - CLIP image features for image-to-video mode + CLIP image features for image-to-video mode or first-last-frame-to-video mode y (List[Tensor], *optional*): Conditional video inputs for image-to-video mode, same shape as x @@ -509,7 +519,7 @@ class WanModel(ModelMixin, ConfigMixin): List[Tensor]: List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] """ - if self.model_type == 'i2v': + if self.model_type == 'i2v' or self.model_type == 'flf2v': assert clip_fea is not None and y is not None # params device = self.patch_embedding.weight.device @@ -548,7 +558,7 @@ class WanModel(ModelMixin, ConfigMixin): ])) if clip_fea is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context_clip = self.img_emb(clip_fea) # bs x 257 (x2) x dim context = torch.concat([context_clip, context], dim=1) # arguments diff --git a/wan/utils/prompt_extend.py b/wan/utils/prompt_extend.py index f3981a8..5e3a216 100644 --- a/wan/utils/prompt_extend.py +++ b/wan/utils/prompt_extend.py @@ -7,7 +7,7 @@ import sys import tempfile from dataclasses import dataclass from http import HTTPStatus -from typing import Optional, Union +from typing import Optional, Union, List import dashscope import torch @@ -97,6 +97,59 @@ VL_EN_SYS_PROMPT = \ '''Directly output the rewritten English text.''' +VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师,旨在参考用户输入的图像的细节内容,把用户输入的Prompt改写为优质Prompt,使其更完整、更具表现力,同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写,严格参考示例的格式进行改写 +任务要求: +1. 用户会输入两张图片,第一张是视频的第一帧,第二张时视频的最后一帧,你需要综合两个照片的内容进行优化改写 +2. 对于过于简短的用户输入,在不改变原意前提下,合理推断并补充细节,使得画面更加完整好看; +3. 完善用户描述中出现的主体特征(如外貌、表情,数量、种族、姿态等)、画面风格、空间关系、镜头景别; +4. 整体中文输出,保留引号、书名号中原文以及重要的输入信息,不要改写; +5. Prompt应匹配符合用户意图且精准细分的风格描述。如果用户未指定,则根据用户提供的照片的风格,你需要仔细分析照片的风格,并参考风格进行改写。 +6. 如果Prompt是古诗词,应该在生成的Prompt中强调中国古典元素,避免出现西方、现代、外国场景; +7. 你需要强调输入中的运动信息和不同的镜头运镜; +8. 你的输出应当带有自然运动属性,需要根据描述主体目标类别增加这个目标的自然动作,描述尽可能用简单直接的动词; +9. 你需要尽可能的参考图片的细节信息,如人物动作、服装、背景等,强调照片的细节元素; +10. 你需要强调两画面可能出现的潜在变化,如“走进”,“出现”,“变身成”,“镜头左移”,“镜头右移动”,“镜头上移动”, “镜头下移”等等; +11. 无论用户输入那种语言,你都需要输出中文; +12. 改写后的prompt字数控制在80-100字左右; +改写后 prompt 示例: +1. 日系小清新胶片写真,扎着双麻花辫的年轻东亚女孩坐在船边。女孩穿着白色方领泡泡袖连衣裙,裙子上有褶皱和纽扣装饰。她皮肤白皙,五官清秀,眼神略带忧郁,直视镜头。女孩的头发自然垂落,刘海遮住部分额头。她双手扶船,姿态自然放松。背景是模糊的户外场景,隐约可见蓝天、山峦和一些干枯植物。复古胶片质感照片。中景半身坐姿人像。 +2. 二次元厚涂动漫插画,一个猫耳兽耳白人少女手持文件夹,神情略带不满。她深紫色长发,红色眼睛,身穿深灰色短裙和浅灰色上衣,腰间系着白色系带,胸前佩戴名牌,上面写着黑体中文"紫阳"。淡黄色调室内背景,隐约可见一些家具轮廓。少女头顶有一个粉色光圈。线条流畅的日系赛璐璐风格。近景半身略俯视视角。 +3. CG游戏概念数字艺术,一只巨大的鳄鱼张开大嘴,背上长着树木和荆棘。鳄鱼皮肤粗糙,呈灰白色,像是石头或木头的质感。它背上生长着茂盛的树木、灌木和一些荆棘状的突起。鳄鱼嘴巴大张,露出粉红色的舌头和锋利的牙齿。画面背景是黄昏的天空,远处有一些树木。场景整体暗黑阴冷。近景,仰视视角。 +4. 美剧宣传海报风格,身穿黄色防护服的Walter White坐在金属折叠椅上,上方无衬线英文写着"Breaking Bad",周围是成堆的美元和蓝色塑料储物箱。他戴着眼镜目光直视前方,身穿黄色连体防护服,双手放在膝盖上,神态稳重自信。背景是一个废弃的阴暗厂房,窗户透着光线。带有明显颗粒质感纹理。中景,镜头下移。 +请直接输出改写后的文本,不要进行多余的回复。""" + +VL_EN_SYS_PROMPT_FOR_MULTI_IMAGES = \ + '''You are a prompt optimization specialist whose goal is to rewrite the user's input prompts into high-quality English prompts by referring to the details of the user's input images, making them more complete and expressive while maintaining the original meaning. You need to integrate the content of the user's photo with the input prompt for the rewrite, strictly adhering to the formatting of the examples provided.\n''' \ + '''Task Requirements:\n''' \ + '''1. The user will input two images, the first is the first frame of the video, and the second is the last frame of the video. You need to integrate the content of the two photos with the input prompt for the rewrite.\n''' \ + '''2. For overly brief user inputs, reasonably infer and supplement details without changing the original meaning, making the image more complete and visually appealing;\n''' \ + '''3. Improve the characteristics of the main subject in the user's description (such as appearance, expression, quantity, ethnicity, posture, etc.), rendering style, spatial relationships, and camera angles;\n''' \ + '''4. The overall output should be in Chinese, retaining original text in quotes and book titles as well as important input information without rewriting them;\n''' \ + '''5. The prompt should match the user’s intent and provide a precise and detailed style description. If the user has not specified a style, you need to carefully analyze the style of the user's provided photo and use that as a reference for rewriting;\n''' \ + '''6. If the prompt is an ancient poem, classical Chinese elements should be emphasized in the generated prompt, avoiding references to Western, modern, or foreign scenes;\n''' \ + '''7. You need to emphasize movement information in the input and different camera angles;\n''' \ + '''8. Your output should convey natural movement attributes, incorporating natural actions related to the described subject category, using simple and direct verbs as much as possible;\n''' \ + '''9. You should reference the detailed information in the image, such as character actions, clothing, backgrounds, and emphasize the details in the photo;\n''' \ + '''10. You need to emphasize potential changes that may occur between the two frames, such as "walking into", "appearing", "turning into", "camera left", "camera right", "camera up", "camera down", etc.;\n''' \ + '''11. Control the rewritten prompt to around 80-100 words.\n''' \ + '''12. No matter what language the user inputs, you must always output in English.\n''' \ + '''Example of the rewritten English prompt:\n''' \ + '''1. A Japanese fresh film-style photo of a young East Asian girl with double braids sitting by the boat. The girl wears a white square collar puff sleeve dress, decorated with pleats and buttons. She has fair skin, delicate features, and slightly melancholic eyes, staring directly at the camera. Her hair falls naturally, with bangs covering part of her forehead. She rests her hands on the boat, appearing natural and relaxed. The background features a blurred outdoor scene, with hints of blue sky, mountains, and some dry plants. The photo has a vintage film texture. A medium shot of a seated portrait.\n''' \ + '''2. An anime illustration in vibrant thick painting style of a white girl with cat ears holding a folder, showing a slightly dissatisfied expression. She has long dark purple hair and red eyes, wearing a dark gray skirt and a light gray top with a white waist tie and a name tag in bold Chinese characters that says "紫阳" (Ziyang). The background has a light yellow indoor tone, with faint outlines of some furniture visible. A pink halo hovers above her head, in a smooth Japanese cel-shading style. A close-up shot from a slightly elevated perspective.\n''' \ + '''3. CG game concept digital art featuring a huge crocodile with its mouth wide open, with trees and thorns growing on its back. The crocodile's skin is rough and grayish-white, resembling stone or wood texture. Its back is lush with trees, shrubs, and thorny protrusions. With its mouth agape, the crocodile reveals a pink tongue and sharp teeth. The background features a dusk sky with some distant trees, giving the overall scene a dark and cold atmosphere. A close-up from a low angle.\n''' \ + '''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. There’s a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \ + '''Directly output the rewritten English text.''' + +SYSTEM_PROMPT_TYPES = { + int(b'000', 2): LM_EN_SYS_PROMPT, + int(b'001', 2): LM_ZH_SYS_PROMPT, + int(b'010', 2): VL_EN_SYS_PROMPT, + int(b'011', 2): VL_ZH_SYS_PROMPT, + int(b'110', 2): VL_EN_SYS_PROMPT_FOR_MULTI_IMAGES, + int(b'111', 2): VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES +} + + @dataclass class PromptOutput(object): status: bool @@ -128,12 +181,11 @@ class PromptExpander: def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs): pass - def decide_system_prompt(self, tar_lang="zh"): + def decide_system_prompt(self, tar_lang="zh", multi_images_input=False): zh = tar_lang == "zh" - if zh: - return LM_ZH_SYS_PROMPT if not self.is_vl else VL_ZH_SYS_PROMPT - else: - return LM_EN_SYS_PROMPT if not self.is_vl else VL_EN_SYS_PROMPT + self.is_vl |= multi_images_input + task_type = zh + (self.is_vl << 1) + (multi_images_input << 2) + return SYSTEM_PROMPT_TYPES[task_type] def __call__(self, prompt, @@ -144,7 +196,10 @@ class PromptExpander: *args, **kwargs): if system_prompt is None: - system_prompt = self.decide_system_prompt(tar_lang=tar_lang) + system_prompt = self.decide_system_prompt( + tar_lang=tar_lang, + multi_images_input=isinstance(image, (list, tuple)) and len(image) > 1 + ) if seed < 0: seed = random.randint(0, sys.maxsize) if image is not None and self.is_vl: @@ -234,38 +289,42 @@ class DashScopePromptExpander(PromptExpander): def extend_with_img(self, prompt, system_prompt, - image: Union[Image.Image, str] = None, + image: Union[List[Image.Image], List[str], Image.Image, str] = None, seed=-1, *args, **kwargs): - if isinstance(image, str): - image = Image.open(image).convert('RGB') - w = image.width - h = image.height - area = min(w * h, self.max_image_size) - aspect_ratio = h / w - resized_h = round(math.sqrt(area * aspect_ratio)) - resized_w = round(math.sqrt(area / aspect_ratio)) - image = image.resize((resized_w, resized_h)) - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: - image.save(f.name) - fname = f.name - image_path = f"file://{f.name}" + + def ensure_image(_image): + if isinstance(_image, str): + _image = Image.open(_image).convert('RGB') + w = _image.width + h = _image.height + area = min(w * h, self.max_image_size) + aspect_ratio = h / w + resized_h = round(math.sqrt(area * aspect_ratio)) + resized_w = round(math.sqrt(area / aspect_ratio)) + _image = _image.resize((resized_w, resized_h)) + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + _image.save(f.name) + image_path = f"file://{f.name}" + return image_path + if not isinstance(image, (list, tuple)): + image = [image] + image_path_list = [ensure_image(_image) for _image in image] + role_content = [ + {"text": prompt}, + *[{"image": image_path} for image_path in image_path_list] + ] + system_content = [{"text": system_prompt}] prompt = f"{prompt}" messages = [ { 'role': 'system', - 'content': [{ - "text": system_prompt - }] + 'content': system_content }, { 'role': 'user', - 'content': [{ - "text": prompt - }, { - "image": image_path - }] + 'content': role_content }, ] response = None @@ -288,7 +347,8 @@ class DashScopePromptExpander(PromptExpander): except Exception as e: exception = e result_prompt = result_prompt.replace('\n', '\\n') - os.remove(fname) + for image_path in image_path_list: + os.remove(image_path.removeprefix('file://')) return PromptOutput( status=status, @@ -399,30 +459,36 @@ class QwenPromptExpander(PromptExpander): def extend_with_img(self, prompt, system_prompt, - image: Union[Image.Image, str] = None, + image: Union[List[Image.Image], List[str], Image.Image, str] = None, seed=-1, *args, **kwargs): self.model = self.model.to(self.device) + + if not isinstance(image, (list, tuple)): + image = [image] + + system_content = [{ + "type": "text", + "text": system_prompt + }] + role_content = [ + { + "type": "text", + "text": prompt + }, + *[ + {"image": image_path} for image_path in image + ] + ] + messages = [{ 'role': 'system', - 'content': [{ - "type": "text", - "text": system_prompt - }] + 'content': system_content, }, { "role": "user", - "content": [ - { - "type": "image", - "image": image, - }, - { - "type": "text", - "text": prompt - }, - ], + "content": role_content, }] # Preparation for inference @@ -502,7 +568,8 @@ if __name__ == "__main__": # test case for prompt-image extend ds_model_name = "qwen-vl-max" #qwen_model_name = "./models/Qwen2.5-VL-3B-Instruct/" #VRAM: 9686MiB - qwen_model_name = "./models/Qwen2.5-VL-7B-Instruct-AWQ/" # VRAM: 8492 + # qwen_model_name = "./models/Qwen2.5-VL-7B-Instruct-AWQ/" # VRAM: 8492 + qwen_model_name = "./models/Qwen2.5-VL-7B-Instruct/" image = "./examples/i2v_input.JPG" # test dashscope api why image_path is local directory; skip @@ -543,3 +610,26 @@ if __name__ == "__main__": en_prompt, tar_lang="en", image=image, seed=seed) print("VL qwen vl en result -> en", qwen_result.prompt) # , qwen_result.system_prompt) + # test multi images + image = ["./examples/flf2v_input_first_frame.png", "./examples/flf2v_input_last_frame.png"] + prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。" + en_prompt = ("Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic " + "aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts " + "resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced " + "architectural structures, combining to create a tranquil and breathtaking coastal landscape.") + + dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True) + dashscope_result = dashscope_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed) + print("VL dashscope result -> zh", dashscope_result.prompt) + + dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True) + dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="zh", image=image, seed=seed) + print("VL dashscope en result -> zh", dashscope_result.prompt) + + qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0) + qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed) + print("VL qwen result -> zh", qwen_result.prompt) + + qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0) + qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed) + print("VL qwen en result -> zh", qwen_result.prompt)