Compare commits

...

9 Commits

Author SHA1 Message Date
Yexiong Lin
e8f3ff1d0e
Merge 36d6d91b90 into c709fcf0e7 2025-05-14 21:39:45 +07:00
Zhen Han
c709fcf0e7
fix vace size (#397) 2025-05-14 22:01:45 +08:00
Ang Wang
18d53feb7a
[feature] Add VACE (#389)
* Add VACE

* Support training with multiple gpus

* Update default args for vace task

* vace block update

* Add vace exmaple jpg

* Fix dist vace fwd hook error

* Update vace exmample

* Update vace args

* Update pipeline name for vace

* vace gradio and Readme

* Update vace snake png

---------

Co-authored-by: hanzhn <han.feng.jason@gmail.com>
2025-05-14 20:44:25 +08:00
Yexiong Lin
36d6d91b90 update text2video.py 2025-03-04 19:54:56 +11:00
Yexiong Lin
1c7b73d13e Add the support for fp8 t5 2025-03-04 19:54:56 +11:00
Yexiong Lin
db54b7c613 Update README.md and text2video.py to offload model and enable using fp8 2025-03-04 19:54:56 +11:00
Yexiong Lin
24007c2c39 support fp8 model 2025-03-04 19:54:56 +11:00
Yexiong Lin
bc2aff711e support fp8 modell 2025-03-04 19:54:56 +11:00
Yexiong Lin
bebb16bb8e 支持Kijai的fp8模型 2025-03-04 19:54:56 +11:00
23 changed files with 1957 additions and 45 deletions

111
README.md
View File

@ -27,6 +27,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* May 14, 2025: 👋 We introduce **Wan2.1** [VACE](https://github.com/ali-vilab/VACE), an all-in-one model for video creation and editing, along with its [inference code](#run-vace), [weights](#model-download), and [technical report](https://arxiv.org/abs/2503.07598)!
* Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights!
* Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback!
* Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try!
@ -64,7 +65,13 @@ If your work has improved **Wan2.1** and you would like more people to see it, p
- [ ] ComfyUI integration
- [ ] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
- Wan2.1 VACE
- [x] Multi-GPU Inference code of the 14B and 1.3B models
- [x] Checkpoints of the 14B and 1.3B models
- [x] Gradio demo
- [x] ComfyUI integration
- [ ] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
## Quickstart
@ -84,13 +91,15 @@ pip install -r requirements.txt
#### Model Download
| Models | Download Link | Notes |
|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
| Models | Download Link | Notes |
|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
| VACE-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | Supports 480P
| VACE-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | Supports both 480P and 720P
> 💡Note:
> * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
@ -157,6 +166,14 @@ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model Tr
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
```
You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.
Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file:
```
t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors'
```
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
@ -293,6 +310,17 @@ Similar to Text-to-Video, Image-to-Video is also divided into processes with and
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
To minimize GPU memory usage, you can enable model offloading with `--offload_model True` and use FP8 precision with `--fp8`.
For example, to run **Wan2.1-I2V-14B-480P** on an RTX 4090 GPU:
1. First, download the [FP8 model weights](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors) and place them in the `Wan2.1-I2V-14B-480P` folder.
2. Then, execute the following command:
```
python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --offload_model True --fp8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
```
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
@ -448,6 +476,73 @@ DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dash
```
#### Run VACE
[VACE](https://github.com/ali-vilab/VACE) now supports two models (1.3B and 14B) and two main resolutions (480P and 720P).
The input supports any resolution, but to achieve optimal results, the video size should fall within a specific range.
The parameters and configurations for these models are as follows:
<table>
<thead>
<tr>
<th rowspan="2">Task</th>
<th colspan="2">Resolution</th>
<th rowspan="2">Model</th>
</tr>
<tr>
<th>480P(~81x480x832)</th>
<th>720P(~81x720x1280)</th>
</tr>
</thead>
<tbody>
<tr>
<td>VACE</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td>Wan2.1-VACE-14B</td>
</tr>
<tr>
<td>VACE</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td style="color: red; text-align: center; vertical-align: middle;"></td>
<td>Wan2.1-VACE-1.3B</td>
</tr>
</tbody>
</table>
In VACE, users can input text prompt and optional video, mask, and image for video generation or editing. Detailed instructions for using VACE can be found in the [User Guide](https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).
The execution process is as follows:
##### (1) Preprocessing
User-collected materials needs to be preprocessed into VACE-recognizable inputs, including `src_video`, `src_mask`, `src_ref_images`, and `prompt`.
For R2V (Reference-to-Video Generation), you may skip this preprocessing, but for V2V (Video-to-Video Editing) and MV2V (Masked Video-to-Video Editing) tasks, additional preprocessing is required to obtain video with conditions such as depth, pose or masked regions.
For more details, please refer to [vace_preproccess](https://github.com/ali-vilab/VACE/blob/main/vace/vace_preproccess.py).
##### (2) cli inference
- Single-GPU inference
```sh
python generate.py --task vace-1.3B --size 832*480 --ckpt_dir ./Wan2.1-VACE-1.3B --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
```
- Multi-GPU inference using FSDP + xDiT USP
```sh
torchrun --nproc_per_node=8 generate.py --task vace-14B --size 1280*720 --ckpt_dir ./Wan2.1-VACE-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
```
##### (3) Running local gradio
- Single-GPU inference
```sh
python gradio/vace.py --ckpt_dir ./Wan2.1-VACE-1.3B
```
- Multi-GPU inference using FSDP + xDiT USP
```sh
python gradio/vace.py --mp --ulysses_size 8 --ckpt_dir ./Wan2.1-VACE-14B/
```
#### Run Text-to-Image Generation
Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:

BIN
examples/girl.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 817 KiB

BIN
examples/snake.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 435 KiB

View File

@ -41,6 +41,14 @@ EXAMPLE_PROMPT = {
"last_frame":
"examples/flf2v_input_last_frame.png",
},
"vace-1.3B": {
"src_ref_images": 'examples/girl.png,examples/snake.png',
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
},
"vace-14B": {
"src_ref_images": 'examples/girl.png,examples/snake.png',
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
}
}
@ -52,15 +60,19 @@ def _validate_args(args):
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
if args.sample_steps is None:
args.sample_steps = 40 if "i2v" in args.task else 50
args.sample_steps = 50
if "i2v" in args.task:
args.sample_steps = 40
if args.sample_shift is None:
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
if "flf2v" in args.task:
elif "flf2v" in args.task or "vace" in args.task:
args.sample_shift = 16
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
args.frame_num = 1 if "t2i" in args.task else 81
@ -136,11 +148,31 @@ def _parse_args():
action="store_true",
default=False,
help="Whether to use FSDP for DiT.")
parser.add_argument(
"--fp8",
action="store_true",
default=False,
help="Whether to use fp8.")
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated image or video to.")
parser.add_argument(
"--src_video",
type=str,
default=None,
help="The file of the source video. Default None.")
parser.add_argument(
"--src_mask",
type=str,
default=None,
help="The file of the source mask. Default None.")
parser.add_argument(
"--src_ref_images",
type=str,
default=None,
help="The file list of the source reference images. Separated by ','. Default None.")
parser.add_argument(
"--prompt",
type=str,
@ -326,6 +358,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
fp8=args.fp8,
)
logging.info(
@ -383,6 +416,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
fp8=args.fp8,
)
logging.info("Generating video ...")
@ -397,7 +431,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
elif "flf2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.first_frame is None or args.last_frame is None:
@ -457,6 +491,60 @@ def generate(args):
seed=args.base_seed,
offload_model=args.offload_model
)
elif "vace" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None)
logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
logging.info("Extending prompt ...")
if rank == 0:
prompt = prompt_expander.forward(args.prompt)
logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'")
input_prompt = [prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating VACE pipeline.")
wan_vace = wan.WanVace(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video],
[args.src_mask],
[None if args.src_ref_images is None else args.src_ref_images.split(',')],
args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info(f"Generating video...")
video = wan_vace.generate(
args.prompt,
src_video,
src_mask,
src_ref_images,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
raise ValueError(f"Unkown task type: {args.task}")
if rank == 0:
if args.save_file is None:

295
gradio/vace.py Normal file
View File

@ -0,0 +1,295 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
import sys
import datetime
import imageio
import numpy as np
import torch
import gradio as gr
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan import WanVace, WanVaceMP
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
class FixedSizeQueue:
def __init__(self, max_size):
self.max_size = max_size
self.queue = []
def add(self, item):
self.queue.insert(0, item)
if len(self.queue) > self.max_size:
self.queue.pop()
def get(self):
return self.queue
def __repr__(self):
return str(self.queue)
class VACEInference:
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
self.cfg = cfg
self.save_dir = cfg.save_dir
self.gallery_share = gallery_share
self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
if not skip_load:
if not args.mp:
self.pipe = WanVace(
config=WAN_CONFIGS[cfg.model_name],
checkpoint_dir=cfg.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
else:
self.pipe = WanVaceMP(
config=WAN_CONFIGS[cfg.model_name],
checkpoint_dir=cfg.ckpt_dir,
use_usp=True,
ulysses_size=cfg.ulysses_size,
ring_size=cfg.ring_size
)
def create_ui(self, *args, **kwargs):
gr.Markdown("""
<div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
<a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
</div>
""")
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
self.src_video = gr.Video(
label="src_video",
sources=['upload'],
value=None,
interactive=True)
with gr.Column(scale=1, min_width=0):
self.src_mask = gr.Video(
label="src_mask",
sources=['upload'],
value=None,
interactive=True)
#
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_1",
format='png')
self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_2",
format='png')
self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_3",
format='png')
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1):
self.prompt = gr.Textbox(
show_label=False,
placeholder="positive_prompt_input",
elem_id='positive_prompt',
container=True,
autofocus=True,
elem_classes='type_row',
visible=True,
lines=2)
self.negative_prompt = gr.Textbox(
show_label=False,
value=self.pipe.config.sample_neg_prompt,
placeholder="negative_prompt_input",
elem_id='negative_prompt',
container=True,
autofocus=False,
elem_classes='type_row',
visible=True,
interactive=True,
lines=1)
#
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.shift_scale = gr.Slider(
label='shift_scale',
minimum=0.0,
maximum=100.0,
step=1.0,
value=16.0,
interactive=True)
self.sample_steps = gr.Slider(
label='sample_steps',
minimum=1,
maximum=100,
step=1,
value=25,
interactive=True)
self.context_scale = gr.Slider(
label='context_scale',
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
interactive=True)
self.guide_scale = gr.Slider(
label='guide_scale',
minimum=1,
maximum=10,
step=0.5,
value=5.0,
interactive=True)
self.infer_seed = gr.Slider(minimum=-1,
maximum=10000000,
value=2025,
label="Seed")
#
with gr.Accordion(label="Usable without source video", open=False):
with gr.Row(equal_height=True):
self.output_height = gr.Textbox(
label='resolutions_height',
# value=480,
value=720,
interactive=True)
self.output_width = gr.Textbox(
label='resolutions_width',
# value=832,
value=1280,
interactive=True)
self.frame_rate = gr.Textbox(
label='frame_rate',
value=16,
interactive=True)
self.num_frames = gr.Textbox(
label='num_frames',
value=81,
interactive=True)
#
with gr.Row(equal_height=True):
with gr.Column(scale=5):
self.generate_button = gr.Button(
value='Run',
elem_classes='type_row',
elem_id='generate_button',
visible=True)
with gr.Column(scale=1):
self.refresh_button = gr.Button(value='\U0001f504') # 🔄
#
self.output_gallery = gr.Gallery(
label="output_gallery",
value=[],
interactive=False,
allow_preview=True,
preview=True)
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
x is not None]
src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
[src_mask],
[src_ref_images],
num_frames=num_frames,
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
device=self.pipe.device)
video = self.pipe.generate(
prompt,
src_video,
src_mask,
src_ref_images,
size=(output_width, output_height),
context_scale=context_scale,
shift=shift_scale,
sampling_steps=sample_steps,
guide_scale=guide_scale,
n_prompt=negative_prompt,
seed=infer_seed,
offload_model=True)
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
try:
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
for frame in video_frames:
writer.append_data(frame)
writer.close()
print(video_path)
except Exception as e:
raise gr.Error(f"Video save error: {e}")
if self.gallery_share:
self.gallery_share_data.add(video_path)
return self.gallery_share_data.get()
else:
return [video_path]
def set_callbacks(self, **kwargs):
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
self.gen_outputs = [self.output_gallery]
self.generate_button.click(self.generate,
inputs=self.gen_inputs,
outputs=self.gen_outputs,
queue=True)
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n')
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
parser.add_argument('--root_path', dest='root_path', help='', default=None)
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",)
parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.")
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--ckpt_dir",
type=str,
# default='models/VACE-Wan2.1-1.3B-Preview',
default='models/Wan2.1-VACE-14B/',
help="The path to the checkpoint directory.",
)
parser.add_argument(
"--offload_to_cpu",
action="store_true",
help="Offloading unnecessary computations to CPU.",
)
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir, exist_ok=True)
with gr.Blocks() as demo:
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
infer_gr.create_ui()
infer_gr.set_callbacks()
allowed_paths = [args.save_dir]
demo.queue(status_update_rate=1).launch(server_name=args.server_name,
server_port=args.server_port,
root_path=args.root_path,
allowed_paths=allowed_paths,
show_error=True, debug=True)

View File

@ -105,9 +105,16 @@ function i2v_14B_720p() {
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
}
function vace_1_3B() {
VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/"
torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR
}
t2i_14B
t2v_1_3B
t2v_14B
i2v_14B_480p
i2v_14B_720p
vace_1_3B

View File

@ -2,3 +2,4 @@ from . import configs, distributed, modules
from .image2video import WanI2V
from .text2video import WanT2V
from .first_last_frame2video import WanFLF2V
from .vace import WanVace, WanVaceMP

View File

@ -22,7 +22,9 @@ WAN_CONFIGS = {
't2v-1.3B': t2v_1_3B,
'i2v-14B': i2v_14B,
't2i-14B': t2i_14B,
'flf2v-14B': flf2v_14B
'flf2v-14B': flf2v_14B,
'vace-1.3B': t2v_1_3B,
'vace-14B': t2v_14B,
}
SIZE_CONFIGS = {
@ -46,4 +48,6 @@ SUPPORTED_SIZES = {
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2i-14B': tuple(SIZE_CONFIGS.keys()),
'vace-1.3B': ('480*832', '832*480'),
'vace-14B': ('720*1280', '1280*720', '480*832', '832*480')
}

View File

@ -11,12 +11,14 @@ i2v_14B.update(wan_shared_cfg)
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
# clip
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
i2v_14B.clip_dtype = torch.float16
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
# i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
# vae

View File

@ -10,6 +10,7 @@ t2v_14B.update(wan_shared_cfg)
# t5
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
# vae

View File

@ -10,6 +10,7 @@ t2v_1_3B.update(wan_shared_cfg)
# t5
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
# vae

View File

@ -63,12 +63,45 @@ def rope_apply(x, grid_sizes, freqs):
return torch.stack(output).float()
def usp_dit_forward_vace(
self,
x,
vace_context,
seq_len,
kwargs
):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
# Context Parallel
c = torch.chunk(
c, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
hints = []
for block in self.vace_blocks:
c, c_skip = block(c, **new_kwargs)
hints.append(c_skip)
return hints
def usp_dit_forward(
self,
x,
t,
context,
seq_len,
vace_context=None,
vace_context_scale=1.0,
clip_fea=None,
y=None,
):
@ -84,7 +117,7 @@ def usp_dit_forward(
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
if self.model_type != 'vace' and y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
@ -114,7 +147,7 @@ def usp_dit_forward(
for u in context
]))
if clip_fea is not None:
if self.model_type != 'vace' and clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
@ -132,6 +165,11 @@ def usp_dit_forward(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
if self.model_type == 'vace':
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
kwargs['hints'] = hints
kwargs['context_scale'] = vace_context_scale
for block in self.blocks:
x = block(x, **kwargs)

View File

@ -16,6 +16,10 @@ import torch.distributed as dist
import torchvision.transforms.functional as TF
from tqdm import tqdm
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file
from .distributed.fsdp import shard_model
from .modules.clip import CLIPModel
from .modules.model import WanModel
@ -39,6 +43,7 @@ class WanI2V:
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
fp8=False,
):
r"""
Initializes the image-to-video generation model components.
@ -62,6 +67,8 @@ class WanI2V:
Whether to place T5 model on CPU. Only works without t5_fsdp.
init_on_cpu (`bool`, *optional*, defaults to True):
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
fp8 (`bool`, *optional*, defaults to False):
Enable 8-bit floating point precision for model parameters.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
@ -73,6 +80,10 @@ class WanI2V:
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
quantization = "fp8_e4m3fn"
else:
quantization = "disabled"
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
@ -80,10 +91,12 @@ class WanI2V:
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None,
quantization=quantization,
)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
@ -96,7 +109,46 @@ class WanI2V:
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
if not fp8:
self.model = WanModel.from_pretrained(checkpoint_dir)
else:
if '480P' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
elif '720P' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu")
dim = state_dict["patch_embedding.weight"].shape[0]
in_channels = state_dict["patch_embedding.weight"].shape[1]
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
model_type = "i2v" if in_channels == 36 else "t2v"
num_heads = 40 if dim == 5120 else 12
num_layers = 40 if dim == 5120 else 30
TRANSFORMER_CONFIG= {
"dim": dim,
"ffn_dim": ffn_dim,
"eps": 1e-06,
"freq_dim": 256,
"in_dim": in_channels,
"model_type": model_type,
"out_dim": 16,
"text_len": 512,
"num_heads": num_heads,
"num_layers": num_layers,
}
with init_empty_weights():
self.model = WanModel(**TRANSFORMER_CONFIG)
base_dtype=torch.bfloat16
dtype=torch.float8_e4m3fn
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
for name, param in self.model.named_parameters():
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
# dtype_to_use = torch.bfloat16
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
del state_dict
self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
@ -219,13 +271,15 @@ class WanI2V:
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
@ -242,9 +296,12 @@ class WanI2V:
torch.zeros(3, F - 1, h, w)
],
dim=1).to(self.device)
])[0]
],device=self.device)[0]
y = torch.concat([msk, y])
if offload_model:
self.vae.model.cpu()
@contextmanager
def noop_no_sync():
yield
@ -332,9 +389,11 @@ class WanI2V:
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
# load vae model back to device
self.vae.model.to(self.device)
if self.rank == 0:
videos = self.vae.decode(x0)
videos = self.vae.decode(x0, device=self.device)
del noise, latent
del sample_scheduler

View File

@ -2,11 +2,13 @@ from .attention import flash_attention
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vace_model import VaceWanModel
from .vae import WanVAE
__all__ = [
'WanVAE',
'WanModel',
'VaceWanModel',
'T5Model',
'T5Encoder',
'T5Decoder',

View File

@ -7,6 +7,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from safetensors.torch import load_file
from .attention import flash_attention
from .tokenizers import HuggingfaceTokenizer
@ -515,8 +516,13 @@ class CLIPModel:
device=device)
self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
self.model.load_state_dict(
torch.load(checkpoint_path, map_location='cpu'))
if checkpoint_path.endswith('.safetensors'):
state_dict = load_file(checkpoint_path, device='cpu')
self.model.load_state_dict(state_dict)
elif checkpoint_path.endswith('.pth'):
self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
else:
raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}')
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(

View File

@ -400,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin):
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
@ -433,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin):
super().__init__()
assert model_type in ['t2v', 'i2v', 'flf2v']
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
self.model_type = model_type
self.patch_size = patch_size

View File

@ -9,6 +9,10 @@ import torch.nn.functional as F
from .tokenizers import HuggingfaceTokenizer
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file
__all__ = [
'T5Model',
'T5Encoder',
@ -442,7 +446,7 @@ def _t5(name,
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
# model = model.to(dtype=dtype, device=device)
# init tokenizer
if return_tokenizer:
@ -479,6 +483,7 @@ class T5EncoderModel:
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,
quantization="disabled",
):
self.text_len = text_len
self.dtype = dtype
@ -486,14 +491,31 @@ class T5EncoderModel:
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}')
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
if quantization == "disabled":
# init model
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
elif quantization == "fp8_e4m3fn":
with init_empty_weights():
model = umt5_xxl(
encoder_only=True,
return_tokenizer=False,
dtype=dtype,
device=device).eval().requires_grad_(False)
cast_dtype = torch.float8_e4m3fn
state_dict = load_file(checkpoint_path, device="cpu")
params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
for name, param in model.named_parameters():
dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name])
del state_dict
self.model = model
if shard_fn is not None:
self.model = shard_fn(self.model, sync_module_states=False)

233
wan/modules/vace_model.py Normal file
View File

@ -0,0 +1,233 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import register_to_config
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
class VaceWanAttentionBlock(WanAttentionBlock):
def __init__(
self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=0
):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.before_proj.weight)
nn.init.zeros_(self.before_proj.bias)
self.after_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
return c, c_skip
class BaseWanAttentionBlock(WanAttentionBlock):
def __init__(
self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=None
):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
self.block_id = block_id
def forward(self, x, hints, context_scale=1.0, **kwargs):
x = super().forward(x, **kwargs)
if self.block_id is not None:
x = x + hints[self.block_id] * context_scale
return x
class VaceWanModel(WanModel):
@register_to_config
def __init__(self,
vace_layers=None,
vace_in_dim=None,
model_type='vace',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6):
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
assert 0 in self.vace_layers
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
# blocks
self.blocks = nn.ModuleList([
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps,
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
for i in range(self.num_layers)
])
# vace blocks
self.vace_blocks = nn.ModuleList([
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps, block_id=i)
for i in self.vace_layers
])
# vace patch embeddings
self.vace_patch_embedding = nn.Conv3d(
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
)
def forward_vace(
self,
x,
vace_context,
seq_len,
kwargs
):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
hints = []
for block in self.vace_blocks:
c, c_skip = block(c, **new_kwargs)
hints.append(c_skip)
return hints
def forward(
self,
x,
t,
vace_context,
context,
seq_len,
vace_context_scale=1.0,
clip_fea=None,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
# if self.model_type == 'i2v':
# assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
# if y is not None:
# x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
# if clip_fea is not None:
# context_clip = self.img_emb(clip_fea) # bs x 257 x dim
# context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
kwargs['hints'] = hints
kwargs['context_scale'] = vace_context_scale
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]

View File

@ -644,7 +644,7 @@ class WanVAE:
z_dim=z_dim,
).eval().requires_grad_(False).to(device)
def encode(self, videos):
def encode(self, videos, device=None):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
@ -654,7 +654,7 @@ class WanVAE:
for u in videos
]
def decode(self, zs):
def decode(self, zs, device=None):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),

View File

@ -14,6 +14,10 @@ import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from safetensors.torch import load_file
from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
@ -35,6 +39,8 @@ class WanT2V:
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
fp8=False,
):
r"""
Initializes the Wan text-to-video generation model components.
@ -56,6 +62,8 @@ class WanT2V:
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
fp8 (`bool`, *optional*, defaults to False):
Enable 8-bit floating point precision for model parameters.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
@ -66,13 +74,19 @@ class WanT2V:
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
quantization = "fp8_e4m3fn"
else:
quantization = "disabled"
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None)
shard_fn=shard_fn if t5_fsdp else None,
quantization=quantization)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
@ -81,9 +95,52 @@ class WanT2V:
device=self.device)
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
if not fp8:
self.model = WanModel.from_pretrained(checkpoint_dir)
else:
if '14B' in checkpoint_dir:
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-14B_fp8_e4m3fn.safetensors', device="cpu")
else:
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors', device="cpu")
dim = state_dict["patch_embedding.weight"].shape[0]
in_channels = state_dict["patch_embedding.weight"].shape[1]
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
model_type = "i2v" if in_channels == 36 else "t2v"
num_heads = 40 if dim == 5120 else 12
num_layers = 40 if dim == 5120 else 30
TRANSFORMER_CONFIG= {
"dim": dim,
"ffn_dim": ffn_dim,
"eps": 1e-06,
"freq_dim": 256,
"in_dim": in_channels,
"model_type": model_type,
"out_dim": 16,
"text_len": 512,
"num_heads": num_heads,
"num_layers": num_layers,
}
with init_empty_weights():
self.model = WanModel(**TRANSFORMER_CONFIG)
base_dtype=torch.bfloat16
dtype=torch.float8_e4m3fn
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
for name, param in self.model.named_parameters():
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
# dtype_to_use = torch.bfloat16
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
del state_dict
self.model.eval().requires_grad_(False)
if t5_fsdp or dit_fsdp or use_usp:
init_on_cpu = False
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
@ -103,7 +160,8 @@ class WanT2V:
if dit_fsdp:
self.model = shard_fn(self.model)
else:
self.model.to(self.device)
if not init_on_cpu:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
@ -169,13 +227,15 @@ class WanT2V:
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
@ -190,6 +250,9 @@ class WanT2V:
generator=seed_g)
]
if offload_model:
self.vae.model.cpu()
@contextmanager
def noop_no_sync():
yield
@ -226,13 +289,15 @@ class WanT2V:
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
if offload_model:
torch.cuda.empty_cache()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
@ -253,6 +318,9 @@ class WanT2V:
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
# load vae model back to device
self.vae.model.to(self.device)
if self.rank == 0:
videos = self.vae.decode(x0)

View File

@ -1,8 +1,10 @@
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
retrieve_timesteps)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .vace_processor import VaceVideoProcessor
__all__ = [
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
'VaceVideoProcessor'
]

270
wan/utils/vace_processor.py Normal file
View File

@ -0,0 +1,270 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
class VaceImageProcessor(object):
def __init__(self, downsample=None, seq_len=None):
self.downsample = downsample
self.seq_len = seq_len
def _pillow_convert(self, image, cvt_type='RGB'):
if image.mode != cvt_type:
if image.mode == 'P':
image = image.convert(f'{cvt_type}A')
if image.mode == f'{cvt_type}A':
bg = Image.new(cvt_type,
size=(image.width, image.height),
color=(255, 255, 255))
bg.paste(image, (0, 0), mask=image)
image = bg
else:
image = image.convert(cvt_type)
return image
def _load_image(self, img_path):
if img_path is None or img_path == '':
return None
img = Image.open(img_path)
img = self._pillow_convert(img)
return img
def _resize_crop(self, img, oh, ow, normalize=True):
"""
Resize, center crop, convert to tensor, and normalize.
"""
# resize and crop
iw, ih = img.size
if iw != ow or ih != oh:
# resize
scale = max(ow / iw, oh / ih)
img = img.resize(
(round(scale * iw), round(scale * ih)),
resample=Image.Resampling.LANCZOS
)
assert img.width >= ow and img.height >= oh
# center crop
x1 = (img.width - ow) // 2
y1 = (img.height - oh) // 2
img = img.crop((x1, y1, x1 + ow, y1 + oh))
# normalize
if normalize:
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
return img
def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
return self._resize_crop(img, oh, ow, normalize)
def load_image(self, data_key, **kwargs):
return self.load_image_batch(data_key, **kwargs)
def load_image_pair(self, data_key, data_key2, **kwargs):
return self.load_image_batch(data_key, data_key2, **kwargs)
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
seq_len = self.seq_len if seq_len is None else seq_len
imgs = []
for data_key in data_key_batch:
img = self._load_image(data_key)
imgs.append(img)
w, h = imgs[0].size
dh, dw = self.downsample[1:]
# compute output size
scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
oh = int(h * scale) // dh * dh
ow = int(w * scale) // dw * dw
assert (oh // dh) * (ow // dw) <= seq_len
imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
return *imgs, (oh, ow)
class VaceVideoProcessor(object):
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
self.downsample = downsample
self.min_area = min_area
self.max_area = max_area
self.min_fps = min_fps
self.max_fps = max_fps
self.zero_start = zero_start
self.keep_last = keep_last
self.seq_len = seq_len
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
def set_area(self, area):
self.min_area = area
self.max_area = area
def set_seq_len(self, seq_len):
self.seq_len = seq_len
@staticmethod
def resize_crop(video: torch.Tensor, oh: int, ow: int):
"""
Resize, center crop and normalize for decord loaded video (torch.Tensor type)
Parameters:
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
oh - target height (int)
ow - target width (int)
Returns:
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
Raises:
"""
# permute ([t, h, w, c] -> [t, c, h, w])
video = video.permute(0, 3, 1, 2)
# resize and crop
ih, iw = video.shape[2:]
if ih != oh or iw != ow:
# resize
scale = max(ow / iw, oh / ih)
video = F.interpolate(
video,
size=(round(scale * ih), round(scale * iw)),
mode='bicubic',
antialias=True
)
assert video.size(3) >= ow and video.size(2) >= oh
# center crop
x1 = (video.size(3) - ow) // 2
y1 = (video.size(2) - oh) // 2
video = video[:, :, y1:y1 + oh, x1:x1 + ow]
# permute ([t, c, h, w] -> [c, t, h, w]) and normalize
video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
return video
def _video_preprocess(self, video, oh, ow):
return self.resize_crop(video, oh, ow)
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
target_fps = min(fps, self.max_fps)
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
of = min(
(int(duration * target_fps) - 1) // df + 1,
int(self.seq_len / area_z)
)
# deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of))
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
# sample frame ids
target_duration = of / target_fps
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
timestamps = np.linspace(begin, begin + target_duration, of)
frame_ids = np.argmax(np.logical_and(
timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] < frame_timestamps[None, :, 1]
), axis=1).tolist()
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng):
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
of = min(
(len(frame_timestamps) - 1) // df + 1,
int(self.seq_len / area_z)
)
# deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of))
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
# sample frame ids
target_duration = duration
target_fps = of / target_duration
timestamps = np.linspace(0., target_duration, of)
frame_ids = np.argmax(np.logical_and(
timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] <= frame_timestamps[None, :, 1]
), axis=1).tolist()
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng)
else:
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng)
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
decord.bridge.set_bridge('torch')
readers = []
for data_k in data_key_batch:
reader = decord.VideoReader(data_k)
readers.append(reader)
fps = readers[0].get_avg_fps()
length = min([len(r) for r in readers])
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng)
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
videos = [self._video_preprocess(video, oh, ow) for video in videos]
return *videos, frame_ids, (oh, ow), fps
# return videos if len(videos) > 1 else videos[0]
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_video is None and sub_src_mask is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
for j, ref_img in enumerate(ref_images):
if ref_img is not None and ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
src_ref_images[i][j] = white_canvas
return src_video, src_mask, src_ref_images

718
wan/vace.py Normal file
View File

@ -0,0 +1,718 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import sys
import gc
import math
import time
import random
import types
import logging
import traceback
from contextlib import contextmanager
from functools import partial
from PIL import Image
import torchvision.transforms.functional as TF
import torch
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.multiprocessing as mp
from tqdm import tqdm
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
from .modules.vace_model import VaceWanModel
from .utils.vace_processor import VaceVideoProcessor
class WanVace(WanT2V):
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
):
r"""
Initializes the Wan text-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_usp (`bool`, *optional*, defaults to False):
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
self.model = VaceWanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False)
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
for block in self.model.vace_blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
self.model = shard_fn(self.model)
else:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=720*1280,
max_area=720*1280,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=75600,
keep_last=True)
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
vae = self.vae if vae is None else vae
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
if masks is None:
latents = vae.encode(frames)
else:
masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = vae.encode(inactive)
reactive = vae.encode(reactive)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = vae.encode(refs)
else:
ref_latent = vae.encode(refs)
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
return cat_latents
def vace_encode_masks(self, masks, ref_images=None, vae_stride=None):
vae_stride = self.vae_stride if vae_stride is None else vae_stride
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
new_depth = int((depth + 3) // vae_stride[0])
height = 2 * (int(height) // (vae_stride[1] * 2))
width = 2 * (int(width) // (vae_stride[2] * 2))
# reshape
mask = mask[0, :, :, :]
mask = mask.view(
depth, height, vae_stride[1], width, vae_stride[1]
) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(
vae_stride[1] * vae_stride[2], depth, height, width
) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
if refs is not None:
length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :])
mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask)
return result_masks
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device):
area = image_size[0] * image_size[1]
self.vid_proc.set_area(area)
if area == 720*1280:
self.vid_proc.set_seq_len(75600)
elif area == 480*832:
self.vid_proc.set_seq_len(32760)
else:
raise NotImplementedError(f'image_size {image_size} is not supported')
image_size = (image_size[1], image_size[0])
image_sizes = []
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(src_video[i].shape[2:])
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
ref_img = Image.open(ref_img).convert("RGB")
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, vae=None):
vae = self.vae if vae is None else vae
if ref_images is None:
ref_images = [None] * len(zs)
else:
assert len(zs) == len(ref_images)
trimed_zs = []
for z, refs in zip(zs, ref_images):
if refs is not None:
z = z[:, len(refs):, :, :]
trimed_zs.append(z)
return vae.decode(trimed_zs)
def generate(self,
input_prompt,
input_frames,
input_masks,
input_ref_images,
size=(1280, 720),
frame_num=81,
context_scale=1.0,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tupele[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
# F = frame_num
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
# size[1] // self.vae_stride[1],
# size[0] // self.vae_stride[2])
#
# seq_len = math.ceil((target_shape[2] * target_shape[3]) /
# (self.patch_size[1] * self.patch_size[2]) *
# target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
# vace context encode
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks)
m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0]
noise_pred_uncond = self.model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.decode_latent(x0, input_ref_images)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
class WanVaceMP(WanVace):
def __init__(
self,
config,
checkpoint_dir,
use_usp=False,
ulysses_size=None,
ring_size=None
):
self.config = config
self.checkpoint_dir = checkpoint_dir
self.use_usp = use_usp
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
self.in_q_list = None
self.out_q = None
self.inference_pids = None
self.ulysses_size = ulysses_size
self.ring_size = ring_size
self.dynamic_load()
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
self.vid_proc = VaceVideoProcessor(
downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]),
min_area=480 * 832,
max_area=480 * 832,
min_fps=self.config.sample_fps,
max_fps=self.config.sample_fps,
zero_start=True,
seq_len=32760,
keep_last=True)
def dynamic_load(self):
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
return
gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count()
pmi_rank = int(os.environ['RANK'])
pmi_world_size = int(os.environ['WORLD_SIZE'])
in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)]
out_q = torch.multiprocessing.Manager().Queue()
initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)]
context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False)
all_initialized = False
while not all_initialized:
all_initialized = all(event.is_set() for event in initialized_events)
if not all_initialized:
time.sleep(0.1)
print('Inference model is initialized', flush=True)
self.in_q_list = in_q_list
self.out_q = out_q
self.inference_pids = context.pids()
self.initialized_events = initialized_events
def transfer_data_to_cuda(self, data, device):
if data is None:
return None
else:
if isinstance(data, torch.Tensor):
data = data.to(device)
elif isinstance(data, list):
data = [self.transfer_data_to_cuda(subdata, device) for subdata in data]
elif isinstance(data, dict):
data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()}
return data
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env):
try:
world_size = pmi_world_size * gpu_infer
rank = pmi_rank * gpu_infer + gpu
print("world_size", world_size, "rank", rank, flush=True)
torch.cuda.set_device(gpu)
dist.init_process_group(
backend='nccl',
init_method='env://',
rank=rank,
world_size=world_size
)
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=self.ring_size or 1,
ulysses_degree=self.ulysses_size or 1
)
num_train_timesteps = self.config.num_train_timesteps
param_dtype = self.config.param_dtype
shard_fn = partial(shard_model, device_id=gpu)
text_encoder = T5EncoderModel(
text_len=self.config.text_len,
dtype=self.config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint),
tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer),
shard_fn=shard_fn if True else None)
text_encoder.model.to(gpu)
vae_stride = self.config.vae_stride
patch_size = self.config.patch_size
vae = WanVAE(
vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint),
device=gpu)
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
model.eval().requires_grad_(False)
if self.use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace)
for block in model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
for block in model.vace_blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
model.forward = types.MethodType(usp_dit_forward, model)
model.forward_vace = types.MethodType(usp_dit_forward_vace, model)
sp_size = get_sequence_parallel_world_size()
else:
sp_size = 1
dist.barrier()
model = shard_fn(model)
sample_neg_prompt = self.config.sample_neg_prompt
torch.cuda.empty_cache()
event = initialized_events[gpu]
in_q = in_q_list[gpu]
event.set()
while True:
item = in_q.get()
input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu)
if n_prompt == "":
n_prompt = sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=gpu)
seed_g.manual_seed(seed)
context = text_encoder([input_prompt], gpu)
context_null = text_encoder([n_prompt], gpu)
# vace context encode
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae)
m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=gpu,
generator=seed_g)
]
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(patch_size[1] * patch_size[2]) *
target_shape[1] / sp_size) * sp_size
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=gpu, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=gpu,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
model.to(gpu)
noise_pred_cond = model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[
0]
noise_pred_uncond = model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,
**arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
torch.cuda.empty_cache()
x0 = latents
if rank == 0:
videos = self.decode_latent(x0, input_ref_images, vae=vae)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
if rank == 0:
out_q.put(videos[0].cpu())
except Exception as e:
trace_info = traceback.format_exc()
print(trace_info, flush=True)
print(e, flush=True)
def generate(self,
input_prompt,
input_frames,
input_masks,
input_ref_images,
size=(1280, 720),
frame_num=81,
context_scale=1.0,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale,
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model)
for in_q in self.in_q_list:
in_q.put(input_data)
value_output = self.out_q.get()
return value_output