add [first-last frame to video] feature

This commit is contained in:
澎鹏 2025-04-17 10:42:10 +08:00
parent 679ccc6c68
commit ae4c0c9aa5
12 changed files with 961 additions and 72 deletions

View File

@ -27,6 +27,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights!
* Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback!
* Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try!
* Feb 27, 2025: 👋 **Wan2.1** has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
@ -54,6 +55,13 @@ If your work has improved **Wan2.1** and you would like more people to see it, p
- [x] ComfyUI integration
- [x] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
- Wan2.1 First-Last-Frame-to-Video
- [x] Multi-GPU Inference code of the 14B model
- [x] Checkpoints of the 14B model
- [x] Gradio demo
- [ ] ComfyUI integration
- [ ] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
## Quickstart
@ -74,14 +82,15 @@ pip install -r requirements.txt
#### Model Download
| Models | Download Link | Notes |
| --------------|-------------------------------------------------------------------------------|-------------------------------|
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
| Models | Download Link | Notes |
|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B) | Supports 720P
> 💡Note: The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
> 💡Note: The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution. For the first-last frame to video generation, we train our model primarily on Chinese text-video pairs. Therefore, we recommend using Chinese prompt to achieve better results.
Download models using huggingface-cli:
@ -185,7 +194,7 @@ DASH_API_KEY=your_key python generate.py --task t2v-14B --size 1280*720 --ckpt_
- By default, the Qwen model on HuggingFace is used for this extension. Users can choose Qwen models or other models based on the available GPU memory size.
- For text-to-video tasks, you can use models like `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-7B-Instruct` and `Qwen/Qwen2.5-3B-Instruct`.
- For image-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`.
- For image-to-video or first-last-frame-to-video tasks, you can use models like `Qwen/Qwen2.5-VL-7B-Instruct` and `Qwen/Qwen2.5-VL-3B-Instruct`.
- Larger models generally provide better extension results but require more GPU memory.
- You can modify the model used for extension with the parameter `--prompt_extend_model` , allowing you to specify either a local model path or a Hugging Face model. For example:
@ -367,6 +376,74 @@ DASH_API_KEY=your_key python i2v_14B_singleGPU.py --prompt_extend_method 'dashsc
```
#### Run First-Last-Frame-to-Video Generation
First-Last-Frame-to-Video is also divided into processes with and without the prompt extension step. Currently, only 720P is supported. The specific parameters and corresponding settings are as follows:
<table>
<thead>
<tr>
<th rowspan="2">Task</th>
<th colspan="2">Resolution</th>
<th rowspan="2">Model</th>
</tr>
<tr>
<th>480P</th>
<th>720P</th>
</tr>
</thead>
<tbody>
<tr>
<td>flf2v-14B</td>
<td style="color: green;"></td>
<td style="color: green;">✔️</td>
<td>Wan2.1-FLF2V-14B-720P</td>
</tr>
</tbody>
</table>
##### (1) Without Prompt Extension
- Single-GPU inference
```sh
python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The birds feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
```
> 💡Similar to Image-to-Video, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
- Multi-GPU inference using FSDP + xDiT USP
```sh
pip install "xfuser>=0.4.1"
torchrun --nproc_per_node=8 generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --dit_fsdp --t5_fsdp --ulysses_size 8 --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The birds feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
```
##### (2) Using Prompt Extension
The process of prompt extension can be referenced [here](#2-using-prompt-extention).
Run with local prompt extension using `Qwen/Qwen2.5-VL-7B-Instruct`:
```
python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --use_prompt_extend --prompt_extend_model Qwen/Qwen2.5-VL-7B-Instruct --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The birds feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
```
Run with remote prompt extension using `dashscope`:
```
DASH_API_KEY=your_key python generate.py --task flf2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-FLF2V-14B-720P --first_frame examples/flf2v_input_first_frame.png --last_frame examples/flf2v_input_last_frame.png --use_prompt_extend --prompt_extend_method 'dashscope' --prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The birds feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
```
##### (3) Running local gradio
```sh
cd gradio
# use 720P model in gradio
DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dashscope' --ckpt_dir_720p ./Wan2.1-FLF2V-14B-720P
```
#### Run Text-to-Image Generation
Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

View File

@ -33,6 +33,14 @@ EXAMPLE_PROMPT = {
"image":
"examples/i2v_input.JPG",
},
"flf2v-14B": {
"prompt":
"CG动画风格一只蓝色的小鸟从地面起飞煽动翅膀。小鸟羽毛细腻胸前有独特的花纹背景是蓝天白云阳光明媚。镜跟随小鸟向上移动展现出小鸟飞翔的姿态和天空的广阔。近景仰视视角。",
"first_frame":
"examples/flf2v_input_first_frame.png",
"last_frame":
"examples/flf2v_input_last_frame.png",
},
}
@ -50,6 +58,8 @@ def _validate_args(args):
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
if "flf2v" in args.task:
args.sample_shift = 16
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
@ -167,7 +177,17 @@ def _parse_args():
"--image",
type=str,
default=None,
help="The image to generate the video from.")
help="[image to video] The image to generate the video from.")
parser.add_argument(
"--first_frame",
type=str,
default=None,
help="[first-last frame to video] The image (first frame) to generate the video from.")
parser.add_argument(
"--last_frame",
type=str,
default=None,
help="[first-last frame to video] The image (last frame) to generate the video from.")
parser.add_argument(
"--sample_solver",
type=str,
@ -248,7 +268,7 @@ def generate(args):
if args.use_prompt_extend:
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl="i2v" in args.task)
model_name=args.prompt_extend_model, is_vl="i2v" in args.task or "flf2v" in args.task)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model,
@ -321,7 +341,7 @@ def generate(args):
seed=args.base_seed,
offload_model=args.offload_model)
else:
elif "i2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.image is None:
@ -377,6 +397,66 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.first_frame is None or args.last_frame is None:
args.first_frame = EXAMPLE_PROMPT[args.task]["first_frame"]
args.last_frame = EXAMPLE_PROMPT[args.task]["last_frame"]
logging.info(f"Input prompt: {args.prompt}")
logging.info(f"Input first frame: {args.first_frame}")
logging.info(f"Input last frame: {args.last_frame}")
first_frame = Image.open(args.first_frame).convert("RGB")
last_frame = Image.open(args.last_frame).convert("RGB")
if args.use_prompt_extend:
logging.info("Extending prompt ...")
if rank == 0:
prompt_output = prompt_expander(
args.prompt,
tar_lang=args.prompt_extend_target_lang,
image=[first_frame, last_frame],
seed=args.base_seed)
if prompt_output.status == False:
logging.info(
f"Extending prompt failed: {prompt_output.message}")
logging.info("Falling back to original prompt.")
input_prompt = args.prompt
else:
input_prompt = prompt_output.prompt
input_prompt = [input_prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating WanFLF2V pipeline.")
wan_flf2v = wan.WanFLF2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
logging.info("Generating video ...")
video = wan_flf2v.generate(
args.prompt,
first_frame,
last_frame,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model
)
if rank == 0:
if args.save_file is None:

View File

@ -0,0 +1,252 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import gc
import os.path as osp
import os
import sys
import warnings
import gradio as gr
warnings.filterwarnings('ignore')
# Model
sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_video
# Global Var
prompt_expander = None
wan_flf2v_720P = None
# Button Func
def load_model(value):
global wan_flf2v_720P
if value == '------':
print("No model loaded")
return '------'
if value == '720P':
if args.ckpt_dir_720p is None:
print("Please specify the checkpoint directory for 720P model")
return '------'
if wan_flf2v_720P is not None:
pass
else:
gc.collect()
print("load 14B-720P flf2v model...", end='', flush=True)
cfg = WAN_CONFIGS['flf2v-14B']
wan_flf2v_720P = wan.WanFLF2V(
config=cfg,
checkpoint_dir=args.ckpt_dir_720p,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
print("done", flush=True)
return '720P'
return value
def prompt_enc(prompt, img_first, img_last, tar_lang):
print('prompt extend...')
if img_first is None or img_last is None:
print('Please upload the first and last frames')
return prompt
global prompt_expander
prompt_output = prompt_expander(
prompt, image=[img_first, img_last], tar_lang=tar_lang.lower())
if prompt_output.status == False:
return prompt
else:
return prompt_output.prompt
def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt):
if resolution == '------':
print(
'Please specify the resolution ckpt dir or specify the resolution'
)
return None
else:
if resolution == '720P':
global wan_flf2v_720P
video = wan_flf2v_720P.generate(
flf2vid_prompt,
flf2vid_image_first,
flf2vid_image_last,
max_area=MAX_AREA_CONFIGS['720*1280'],
shift=shift_scale,
sampling_steps=sd_steps,
guide_scale=guide_scale,
n_prompt=n_prompt,
seed=seed,
offload_model=True)
pass
else:
print(
'Sorry, currently only 720P is supported.'
)
return None
cache_video(
tensor=video[None],
save_file="example.mp4",
fps=16,
nrow=1,
normalize=True,
value_range=(-1, 1))
return "example.mp4"
# Interface
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
Wan2.1 (FLF2V-14B)
</div>
<div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
Wan: Open and Advanced Large-Scale Video Generative Models.
</div>
""")
with gr.Row():
with gr.Column():
resolution = gr.Dropdown(
label='Resolution',
choices=['------', '720P'],
value='------')
flf2vid_image_first = gr.Image(
type="pil",
label="Upload First Frame",
elem_id="image_upload",
)
flf2vid_image_last = gr.Image(
type="pil",
label="Upload Last Frame",
elem_id="image_upload",
)
flf2vid_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the video you want to generate",
)
tar_lang = gr.Radio(
choices=["ZH", "EN"],
label="Target language of prompt enhance",
value="ZH")
run_p_button = gr.Button(value="Prompt Enhance")
with gr.Accordion("Advanced Options", open=True):
with gr.Row():
sd_steps = gr.Slider(
label="Diffusion steps",
minimum=1,
maximum=1000,
value=50,
step=1)
guide_scale = gr.Slider(
label="Guide scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
with gr.Row():
shift_scale = gr.Slider(
label="Shift scale",
minimum=0,
maximum=20,
value=5.0,
step=1)
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
value=-1)
n_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Describe the negative prompt you want to add"
)
run_flf2v_button = gr.Button("Generate Video")
with gr.Column():
result_gallery = gr.Video(
label='Generated Video', interactive=False, height=600)
resolution.input(
fn=load_model, inputs=[resolution], outputs=[resolution])
run_p_button.click(
fn=prompt_enc,
inputs=[flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, tar_lang],
outputs=[flf2vid_prompt])
run_flf2v_button.click(
fn=flf2v_generation,
inputs=[
flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, resolution, sd_steps,
guide_scale, shift_scale, seed, n_prompt
],
outputs=[result_gallery],
)
return demo
# Main
def _parse_args():
parser = argparse.ArgumentParser(
description="Generate a video from a text prompt or image using Gradio")
parser.add_argument(
"--ckpt_dir_720p",
type=str,
default=None,
help="The path to the checkpoint directory.")
parser.add_argument(
"--prompt_extend_method",
type=str,
default="local_qwen",
choices=["dashscope", "local_qwen"],
help="The prompt extend method to use.")
parser.add_argument(
"--prompt_extend_model",
type=str,
default=None,
help="The prompt extend model to use.")
args = parser.parse_args()
assert args.ckpt_dir_720p is not None, "Please specify the checkpoint directory."
return args
if __name__ == '__main__':
args = _parse_args()
print("Step1: Init prompt_expander...", end='', flush=True)
if args.prompt_extend_method == "dashscope":
prompt_expander = DashScopePromptExpander(
model_name=args.prompt_extend_model, is_vl=True)
elif args.prompt_extend_method == "local_qwen":
prompt_expander = QwenPromptExpander(
model_name=args.prompt_extend_model, is_vl=True, device=0)
else:
raise NotImplementedError(
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
print("done", flush=True)
demo = gradio_interface()
demo.launch(server_name="0.0.0.0", share=False, server_port=7860)

View File

@ -1,3 +1,4 @@
from . import configs, distributed, modules
from .image2video import WanI2V
from .text2video import WanT2V
from .first_last_frame2video import WanFLF2V

View File

@ -12,11 +12,17 @@ from .wan_t2v_14B import t2v_14B
t2i_14B = copy.deepcopy(t2v_14B)
t2i_14B.__name__ = 'Config: Wan T2I 14B'
# the config of flf2v_14B is the same as i2v_14B
flf2v_14B = copy.deepcopy(i2v_14B)
flf2v_14B.__name__ = 'Config: Wan FLF2V 14B'
flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt
WAN_CONFIGS = {
't2v-14B': t2v_14B,
't2v-1.3B': t2v_1_3B,
'i2v-14B': i2v_14B,
't2i-14B': t2i_14B,
'flf2v-14B': flf2v_14B
}
SIZE_CONFIGS = {
@ -38,5 +44,6 @@ SUPPORTED_SIZES = {
't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2v-1.3B': ('480*832', '832*480'),
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2i-14B': tuple(SIZE_CONFIGS.keys()),
}

View File

@ -8,6 +8,7 @@ from .shared_config import wan_shared_cfg
i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
i2v_14B.update(wan_shared_cfg)
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
i2v_14B.t5_tokenizer = 'google/umt5-xxl'

View File

@ -0,0 +1,370 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class WanFLF2V:
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
):
r"""
Initializes the image-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_usp (`bool`, *optional*, defaults to False):
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.use_usp = use_usp
self.t5_cpu = t5_cpu
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
self.clip = CLIPModel(
dtype=config.clip_dtype,
device=self.device,
checkpoint_path=os.path.join(checkpoint_dir,
config.clip_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
self.model = shard_fn(self.model)
else:
if not init_on_cpu:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
def generate(self,
input_prompt,
first_frame,
last_frame,
max_area=720 * 1280,
frame_num=81,
shift=16,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.5,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from input first-last frame and text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation.
first_frame (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
last_frame (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
[NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized
to match first_frame.
max_area (`int`, *optional*, defaults to 720*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
first_frame_size = first_frame.size
last_frame_size = last_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(self.device)
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(self.device)
F = frame_num
first_frame_h, first_frame_w = first_frame.shape[1:]
aspect_ratio = first_frame_h / first_frame_w
lat_h = round(
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
first_frame_h = lat_h * self.vae_stride[1]
first_frame_w = lat_w * self.vae_stride[2]
if first_frame_size != last_frame_size:
# 1. resize
last_frame_resize_ratio = max(
first_frame_size[0] / last_frame_size[0],
first_frame_size[1] / last_frame_size[1]
)
last_frame_size = [
round(last_frame_size[0] * last_frame_resize_ratio),
round(last_frame_size[1] * last_frame_resize_ratio),
]
# 2. center crop
last_frame = TF.center_crop(last_frame, last_frame_size)
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
16,
(F - 1) // 4 + 1,
lat_h,
lat_w,
dtype=torch.float32,
generator=seed_g,
device=self.device)
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
msk[:, 1: -1] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
self.clip.model.to(self.device)
clip_context = self.clip.visual([first_frame[:, None, :, :], last_frame[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
y = self.vae.encode([
torch.concat([
torch.nn.functional.interpolate(
first_frame[None].cpu(),
size=(first_frame_h, first_frame_w),
mode='bicubic'
).transpose(0, 1),
torch.zeros(3, F - 2, first_frame_h, first_frame_w),
torch.nn.functional.interpolate(
last_frame[None].cpu(),
size=(first_frame_h, first_frame_w),
mode='bicubic'
).transpose(0, 1),
], dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latent = noise
arg_c = {
'context': [context[0]],
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
}
arg_null = {
'context': context_null,
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
}
if offload_model:
torch.cuda.empty_cache()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
latent = latent.to(
torch.device('cpu') if offload_model else self.device)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latent.unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latent
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None

View File

@ -197,7 +197,7 @@ class WanI2V:
seed_g.manual_seed(seed)
noise = torch.randn(
16,
21,
(F - 1) // 4 + 1,
lat_h,
lat_w,
dtype=torch.float32,
@ -239,7 +239,7 @@ class WanI2V:
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, 80, h, w)
torch.zeros(3, F - 1, h, w)
],
dim=1).to(self.device)
])[0]

View File

@ -11,6 +11,9 @@ from .attention import flash_attention
__all__ = ['WanModel']
T5_CONTEXT_TOKEN_NUMBER = 512
FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2
def sinusoidal_embedding_1d(dim, position):
# preprocess
@ -203,8 +206,9 @@ class WanI2VCrossAttention(WanSelfAttention):
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
context_img = context[:, :257]
context = context[:, 257:]
image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER
context_img = context[:, :image_context_length]
context = context[:, image_context_length:]
b, n, d = x.size(0), self.num_heads, self.head_dim
# compute query, key, value
@ -269,7 +273,7 @@ class WanAttentionBlock(nn.Module):
nn.Linear(ffn_dim, dim))
# modulation
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
def forward(
self,
@ -328,7 +332,7 @@ class Head(nn.Module):
self.head = nn.Linear(dim, out_dim)
# modulation
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5)
def forward(self, x, e):
r"""
@ -345,15 +349,21 @@ class Head(nn.Module):
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim):
def __init__(self, in_dim, out_dim, flf_pos_emb=False):
super().__init__()
self.proj = torch.nn.Sequential(
torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
torch.nn.LayerNorm(out_dim))
if flf_pos_emb: # NOTE: we only use this for `flf2v`
self.emb_pos = nn.Parameter(torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
def forward(self, image_embeds):
if hasattr(self, 'emb_pos'):
bs, n, d = image_embeds.shape
image_embeds = image_embeds.view(-1, 2 * n, d)
image_embeds = image_embeds + self.emb_pos
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
@ -390,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin):
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
@ -423,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin):
super().__init__()
assert model_type in ['t2v', 'i2v']
assert model_type in ['t2v', 'i2v', 'flf2v']
self.model_type = model_type
self.patch_size = patch_size
@ -473,8 +483,8 @@ class WanModel(ModelMixin, ConfigMixin):
],
dim=1)
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim)
if model_type == 'i2v' or model_type == 'flf2v':
self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
# initialize weights
self.init_weights()
@ -501,7 +511,7 @@ class WanModel(ModelMixin, ConfigMixin):
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
CLIP image features for image-to-video mode or first-last-frame-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
@ -509,7 +519,7 @@ class WanModel(ModelMixin, ConfigMixin):
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
if self.model_type == 'i2v':
if self.model_type == 'i2v' or self.model_type == 'flf2v':
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
@ -548,7 +558,7 @@ class WanModel(ModelMixin, ConfigMixin):
]))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context_clip = self.img_emb(clip_fea) # bs x 257 (x2) x dim
context = torch.concat([context_clip, context], dim=1)
# arguments

View File

@ -7,7 +7,7 @@ import sys
import tempfile
from dataclasses import dataclass
from http import HTTPStatus
from typing import Optional, Union
from typing import Optional, Union, List
import dashscope
import torch
@ -97,6 +97,59 @@ VL_EN_SYS_PROMPT = \
'''Directly output the rewritten English text.'''
VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES = """你是一位Prompt优化师旨在参考用户输入的图像的细节内容把用户输入的Prompt改写为优质Prompt使其更完整、更具表现力同时不改变原意。你需要综合用户输入的照片内容和输入的Prompt进行改写严格参考示例的格式进行改写
任务要求
1. 用户会输入两张图片第一张是视频的第一帧第二张时视频的最后一帧你需要综合两个照片的内容进行优化改写
2. 对于过于简短的用户输入在不改变原意前提下合理推断并补充细节使得画面更加完整好看
3. 完善用户描述中出现的主体特征如外貌表情数量种族姿态等画面风格空间关系镜头景别
4. 整体中文输出保留引号书名号中原文以及重要的输入信息不要改写
5. Prompt应匹配符合用户意图且精准细分的风格描述如果用户未指定则根据用户提供的照片的风格你需要仔细分析照片的风格并参考风格进行改写
6. 如果Prompt是古诗词应该在生成的Prompt中强调中国古典元素避免出现西方现代外国场景
7. 你需要强调输入中的运动信息和不同的镜头运镜
8. 你的输出应当带有自然运动属性需要根据描述主体目标类别增加这个目标的自然动作描述尽可能用简单直接的动词
9. 你需要尽可能的参考图片的细节信息如人物动作服装背景等强调照片的细节元素
10. 你需要强调两画面可能出现的潜在变化走进出现变身成镜头左移镜头右移动镜头上移动 镜头下移等等
11. 无论用户输入那种语言你都需要输出中文
12. 改写后的prompt字数控制在80-100字左右
改写后 prompt 示例
1. 日系小清新胶片写真扎着双麻花辫的年轻东亚女孩坐在船边女孩穿着白色方领泡泡袖连衣裙裙子上有褶皱和纽扣装饰她皮肤白皙五官清秀眼神略带忧郁直视镜头女孩的头发自然垂落刘海遮住部分额头她双手扶船姿态自然放松背景是模糊的户外场景隐约可见蓝天山峦和一些干枯植物复古胶片质感照片中景半身坐姿人像
2. 二次元厚涂动漫插画一个猫耳兽耳白人少女手持文件夹神情略带不满她深紫色长发红色眼睛身穿深灰色短裙和浅灰色上衣腰间系着白色系带胸前佩戴名牌上面写着黑体中文"紫阳"淡黄色调室内背景隐约可见一些家具轮廓少女头顶有一个粉色光圈线条流畅的日系赛璐璐风格近景半身略俯视视角
3. CG游戏概念数字艺术一只巨大的鳄鱼张开大嘴背上长着树木和荆棘鳄鱼皮肤粗糙呈灰白色像是石头或木头的质感它背上生长着茂盛的树木灌木和一些荆棘状的突起鳄鱼嘴巴大张露出粉红色的舌头和锋利的牙齿画面背景是黄昏的天空远处有一些树木场景整体暗黑阴冷近景仰视视角
4. 美剧宣传海报风格身穿黄色防护服的Walter White坐在金属折叠椅上上方无衬线英文写着"Breaking Bad"周围是成堆的美元和蓝色塑料储物箱他戴着眼镜目光直视前方身穿黄色连体防护服双手放在膝盖上神态稳重自信背景是一个废弃的阴暗厂房窗户透着光线带有明显颗粒质感纹理中景镜头下移
请直接输出改写后的文本不要进行多余的回复"""
VL_EN_SYS_PROMPT_FOR_MULTI_IMAGES = \
'''You are a prompt optimization specialist whose goal is to rewrite the user's input prompts into high-quality English prompts by referring to the details of the user's input images, making them more complete and expressive while maintaining the original meaning. You need to integrate the content of the user's photo with the input prompt for the rewrite, strictly adhering to the formatting of the examples provided.\n''' \
'''Task Requirements:\n''' \
'''1. The user will input two images, the first is the first frame of the video, and the second is the last frame of the video. You need to integrate the content of the two photos with the input prompt for the rewrite.\n''' \
'''2. For overly brief user inputs, reasonably infer and supplement details without changing the original meaning, making the image more complete and visually appealing;\n''' \
'''3. Improve the characteristics of the main subject in the user's description (such as appearance, expression, quantity, ethnicity, posture, etc.), rendering style, spatial relationships, and camera angles;\n''' \
'''4. The overall output should be in Chinese, retaining original text in quotes and book titles as well as important input information without rewriting them;\n''' \
'''5. The prompt should match the users intent and provide a precise and detailed style description. If the user has not specified a style, you need to carefully analyze the style of the user's provided photo and use that as a reference for rewriting;\n''' \
'''6. If the prompt is an ancient poem, classical Chinese elements should be emphasized in the generated prompt, avoiding references to Western, modern, or foreign scenes;\n''' \
'''7. You need to emphasize movement information in the input and different camera angles;\n''' \
'''8. Your output should convey natural movement attributes, incorporating natural actions related to the described subject category, using simple and direct verbs as much as possible;\n''' \
'''9. You should reference the detailed information in the image, such as character actions, clothing, backgrounds, and emphasize the details in the photo;\n''' \
'''10. You need to emphasize potential changes that may occur between the two frames, such as "walking into", "appearing", "turning into", "camera left", "camera right", "camera up", "camera down", etc.;\n''' \
'''11. Control the rewritten prompt to around 80-100 words.\n''' \
'''12. No matter what language the user inputs, you must always output in English.\n''' \
'''Example of the rewritten English prompt:\n''' \
'''1. A Japanese fresh film-style photo of a young East Asian girl with double braids sitting by the boat. The girl wears a white square collar puff sleeve dress, decorated with pleats and buttons. She has fair skin, delicate features, and slightly melancholic eyes, staring directly at the camera. Her hair falls naturally, with bangs covering part of her forehead. She rests her hands on the boat, appearing natural and relaxed. The background features a blurred outdoor scene, with hints of blue sky, mountains, and some dry plants. The photo has a vintage film texture. A medium shot of a seated portrait.\n''' \
'''2. An anime illustration in vibrant thick painting style of a white girl with cat ears holding a folder, showing a slightly dissatisfied expression. She has long dark purple hair and red eyes, wearing a dark gray skirt and a light gray top with a white waist tie and a name tag in bold Chinese characters that says "紫阳" (Ziyang). The background has a light yellow indoor tone, with faint outlines of some furniture visible. A pink halo hovers above her head, in a smooth Japanese cel-shading style. A close-up shot from a slightly elevated perspective.\n''' \
'''3. CG game concept digital art featuring a huge crocodile with its mouth wide open, with trees and thorns growing on its back. The crocodile's skin is rough and grayish-white, resembling stone or wood texture. Its back is lush with trees, shrubs, and thorny protrusions. With its mouth agape, the crocodile reveals a pink tongue and sharp teeth. The background features a dusk sky with some distant trees, giving the overall scene a dark and cold atmosphere. A close-up from a low angle.\n''' \
'''4. In the style of an American drama promotional poster, Walter White sits in a metal folding chair wearing a yellow protective suit, with the words "Breaking Bad" written in sans-serif English above him, surrounded by piles of dollar bills and blue plastic storage boxes. He wears glasses, staring forward, dressed in a yellow jumpsuit, with his hands resting on his knees, exuding a calm and confident demeanor. The background shows an abandoned, dim factory with light filtering through the windows. Theres a noticeable grainy texture. A medium shot with a straight-on close-up of the character.\n''' \
'''Directly output the rewritten English text.'''
SYSTEM_PROMPT_TYPES = {
int(b'000', 2): LM_EN_SYS_PROMPT,
int(b'001', 2): LM_ZH_SYS_PROMPT,
int(b'010', 2): VL_EN_SYS_PROMPT,
int(b'011', 2): VL_ZH_SYS_PROMPT,
int(b'110', 2): VL_EN_SYS_PROMPT_FOR_MULTI_IMAGES,
int(b'111', 2): VL_ZH_SYS_PROMPT_FOR_MULTI_IMAGES
}
@dataclass
class PromptOutput(object):
status: bool
@ -128,12 +181,12 @@ class PromptExpander:
def extend(self, prompt, system_prompt, seed=-1, *args, **kwargs):
pass
def decide_system_prompt(self, tar_lang="zh"):
def decide_system_prompt(self, tar_lang="zh", multi_images_input=False):
zh = tar_lang == "zh"
if zh:
return LM_ZH_SYS_PROMPT if not self.is_vl else VL_ZH_SYS_PROMPT
else:
return LM_EN_SYS_PROMPT if not self.is_vl else VL_EN_SYS_PROMPT
self.is_vl |= multi_images_input
task_type = zh + (self.is_vl << 1) + (multi_images_input << 2)
print(task_type)
return SYSTEM_PROMPT_TYPES[task_type]
def __call__(self,
prompt,
@ -144,7 +197,10 @@ class PromptExpander:
*args,
**kwargs):
if system_prompt is None:
system_prompt = self.decide_system_prompt(tar_lang=tar_lang)
system_prompt = self.decide_system_prompt(
tar_lang=tar_lang,
multi_images_input=isinstance(image, (list, tuple)) and len(image) > 1
)
if seed < 0:
seed = random.randint(0, sys.maxsize)
if image is not None and self.is_vl:
@ -234,38 +290,42 @@ class DashScopePromptExpander(PromptExpander):
def extend_with_img(self,
prompt,
system_prompt,
image: Union[Image.Image, str] = None,
image: Union[List[Image.Image], List[str], Image.Image, str] = None,
seed=-1,
*args,
**kwargs):
if isinstance(image, str):
image = Image.open(image).convert('RGB')
w = image.width
h = image.height
area = min(w * h, self.max_image_size)
aspect_ratio = h / w
resized_h = round(math.sqrt(area * aspect_ratio))
resized_w = round(math.sqrt(area / aspect_ratio))
image = image.resize((resized_w, resized_h))
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
image.save(f.name)
fname = f.name
image_path = f"file://{f.name}"
def ensure_image(_image):
if isinstance(_image, str):
_image = Image.open(_image).convert('RGB')
w = _image.width
h = _image.height
area = min(w * h, self.max_image_size)
aspect_ratio = h / w
resized_h = round(math.sqrt(area * aspect_ratio))
resized_w = round(math.sqrt(area / aspect_ratio))
_image = _image.resize((resized_w, resized_h))
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
_image.save(f.name)
image_path = f"file://{f.name}"
return image_path
if not isinstance(image, (list, tuple)):
image = [image]
image_path_list = [ensure_image(_image) for _image in image]
role_content = [
{"text": prompt},
*[{"image": image_path} for image_path in image_path_list]
]
system_content = [{"text": system_prompt}]
prompt = f"{prompt}"
messages = [
{
'role': 'system',
'content': [{
"text": system_prompt
}]
'content': system_content
},
{
'role': 'user',
'content': [{
"text": prompt
}, {
"image": image_path
}]
'content': role_content
},
]
response = None
@ -288,7 +348,8 @@ class DashScopePromptExpander(PromptExpander):
except Exception as e:
exception = e
result_prompt = result_prompt.replace('\n', '\\n')
os.remove(fname)
for image_path in image_path_list:
os.remove(image_path.removeprefix('file://'))
return PromptOutput(
status=status,
@ -399,30 +460,36 @@ class QwenPromptExpander(PromptExpander):
def extend_with_img(self,
prompt,
system_prompt,
image: Union[Image.Image, str] = None,
image: Union[List[Image.Image], List[str], Image.Image, str] = None,
seed=-1,
*args,
**kwargs):
self.model = self.model.to(self.device)
if not isinstance(image, (list, tuple)):
image = [image]
system_content = [{
"type": "text",
"text": system_prompt
}]
role_content = [
{
"type": "text",
"text": prompt
},
*[
{"image": image_path} for image_path in image
]
]
messages = [{
'role': 'system',
'content': [{
"type": "text",
"text": system_prompt
}]
'content': system_content,
}, {
"role":
"user",
"content": [
{
"type": "image",
"image": image,
},
{
"type": "text",
"text": prompt
},
],
"content": role_content,
}]
# Preparation for inference
@ -459,7 +526,7 @@ class QwenPromptExpander(PromptExpander):
if __name__ == "__main__":
import os; os.environ['DASH_API_KEY'] = "sk-6c78f7206ff846a1b4851987cb507ec3"
seed = 100
prompt = "夏日海滩度假风格,一只戴着墨镜的白色猫咪坐在冲浪板上。猫咪毛发蓬松,表情悠闲,直视镜头。背景是模糊的海滩景色,海水清澈,远处有绿色的山丘和蓝天白云。猫咪的姿态自然放松,仿佛在享受海风和阳光。近景特写,强调猫咪的细节和海滩的清新氛围。"
en_prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
@ -502,7 +569,8 @@ if __name__ == "__main__":
# test case for prompt-image extend
ds_model_name = "qwen-vl-max"
#qwen_model_name = "./models/Qwen2.5-VL-3B-Instruct/" #VRAM: 9686MiB
qwen_model_name = "./models/Qwen2.5-VL-7B-Instruct-AWQ/" # VRAM: 8492
# qwen_model_name = "./models/Qwen2.5-VL-7B-Instruct-AWQ/" # VRAM: 8492
qwen_model_name = "./models/Qwen2.5-VL-7B-Instruct/"
image = "./examples/i2v_input.JPG"
# test dashscope api why image_path is local directory; skip
@ -543,3 +611,26 @@ if __name__ == "__main__":
en_prompt, tar_lang="en", image=image, seed=seed)
print("VL qwen vl en result -> en",
qwen_result.prompt) # , qwen_result.system_prompt)
# test multi images
image = ["./examples/flf2v_input_first_frame.png", "./examples/flf2v_input_last_frame.png"]
prompt = "无人机拍摄,镜头快速推进,然后拉远至全景俯瞰,展示一个宁静美丽的海港。海港内停满了游艇,水面清澈透蓝。周围是起伏的山丘和错落有致的建筑,整体景色宁静而美丽。"
en_prompt = ("Shot from a drone perspective, the camera rapidly zooms in before pulling back to reveal a panoramic "
"aerial view of a serene and picturesque harbor. The tranquil bay is dotted with numerous yachts "
"resting on crystal-clear blue waters. Surrounding the harbor are rolling hills and well-spaced "
"architectural structures, combining to create a tranquil and breathtaking coastal landscape.")
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
print("VL dashscope result -> zh", dashscope_result.prompt)
dashscope_prompt_expander = DashScopePromptExpander(model_name=ds_model_name, is_vl=True)
dashscope_result = dashscope_prompt_expander(en_prompt, tar_lang="zh", image=image, seed=seed)
print("VL dashscope en result -> zh", dashscope_result.prompt)
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
print("VL qwen result -> zh", qwen_result.prompt)
qwen_prompt_expander = QwenPromptExpander(model_name=qwen_model_name, is_vl=True, device=0)
qwen_result = qwen_prompt_expander(prompt, tar_lang="zh", image=image, seed=seed)
print("VL qwen en result -> zh", qwen_result.prompt)