mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-02 05:17:25 +00:00
[feature] Add VACE (#389)
* Add VACE * Support training with multiple gpus * Update default args for vace task * vace block update * Add vace exmaple jpg * Fix dist vace fwd hook error * Update vace exmample * Update vace args * Update pipeline name for vace * vace gradio and Readme * Update vace snake png --------- Co-authored-by: hanzhn <han.feng.jason@gmail.com>
This commit is contained in:
parent
204f899b64
commit
18d53feb7a
92
README.md
92
README.md
@ -27,6 +27,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
||||
|
||||
## 🔥 Latest News!!
|
||||
|
||||
* May 14, 2025: 👋 We introduce **Wan2.1** [VACE](https://github.com/ali-vilab/VACE), an all-in-one model for video creation and editing, along with its [inference code](#run-vace), [weights](#model-download), and [technical report](https://arxiv.org/abs/2503.07598)!
|
||||
* Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights!
|
||||
* Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback!
|
||||
* Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try!
|
||||
@ -64,7 +65,13 @@ If your work has improved **Wan2.1** and you would like more people to see it, p
|
||||
- [ ] ComfyUI integration
|
||||
- [ ] Diffusers integration
|
||||
- [ ] Diffusers + Multi-GPU Inference
|
||||
|
||||
- Wan2.1 VACE
|
||||
- [x] Multi-GPU Inference code of the 14B and 1.3B models
|
||||
- [x] Checkpoints of the 14B and 1.3B models
|
||||
- [x] Gradio demo
|
||||
- [x] ComfyUI integration
|
||||
- [ ] Diffusers integration
|
||||
- [ ] Diffusers + Multi-GPU Inference
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -84,13 +91,15 @@ pip install -r requirements.txt
|
||||
|
||||
#### Model Download
|
||||
|
||||
| Models | Download Link | Notes |
|
||||
|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
|
||||
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
|
||||
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
|
||||
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
|
||||
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
|
||||
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
|
||||
| Models | Download Link | Notes |
|
||||
|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
|
||||
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
|
||||
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
|
||||
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
|
||||
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
|
||||
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
|
||||
| VACE-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | Supports 480P
|
||||
| VACE-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | Supports both 480P and 720P
|
||||
|
||||
> 💡Note:
|
||||
> * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
|
||||
@ -448,6 +457,73 @@ DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dash
|
||||
```
|
||||
|
||||
|
||||
#### Run VACE
|
||||
|
||||
[VACE](https://github.com/ali-vilab/VACE) now supports two models (1.3B and 14B) and two main resolutions (480P and 720P).
|
||||
The input supports any resolution, but to achieve optimal results, the video size should fall within a specific range.
|
||||
The parameters and configurations for these models are as follows:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th rowspan="2">Task</th>
|
||||
<th colspan="2">Resolution</th>
|
||||
<th rowspan="2">Model</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>480P(~81x480x832)</th>
|
||||
<th>720P(~81x720x1280)</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>VACE</td>
|
||||
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
|
||||
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
|
||||
<td>Wan2.1-VACE-14B</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>VACE</td>
|
||||
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
|
||||
<td style="color: red; text-align: center; vertical-align: middle;">❌</td>
|
||||
<td>Wan2.1-VACE-1.3B</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
In VACE, users can input text prompt and optional video, mask, and image for video generation or editing. Detailed instructions for using VACE can be found in the [User Guide](https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).
|
||||
The execution process is as follows:
|
||||
|
||||
##### (1) Preprocessing
|
||||
|
||||
User-collected materials needs to be preprocessed into VACE-recognizable inputs, including `src_video`, `src_mask`, `src_ref_images`, and `prompt`.
|
||||
For R2V (Reference-to-Video Generation), you may skip this preprocessing, but for V2V (Video-to-Video Editing) and MV2V (Masked Video-to-Video Editing) tasks, additional preprocessing is required to obtain video with conditions such as depth, pose or masked regions.
|
||||
For more details, please refer to [vace_preproccess](https://github.com/ali-vilab/VACE/blob/main/vace/vace_preproccess.py).
|
||||
|
||||
##### (2) cli inference
|
||||
|
||||
- Single-GPU inference
|
||||
```sh
|
||||
python generate.py --task vace-1.3B --size 832*480 --ckpt_dir ./Wan2.1-VACE-1.3B --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||
```
|
||||
|
||||
- Multi-GPU inference using FSDP + xDiT USP
|
||||
|
||||
```sh
|
||||
torchrun --nproc_per_node=8 generate.py --task vace-14B --size 1280*720 --ckpt_dir ./Wan2.1-VACE-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||
```
|
||||
|
||||
##### (3) Running local gradio
|
||||
- Single-GPU inference
|
||||
```sh
|
||||
python gradio/vace.py --ckpt_dir ./Wan2.1-VACE-1.3B
|
||||
```
|
||||
|
||||
- Multi-GPU inference using FSDP + xDiT USP
|
||||
```sh
|
||||
python gradio/vace.py --mp --ulysses_size 8 --ckpt_dir ./Wan2.1-VACE-14B/
|
||||
```
|
||||
|
||||
#### Run Text-to-Image Generation
|
||||
|
||||
Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:
|
||||
|
BIN
examples/girl.png
Normal file
BIN
examples/girl.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 817 KiB |
BIN
examples/snake.png
Normal file
BIN
examples/snake.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 435 KiB |
87
generate.py
87
generate.py
@ -41,6 +41,14 @@ EXAMPLE_PROMPT = {
|
||||
"last_frame":
|
||||
"examples/flf2v_input_last_frame.png",
|
||||
},
|
||||
"vace-1.3B": {
|
||||
"src_ref_images": 'examples/girl.png,examples/snake.png',
|
||||
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||
},
|
||||
"vace-14B": {
|
||||
"src_ref_images": 'examples/girl.png,examples/snake.png',
|
||||
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -52,15 +60,19 @@ def _validate_args(args):
|
||||
|
||||
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
|
||||
if args.sample_steps is None:
|
||||
args.sample_steps = 40 if "i2v" in args.task else 50
|
||||
args.sample_steps = 50
|
||||
if "i2v" in args.task:
|
||||
args.sample_steps = 40
|
||||
|
||||
|
||||
if args.sample_shift is None:
|
||||
args.sample_shift = 5.0
|
||||
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
||||
args.sample_shift = 3.0
|
||||
if "flf2v" in args.task:
|
||||
elif "flf2v" in args.task or "vace" in args.task:
|
||||
args.sample_shift = 16
|
||||
|
||||
|
||||
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
||||
if args.frame_num is None:
|
||||
args.frame_num = 1 if "t2i" in args.task else 81
|
||||
@ -141,6 +153,21 @@ def _parse_args():
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file to save the generated image or video to.")
|
||||
parser.add_argument(
|
||||
"--src_video",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file of the source video. Default None.")
|
||||
parser.add_argument(
|
||||
"--src_mask",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file of the source mask. Default None.")
|
||||
parser.add_argument(
|
||||
"--src_ref_images",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file list of the source reference images. Separated by ','. Default None.")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
@ -397,7 +424,7 @@ def generate(args):
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
else:
|
||||
elif "flf2v" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
if args.first_frame is None or args.last_frame is None:
|
||||
@ -457,6 +484,60 @@ def generate(args):
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model
|
||||
)
|
||||
elif "vace" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
|
||||
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
|
||||
args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None)
|
||||
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
|
||||
logging.info("Extending prompt ...")
|
||||
if rank == 0:
|
||||
prompt = prompt_expander.forward(args.prompt)
|
||||
logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'")
|
||||
input_prompt = [prompt]
|
||||
else:
|
||||
input_prompt = [None]
|
||||
if dist.is_initialized():
|
||||
dist.broadcast_object_list(input_prompt, src=0)
|
||||
args.prompt = input_prompt[0]
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating VACE pipeline.")
|
||||
wan_vace = wan.WanVace(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video],
|
||||
[args.src_mask],
|
||||
[None if args.src_ref_images is None else args.src_ref_images.split(',')],
|
||||
args.frame_num, SIZE_CONFIGS[args.size], device)
|
||||
|
||||
logging.info(f"Generating video...")
|
||||
video = wan_vace.generate(
|
||||
args.prompt,
|
||||
src_video,
|
||||
src_mask,
|
||||
src_ref_images,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
else:
|
||||
raise ValueError(f"Unkown task type: {args.task}")
|
||||
|
||||
if rank == 0:
|
||||
if args.save_file is None:
|
||||
|
295
gradio/vace.py
Normal file
295
gradio/vace.py
Normal file
@ -0,0 +1,295 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import datetime
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
|
||||
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
|
||||
import wan
|
||||
from wan import WanVace, WanVaceMP
|
||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
|
||||
|
||||
|
||||
class FixedSizeQueue:
|
||||
def __init__(self, max_size):
|
||||
self.max_size = max_size
|
||||
self.queue = []
|
||||
def add(self, item):
|
||||
self.queue.insert(0, item)
|
||||
if len(self.queue) > self.max_size:
|
||||
self.queue.pop()
|
||||
def get(self):
|
||||
return self.queue
|
||||
def __repr__(self):
|
||||
return str(self.queue)
|
||||
|
||||
|
||||
class VACEInference:
|
||||
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
|
||||
self.cfg = cfg
|
||||
self.save_dir = cfg.save_dir
|
||||
self.gallery_share = gallery_share
|
||||
self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
|
||||
if not skip_load:
|
||||
if not args.mp:
|
||||
self.pipe = WanVace(
|
||||
config=WAN_CONFIGS[cfg.model_name],
|
||||
checkpoint_dir=cfg.ckpt_dir,
|
||||
device_id=0,
|
||||
rank=0,
|
||||
t5_fsdp=False,
|
||||
dit_fsdp=False,
|
||||
use_usp=False,
|
||||
)
|
||||
else:
|
||||
self.pipe = WanVaceMP(
|
||||
config=WAN_CONFIGS[cfg.model_name],
|
||||
checkpoint_dir=cfg.ckpt_dir,
|
||||
use_usp=True,
|
||||
ulysses_size=cfg.ulysses_size,
|
||||
ring_size=cfg.ring_size
|
||||
)
|
||||
|
||||
|
||||
def create_ui(self, *args, **kwargs):
|
||||
gr.Markdown("""
|
||||
<div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
|
||||
<a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
|
||||
</div>
|
||||
""")
|
||||
with gr.Row(variant='panel', equal_height=True):
|
||||
with gr.Column(scale=1, min_width=0):
|
||||
self.src_video = gr.Video(
|
||||
label="src_video",
|
||||
sources=['upload'],
|
||||
value=None,
|
||||
interactive=True)
|
||||
with gr.Column(scale=1, min_width=0):
|
||||
self.src_mask = gr.Video(
|
||||
label="src_mask",
|
||||
sources=['upload'],
|
||||
value=None,
|
||||
interactive=True)
|
||||
#
|
||||
with gr.Row(variant='panel', equal_height=True):
|
||||
with gr.Column(scale=1, min_width=0):
|
||||
with gr.Row(equal_height=True):
|
||||
self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
|
||||
height=200,
|
||||
interactive=True,
|
||||
type='filepath',
|
||||
image_mode='RGB',
|
||||
sources=['upload'],
|
||||
elem_id="src_ref_image_1",
|
||||
format='png')
|
||||
self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
|
||||
height=200,
|
||||
interactive=True,
|
||||
type='filepath',
|
||||
image_mode='RGB',
|
||||
sources=['upload'],
|
||||
elem_id="src_ref_image_2",
|
||||
format='png')
|
||||
self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
|
||||
height=200,
|
||||
interactive=True,
|
||||
type='filepath',
|
||||
image_mode='RGB',
|
||||
sources=['upload'],
|
||||
elem_id="src_ref_image_3",
|
||||
format='png')
|
||||
with gr.Row(variant='panel', equal_height=True):
|
||||
with gr.Column(scale=1):
|
||||
self.prompt = gr.Textbox(
|
||||
show_label=False,
|
||||
placeholder="positive_prompt_input",
|
||||
elem_id='positive_prompt',
|
||||
container=True,
|
||||
autofocus=True,
|
||||
elem_classes='type_row',
|
||||
visible=True,
|
||||
lines=2)
|
||||
self.negative_prompt = gr.Textbox(
|
||||
show_label=False,
|
||||
value=self.pipe.config.sample_neg_prompt,
|
||||
placeholder="negative_prompt_input",
|
||||
elem_id='negative_prompt',
|
||||
container=True,
|
||||
autofocus=False,
|
||||
elem_classes='type_row',
|
||||
visible=True,
|
||||
interactive=True,
|
||||
lines=1)
|
||||
#
|
||||
with gr.Row(variant='panel', equal_height=True):
|
||||
with gr.Column(scale=1, min_width=0):
|
||||
with gr.Row(equal_height=True):
|
||||
self.shift_scale = gr.Slider(
|
||||
label='shift_scale',
|
||||
minimum=0.0,
|
||||
maximum=100.0,
|
||||
step=1.0,
|
||||
value=16.0,
|
||||
interactive=True)
|
||||
self.sample_steps = gr.Slider(
|
||||
label='sample_steps',
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
step=1,
|
||||
value=25,
|
||||
interactive=True)
|
||||
self.context_scale = gr.Slider(
|
||||
label='context_scale',
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
step=0.1,
|
||||
value=1.0,
|
||||
interactive=True)
|
||||
self.guide_scale = gr.Slider(
|
||||
label='guide_scale',
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=0.5,
|
||||
value=5.0,
|
||||
interactive=True)
|
||||
self.infer_seed = gr.Slider(minimum=-1,
|
||||
maximum=10000000,
|
||||
value=2025,
|
||||
label="Seed")
|
||||
#
|
||||
with gr.Accordion(label="Usable without source video", open=False):
|
||||
with gr.Row(equal_height=True):
|
||||
self.output_height = gr.Textbox(
|
||||
label='resolutions_height',
|
||||
# value=480,
|
||||
value=720,
|
||||
interactive=True)
|
||||
self.output_width = gr.Textbox(
|
||||
label='resolutions_width',
|
||||
# value=832,
|
||||
value=1280,
|
||||
interactive=True)
|
||||
self.frame_rate = gr.Textbox(
|
||||
label='frame_rate',
|
||||
value=16,
|
||||
interactive=True)
|
||||
self.num_frames = gr.Textbox(
|
||||
label='num_frames',
|
||||
value=81,
|
||||
interactive=True)
|
||||
#
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=5):
|
||||
self.generate_button = gr.Button(
|
||||
value='Run',
|
||||
elem_classes='type_row',
|
||||
elem_id='generate_button',
|
||||
visible=True)
|
||||
with gr.Column(scale=1):
|
||||
self.refresh_button = gr.Button(value='\U0001f504') # 🔄
|
||||
#
|
||||
self.output_gallery = gr.Gallery(
|
||||
label="output_gallery",
|
||||
value=[],
|
||||
interactive=False,
|
||||
allow_preview=True,
|
||||
preview=True)
|
||||
|
||||
|
||||
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
|
||||
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
|
||||
src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
|
||||
x is not None]
|
||||
src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
|
||||
[src_mask],
|
||||
[src_ref_images],
|
||||
num_frames=num_frames,
|
||||
image_size=SIZE_CONFIGS[f"{output_height}*{output_width}"],
|
||||
device=self.pipe.device)
|
||||
video = self.pipe.generate(
|
||||
prompt,
|
||||
src_video,
|
||||
src_mask,
|
||||
src_ref_images,
|
||||
size=(output_width, output_height),
|
||||
context_scale=context_scale,
|
||||
shift=shift_scale,
|
||||
sampling_steps=sample_steps,
|
||||
guide_scale=guide_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=infer_seed,
|
||||
offload_model=True)
|
||||
|
||||
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
|
||||
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
|
||||
video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
|
||||
|
||||
try:
|
||||
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
|
||||
for frame in video_frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
print(video_path)
|
||||
except Exception as e:
|
||||
raise gr.Error(f"Video save error: {e}")
|
||||
|
||||
if self.gallery_share:
|
||||
self.gallery_share_data.add(video_path)
|
||||
return self.gallery_share_data.get()
|
||||
else:
|
||||
return [video_path]
|
||||
|
||||
def set_callbacks(self, **kwargs):
|
||||
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
|
||||
self.gen_outputs = [self.output_gallery]
|
||||
self.generate_button.click(self.generate,
|
||||
inputs=self.gen_inputs,
|
||||
outputs=self.gen_outputs,
|
||||
queue=True)
|
||||
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n')
|
||||
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
|
||||
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
|
||||
parser.add_argument('--root_path', dest='root_path', help='', default=None)
|
||||
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
|
||||
parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",)
|
||||
parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.")
|
||||
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
|
||||
parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.")
|
||||
parser.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
# default='models/VACE-Wan2.1-1.3B-Preview',
|
||||
default='models/Wan2.1-VACE-14B/',
|
||||
help="The path to the checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_to_cpu",
|
||||
action="store_true",
|
||||
help="Offloading unnecessary computations to CPU.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.save_dir):
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
|
||||
infer_gr.create_ui()
|
||||
infer_gr.set_callbacks()
|
||||
allowed_paths = [args.save_dir]
|
||||
demo.queue(status_update_rate=1).launch(server_name=args.server_name,
|
||||
server_port=args.server_port,
|
||||
root_path=args.root_path,
|
||||
allowed_paths=allowed_paths,
|
||||
show_error=True, debug=True)
|
@ -105,9 +105,16 @@ function i2v_14B_720p() {
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
}
|
||||
|
||||
function vace_1_3B() {
|
||||
VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/"
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR
|
||||
|
||||
}
|
||||
|
||||
|
||||
t2i_14B
|
||||
t2v_1_3B
|
||||
t2v_14B
|
||||
i2v_14B_480p
|
||||
i2v_14B_720p
|
||||
vace_1_3B
|
||||
|
@ -2,3 +2,4 @@ from . import configs, distributed, modules
|
||||
from .image2video import WanI2V
|
||||
from .text2video import WanT2V
|
||||
from .first_last_frame2video import WanFLF2V
|
||||
from .vace import WanVace, WanVaceMP
|
||||
|
@ -22,7 +22,9 @@ WAN_CONFIGS = {
|
||||
't2v-1.3B': t2v_1_3B,
|
||||
'i2v-14B': i2v_14B,
|
||||
't2i-14B': t2i_14B,
|
||||
'flf2v-14B': flf2v_14B
|
||||
'flf2v-14B': flf2v_14B,
|
||||
'vace-1.3B': t2v_1_3B,
|
||||
'vace-14B': t2v_14B,
|
||||
}
|
||||
|
||||
SIZE_CONFIGS = {
|
||||
@ -46,4 +48,6 @@ SUPPORTED_SIZES = {
|
||||
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
||||
'vace-1.3B': ('480*832', '832*480'),
|
||||
'vace-14B': ('720*1280', '1280*720', '480*832', '832*480')
|
||||
}
|
||||
|
@ -63,12 +63,45 @@ def rope_apply(x, grid_sizes, freqs):
|
||||
return torch.stack(output).float()
|
||||
|
||||
|
||||
def usp_dit_forward_vace(
|
||||
self,
|
||||
x,
|
||||
vace_context,
|
||||
seq_len,
|
||||
kwargs
|
||||
):
|
||||
# embeddings
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(x=x)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
# Context Parallel
|
||||
c = torch.chunk(
|
||||
c, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
hints = []
|
||||
for block in self.vace_blocks:
|
||||
c, c_skip = block(c, **new_kwargs)
|
||||
hints.append(c_skip)
|
||||
return hints
|
||||
|
||||
|
||||
def usp_dit_forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
vace_context=None,
|
||||
vace_context_scale=1.0,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
):
|
||||
@ -84,7 +117,7 @@ def usp_dit_forward(
|
||||
if self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(device)
|
||||
|
||||
if y is not None:
|
||||
if self.model_type != 'vace' and y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
|
||||
# embeddings
|
||||
@ -114,7 +147,7 @@ def usp_dit_forward(
|
||||
for u in context
|
||||
]))
|
||||
|
||||
if clip_fea is not None:
|
||||
if self.model_type != 'vace' and clip_fea is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
@ -132,6 +165,11 @@ def usp_dit_forward(
|
||||
x, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
if self.model_type == 'vace':
|
||||
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
|
||||
kwargs['hints'] = hints
|
||||
kwargs['context_scale'] = vace_context_scale
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
|
@ -2,11 +2,13 @@ from .attention import flash_attention
|
||||
from .model import WanModel
|
||||
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
from .vace_model import VaceWanModel
|
||||
from .vae import WanVAE
|
||||
|
||||
__all__ = [
|
||||
'WanVAE',
|
||||
'WanModel',
|
||||
'VaceWanModel',
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
|
@ -400,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_type (`str`, *optional*, defaults to 't2v'):
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
|
||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||
text_len (`int`, *optional*, defaults to 512):
|
||||
@ -433,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert model_type in ['t2v', 'i2v', 'flf2v']
|
||||
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
233
wan/modules/vace_model.py
Normal file
233
wan/modules/vace_model.py
Normal file
@ -0,0 +1,233 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
|
||||
|
||||
|
||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
def __init__(
|
||||
self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6,
|
||||
block_id=0
|
||||
):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = nn.Linear(self.dim, self.dim)
|
||||
nn.init.zeros_(self.before_proj.weight)
|
||||
nn.init.zeros_(self.before_proj.bias)
|
||||
self.after_proj = nn.Linear(self.dim, self.dim)
|
||||
nn.init.zeros_(self.after_proj.weight)
|
||||
nn.init.zeros_(self.after_proj.bias)
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
return c, c_skip
|
||||
|
||||
|
||||
class BaseWanAttentionBlock(WanAttentionBlock):
|
||||
def __init__(
|
||||
self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6,
|
||||
block_id=None
|
||||
):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
||||
self.block_id = block_id
|
||||
|
||||
def forward(self, x, hints, context_scale=1.0, **kwargs):
|
||||
x = super().forward(x, **kwargs)
|
||||
if self.block_id is not None:
|
||||
x = x + hints[self.block_id] * context_scale
|
||||
return x
|
||||
|
||||
|
||||
class VaceWanModel(WanModel):
|
||||
@register_to_config
|
||||
def __init__(self,
|
||||
vace_layers=None,
|
||||
vace_in_dim=None,
|
||||
model_type='vace',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6):
|
||||
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
|
||||
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
|
||||
|
||||
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
|
||||
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
|
||||
|
||||
assert 0 in self.vace_layers
|
||||
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
||||
|
||||
# blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
||||
self.cross_attn_norm, self.eps,
|
||||
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
|
||||
for i in range(self.num_layers)
|
||||
])
|
||||
|
||||
# vace blocks
|
||||
self.vace_blocks = nn.ModuleList([
|
||||
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
||||
self.cross_attn_norm, self.eps, block_id=i)
|
||||
for i in self.vace_layers
|
||||
])
|
||||
|
||||
# vace patch embeddings
|
||||
self.vace_patch_embedding = nn.Conv3d(
|
||||
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
|
||||
)
|
||||
|
||||
def forward_vace(
|
||||
self,
|
||||
x,
|
||||
vace_context,
|
||||
seq_len,
|
||||
kwargs
|
||||
):
|
||||
# embeddings
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(x=x)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
hints = []
|
||||
for block in self.vace_blocks:
|
||||
c, c_skip = block(c, **new_kwargs)
|
||||
hints.append(c_skip)
|
||||
return hints
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
vace_context,
|
||||
context,
|
||||
seq_len,
|
||||
vace_context_scale=1.0,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of input video tensors, each with shape [C_in, F, H, W]
|
||||
t (Tensor):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Tensor]):
|
||||
List of text embeddings each with shape [L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
clip_fea (Tensor, *optional*):
|
||||
CLIP image features for image-to-video mode
|
||||
y (List[Tensor], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
# if self.model_type == 'i2v':
|
||||
# assert clip_fea is not None and y is not None
|
||||
# params
|
||||
device = self.patch_embedding.weight.device
|
||||
if self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(device)
|
||||
|
||||
# if y is not None:
|
||||
# x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
|
||||
# embeddings
|
||||
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||
grid_sizes = torch.stack(
|
||||
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||
assert seq_lens.max() <= seq_len
|
||||
x = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in x
|
||||
])
|
||||
|
||||
# time embeddings
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
context = self.text_embedding(
|
||||
torch.stack([
|
||||
torch.cat(
|
||||
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||
for u in context
|
||||
]))
|
||||
|
||||
# if clip_fea is not None:
|
||||
# context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
# context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
|
||||
kwargs['hints'] = hints
|
||||
kwargs['context_scale'] = vace_context_scale
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return [u.float() for u in x]
|
@ -1,8 +1,10 @@
|
||||
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
|
||||
retrieve_timesteps)
|
||||
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from .vace_processor import VaceVideoProcessor
|
||||
|
||||
__all__ = [
|
||||
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
||||
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
|
||||
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
|
||||
'VaceVideoProcessor'
|
||||
]
|
||||
|
270
wan/utils/vace_processor.py
Normal file
270
wan/utils/vace_processor.py
Normal file
@ -0,0 +1,270 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
|
||||
class VaceImageProcessor(object):
|
||||
def __init__(self, downsample=None, seq_len=None):
|
||||
self.downsample = downsample
|
||||
self.seq_len = seq_len
|
||||
|
||||
def _pillow_convert(self, image, cvt_type='RGB'):
|
||||
if image.mode != cvt_type:
|
||||
if image.mode == 'P':
|
||||
image = image.convert(f'{cvt_type}A')
|
||||
if image.mode == f'{cvt_type}A':
|
||||
bg = Image.new(cvt_type,
|
||||
size=(image.width, image.height),
|
||||
color=(255, 255, 255))
|
||||
bg.paste(image, (0, 0), mask=image)
|
||||
image = bg
|
||||
else:
|
||||
image = image.convert(cvt_type)
|
||||
return image
|
||||
|
||||
def _load_image(self, img_path):
|
||||
if img_path is None or img_path == '':
|
||||
return None
|
||||
img = Image.open(img_path)
|
||||
img = self._pillow_convert(img)
|
||||
return img
|
||||
|
||||
def _resize_crop(self, img, oh, ow, normalize=True):
|
||||
"""
|
||||
Resize, center crop, convert to tensor, and normalize.
|
||||
"""
|
||||
# resize and crop
|
||||
iw, ih = img.size
|
||||
if iw != ow or ih != oh:
|
||||
# resize
|
||||
scale = max(ow / iw, oh / ih)
|
||||
img = img.resize(
|
||||
(round(scale * iw), round(scale * ih)),
|
||||
resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
assert img.width >= ow and img.height >= oh
|
||||
|
||||
# center crop
|
||||
x1 = (img.width - ow) // 2
|
||||
y1 = (img.height - oh) // 2
|
||||
img = img.crop((x1, y1, x1 + ow, y1 + oh))
|
||||
|
||||
# normalize
|
||||
if normalize:
|
||||
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||
return img
|
||||
|
||||
def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
|
||||
return self._resize_crop(img, oh, ow, normalize)
|
||||
|
||||
def load_image(self, data_key, **kwargs):
|
||||
return self.load_image_batch(data_key, **kwargs)
|
||||
|
||||
def load_image_pair(self, data_key, data_key2, **kwargs):
|
||||
return self.load_image_batch(data_key, data_key2, **kwargs)
|
||||
|
||||
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
|
||||
seq_len = self.seq_len if seq_len is None else seq_len
|
||||
imgs = []
|
||||
for data_key in data_key_batch:
|
||||
img = self._load_image(data_key)
|
||||
imgs.append(img)
|
||||
w, h = imgs[0].size
|
||||
dh, dw = self.downsample[1:]
|
||||
|
||||
# compute output size
|
||||
scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
|
||||
oh = int(h * scale) // dh * dh
|
||||
ow = int(w * scale) // dw * dw
|
||||
assert (oh // dh) * (ow // dw) <= seq_len
|
||||
imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
|
||||
return *imgs, (oh, ow)
|
||||
|
||||
|
||||
class VaceVideoProcessor(object):
|
||||
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
|
||||
self.downsample = downsample
|
||||
self.min_area = min_area
|
||||
self.max_area = max_area
|
||||
self.min_fps = min_fps
|
||||
self.max_fps = max_fps
|
||||
self.zero_start = zero_start
|
||||
self.keep_last = keep_last
|
||||
self.seq_len = seq_len
|
||||
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
|
||||
|
||||
def set_area(self, area):
|
||||
self.min_area = area
|
||||
self.max_area = area
|
||||
|
||||
def set_seq_len(self, seq_len):
|
||||
self.seq_len = seq_len
|
||||
|
||||
@staticmethod
|
||||
def resize_crop(video: torch.Tensor, oh: int, ow: int):
|
||||
"""
|
||||
Resize, center crop and normalize for decord loaded video (torch.Tensor type)
|
||||
|
||||
Parameters:
|
||||
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
|
||||
oh - target height (int)
|
||||
ow - target width (int)
|
||||
|
||||
Returns:
|
||||
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
|
||||
|
||||
Raises:
|
||||
"""
|
||||
# permute ([t, h, w, c] -> [t, c, h, w])
|
||||
video = video.permute(0, 3, 1, 2)
|
||||
|
||||
# resize and crop
|
||||
ih, iw = video.shape[2:]
|
||||
if ih != oh or iw != ow:
|
||||
# resize
|
||||
scale = max(ow / iw, oh / ih)
|
||||
video = F.interpolate(
|
||||
video,
|
||||
size=(round(scale * ih), round(scale * iw)),
|
||||
mode='bicubic',
|
||||
antialias=True
|
||||
)
|
||||
assert video.size(3) >= ow and video.size(2) >= oh
|
||||
|
||||
# center crop
|
||||
x1 = (video.size(3) - ow) // 2
|
||||
y1 = (video.size(2) - oh) // 2
|
||||
video = video[:, :, y1:y1 + oh, x1:x1 + ow]
|
||||
|
||||
# permute ([t, c, h, w] -> [c, t, h, w]) and normalize
|
||||
video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
|
||||
return video
|
||||
|
||||
def _video_preprocess(self, video, oh, ow):
|
||||
return self.resize_crop(video, oh, ow)
|
||||
|
||||
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||
target_fps = min(fps, self.max_fps)
|
||||
duration = frame_timestamps[-1].mean()
|
||||
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||
h, w = y2 - y1, x2 - x1
|
||||
ratio = h / w
|
||||
df, dh, dw = self.downsample
|
||||
|
||||
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
||||
of = min(
|
||||
(int(duration * target_fps) - 1) // df + 1,
|
||||
int(self.seq_len / area_z)
|
||||
)
|
||||
|
||||
# deduce target shape of the [latent video]
|
||||
target_area_z = min(area_z, int(self.seq_len / of))
|
||||
oh = round(np.sqrt(target_area_z * ratio))
|
||||
ow = int(target_area_z / oh)
|
||||
of = (of - 1) * df + 1
|
||||
oh *= dh
|
||||
ow *= dw
|
||||
|
||||
# sample frame ids
|
||||
target_duration = of / target_fps
|
||||
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
|
||||
timestamps = np.linspace(begin, begin + target_duration, of)
|
||||
frame_ids = np.argmax(np.logical_and(
|
||||
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
||||
timestamps[:, None] < frame_timestamps[None, :, 1]
|
||||
), axis=1).tolist()
|
||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||
|
||||
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||
duration = frame_timestamps[-1].mean()
|
||||
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||
h, w = y2 - y1, x2 - x1
|
||||
ratio = h / w
|
||||
df, dh, dw = self.downsample
|
||||
|
||||
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
||||
of = min(
|
||||
(len(frame_timestamps) - 1) // df + 1,
|
||||
int(self.seq_len / area_z)
|
||||
)
|
||||
|
||||
# deduce target shape of the [latent video]
|
||||
target_area_z = min(area_z, int(self.seq_len / of))
|
||||
oh = round(np.sqrt(target_area_z * ratio))
|
||||
ow = int(target_area_z / oh)
|
||||
of = (of - 1) * df + 1
|
||||
oh *= dh
|
||||
ow *= dw
|
||||
|
||||
# sample frame ids
|
||||
target_duration = duration
|
||||
target_fps = of / target_duration
|
||||
timestamps = np.linspace(0., target_duration, of)
|
||||
frame_ids = np.argmax(np.logical_and(
|
||||
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
||||
timestamps[:, None] <= frame_timestamps[None, :, 1]
|
||||
), axis=1).tolist()
|
||||
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
|
||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||
|
||||
|
||||
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||
if self.keep_last:
|
||||
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng)
|
||||
else:
|
||||
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng)
|
||||
|
||||
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
||||
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
|
||||
|
||||
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
|
||||
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
|
||||
|
||||
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs):
|
||||
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
|
||||
# read video
|
||||
import decord
|
||||
decord.bridge.set_bridge('torch')
|
||||
readers = []
|
||||
for data_k in data_key_batch:
|
||||
reader = decord.VideoReader(data_k)
|
||||
readers.append(reader)
|
||||
|
||||
fps = readers[0].get_avg_fps()
|
||||
length = min([len(r) for r in readers])
|
||||
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
|
||||
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
||||
h, w = readers[0].next().shape[:2]
|
||||
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng)
|
||||
|
||||
# preprocess video
|
||||
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
||||
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
||||
return *videos, frame_ids, (oh, ow), fps
|
||||
# return videos if len(videos) > 1 else videos[0]
|
||||
|
||||
|
||||
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
||||
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||
if sub_src_video is None and sub_src_mask is None:
|
||||
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
||||
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
|
||||
for i, ref_images in enumerate(src_ref_images):
|
||||
if ref_images is not None:
|
||||
for j, ref_img in enumerate(ref_images):
|
||||
if ref_img is not None and ref_img.shape[-2:] != image_size:
|
||||
canvas_height, canvas_width = image_size
|
||||
ref_height, ref_width = ref_img.shape[-2:]
|
||||
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||
new_height = int(ref_height * scale)
|
||||
new_width = int(ref_width * scale)
|
||||
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
||||
top = (canvas_height - new_height) // 2
|
||||
left = (canvas_width - new_width) // 2
|
||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
||||
src_ref_images[i][j] = white_canvas
|
||||
return src_video, src_mask, src_ref_images
|
717
wan/vace.py
Normal file
717
wan/vace.py
Normal file
@ -0,0 +1,717 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import os
|
||||
import sys
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
import random
|
||||
import types
|
||||
import logging
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
from PIL import Image
|
||||
import torchvision.transforms.functional as TF
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.cuda.amp as amp
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from tqdm import tqdm
|
||||
|
||||
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
|
||||
from .modules.vace_model import VaceWanModel
|
||||
from .utils.vace_processor import VaceVideoProcessor
|
||||
|
||||
|
||||
class WanVace(WanT2V):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir,
|
||||
device_id=0,
|
||||
rank=0,
|
||||
t5_fsdp=False,
|
||||
dit_fsdp=False,
|
||||
use_usp=False,
|
||||
t5_cpu=False,
|
||||
):
|
||||
r"""
|
||||
Initializes the Wan text-to-video generation model components.
|
||||
|
||||
Args:
|
||||
config (EasyDict):
|
||||
Object containing model parameters initialized from config.py
|
||||
checkpoint_dir (`str`):
|
||||
Path to directory containing model checkpoints
|
||||
device_id (`int`, *optional*, defaults to 0):
|
||||
Id of target GPU device
|
||||
rank (`int`, *optional*, defaults to 0):
|
||||
Process rank for distributed training
|
||||
t5_fsdp (`bool`, *optional*, defaults to False):
|
||||
Enable FSDP sharding for T5 model
|
||||
dit_fsdp (`bool`, *optional*, defaults to False):
|
||||
Enable FSDP sharding for DiT model
|
||||
use_usp (`bool`, *optional*, defaults to False):
|
||||
Enable distribution strategy of USP.
|
||||
t5_cpu (`bool`, *optional*, defaults to False):
|
||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||
"""
|
||||
self.device = torch.device(f"cuda:{device_id}")
|
||||
self.config = config
|
||||
self.rank = rank
|
||||
self.t5_cpu = t5_cpu
|
||||
|
||||
self.num_train_timesteps = config.num_train_timesteps
|
||||
self.param_dtype = config.param_dtype
|
||||
|
||||
shard_fn = partial(shard_model, device_id=device_id)
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
dtype=config.t5_dtype,
|
||||
device=torch.device('cpu'),
|
||||
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||
shard_fn=shard_fn if t5_fsdp else None)
|
||||
|
||||
self.vae_stride = config.vae_stride
|
||||
self.patch_size = config.patch_size
|
||||
self.vae = WanVAE(
|
||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||
device=self.device)
|
||||
|
||||
logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
|
||||
self.model = VaceWanModel.from_pretrained(checkpoint_dir)
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if use_usp:
|
||||
from xfuser.core.distributed import \
|
||||
get_sequence_parallel_world_size
|
||||
|
||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
usp_dit_forward_vace)
|
||||
for block in self.model.blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
for block in self.model.vace_blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
||||
self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model)
|
||||
self.sp_size = get_sequence_parallel_world_size()
|
||||
else:
|
||||
self.sp_size = 1
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
if dit_fsdp:
|
||||
self.model = shard_fn(self.model)
|
||||
else:
|
||||
self.model.to(self.device)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
||||
min_area=720*1280,
|
||||
max_area=720*1280,
|
||||
min_fps=config.sample_fps,
|
||||
max_fps=config.sample_fps,
|
||||
zero_start=True,
|
||||
seq_len=75600,
|
||||
keep_last=True)
|
||||
|
||||
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
|
||||
vae = self.vae if vae is None else vae
|
||||
if ref_images is None:
|
||||
ref_images = [None] * len(frames)
|
||||
else:
|
||||
assert len(frames) == len(ref_images)
|
||||
|
||||
if masks is None:
|
||||
latents = vae.encode(frames)
|
||||
else:
|
||||
masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
|
||||
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
|
||||
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
||||
inactive = vae.encode(inactive)
|
||||
reactive = vae.encode(reactive)
|
||||
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
||||
|
||||
cat_latents = []
|
||||
for latent, refs in zip(latents, ref_images):
|
||||
if refs is not None:
|
||||
if masks is None:
|
||||
ref_latent = vae.encode(refs)
|
||||
else:
|
||||
ref_latent = vae.encode(refs)
|
||||
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
|
||||
assert all([x.shape[1] == 1 for x in ref_latent])
|
||||
latent = torch.cat([*ref_latent, latent], dim=1)
|
||||
cat_latents.append(latent)
|
||||
return cat_latents
|
||||
|
||||
def vace_encode_masks(self, masks, ref_images=None, vae_stride=None):
|
||||
vae_stride = self.vae_stride if vae_stride is None else vae_stride
|
||||
if ref_images is None:
|
||||
ref_images = [None] * len(masks)
|
||||
else:
|
||||
assert len(masks) == len(ref_images)
|
||||
|
||||
result_masks = []
|
||||
for mask, refs in zip(masks, ref_images):
|
||||
c, depth, height, width = mask.shape
|
||||
new_depth = int((depth + 3) // vae_stride[0])
|
||||
height = 2 * (int(height) // (vae_stride[1] * 2))
|
||||
width = 2 * (int(width) // (vae_stride[2] * 2))
|
||||
|
||||
# reshape
|
||||
mask = mask[0, :, :, :]
|
||||
mask = mask.view(
|
||||
depth, height, vae_stride[1], width, vae_stride[1]
|
||||
) # depth, height, 8, width, 8
|
||||
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
|
||||
mask = mask.reshape(
|
||||
vae_stride[1] * vae_stride[2], depth, height, width
|
||||
) # 8*8, depth, height, width
|
||||
|
||||
# interpolation
|
||||
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
|
||||
|
||||
if refs is not None:
|
||||
length = len(refs)
|
||||
mask_pad = torch.zeros_like(mask[:, :length, :, :])
|
||||
mask = torch.cat((mask_pad, mask), dim=1)
|
||||
result_masks.append(mask)
|
||||
return result_masks
|
||||
|
||||
def vace_latent(self, z, m):
|
||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
||||
area = image_size[0] * image_size[1]
|
||||
self.vid_proc.set_area(area)
|
||||
if area == 720*1280:
|
||||
self.vid_proc.set_seq_len(75600)
|
||||
elif area == 480*832:
|
||||
self.vid_proc.set_seq_len(32760)
|
||||
else:
|
||||
raise NotImplementedError(f'image_size {image_size} is not supported')
|
||||
|
||||
image_sizes = []
|
||||
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||
if sub_src_mask is not None and sub_src_video is not None:
|
||||
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
|
||||
src_video[i] = src_video[i].to(device)
|
||||
src_mask[i] = src_mask[i].to(device)
|
||||
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
||||
image_sizes.append(src_video[i].shape[2:])
|
||||
elif sub_src_video is None:
|
||||
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||
image_sizes.append(image_size)
|
||||
else:
|
||||
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
|
||||
src_video[i] = src_video[i].to(device)
|
||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||
image_sizes.append(src_video[i].shape[2:])
|
||||
|
||||
for i, ref_images in enumerate(src_ref_images):
|
||||
if ref_images is not None:
|
||||
image_size = image_sizes[i]
|
||||
for j, ref_img in enumerate(ref_images):
|
||||
if ref_img is not None:
|
||||
ref_img = Image.open(ref_img).convert("RGB")
|
||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||
if ref_img.shape[-2:] != image_size:
|
||||
canvas_height, canvas_width = image_size
|
||||
ref_height, ref_width = ref_img.shape[-2:]
|
||||
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||
new_height = int(ref_height * scale)
|
||||
new_width = int(ref_width * scale)
|
||||
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
||||
top = (canvas_height - new_height) // 2
|
||||
left = (canvas_width - new_width) // 2
|
||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
||||
ref_img = white_canvas
|
||||
src_ref_images[i][j] = ref_img.to(device)
|
||||
return src_video, src_mask, src_ref_images
|
||||
|
||||
def decode_latent(self, zs, ref_images=None, vae=None):
|
||||
vae = self.vae if vae is None else vae
|
||||
if ref_images is None:
|
||||
ref_images = [None] * len(zs)
|
||||
else:
|
||||
assert len(zs) == len(ref_images)
|
||||
|
||||
trimed_zs = []
|
||||
for z, refs in zip(zs, ref_images):
|
||||
if refs is not None:
|
||||
z = z[:, len(refs):, :, :]
|
||||
trimed_zs.append(z)
|
||||
|
||||
return vae.decode(trimed_zs)
|
||||
|
||||
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
input_frames,
|
||||
input_masks,
|
||||
input_ref_images,
|
||||
size=(1280, 720),
|
||||
frame_num=81,
|
||||
context_scale=1.0,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=50,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True):
|
||||
r"""
|
||||
Generates video frames from text prompt using diffusion process.
|
||||
|
||||
Args:
|
||||
input_prompt (`str`):
|
||||
Text prompt for content generation
|
||||
size (tupele[`int`], *optional*, defaults to (1280,720)):
|
||||
Controls video resolution, (width,height).
|
||||
frame_num (`int`, *optional*, defaults to 81):
|
||||
How many frames to sample from a video. The number should be 4n+1
|
||||
shift (`float`, *optional*, defaults to 5.0):
|
||||
Noise schedule shift parameter. Affects temporal dynamics
|
||||
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
||||
Solver used to sample the video.
|
||||
sampling_steps (`int`, *optional*, defaults to 40):
|
||||
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
||||
guide_scale (`float`, *optional*, defaults 5.0):
|
||||
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
||||
n_prompt (`str`, *optional*, defaults to ""):
|
||||
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
||||
seed (`int`, *optional*, defaults to -1):
|
||||
Random seed for noise generation. If -1, use random seed.
|
||||
offload_model (`bool`, *optional*, defaults to True):
|
||||
If True, offloads models to CPU during generation to save VRAM
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
||||
- C: Color channels (3 for RGB)
|
||||
- N: Number of frames (81)
|
||||
- H: Frame height (from size)
|
||||
- W: Frame width from size)
|
||||
"""
|
||||
# preprocess
|
||||
# F = frame_num
|
||||
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
||||
# size[1] // self.vae_stride[1],
|
||||
# size[0] // self.vae_stride[2])
|
||||
#
|
||||
# seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
||||
# (self.patch_size[1] * self.patch_size[2]) *
|
||||
# target_shape[1] / self.sp_size) * self.sp_size
|
||||
|
||||
if n_prompt == "":
|
||||
n_prompt = self.sample_neg_prompt
|
||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||
seed_g = torch.Generator(device=self.device)
|
||||
seed_g.manual_seed(seed)
|
||||
|
||||
if not self.t5_cpu:
|
||||
self.text_encoder.model.to(self.device)
|
||||
context = self.text_encoder([input_prompt], self.device)
|
||||
context_null = self.text_encoder([n_prompt], self.device)
|
||||
if offload_model:
|
||||
self.text_encoder.model.cpu()
|
||||
else:
|
||||
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||
context = [t.to(self.device) for t in context]
|
||||
context_null = [t.to(self.device) for t in context_null]
|
||||
|
||||
# vace context encode
|
||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks)
|
||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||
z = self.vace_latent(z0, m0)
|
||||
|
||||
target_shape = list(z0[0].shape)
|
||||
target_shape[0] = int(target_shape[0] / 2)
|
||||
noise = [
|
||||
torch.randn(
|
||||
target_shape[0],
|
||||
target_shape[1],
|
||||
target_shape[2],
|
||||
target_shape[3],
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
generator=seed_g)
|
||||
]
|
||||
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
||||
(self.patch_size[1] * self.patch_size[2]) *
|
||||
target_shape[1] / self.sp_size) * self.sp_size
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
|
||||
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
||||
|
||||
# evaluation mode
|
||||
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
||||
|
||||
if sample_solver == 'unipc':
|
||||
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False)
|
||||
sample_scheduler.set_timesteps(
|
||||
sampling_steps, device=self.device, shift=shift)
|
||||
timesteps = sample_scheduler.timesteps
|
||||
elif sample_solver == 'dpm++':
|
||||
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False)
|
||||
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
||||
timesteps, _ = retrieve_timesteps(
|
||||
sample_scheduler,
|
||||
device=self.device,
|
||||
sigmas=sampling_sigmas)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported solver.")
|
||||
|
||||
# sample videos
|
||||
latents = noise
|
||||
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
timestep = [t]
|
||||
|
||||
timestep = torch.stack(timestep)
|
||||
|
||||
self.model.to(self.device)
|
||||
noise_pred_cond = self.model(
|
||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0]
|
||||
noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0]
|
||||
|
||||
noise_pred = noise_pred_uncond + guide_scale * (
|
||||
noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
temp_x0 = sample_scheduler.step(
|
||||
noise_pred.unsqueeze(0),
|
||||
t,
|
||||
latents[0].unsqueeze(0),
|
||||
return_dict=False,
|
||||
generator=seed_g)[0]
|
||||
latents = [temp_x0.squeeze(0)]
|
||||
|
||||
x0 = latents
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
if self.rank == 0:
|
||||
videos = self.decode_latent(x0, input_ref_images)
|
||||
|
||||
del noise, latents
|
||||
del sample_scheduler
|
||||
if offload_model:
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
return videos[0] if self.rank == 0 else None
|
||||
|
||||
|
||||
class WanVaceMP(WanVace):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir,
|
||||
use_usp=False,
|
||||
ulysses_size=None,
|
||||
ring_size=None
|
||||
):
|
||||
self.config = config
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.use_usp = use_usp
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12345'
|
||||
os.environ['RANK'] = '0'
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
self.in_q_list = None
|
||||
self.out_q = None
|
||||
self.inference_pids = None
|
||||
self.ulysses_size = ulysses_size
|
||||
self.ring_size = ring_size
|
||||
self.dynamic_load()
|
||||
|
||||
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
|
||||
self.vid_proc = VaceVideoProcessor(
|
||||
downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]),
|
||||
min_area=480 * 832,
|
||||
max_area=480 * 832,
|
||||
min_fps=self.config.sample_fps,
|
||||
max_fps=self.config.sample_fps,
|
||||
zero_start=True,
|
||||
seq_len=32760,
|
||||
keep_last=True)
|
||||
|
||||
|
||||
def dynamic_load(self):
|
||||
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
|
||||
return
|
||||
gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count()
|
||||
pmi_rank = int(os.environ['RANK'])
|
||||
pmi_world_size = int(os.environ['WORLD_SIZE'])
|
||||
in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)]
|
||||
out_q = torch.multiprocessing.Manager().Queue()
|
||||
initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)]
|
||||
context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False)
|
||||
all_initialized = False
|
||||
while not all_initialized:
|
||||
all_initialized = all(event.is_set() for event in initialized_events)
|
||||
if not all_initialized:
|
||||
time.sleep(0.1)
|
||||
print('Inference model is initialized', flush=True)
|
||||
self.in_q_list = in_q_list
|
||||
self.out_q = out_q
|
||||
self.inference_pids = context.pids()
|
||||
self.initialized_events = initialized_events
|
||||
|
||||
def transfer_data_to_cuda(self, data, device):
|
||||
if data is None:
|
||||
return None
|
||||
else:
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.to(device)
|
||||
elif isinstance(data, list):
|
||||
data = [self.transfer_data_to_cuda(subdata, device) for subdata in data]
|
||||
elif isinstance(data, dict):
|
||||
data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()}
|
||||
return data
|
||||
|
||||
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env):
|
||||
try:
|
||||
world_size = pmi_world_size * gpu_infer
|
||||
rank = pmi_rank * gpu_infer + gpu
|
||||
print("world_size", world_size, "rank", rank, flush=True)
|
||||
|
||||
torch.cuda.set_device(gpu)
|
||||
dist.init_process_group(
|
||||
backend='nccl',
|
||||
init_method='env://',
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
)
|
||||
|
||||
from xfuser.core.distributed import (initialize_model_parallel,
|
||||
init_distributed_environment)
|
||||
init_distributed_environment(
|
||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
|
||||
initialize_model_parallel(
|
||||
sequence_parallel_degree=dist.get_world_size(),
|
||||
ring_degree=self.ring_size or 1,
|
||||
ulysses_degree=self.ulysses_size or 1
|
||||
)
|
||||
|
||||
num_train_timesteps = self.config.num_train_timesteps
|
||||
param_dtype = self.config.param_dtype
|
||||
shard_fn = partial(shard_model, device_id=gpu)
|
||||
text_encoder = T5EncoderModel(
|
||||
text_len=self.config.text_len,
|
||||
dtype=self.config.t5_dtype,
|
||||
device=torch.device('cpu'),
|
||||
checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint),
|
||||
tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer),
|
||||
shard_fn=shard_fn if True else None)
|
||||
text_encoder.model.to(gpu)
|
||||
vae_stride = self.config.vae_stride
|
||||
patch_size = self.config.patch_size
|
||||
vae = WanVAE(
|
||||
vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint),
|
||||
device=gpu)
|
||||
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
|
||||
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
|
||||
model.eval().requires_grad_(False)
|
||||
|
||||
if self.use_usp:
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
usp_dit_forward_vace)
|
||||
for block in model.blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
for block in model.vace_blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
model.forward = types.MethodType(usp_dit_forward, model)
|
||||
model.forward_vace = types.MethodType(usp_dit_forward_vace, model)
|
||||
sp_size = get_sequence_parallel_world_size()
|
||||
else:
|
||||
sp_size = 1
|
||||
|
||||
dist.barrier()
|
||||
model = shard_fn(model)
|
||||
sample_neg_prompt = self.config.sample_neg_prompt
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
event = initialized_events[gpu]
|
||||
in_q = in_q_list[gpu]
|
||||
event.set()
|
||||
|
||||
while True:
|
||||
item = in_q.get()
|
||||
input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \
|
||||
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
|
||||
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
|
||||
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
|
||||
input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu)
|
||||
|
||||
if n_prompt == "":
|
||||
n_prompt = sample_neg_prompt
|
||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||
seed_g = torch.Generator(device=gpu)
|
||||
seed_g.manual_seed(seed)
|
||||
|
||||
context = text_encoder([input_prompt], gpu)
|
||||
context_null = text_encoder([n_prompt], gpu)
|
||||
|
||||
# vace context encode
|
||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae)
|
||||
m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride)
|
||||
z = self.vace_latent(z0, m0)
|
||||
|
||||
target_shape = list(z0[0].shape)
|
||||
target_shape[0] = int(target_shape[0] / 2)
|
||||
noise = [
|
||||
torch.randn(
|
||||
target_shape[0],
|
||||
target_shape[1],
|
||||
target_shape[2],
|
||||
target_shape[3],
|
||||
dtype=torch.float32,
|
||||
device=gpu,
|
||||
generator=seed_g)
|
||||
]
|
||||
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
||||
(patch_size[1] * patch_size[2]) *
|
||||
target_shape[1] / sp_size) * sp_size
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
|
||||
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
||||
|
||||
# evaluation mode
|
||||
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
|
||||
|
||||
if sample_solver == 'unipc':
|
||||
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False)
|
||||
sample_scheduler.set_timesteps(
|
||||
sampling_steps, device=gpu, shift=shift)
|
||||
timesteps = sample_scheduler.timesteps
|
||||
elif sample_solver == 'dpm++':
|
||||
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False)
|
||||
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
||||
timesteps, _ = retrieve_timesteps(
|
||||
sample_scheduler,
|
||||
device=gpu,
|
||||
sigmas=sampling_sigmas)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported solver.")
|
||||
|
||||
# sample videos
|
||||
latents = noise
|
||||
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
timestep = [t]
|
||||
|
||||
timestep = torch.stack(timestep)
|
||||
|
||||
model.to(gpu)
|
||||
noise_pred_cond = model(
|
||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[
|
||||
0]
|
||||
noise_pred_uncond = model(
|
||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,
|
||||
**arg_null)[0]
|
||||
|
||||
noise_pred = noise_pred_uncond + guide_scale * (
|
||||
noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
temp_x0 = sample_scheduler.step(
|
||||
noise_pred.unsqueeze(0),
|
||||
t,
|
||||
latents[0].unsqueeze(0),
|
||||
return_dict=False,
|
||||
generator=seed_g)[0]
|
||||
latents = [temp_x0.squeeze(0)]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
x0 = latents
|
||||
if rank == 0:
|
||||
videos = self.decode_latent(x0, input_ref_images, vae=vae)
|
||||
|
||||
del noise, latents
|
||||
del sample_scheduler
|
||||
if offload_model:
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
out_q.put(videos[0].cpu())
|
||||
|
||||
except Exception as e:
|
||||
trace_info = traceback.format_exc()
|
||||
print(trace_info, flush=True)
|
||||
print(e, flush=True)
|
||||
|
||||
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
input_frames,
|
||||
input_masks,
|
||||
input_ref_images,
|
||||
size=(1280, 720),
|
||||
frame_num=81,
|
||||
context_scale=1.0,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=50,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True):
|
||||
|
||||
input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale,
|
||||
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model)
|
||||
for in_q in self.in_q_list:
|
||||
in_q.put(input_data)
|
||||
value_output = self.out_q.get()
|
||||
|
||||
return value_output
|
Loading…
Reference in New Issue
Block a user