mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-06-17 21:07:41 +00:00
Compare commits
9 Commits
ccd3ad1edd
...
e8f3ff1d0e
Author | SHA1 | Date | |
---|---|---|---|
|
e8f3ff1d0e | ||
|
c709fcf0e7 | ||
|
18d53feb7a | ||
|
36d6d91b90 | ||
|
1c7b73d13e | ||
|
db54b7c613 | ||
|
24007c2c39 | ||
|
bc2aff711e | ||
|
bebb16bb8e |
99
README.md
99
README.md
@ -27,6 +27,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
||||
|
||||
## 🔥 Latest News!!
|
||||
|
||||
* May 14, 2025: 👋 We introduce **Wan2.1** [VACE](https://github.com/ali-vilab/VACE), an all-in-one model for video creation and editing, along with its [inference code](#run-vace), [weights](#model-download), and [technical report](https://arxiv.org/abs/2503.07598)!
|
||||
* Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights!
|
||||
* Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback!
|
||||
* Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try!
|
||||
@ -64,7 +65,13 @@ If your work has improved **Wan2.1** and you would like more people to see it, p
|
||||
- [ ] ComfyUI integration
|
||||
- [ ] Diffusers integration
|
||||
- [ ] Diffusers + Multi-GPU Inference
|
||||
|
||||
- Wan2.1 VACE
|
||||
- [x] Multi-GPU Inference code of the 14B and 1.3B models
|
||||
- [x] Checkpoints of the 14B and 1.3B models
|
||||
- [x] Gradio demo
|
||||
- [x] ComfyUI integration
|
||||
- [ ] Diffusers integration
|
||||
- [ ] Diffusers + Multi-GPU Inference
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -85,12 +92,14 @@ pip install -r requirements.txt
|
||||
#### Model Download
|
||||
|
||||
| Models | Download Link | Notes |
|
||||
|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
|
||||
|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
|
||||
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
|
||||
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
|
||||
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
|
||||
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
|
||||
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
|
||||
| VACE-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | Supports 480P
|
||||
| VACE-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | Supports both 480P and 720P
|
||||
|
||||
> 💡Note:
|
||||
> * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
|
||||
@ -157,6 +166,14 @@ If you encounter OOM (Out-of-Memory) issues, you can use the `--offload_model Tr
|
||||
python generate.py --task t2v-1.3B --size 832*480 --ckpt_dir ./Wan2.1-T2V-1.3B --offload_model True --t5_cpu --sample_shift 8 --sample_guide_scale 6 --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
|
||||
```
|
||||
|
||||
You can also use the `--fp8` option to enable FP8 precision for reduced memory usage. Make sure to download the [FP8 model weight](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors) and place it in the `Wan2.1-T2V-1.3B` folder.
|
||||
|
||||
Additionally, an [FP8 version of the T5 model](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/umt5-xxl-enc-fp8_e4m3fn.safetensors) is available. To use the FP8 T5 model, update the configuration file:
|
||||
|
||||
```
|
||||
t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors'
|
||||
```
|
||||
|
||||
> 💡Note: If you are using the `T2V-1.3B` model, we recommend setting the parameter `--sample_guide_scale 6`. The `--sample_shift parameter` can be adjusted within the range of 8 to 12 based on the performance.
|
||||
|
||||
|
||||
@ -293,6 +310,17 @@ Similar to Text-to-Video, Image-to-Video is also divided into processes with and
|
||||
python generate.py --task i2v-14B --size 1280*720 --ckpt_dir ./Wan2.1-I2V-14B-720P --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
||||
```
|
||||
|
||||
To minimize GPU memory usage, you can enable model offloading with `--offload_model True` and use FP8 precision with `--fp8`.
|
||||
|
||||
For example, to run **Wan2.1-I2V-14B-480P** on an RTX 4090 GPU:
|
||||
|
||||
1. First, download the [FP8 model weights](https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors) and place them in the `Wan2.1-I2V-14B-480P` folder.
|
||||
2. Then, execute the following command:
|
||||
|
||||
```
|
||||
python generate.py --task i2v-14B --size 832*480 --ckpt_dir ./Wan2.1-I2V-14B-480P --offload_model True --fp8 --image examples/i2v_input.JPG --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
|
||||
```
|
||||
|
||||
> 💡For the Image-to-Video task, the `size` parameter represents the area of the generated video, with the aspect ratio following that of the original input image.
|
||||
|
||||
|
||||
@ -448,6 +476,73 @@ DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dash
|
||||
```
|
||||
|
||||
|
||||
#### Run VACE
|
||||
|
||||
[VACE](https://github.com/ali-vilab/VACE) now supports two models (1.3B and 14B) and two main resolutions (480P and 720P).
|
||||
The input supports any resolution, but to achieve optimal results, the video size should fall within a specific range.
|
||||
The parameters and configurations for these models are as follows:
|
||||
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th rowspan="2">Task</th>
|
||||
<th colspan="2">Resolution</th>
|
||||
<th rowspan="2">Model</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th>480P(~81x480x832)</th>
|
||||
<th>720P(~81x720x1280)</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td>VACE</td>
|
||||
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
|
||||
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
|
||||
<td>Wan2.1-VACE-14B</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>VACE</td>
|
||||
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
|
||||
<td style="color: red; text-align: center; vertical-align: middle;">❌</td>
|
||||
<td>Wan2.1-VACE-1.3B</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
In VACE, users can input text prompt and optional video, mask, and image for video generation or editing. Detailed instructions for using VACE can be found in the [User Guide](https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).
|
||||
The execution process is as follows:
|
||||
|
||||
##### (1) Preprocessing
|
||||
|
||||
User-collected materials needs to be preprocessed into VACE-recognizable inputs, including `src_video`, `src_mask`, `src_ref_images`, and `prompt`.
|
||||
For R2V (Reference-to-Video Generation), you may skip this preprocessing, but for V2V (Video-to-Video Editing) and MV2V (Masked Video-to-Video Editing) tasks, additional preprocessing is required to obtain video with conditions such as depth, pose or masked regions.
|
||||
For more details, please refer to [vace_preproccess](https://github.com/ali-vilab/VACE/blob/main/vace/vace_preproccess.py).
|
||||
|
||||
##### (2) cli inference
|
||||
|
||||
- Single-GPU inference
|
||||
```sh
|
||||
python generate.py --task vace-1.3B --size 832*480 --ckpt_dir ./Wan2.1-VACE-1.3B --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||
```
|
||||
|
||||
- Multi-GPU inference using FSDP + xDiT USP
|
||||
|
||||
```sh
|
||||
torchrun --nproc_per_node=8 generate.py --task vace-14B --size 1280*720 --ckpt_dir ./Wan2.1-VACE-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||
```
|
||||
|
||||
##### (3) Running local gradio
|
||||
- Single-GPU inference
|
||||
```sh
|
||||
python gradio/vace.py --ckpt_dir ./Wan2.1-VACE-1.3B
|
||||
```
|
||||
|
||||
- Multi-GPU inference using FSDP + xDiT USP
|
||||
```sh
|
||||
python gradio/vace.py --mp --ulysses_size 8 --ckpt_dir ./Wan2.1-VACE-14B/
|
||||
```
|
||||
|
||||
#### Run Text-to-Image Generation
|
||||
|
||||
Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:
|
||||
|
BIN
examples/girl.png
Normal file
BIN
examples/girl.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 817 KiB |
BIN
examples/snake.png
Normal file
BIN
examples/snake.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 435 KiB |
94
generate.py
94
generate.py
@ -41,6 +41,14 @@ EXAMPLE_PROMPT = {
|
||||
"last_frame":
|
||||
"examples/flf2v_input_last_frame.png",
|
||||
},
|
||||
"vace-1.3B": {
|
||||
"src_ref_images": 'examples/girl.png,examples/snake.png',
|
||||
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||
},
|
||||
"vace-14B": {
|
||||
"src_ref_images": 'examples/girl.png,examples/snake.png',
|
||||
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -52,15 +60,19 @@ def _validate_args(args):
|
||||
|
||||
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
|
||||
if args.sample_steps is None:
|
||||
args.sample_steps = 40 if "i2v" in args.task else 50
|
||||
args.sample_steps = 50
|
||||
if "i2v" in args.task:
|
||||
args.sample_steps = 40
|
||||
|
||||
|
||||
if args.sample_shift is None:
|
||||
args.sample_shift = 5.0
|
||||
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
||||
args.sample_shift = 3.0
|
||||
if "flf2v" in args.task:
|
||||
elif "flf2v" in args.task or "vace" in args.task:
|
||||
args.sample_shift = 16
|
||||
|
||||
|
||||
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
||||
if args.frame_num is None:
|
||||
args.frame_num = 1 if "t2i" in args.task else 81
|
||||
@ -136,11 +148,31 @@ def _parse_args():
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use FSDP for DiT.")
|
||||
parser.add_argument(
|
||||
"--fp8",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use fp8.")
|
||||
parser.add_argument(
|
||||
"--save_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file to save the generated image or video to.")
|
||||
parser.add_argument(
|
||||
"--src_video",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file of the source video. Default None.")
|
||||
parser.add_argument(
|
||||
"--src_mask",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file of the source mask. Default None.")
|
||||
parser.add_argument(
|
||||
"--src_ref_images",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file list of the source reference images. Separated by ','. Default None.")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
@ -326,6 +358,7 @@ def generate(args):
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
fp8=args.fp8,
|
||||
)
|
||||
|
||||
logging.info(
|
||||
@ -383,6 +416,7 @@ def generate(args):
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
fp8=args.fp8,
|
||||
)
|
||||
|
||||
logging.info("Generating video ...")
|
||||
@ -397,7 +431,7 @@ def generate(args):
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
else:
|
||||
elif "flf2v" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
if args.first_frame is None or args.last_frame is None:
|
||||
@ -457,6 +491,60 @@ def generate(args):
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model
|
||||
)
|
||||
elif "vace" in args.task:
|
||||
if args.prompt is None:
|
||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
|
||||
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
|
||||
args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None)
|
||||
|
||||
logging.info(f"Input prompt: {args.prompt}")
|
||||
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
|
||||
logging.info("Extending prompt ...")
|
||||
if rank == 0:
|
||||
prompt = prompt_expander.forward(args.prompt)
|
||||
logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'")
|
||||
input_prompt = [prompt]
|
||||
else:
|
||||
input_prompt = [None]
|
||||
if dist.is_initialized():
|
||||
dist.broadcast_object_list(input_prompt, src=0)
|
||||
args.prompt = input_prompt[0]
|
||||
logging.info(f"Extended prompt: {args.prompt}")
|
||||
|
||||
logging.info("Creating VACE pipeline.")
|
||||
wan_vace = wan.WanVace(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video],
|
||||
[args.src_mask],
|
||||
[None if args.src_ref_images is None else args.src_ref_images.split(',')],
|
||||
args.frame_num, SIZE_CONFIGS[args.size], device)
|
||||
|
||||
logging.info(f"Generating video...")
|
||||
video = wan_vace.generate(
|
||||
args.prompt,
|
||||
src_video,
|
||||
src_mask,
|
||||
src_ref_images,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
else:
|
||||
raise ValueError(f"Unkown task type: {args.task}")
|
||||
|
||||
if rank == 0:
|
||||
if args.save_file is None:
|
||||
|
295
gradio/vace.py
Normal file
295
gradio/vace.py
Normal file
@ -0,0 +1,295 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import datetime
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
|
||||
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
|
||||
import wan
|
||||
from wan import WanVace, WanVaceMP
|
||||
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
|
||||
|
||||
|
||||
class FixedSizeQueue:
|
||||
def __init__(self, max_size):
|
||||
self.max_size = max_size
|
||||
self.queue = []
|
||||
def add(self, item):
|
||||
self.queue.insert(0, item)
|
||||
if len(self.queue) > self.max_size:
|
||||
self.queue.pop()
|
||||
def get(self):
|
||||
return self.queue
|
||||
def __repr__(self):
|
||||
return str(self.queue)
|
||||
|
||||
|
||||
class VACEInference:
|
||||
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
|
||||
self.cfg = cfg
|
||||
self.save_dir = cfg.save_dir
|
||||
self.gallery_share = gallery_share
|
||||
self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
|
||||
if not skip_load:
|
||||
if not args.mp:
|
||||
self.pipe = WanVace(
|
||||
config=WAN_CONFIGS[cfg.model_name],
|
||||
checkpoint_dir=cfg.ckpt_dir,
|
||||
device_id=0,
|
||||
rank=0,
|
||||
t5_fsdp=False,
|
||||
dit_fsdp=False,
|
||||
use_usp=False,
|
||||
)
|
||||
else:
|
||||
self.pipe = WanVaceMP(
|
||||
config=WAN_CONFIGS[cfg.model_name],
|
||||
checkpoint_dir=cfg.ckpt_dir,
|
||||
use_usp=True,
|
||||
ulysses_size=cfg.ulysses_size,
|
||||
ring_size=cfg.ring_size
|
||||
)
|
||||
|
||||
|
||||
def create_ui(self, *args, **kwargs):
|
||||
gr.Markdown("""
|
||||
<div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
|
||||
<a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
|
||||
</div>
|
||||
""")
|
||||
with gr.Row(variant='panel', equal_height=True):
|
||||
with gr.Column(scale=1, min_width=0):
|
||||
self.src_video = gr.Video(
|
||||
label="src_video",
|
||||
sources=['upload'],
|
||||
value=None,
|
||||
interactive=True)
|
||||
with gr.Column(scale=1, min_width=0):
|
||||
self.src_mask = gr.Video(
|
||||
label="src_mask",
|
||||
sources=['upload'],
|
||||
value=None,
|
||||
interactive=True)
|
||||
#
|
||||
with gr.Row(variant='panel', equal_height=True):
|
||||
with gr.Column(scale=1, min_width=0):
|
||||
with gr.Row(equal_height=True):
|
||||
self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
|
||||
height=200,
|
||||
interactive=True,
|
||||
type='filepath',
|
||||
image_mode='RGB',
|
||||
sources=['upload'],
|
||||
elem_id="src_ref_image_1",
|
||||
format='png')
|
||||
self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
|
||||
height=200,
|
||||
interactive=True,
|
||||
type='filepath',
|
||||
image_mode='RGB',
|
||||
sources=['upload'],
|
||||
elem_id="src_ref_image_2",
|
||||
format='png')
|
||||
self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
|
||||
height=200,
|
||||
interactive=True,
|
||||
type='filepath',
|
||||
image_mode='RGB',
|
||||
sources=['upload'],
|
||||
elem_id="src_ref_image_3",
|
||||
format='png')
|
||||
with gr.Row(variant='panel', equal_height=True):
|
||||
with gr.Column(scale=1):
|
||||
self.prompt = gr.Textbox(
|
||||
show_label=False,
|
||||
placeholder="positive_prompt_input",
|
||||
elem_id='positive_prompt',
|
||||
container=True,
|
||||
autofocus=True,
|
||||
elem_classes='type_row',
|
||||
visible=True,
|
||||
lines=2)
|
||||
self.negative_prompt = gr.Textbox(
|
||||
show_label=False,
|
||||
value=self.pipe.config.sample_neg_prompt,
|
||||
placeholder="negative_prompt_input",
|
||||
elem_id='negative_prompt',
|
||||
container=True,
|
||||
autofocus=False,
|
||||
elem_classes='type_row',
|
||||
visible=True,
|
||||
interactive=True,
|
||||
lines=1)
|
||||
#
|
||||
with gr.Row(variant='panel', equal_height=True):
|
||||
with gr.Column(scale=1, min_width=0):
|
||||
with gr.Row(equal_height=True):
|
||||
self.shift_scale = gr.Slider(
|
||||
label='shift_scale',
|
||||
minimum=0.0,
|
||||
maximum=100.0,
|
||||
step=1.0,
|
||||
value=16.0,
|
||||
interactive=True)
|
||||
self.sample_steps = gr.Slider(
|
||||
label='sample_steps',
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
step=1,
|
||||
value=25,
|
||||
interactive=True)
|
||||
self.context_scale = gr.Slider(
|
||||
label='context_scale',
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
step=0.1,
|
||||
value=1.0,
|
||||
interactive=True)
|
||||
self.guide_scale = gr.Slider(
|
||||
label='guide_scale',
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
step=0.5,
|
||||
value=5.0,
|
||||
interactive=True)
|
||||
self.infer_seed = gr.Slider(minimum=-1,
|
||||
maximum=10000000,
|
||||
value=2025,
|
||||
label="Seed")
|
||||
#
|
||||
with gr.Accordion(label="Usable without source video", open=False):
|
||||
with gr.Row(equal_height=True):
|
||||
self.output_height = gr.Textbox(
|
||||
label='resolutions_height',
|
||||
# value=480,
|
||||
value=720,
|
||||
interactive=True)
|
||||
self.output_width = gr.Textbox(
|
||||
label='resolutions_width',
|
||||
# value=832,
|
||||
value=1280,
|
||||
interactive=True)
|
||||
self.frame_rate = gr.Textbox(
|
||||
label='frame_rate',
|
||||
value=16,
|
||||
interactive=True)
|
||||
self.num_frames = gr.Textbox(
|
||||
label='num_frames',
|
||||
value=81,
|
||||
interactive=True)
|
||||
#
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=5):
|
||||
self.generate_button = gr.Button(
|
||||
value='Run',
|
||||
elem_classes='type_row',
|
||||
elem_id='generate_button',
|
||||
visible=True)
|
||||
with gr.Column(scale=1):
|
||||
self.refresh_button = gr.Button(value='\U0001f504') # 🔄
|
||||
#
|
||||
self.output_gallery = gr.Gallery(
|
||||
label="output_gallery",
|
||||
value=[],
|
||||
interactive=False,
|
||||
allow_preview=True,
|
||||
preview=True)
|
||||
|
||||
|
||||
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
|
||||
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
|
||||
src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
|
||||
x is not None]
|
||||
src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
|
||||
[src_mask],
|
||||
[src_ref_images],
|
||||
num_frames=num_frames,
|
||||
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
|
||||
device=self.pipe.device)
|
||||
video = self.pipe.generate(
|
||||
prompt,
|
||||
src_video,
|
||||
src_mask,
|
||||
src_ref_images,
|
||||
size=(output_width, output_height),
|
||||
context_scale=context_scale,
|
||||
shift=shift_scale,
|
||||
sampling_steps=sample_steps,
|
||||
guide_scale=guide_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=infer_seed,
|
||||
offload_model=True)
|
||||
|
||||
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
|
||||
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
|
||||
video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
|
||||
|
||||
try:
|
||||
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
|
||||
for frame in video_frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
print(video_path)
|
||||
except Exception as e:
|
||||
raise gr.Error(f"Video save error: {e}")
|
||||
|
||||
if self.gallery_share:
|
||||
self.gallery_share_data.add(video_path)
|
||||
return self.gallery_share_data.get()
|
||||
else:
|
||||
return [video_path]
|
||||
|
||||
def set_callbacks(self, **kwargs):
|
||||
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
|
||||
self.gen_outputs = [self.output_gallery]
|
||||
self.generate_button.click(self.generate,
|
||||
inputs=self.gen_inputs,
|
||||
outputs=self.gen_outputs,
|
||||
queue=True)
|
||||
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n')
|
||||
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
|
||||
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
|
||||
parser.add_argument('--root_path', dest='root_path', help='', default=None)
|
||||
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
|
||||
parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",)
|
||||
parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.")
|
||||
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
|
||||
parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.")
|
||||
parser.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
# default='models/VACE-Wan2.1-1.3B-Preview',
|
||||
default='models/Wan2.1-VACE-14B/',
|
||||
help="The path to the checkpoint directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_to_cpu",
|
||||
action="store_true",
|
||||
help="Offloading unnecessary computations to CPU.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.save_dir):
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
|
||||
infer_gr.create_ui()
|
||||
infer_gr.set_callbacks()
|
||||
allowed_paths = [args.save_dir]
|
||||
demo.queue(status_update_rate=1).launch(server_name=args.server_name,
|
||||
server_port=args.server_port,
|
||||
root_path=args.root_path,
|
||||
allowed_paths=allowed_paths,
|
||||
show_error=True, debug=True)
|
@ -105,9 +105,16 @@ function i2v_14B_720p() {
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||
}
|
||||
|
||||
function vace_1_3B() {
|
||||
VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/"
|
||||
torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR
|
||||
|
||||
}
|
||||
|
||||
|
||||
t2i_14B
|
||||
t2v_1_3B
|
||||
t2v_14B
|
||||
i2v_14B_480p
|
||||
i2v_14B_720p
|
||||
vace_1_3B
|
||||
|
@ -2,3 +2,4 @@ from . import configs, distributed, modules
|
||||
from .image2video import WanI2V
|
||||
from .text2video import WanT2V
|
||||
from .first_last_frame2video import WanFLF2V
|
||||
from .vace import WanVace, WanVaceMP
|
||||
|
@ -22,7 +22,9 @@ WAN_CONFIGS = {
|
||||
't2v-1.3B': t2v_1_3B,
|
||||
'i2v-14B': i2v_14B,
|
||||
't2i-14B': t2i_14B,
|
||||
'flf2v-14B': flf2v_14B
|
||||
'flf2v-14B': flf2v_14B,
|
||||
'vace-1.3B': t2v_1_3B,
|
||||
'vace-14B': t2v_14B,
|
||||
}
|
||||
|
||||
SIZE_CONFIGS = {
|
||||
@ -46,4 +48,6 @@ SUPPORTED_SIZES = {
|
||||
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
||||
'vace-1.3B': ('480*832', '832*480'),
|
||||
'vace-14B': ('720*1280', '1280*720', '480*832', '832*480')
|
||||
}
|
||||
|
@ -11,12 +11,14 @@ i2v_14B.update(wan_shared_cfg)
|
||||
i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
|
||||
|
||||
i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||
# i2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
|
||||
i2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# clip
|
||||
i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
|
||||
i2v_14B.clip_dtype = torch.float16
|
||||
i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
||||
# i2v_14B.clip_checkpoint = 'open-clip-xlm-roberta-large-vit-huge-14_fp16.safetensors' # Kijai's fp16 model
|
||||
i2v_14B.clip_tokenizer = 'xlm-roberta-large'
|
||||
|
||||
# vae
|
||||
|
@ -10,6 +10,7 @@ t2v_14B.update(wan_shared_cfg)
|
||||
|
||||
# t5
|
||||
t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||
# t2v_14B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
|
||||
t2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# vae
|
||||
|
@ -10,6 +10,7 @@ t2v_1_3B.update(wan_shared_cfg)
|
||||
|
||||
# t5
|
||||
t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
||||
# t2v_1_3B.t5_checkpoint = 'umt5-xxl-enc-fp8_e4m3fn.safetensors' # fp8 model
|
||||
t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
|
||||
|
||||
# vae
|
||||
|
@ -63,12 +63,45 @@ def rope_apply(x, grid_sizes, freqs):
|
||||
return torch.stack(output).float()
|
||||
|
||||
|
||||
def usp_dit_forward_vace(
|
||||
self,
|
||||
x,
|
||||
vace_context,
|
||||
seq_len,
|
||||
kwargs
|
||||
):
|
||||
# embeddings
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(x=x)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
# Context Parallel
|
||||
c = torch.chunk(
|
||||
c, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
hints = []
|
||||
for block in self.vace_blocks:
|
||||
c, c_skip = block(c, **new_kwargs)
|
||||
hints.append(c_skip)
|
||||
return hints
|
||||
|
||||
|
||||
def usp_dit_forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
context,
|
||||
seq_len,
|
||||
vace_context=None,
|
||||
vace_context_scale=1.0,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
):
|
||||
@ -84,7 +117,7 @@ def usp_dit_forward(
|
||||
if self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(device)
|
||||
|
||||
if y is not None:
|
||||
if self.model_type != 'vace' and y is not None:
|
||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
|
||||
# embeddings
|
||||
@ -114,7 +147,7 @@ def usp_dit_forward(
|
||||
for u in context
|
||||
]))
|
||||
|
||||
if clip_fea is not None:
|
||||
if self.model_type != 'vace' and clip_fea is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
@ -132,6 +165,11 @@ def usp_dit_forward(
|
||||
x, get_sequence_parallel_world_size(),
|
||||
dim=1)[get_sequence_parallel_rank()]
|
||||
|
||||
if self.model_type == 'vace':
|
||||
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
|
||||
kwargs['hints'] = hints
|
||||
kwargs['context_scale'] = vace_context_scale
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
|
@ -16,6 +16,10 @@ import torch.distributed as dist
|
||||
import torchvision.transforms.functional as TF
|
||||
from tqdm import tqdm
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from .distributed.fsdp import shard_model
|
||||
from .modules.clip import CLIPModel
|
||||
from .modules.model import WanModel
|
||||
@ -39,6 +43,7 @@ class WanI2V:
|
||||
use_usp=False,
|
||||
t5_cpu=False,
|
||||
init_on_cpu=True,
|
||||
fp8=False,
|
||||
):
|
||||
r"""
|
||||
Initializes the image-to-video generation model components.
|
||||
@ -62,6 +67,8 @@ class WanI2V:
|
||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||
init_on_cpu (`bool`, *optional*, defaults to True):
|
||||
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
||||
fp8 (`bool`, *optional*, defaults to False):
|
||||
Enable 8-bit floating point precision for model parameters.
|
||||
"""
|
||||
self.device = torch.device(f"cuda:{device_id}")
|
||||
self.config = config
|
||||
@ -73,6 +80,10 @@ class WanI2V:
|
||||
self.param_dtype = config.param_dtype
|
||||
|
||||
shard_fn = partial(shard_model, device_id=device_id)
|
||||
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
|
||||
quantization = "fp8_e4m3fn"
|
||||
else:
|
||||
quantization = "disabled"
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
dtype=config.t5_dtype,
|
||||
@ -80,10 +91,12 @@ class WanI2V:
|
||||
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||
shard_fn=shard_fn if t5_fsdp else None,
|
||||
quantization=quantization,
|
||||
)
|
||||
|
||||
self.vae_stride = config.vae_stride
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.vae = WanVAE(
|
||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||
device=self.device)
|
||||
@ -96,7 +109,46 @@ class WanI2V:
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
||||
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
if not fp8:
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
else:
|
||||
if '480P' in checkpoint_dir:
|
||||
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-480P_fp8_e4m3fn.safetensors', device="cpu")
|
||||
elif '720P' in checkpoint_dir:
|
||||
state_dict = load_file(checkpoint_dir+'/Wan2_1-I2V-14B-720P_fp8_e4m3fn.safetensors', device="cpu")
|
||||
dim = state_dict["patch_embedding.weight"].shape[0]
|
||||
in_channels = state_dict["patch_embedding.weight"].shape[1]
|
||||
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
|
||||
model_type = "i2v" if in_channels == 36 else "t2v"
|
||||
num_heads = 40 if dim == 5120 else 12
|
||||
num_layers = 40 if dim == 5120 else 30
|
||||
TRANSFORMER_CONFIG= {
|
||||
"dim": dim,
|
||||
"ffn_dim": ffn_dim,
|
||||
"eps": 1e-06,
|
||||
"freq_dim": 256,
|
||||
"in_dim": in_channels,
|
||||
"model_type": model_type,
|
||||
"out_dim": 16,
|
||||
"text_len": 512,
|
||||
"num_heads": num_heads,
|
||||
"num_layers": num_layers,
|
||||
}
|
||||
|
||||
with init_empty_weights():
|
||||
self.model = WanModel(**TRANSFORMER_CONFIG)
|
||||
|
||||
base_dtype=torch.bfloat16
|
||||
dtype=torch.float8_e4m3fn
|
||||
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
|
||||
for name, param in self.model.named_parameters():
|
||||
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
||||
# dtype_to_use = torch.bfloat16
|
||||
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
|
||||
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
|
||||
|
||||
del state_dict
|
||||
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if t5_fsdp or dit_fsdp or use_usp:
|
||||
@ -219,11 +271,13 @@ class WanI2V:
|
||||
# preprocess
|
||||
if not self.t5_cpu:
|
||||
self.text_encoder.model.to(self.device)
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
||||
context = self.text_encoder([input_prompt], self.device)
|
||||
context_null = self.text_encoder([n_prompt], self.device)
|
||||
if offload_model:
|
||||
self.text_encoder.model.cpu()
|
||||
else:
|
||||
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
|
||||
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||
context = [t.to(self.device) for t in context]
|
||||
@ -242,9 +296,12 @@ class WanI2V:
|
||||
torch.zeros(3, F - 1, h, w)
|
||||
],
|
||||
dim=1).to(self.device)
|
||||
])[0]
|
||||
],device=self.device)[0]
|
||||
y = torch.concat([msk, y])
|
||||
|
||||
if offload_model:
|
||||
self.vae.model.cpu()
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
@ -332,9 +389,11 @@ class WanI2V:
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
# load vae model back to device
|
||||
self.vae.model.to(self.device)
|
||||
|
||||
if self.rank == 0:
|
||||
videos = self.vae.decode(x0)
|
||||
videos = self.vae.decode(x0, device=self.device)
|
||||
|
||||
del noise, latent
|
||||
del sample_scheduler
|
||||
|
@ -2,11 +2,13 @@ from .attention import flash_attention
|
||||
from .model import WanModel
|
||||
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
from .vace_model import VaceWanModel
|
||||
from .vae import WanVAE
|
||||
|
||||
__all__ = [
|
||||
'WanVAE',
|
||||
'WanModel',
|
||||
'VaceWanModel',
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
'T5Decoder',
|
||||
|
@ -7,6 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as T
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from .attention import flash_attention
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
@ -515,8 +516,13 @@ class CLIPModel:
|
||||
device=device)
|
||||
self.model = self.model.eval().requires_grad_(False)
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
self.model.load_state_dict(
|
||||
torch.load(checkpoint_path, map_location='cpu'))
|
||||
if checkpoint_path.endswith('.safetensors'):
|
||||
state_dict = load_file(checkpoint_path, device='cpu')
|
||||
self.model.load_state_dict(state_dict)
|
||||
elif checkpoint_path.endswith('.pth'):
|
||||
self.model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||
else:
|
||||
raise ValueError(f'Unsupported checkpoint file format: {checkpoint_path}')
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = HuggingfaceTokenizer(
|
||||
|
@ -400,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_type (`str`, *optional*, defaults to 't2v'):
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
|
||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
|
||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||
text_len (`int`, *optional*, defaults to 512):
|
||||
@ -433,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
|
||||
super().__init__()
|
||||
|
||||
assert model_type in ['t2v', 'i2v', 'flf2v']
|
||||
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
|
||||
self.model_type = model_type
|
||||
|
||||
self.patch_size = patch_size
|
||||
|
@ -9,6 +9,10 @@ import torch.nn.functional as F
|
||||
|
||||
from .tokenizers import HuggingfaceTokenizer
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from safetensors.torch import load_file
|
||||
|
||||
__all__ = [
|
||||
'T5Model',
|
||||
'T5Encoder',
|
||||
@ -442,7 +446,7 @@ def _t5(name,
|
||||
model = model_cls(**kwargs)
|
||||
|
||||
# set device
|
||||
model = model.to(dtype=dtype, device=device)
|
||||
# model = model.to(dtype=dtype, device=device)
|
||||
|
||||
# init tokenizer
|
||||
if return_tokenizer:
|
||||
@ -479,6 +483,7 @@ class T5EncoderModel:
|
||||
checkpoint_path=None,
|
||||
tokenizer_path=None,
|
||||
shard_fn=None,
|
||||
quantization="disabled",
|
||||
):
|
||||
self.text_len = text_len
|
||||
self.dtype = dtype
|
||||
@ -486,14 +491,31 @@ class T5EncoderModel:
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
if quantization == "disabled":
|
||||
# init model
|
||||
model = umt5_xxl(
|
||||
encoder_only=True,
|
||||
return_tokenizer=False,
|
||||
dtype=dtype,
|
||||
device=device).eval().requires_grad_(False)
|
||||
logging.info(f'loading {checkpoint_path}')
|
||||
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||
elif quantization == "fp8_e4m3fn":
|
||||
with init_empty_weights():
|
||||
model = umt5_xxl(
|
||||
encoder_only=True,
|
||||
return_tokenizer=False,
|
||||
dtype=dtype,
|
||||
device=device).eval().requires_grad_(False)
|
||||
cast_dtype = torch.float8_e4m3fn
|
||||
state_dict = load_file(checkpoint_path, device="cpu")
|
||||
params_to_keep = {'norm', 'pos_embedding', 'token_embedding'}
|
||||
for name, param in model.named_parameters():
|
||||
dtype_to_use = dtype if any(keyword in name for keyword in params_to_keep) else cast_dtype
|
||||
set_module_tensor_to_device(model, name, device=device, dtype=dtype_to_use, value=state_dict[name])
|
||||
del state_dict
|
||||
|
||||
self.model = model
|
||||
if shard_fn is not None:
|
||||
self.model = shard_fn(self.model, sync_module_states=False)
|
||||
|
233
wan/modules/vace_model.py
Normal file
233
wan/modules/vace_model.py
Normal file
@ -0,0 +1,233 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
|
||||
|
||||
|
||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
def __init__(
|
||||
self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6,
|
||||
block_id=0
|
||||
):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = nn.Linear(self.dim, self.dim)
|
||||
nn.init.zeros_(self.before_proj.weight)
|
||||
nn.init.zeros_(self.before_proj.bias)
|
||||
self.after_proj = nn.Linear(self.dim, self.dim)
|
||||
nn.init.zeros_(self.after_proj.weight)
|
||||
nn.init.zeros_(self.after_proj.bias)
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
return c, c_skip
|
||||
|
||||
|
||||
class BaseWanAttentionBlock(WanAttentionBlock):
|
||||
def __init__(
|
||||
self,
|
||||
cross_attn_type,
|
||||
dim,
|
||||
ffn_dim,
|
||||
num_heads,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=False,
|
||||
eps=1e-6,
|
||||
block_id=None
|
||||
):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
||||
self.block_id = block_id
|
||||
|
||||
def forward(self, x, hints, context_scale=1.0, **kwargs):
|
||||
x = super().forward(x, **kwargs)
|
||||
if self.block_id is not None:
|
||||
x = x + hints[self.block_id] * context_scale
|
||||
return x
|
||||
|
||||
|
||||
class VaceWanModel(WanModel):
|
||||
@register_to_config
|
||||
def __init__(self,
|
||||
vace_layers=None,
|
||||
vace_in_dim=None,
|
||||
model_type='vace',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6):
|
||||
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
|
||||
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
|
||||
|
||||
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
|
||||
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
|
||||
|
||||
assert 0 in self.vace_layers
|
||||
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
||||
|
||||
# blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
||||
self.cross_attn_norm, self.eps,
|
||||
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
|
||||
for i in range(self.num_layers)
|
||||
])
|
||||
|
||||
# vace blocks
|
||||
self.vace_blocks = nn.ModuleList([
|
||||
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
||||
self.cross_attn_norm, self.eps, block_id=i)
|
||||
for i in self.vace_layers
|
||||
])
|
||||
|
||||
# vace patch embeddings
|
||||
self.vace_patch_embedding = nn.Conv3d(
|
||||
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
|
||||
)
|
||||
|
||||
def forward_vace(
|
||||
self,
|
||||
x,
|
||||
vace_context,
|
||||
seq_len,
|
||||
kwargs
|
||||
):
|
||||
# embeddings
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(x=x)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
hints = []
|
||||
for block in self.vace_blocks:
|
||||
c, c_skip = block(c, **new_kwargs)
|
||||
hints.append(c_skip)
|
||||
return hints
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
t,
|
||||
vace_context,
|
||||
context,
|
||||
seq_len,
|
||||
vace_context_scale=1.0,
|
||||
clip_fea=None,
|
||||
y=None,
|
||||
):
|
||||
r"""
|
||||
Forward pass through the diffusion model
|
||||
|
||||
Args:
|
||||
x (List[Tensor]):
|
||||
List of input video tensors, each with shape [C_in, F, H, W]
|
||||
t (Tensor):
|
||||
Diffusion timesteps tensor of shape [B]
|
||||
context (List[Tensor]):
|
||||
List of text embeddings each with shape [L, C]
|
||||
seq_len (`int`):
|
||||
Maximum sequence length for positional encoding
|
||||
clip_fea (Tensor, *optional*):
|
||||
CLIP image features for image-to-video mode
|
||||
y (List[Tensor], *optional*):
|
||||
Conditional video inputs for image-to-video mode, same shape as x
|
||||
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||
"""
|
||||
# if self.model_type == 'i2v':
|
||||
# assert clip_fea is not None and y is not None
|
||||
# params
|
||||
device = self.patch_embedding.weight.device
|
||||
if self.freqs.device != device:
|
||||
self.freqs = self.freqs.to(device)
|
||||
|
||||
# if y is not None:
|
||||
# x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||
|
||||
# embeddings
|
||||
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||
grid_sizes = torch.stack(
|
||||
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||
assert seq_lens.max() <= seq_len
|
||||
x = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||
dim=1) for u in x
|
||||
])
|
||||
|
||||
# time embeddings
|
||||
with amp.autocast(dtype=torch.float32):
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||
|
||||
# context
|
||||
context_lens = None
|
||||
context = self.text_embedding(
|
||||
torch.stack([
|
||||
torch.cat(
|
||||
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||
for u in context
|
||||
]))
|
||||
|
||||
# if clip_fea is not None:
|
||||
# context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
# context = torch.concat([context_clip, context], dim=1)
|
||||
|
||||
# arguments
|
||||
kwargs = dict(
|
||||
e=e0,
|
||||
seq_lens=seq_lens,
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=self.freqs,
|
||||
context=context,
|
||||
context_lens=context_lens)
|
||||
|
||||
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
|
||||
kwargs['hints'] = hints
|
||||
kwargs['context_scale'] = vace_context_scale
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, **kwargs)
|
||||
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return [u.float() for u in x]
|
@ -644,7 +644,7 @@ class WanVAE:
|
||||
z_dim=z_dim,
|
||||
).eval().requires_grad_(False).to(device)
|
||||
|
||||
def encode(self, videos):
|
||||
def encode(self, videos, device=None):
|
||||
"""
|
||||
videos: A list of videos each with shape [C, T, H, W].
|
||||
"""
|
||||
@ -654,7 +654,7 @@ class WanVAE:
|
||||
for u in videos
|
||||
]
|
||||
|
||||
def decode(self, zs):
|
||||
def decode(self, zs, device=None):
|
||||
with amp.autocast(dtype=self.dtype):
|
||||
return [
|
||||
self.model.decode(u.unsqueeze(0),
|
||||
|
@ -14,6 +14,10 @@ import torch.cuda.amp as amp
|
||||
import torch.distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from .distributed.fsdp import shard_model
|
||||
from .modules.model import WanModel
|
||||
from .modules.t5 import T5EncoderModel
|
||||
@ -35,6 +39,8 @@ class WanT2V:
|
||||
dit_fsdp=False,
|
||||
use_usp=False,
|
||||
t5_cpu=False,
|
||||
init_on_cpu=True,
|
||||
fp8=False,
|
||||
):
|
||||
r"""
|
||||
Initializes the Wan text-to-video generation model components.
|
||||
@ -56,6 +62,8 @@ class WanT2V:
|
||||
Enable distribution strategy of USP.
|
||||
t5_cpu (`bool`, *optional*, defaults to False):
|
||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||
fp8 (`bool`, *optional*, defaults to False):
|
||||
Enable 8-bit floating point precision for model parameters.
|
||||
"""
|
||||
self.device = torch.device(f"cuda:{device_id}")
|
||||
self.config = config
|
||||
@ -66,13 +74,19 @@ class WanT2V:
|
||||
self.param_dtype = config.param_dtype
|
||||
|
||||
shard_fn = partial(shard_model, device_id=device_id)
|
||||
if config.t5_checkpoint == 'umt5-xxl-enc-fp8_e4m3fn.safetensors':
|
||||
quantization = "fp8_e4m3fn"
|
||||
else:
|
||||
quantization = "disabled"
|
||||
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
dtype=config.t5_dtype,
|
||||
device=torch.device('cpu'),
|
||||
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||
shard_fn=shard_fn if t5_fsdp else None)
|
||||
shard_fn=shard_fn if t5_fsdp else None,
|
||||
quantization=quantization)
|
||||
|
||||
self.vae_stride = config.vae_stride
|
||||
self.patch_size = config.patch_size
|
||||
@ -81,9 +95,52 @@ class WanT2V:
|
||||
device=self.device)
|
||||
|
||||
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
||||
if not fp8:
|
||||
self.model = WanModel.from_pretrained(checkpoint_dir)
|
||||
else:
|
||||
if '14B' in checkpoint_dir:
|
||||
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-14B_fp8_e4m3fn.safetensors', device="cpu")
|
||||
else:
|
||||
state_dict = load_file(checkpoint_dir+'/Wan2_1-T2V-1_3B_fp8_e4m3fn.safetensors', device="cpu")
|
||||
|
||||
dim = state_dict["patch_embedding.weight"].shape[0]
|
||||
in_channels = state_dict["patch_embedding.weight"].shape[1]
|
||||
ffn_dim = state_dict["blocks.0.ffn.0.bias"].shape[0]
|
||||
model_type = "i2v" if in_channels == 36 else "t2v"
|
||||
num_heads = 40 if dim == 5120 else 12
|
||||
num_layers = 40 if dim == 5120 else 30
|
||||
TRANSFORMER_CONFIG= {
|
||||
"dim": dim,
|
||||
"ffn_dim": ffn_dim,
|
||||
"eps": 1e-06,
|
||||
"freq_dim": 256,
|
||||
"in_dim": in_channels,
|
||||
"model_type": model_type,
|
||||
"out_dim": 16,
|
||||
"text_len": 512,
|
||||
"num_heads": num_heads,
|
||||
"num_layers": num_layers,
|
||||
}
|
||||
|
||||
with init_empty_weights():
|
||||
self.model = WanModel(**TRANSFORMER_CONFIG)
|
||||
|
||||
base_dtype=torch.bfloat16
|
||||
dtype=torch.float8_e4m3fn
|
||||
params_to_keep = {"norm", "head", "bias", "time_in", "vector_in", "patch_embedding", "time_", "img_emb", "modulation"}
|
||||
for name, param in self.model.named_parameters():
|
||||
dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else dtype
|
||||
# dtype_to_use = torch.bfloat16
|
||||
# print("Assigning Parameter name: ", name, " with dtype: ", dtype_to_use)
|
||||
set_module_tensor_to_device(self.model, name, device='cpu', dtype=dtype_to_use, value=state_dict[name])
|
||||
|
||||
del state_dict
|
||||
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if t5_fsdp or dit_fsdp or use_usp:
|
||||
init_on_cpu = False
|
||||
|
||||
if use_usp:
|
||||
from xfuser.core.distributed import \
|
||||
get_sequence_parallel_world_size
|
||||
@ -103,6 +160,7 @@ class WanT2V:
|
||||
if dit_fsdp:
|
||||
self.model = shard_fn(self.model)
|
||||
else:
|
||||
if not init_on_cpu:
|
||||
self.model.to(self.device)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
@ -169,11 +227,13 @@ class WanT2V:
|
||||
|
||||
if not self.t5_cpu:
|
||||
self.text_encoder.model.to(self.device)
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
|
||||
context = self.text_encoder([input_prompt], self.device)
|
||||
context_null = self.text_encoder([n_prompt], self.device)
|
||||
if offload_model:
|
||||
self.text_encoder.model.cpu()
|
||||
else:
|
||||
with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True):
|
||||
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||
context = [t.to(self.device) for t in context]
|
||||
@ -190,6 +250,9 @@ class WanT2V:
|
||||
generator=seed_g)
|
||||
]
|
||||
|
||||
if offload_model:
|
||||
self.vae.model.cpu()
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
@ -226,13 +289,15 @@ class WanT2V:
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
if offload_model:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.model.to(self.device)
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
timestep = [t]
|
||||
|
||||
timestep = torch.stack(timestep)
|
||||
|
||||
self.model.to(self.device)
|
||||
noise_pred_cond = self.model(
|
||||
latent_model_input, t=timestep, **arg_c)[0]
|
||||
noise_pred_uncond = self.model(
|
||||
@ -253,6 +318,9 @@ class WanT2V:
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
# load vae model back to device
|
||||
self.vae.model.to(self.device)
|
||||
|
||||
if self.rank == 0:
|
||||
videos = self.vae.decode(x0)
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
|
||||
retrieve_timesteps)
|
||||
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from .vace_processor import VaceVideoProcessor
|
||||
|
||||
__all__ = [
|
||||
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
||||
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
|
||||
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
|
||||
'VaceVideoProcessor'
|
||||
]
|
||||
|
270
wan/utils/vace_processor.py
Normal file
270
wan/utils/vace_processor.py
Normal file
@ -0,0 +1,270 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as TF
|
||||
|
||||
|
||||
class VaceImageProcessor(object):
|
||||
def __init__(self, downsample=None, seq_len=None):
|
||||
self.downsample = downsample
|
||||
self.seq_len = seq_len
|
||||
|
||||
def _pillow_convert(self, image, cvt_type='RGB'):
|
||||
if image.mode != cvt_type:
|
||||
if image.mode == 'P':
|
||||
image = image.convert(f'{cvt_type}A')
|
||||
if image.mode == f'{cvt_type}A':
|
||||
bg = Image.new(cvt_type,
|
||||
size=(image.width, image.height),
|
||||
color=(255, 255, 255))
|
||||
bg.paste(image, (0, 0), mask=image)
|
||||
image = bg
|
||||
else:
|
||||
image = image.convert(cvt_type)
|
||||
return image
|
||||
|
||||
def _load_image(self, img_path):
|
||||
if img_path is None or img_path == '':
|
||||
return None
|
||||
img = Image.open(img_path)
|
||||
img = self._pillow_convert(img)
|
||||
return img
|
||||
|
||||
def _resize_crop(self, img, oh, ow, normalize=True):
|
||||
"""
|
||||
Resize, center crop, convert to tensor, and normalize.
|
||||
"""
|
||||
# resize and crop
|
||||
iw, ih = img.size
|
||||
if iw != ow or ih != oh:
|
||||
# resize
|
||||
scale = max(ow / iw, oh / ih)
|
||||
img = img.resize(
|
||||
(round(scale * iw), round(scale * ih)),
|
||||
resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
assert img.width >= ow and img.height >= oh
|
||||
|
||||
# center crop
|
||||
x1 = (img.width - ow) // 2
|
||||
y1 = (img.height - oh) // 2
|
||||
img = img.crop((x1, y1, x1 + ow, y1 + oh))
|
||||
|
||||
# normalize
|
||||
if normalize:
|
||||
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||
return img
|
||||
|
||||
def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
|
||||
return self._resize_crop(img, oh, ow, normalize)
|
||||
|
||||
def load_image(self, data_key, **kwargs):
|
||||
return self.load_image_batch(data_key, **kwargs)
|
||||
|
||||
def load_image_pair(self, data_key, data_key2, **kwargs):
|
||||
return self.load_image_batch(data_key, data_key2, **kwargs)
|
||||
|
||||
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
|
||||
seq_len = self.seq_len if seq_len is None else seq_len
|
||||
imgs = []
|
||||
for data_key in data_key_batch:
|
||||
img = self._load_image(data_key)
|
||||
imgs.append(img)
|
||||
w, h = imgs[0].size
|
||||
dh, dw = self.downsample[1:]
|
||||
|
||||
# compute output size
|
||||
scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
|
||||
oh = int(h * scale) // dh * dh
|
||||
ow = int(w * scale) // dw * dw
|
||||
assert (oh // dh) * (ow // dw) <= seq_len
|
||||
imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
|
||||
return *imgs, (oh, ow)
|
||||
|
||||
|
||||
class VaceVideoProcessor(object):
|
||||
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
|
||||
self.downsample = downsample
|
||||
self.min_area = min_area
|
||||
self.max_area = max_area
|
||||
self.min_fps = min_fps
|
||||
self.max_fps = max_fps
|
||||
self.zero_start = zero_start
|
||||
self.keep_last = keep_last
|
||||
self.seq_len = seq_len
|
||||
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
|
||||
|
||||
def set_area(self, area):
|
||||
self.min_area = area
|
||||
self.max_area = area
|
||||
|
||||
def set_seq_len(self, seq_len):
|
||||
self.seq_len = seq_len
|
||||
|
||||
@staticmethod
|
||||
def resize_crop(video: torch.Tensor, oh: int, ow: int):
|
||||
"""
|
||||
Resize, center crop and normalize for decord loaded video (torch.Tensor type)
|
||||
|
||||
Parameters:
|
||||
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
|
||||
oh - target height (int)
|
||||
ow - target width (int)
|
||||
|
||||
Returns:
|
||||
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
|
||||
|
||||
Raises:
|
||||
"""
|
||||
# permute ([t, h, w, c] -> [t, c, h, w])
|
||||
video = video.permute(0, 3, 1, 2)
|
||||
|
||||
# resize and crop
|
||||
ih, iw = video.shape[2:]
|
||||
if ih != oh or iw != ow:
|
||||
# resize
|
||||
scale = max(ow / iw, oh / ih)
|
||||
video = F.interpolate(
|
||||
video,
|
||||
size=(round(scale * ih), round(scale * iw)),
|
||||
mode='bicubic',
|
||||
antialias=True
|
||||
)
|
||||
assert video.size(3) >= ow and video.size(2) >= oh
|
||||
|
||||
# center crop
|
||||
x1 = (video.size(3) - ow) // 2
|
||||
y1 = (video.size(2) - oh) // 2
|
||||
video = video[:, :, y1:y1 + oh, x1:x1 + ow]
|
||||
|
||||
# permute ([t, c, h, w] -> [c, t, h, w]) and normalize
|
||||
video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
|
||||
return video
|
||||
|
||||
def _video_preprocess(self, video, oh, ow):
|
||||
return self.resize_crop(video, oh, ow)
|
||||
|
||||
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||
target_fps = min(fps, self.max_fps)
|
||||
duration = frame_timestamps[-1].mean()
|
||||
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||
h, w = y2 - y1, x2 - x1
|
||||
ratio = h / w
|
||||
df, dh, dw = self.downsample
|
||||
|
||||
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
||||
of = min(
|
||||
(int(duration * target_fps) - 1) // df + 1,
|
||||
int(self.seq_len / area_z)
|
||||
)
|
||||
|
||||
# deduce target shape of the [latent video]
|
||||
target_area_z = min(area_z, int(self.seq_len / of))
|
||||
oh = round(np.sqrt(target_area_z * ratio))
|
||||
ow = int(target_area_z / oh)
|
||||
of = (of - 1) * df + 1
|
||||
oh *= dh
|
||||
ow *= dw
|
||||
|
||||
# sample frame ids
|
||||
target_duration = of / target_fps
|
||||
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
|
||||
timestamps = np.linspace(begin, begin + target_duration, of)
|
||||
frame_ids = np.argmax(np.logical_and(
|
||||
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
||||
timestamps[:, None] < frame_timestamps[None, :, 1]
|
||||
), axis=1).tolist()
|
||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||
|
||||
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||
duration = frame_timestamps[-1].mean()
|
||||
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||
h, w = y2 - y1, x2 - x1
|
||||
ratio = h / w
|
||||
df, dh, dw = self.downsample
|
||||
|
||||
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
||||
of = min(
|
||||
(len(frame_timestamps) - 1) // df + 1,
|
||||
int(self.seq_len / area_z)
|
||||
)
|
||||
|
||||
# deduce target shape of the [latent video]
|
||||
target_area_z = min(area_z, int(self.seq_len / of))
|
||||
oh = round(np.sqrt(target_area_z * ratio))
|
||||
ow = int(target_area_z / oh)
|
||||
of = (of - 1) * df + 1
|
||||
oh *= dh
|
||||
ow *= dw
|
||||
|
||||
# sample frame ids
|
||||
target_duration = duration
|
||||
target_fps = of / target_duration
|
||||
timestamps = np.linspace(0., target_duration, of)
|
||||
frame_ids = np.argmax(np.logical_and(
|
||||
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
||||
timestamps[:, None] <= frame_timestamps[None, :, 1]
|
||||
), axis=1).tolist()
|
||||
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
|
||||
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||
|
||||
|
||||
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||
if self.keep_last:
|
||||
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng)
|
||||
else:
|
||||
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng)
|
||||
|
||||
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
||||
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
|
||||
|
||||
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
|
||||
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
|
||||
|
||||
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs):
|
||||
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
|
||||
# read video
|
||||
import decord
|
||||
decord.bridge.set_bridge('torch')
|
||||
readers = []
|
||||
for data_k in data_key_batch:
|
||||
reader = decord.VideoReader(data_k)
|
||||
readers.append(reader)
|
||||
|
||||
fps = readers[0].get_avg_fps()
|
||||
length = min([len(r) for r in readers])
|
||||
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
|
||||
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
||||
h, w = readers[0].next().shape[:2]
|
||||
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng)
|
||||
|
||||
# preprocess video
|
||||
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
||||
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
||||
return *videos, frame_ids, (oh, ow), fps
|
||||
# return videos if len(videos) > 1 else videos[0]
|
||||
|
||||
|
||||
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
||||
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||
if sub_src_video is None and sub_src_mask is None:
|
||||
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
||||
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
|
||||
for i, ref_images in enumerate(src_ref_images):
|
||||
if ref_images is not None:
|
||||
for j, ref_img in enumerate(ref_images):
|
||||
if ref_img is not None and ref_img.shape[-2:] != image_size:
|
||||
canvas_height, canvas_width = image_size
|
||||
ref_height, ref_width = ref_img.shape[-2:]
|
||||
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||
new_height = int(ref_height * scale)
|
||||
new_width = int(ref_width * scale)
|
||||
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
||||
top = (canvas_height - new_height) // 2
|
||||
left = (canvas_width - new_width) // 2
|
||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
||||
src_ref_images[i][j] = white_canvas
|
||||
return src_video, src_mask, src_ref_images
|
718
wan/vace.py
Normal file
718
wan/vace.py
Normal file
@ -0,0 +1,718 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import os
|
||||
import sys
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
import random
|
||||
import types
|
||||
import logging
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
|
||||
from PIL import Image
|
||||
import torchvision.transforms.functional as TF
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.cuda.amp as amp
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from tqdm import tqdm
|
||||
|
||||
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
|
||||
from .modules.vace_model import VaceWanModel
|
||||
from .utils.vace_processor import VaceVideoProcessor
|
||||
|
||||
|
||||
class WanVace(WanT2V):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir,
|
||||
device_id=0,
|
||||
rank=0,
|
||||
t5_fsdp=False,
|
||||
dit_fsdp=False,
|
||||
use_usp=False,
|
||||
t5_cpu=False,
|
||||
):
|
||||
r"""
|
||||
Initializes the Wan text-to-video generation model components.
|
||||
|
||||
Args:
|
||||
config (EasyDict):
|
||||
Object containing model parameters initialized from config.py
|
||||
checkpoint_dir (`str`):
|
||||
Path to directory containing model checkpoints
|
||||
device_id (`int`, *optional*, defaults to 0):
|
||||
Id of target GPU device
|
||||
rank (`int`, *optional*, defaults to 0):
|
||||
Process rank for distributed training
|
||||
t5_fsdp (`bool`, *optional*, defaults to False):
|
||||
Enable FSDP sharding for T5 model
|
||||
dit_fsdp (`bool`, *optional*, defaults to False):
|
||||
Enable FSDP sharding for DiT model
|
||||
use_usp (`bool`, *optional*, defaults to False):
|
||||
Enable distribution strategy of USP.
|
||||
t5_cpu (`bool`, *optional*, defaults to False):
|
||||
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||
"""
|
||||
self.device = torch.device(f"cuda:{device_id}")
|
||||
self.config = config
|
||||
self.rank = rank
|
||||
self.t5_cpu = t5_cpu
|
||||
|
||||
self.num_train_timesteps = config.num_train_timesteps
|
||||
self.param_dtype = config.param_dtype
|
||||
|
||||
shard_fn = partial(shard_model, device_id=device_id)
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
dtype=config.t5_dtype,
|
||||
device=torch.device('cpu'),
|
||||
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
||||
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||
shard_fn=shard_fn if t5_fsdp else None)
|
||||
|
||||
self.vae_stride = config.vae_stride
|
||||
self.patch_size = config.patch_size
|
||||
self.vae = WanVAE(
|
||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||
device=self.device)
|
||||
|
||||
logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
|
||||
self.model = VaceWanModel.from_pretrained(checkpoint_dir)
|
||||
self.model.eval().requires_grad_(False)
|
||||
|
||||
if use_usp:
|
||||
from xfuser.core.distributed import \
|
||||
get_sequence_parallel_world_size
|
||||
|
||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
usp_dit_forward_vace)
|
||||
for block in self.model.blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
for block in self.model.vace_blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
||||
self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model)
|
||||
self.sp_size = get_sequence_parallel_world_size()
|
||||
else:
|
||||
self.sp_size = 1
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
if dit_fsdp:
|
||||
self.model = shard_fn(self.model)
|
||||
else:
|
||||
self.model.to(self.device)
|
||||
|
||||
self.sample_neg_prompt = config.sample_neg_prompt
|
||||
|
||||
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
||||
min_area=720*1280,
|
||||
max_area=720*1280,
|
||||
min_fps=config.sample_fps,
|
||||
max_fps=config.sample_fps,
|
||||
zero_start=True,
|
||||
seq_len=75600,
|
||||
keep_last=True)
|
||||
|
||||
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
|
||||
vae = self.vae if vae is None else vae
|
||||
if ref_images is None:
|
||||
ref_images = [None] * len(frames)
|
||||
else:
|
||||
assert len(frames) == len(ref_images)
|
||||
|
||||
if masks is None:
|
||||
latents = vae.encode(frames)
|
||||
else:
|
||||
masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
|
||||
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
|
||||
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
||||
inactive = vae.encode(inactive)
|
||||
reactive = vae.encode(reactive)
|
||||
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
||||
|
||||
cat_latents = []
|
||||
for latent, refs in zip(latents, ref_images):
|
||||
if refs is not None:
|
||||
if masks is None:
|
||||
ref_latent = vae.encode(refs)
|
||||
else:
|
||||
ref_latent = vae.encode(refs)
|
||||
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
|
||||
assert all([x.shape[1] == 1 for x in ref_latent])
|
||||
latent = torch.cat([*ref_latent, latent], dim=1)
|
||||
cat_latents.append(latent)
|
||||
return cat_latents
|
||||
|
||||
def vace_encode_masks(self, masks, ref_images=None, vae_stride=None):
|
||||
vae_stride = self.vae_stride if vae_stride is None else vae_stride
|
||||
if ref_images is None:
|
||||
ref_images = [None] * len(masks)
|
||||
else:
|
||||
assert len(masks) == len(ref_images)
|
||||
|
||||
result_masks = []
|
||||
for mask, refs in zip(masks, ref_images):
|
||||
c, depth, height, width = mask.shape
|
||||
new_depth = int((depth + 3) // vae_stride[0])
|
||||
height = 2 * (int(height) // (vae_stride[1] * 2))
|
||||
width = 2 * (int(width) // (vae_stride[2] * 2))
|
||||
|
||||
# reshape
|
||||
mask = mask[0, :, :, :]
|
||||
mask = mask.view(
|
||||
depth, height, vae_stride[1], width, vae_stride[1]
|
||||
) # depth, height, 8, width, 8
|
||||
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
|
||||
mask = mask.reshape(
|
||||
vae_stride[1] * vae_stride[2], depth, height, width
|
||||
) # 8*8, depth, height, width
|
||||
|
||||
# interpolation
|
||||
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
|
||||
|
||||
if refs is not None:
|
||||
length = len(refs)
|
||||
mask_pad = torch.zeros_like(mask[:, :length, :, :])
|
||||
mask = torch.cat((mask_pad, mask), dim=1)
|
||||
result_masks.append(mask)
|
||||
return result_masks
|
||||
|
||||
def vace_latent(self, z, m):
|
||||
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||
|
||||
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
||||
area = image_size[0] * image_size[1]
|
||||
self.vid_proc.set_area(area)
|
||||
if area == 720*1280:
|
||||
self.vid_proc.set_seq_len(75600)
|
||||
elif area == 480*832:
|
||||
self.vid_proc.set_seq_len(32760)
|
||||
else:
|
||||
raise NotImplementedError(f'image_size {image_size} is not supported')
|
||||
|
||||
image_size = (image_size[1], image_size[0])
|
||||
image_sizes = []
|
||||
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||
if sub_src_mask is not None and sub_src_video is not None:
|
||||
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
|
||||
src_video[i] = src_video[i].to(device)
|
||||
src_mask[i] = src_mask[i].to(device)
|
||||
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
||||
image_sizes.append(src_video[i].shape[2:])
|
||||
elif sub_src_video is None:
|
||||
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||
image_sizes.append(image_size)
|
||||
else:
|
||||
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
|
||||
src_video[i] = src_video[i].to(device)
|
||||
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||
image_sizes.append(src_video[i].shape[2:])
|
||||
|
||||
for i, ref_images in enumerate(src_ref_images):
|
||||
if ref_images is not None:
|
||||
image_size = image_sizes[i]
|
||||
for j, ref_img in enumerate(ref_images):
|
||||
if ref_img is not None:
|
||||
ref_img = Image.open(ref_img).convert("RGB")
|
||||
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||
if ref_img.shape[-2:] != image_size:
|
||||
canvas_height, canvas_width = image_size
|
||||
ref_height, ref_width = ref_img.shape[-2:]
|
||||
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
||||
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||
new_height = int(ref_height * scale)
|
||||
new_width = int(ref_width * scale)
|
||||
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
||||
top = (canvas_height - new_height) // 2
|
||||
left = (canvas_width - new_width) // 2
|
||||
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
||||
ref_img = white_canvas
|
||||
src_ref_images[i][j] = ref_img.to(device)
|
||||
return src_video, src_mask, src_ref_images
|
||||
|
||||
def decode_latent(self, zs, ref_images=None, vae=None):
|
||||
vae = self.vae if vae is None else vae
|
||||
if ref_images is None:
|
||||
ref_images = [None] * len(zs)
|
||||
else:
|
||||
assert len(zs) == len(ref_images)
|
||||
|
||||
trimed_zs = []
|
||||
for z, refs in zip(zs, ref_images):
|
||||
if refs is not None:
|
||||
z = z[:, len(refs):, :, :]
|
||||
trimed_zs.append(z)
|
||||
|
||||
return vae.decode(trimed_zs)
|
||||
|
||||
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
input_frames,
|
||||
input_masks,
|
||||
input_ref_images,
|
||||
size=(1280, 720),
|
||||
frame_num=81,
|
||||
context_scale=1.0,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=50,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True):
|
||||
r"""
|
||||
Generates video frames from text prompt using diffusion process.
|
||||
|
||||
Args:
|
||||
input_prompt (`str`):
|
||||
Text prompt for content generation
|
||||
size (tupele[`int`], *optional*, defaults to (1280,720)):
|
||||
Controls video resolution, (width,height).
|
||||
frame_num (`int`, *optional*, defaults to 81):
|
||||
How many frames to sample from a video. The number should be 4n+1
|
||||
shift (`float`, *optional*, defaults to 5.0):
|
||||
Noise schedule shift parameter. Affects temporal dynamics
|
||||
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
||||
Solver used to sample the video.
|
||||
sampling_steps (`int`, *optional*, defaults to 40):
|
||||
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
||||
guide_scale (`float`, *optional*, defaults 5.0):
|
||||
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
||||
n_prompt (`str`, *optional*, defaults to ""):
|
||||
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
||||
seed (`int`, *optional*, defaults to -1):
|
||||
Random seed for noise generation. If -1, use random seed.
|
||||
offload_model (`bool`, *optional*, defaults to True):
|
||||
If True, offloads models to CPU during generation to save VRAM
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
||||
- C: Color channels (3 for RGB)
|
||||
- N: Number of frames (81)
|
||||
- H: Frame height (from size)
|
||||
- W: Frame width from size)
|
||||
"""
|
||||
# preprocess
|
||||
# F = frame_num
|
||||
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
||||
# size[1] // self.vae_stride[1],
|
||||
# size[0] // self.vae_stride[2])
|
||||
#
|
||||
# seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
||||
# (self.patch_size[1] * self.patch_size[2]) *
|
||||
# target_shape[1] / self.sp_size) * self.sp_size
|
||||
|
||||
if n_prompt == "":
|
||||
n_prompt = self.sample_neg_prompt
|
||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||
seed_g = torch.Generator(device=self.device)
|
||||
seed_g.manual_seed(seed)
|
||||
|
||||
if not self.t5_cpu:
|
||||
self.text_encoder.model.to(self.device)
|
||||
context = self.text_encoder([input_prompt], self.device)
|
||||
context_null = self.text_encoder([n_prompt], self.device)
|
||||
if offload_model:
|
||||
self.text_encoder.model.cpu()
|
||||
else:
|
||||
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
||||
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||
context = [t.to(self.device) for t in context]
|
||||
context_null = [t.to(self.device) for t in context_null]
|
||||
|
||||
# vace context encode
|
||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks)
|
||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||
z = self.vace_latent(z0, m0)
|
||||
|
||||
target_shape = list(z0[0].shape)
|
||||
target_shape[0] = int(target_shape[0] / 2)
|
||||
noise = [
|
||||
torch.randn(
|
||||
target_shape[0],
|
||||
target_shape[1],
|
||||
target_shape[2],
|
||||
target_shape[3],
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
generator=seed_g)
|
||||
]
|
||||
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
||||
(self.patch_size[1] * self.patch_size[2]) *
|
||||
target_shape[1] / self.sp_size) * self.sp_size
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
|
||||
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
||||
|
||||
# evaluation mode
|
||||
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
||||
|
||||
if sample_solver == 'unipc':
|
||||
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False)
|
||||
sample_scheduler.set_timesteps(
|
||||
sampling_steps, device=self.device, shift=shift)
|
||||
timesteps = sample_scheduler.timesteps
|
||||
elif sample_solver == 'dpm++':
|
||||
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
||||
num_train_timesteps=self.num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False)
|
||||
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
||||
timesteps, _ = retrieve_timesteps(
|
||||
sample_scheduler,
|
||||
device=self.device,
|
||||
sigmas=sampling_sigmas)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported solver.")
|
||||
|
||||
# sample videos
|
||||
latents = noise
|
||||
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
timestep = [t]
|
||||
|
||||
timestep = torch.stack(timestep)
|
||||
|
||||
self.model.to(self.device)
|
||||
noise_pred_cond = self.model(
|
||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0]
|
||||
noise_pred_uncond = self.model(
|
||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0]
|
||||
|
||||
noise_pred = noise_pred_uncond + guide_scale * (
|
||||
noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
temp_x0 = sample_scheduler.step(
|
||||
noise_pred.unsqueeze(0),
|
||||
t,
|
||||
latents[0].unsqueeze(0),
|
||||
return_dict=False,
|
||||
generator=seed_g)[0]
|
||||
latents = [temp_x0.squeeze(0)]
|
||||
|
||||
x0 = latents
|
||||
if offload_model:
|
||||
self.model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
if self.rank == 0:
|
||||
videos = self.decode_latent(x0, input_ref_images)
|
||||
|
||||
del noise, latents
|
||||
del sample_scheduler
|
||||
if offload_model:
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
return videos[0] if self.rank == 0 else None
|
||||
|
||||
|
||||
class WanVaceMP(WanVace):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
checkpoint_dir,
|
||||
use_usp=False,
|
||||
ulysses_size=None,
|
||||
ring_size=None
|
||||
):
|
||||
self.config = config
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.use_usp = use_usp
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '12345'
|
||||
os.environ['RANK'] = '0'
|
||||
os.environ['WORLD_SIZE'] = '1'
|
||||
self.in_q_list = None
|
||||
self.out_q = None
|
||||
self.inference_pids = None
|
||||
self.ulysses_size = ulysses_size
|
||||
self.ring_size = ring_size
|
||||
self.dynamic_load()
|
||||
|
||||
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
|
||||
self.vid_proc = VaceVideoProcessor(
|
||||
downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]),
|
||||
min_area=480 * 832,
|
||||
max_area=480 * 832,
|
||||
min_fps=self.config.sample_fps,
|
||||
max_fps=self.config.sample_fps,
|
||||
zero_start=True,
|
||||
seq_len=32760,
|
||||
keep_last=True)
|
||||
|
||||
|
||||
def dynamic_load(self):
|
||||
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
|
||||
return
|
||||
gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count()
|
||||
pmi_rank = int(os.environ['RANK'])
|
||||
pmi_world_size = int(os.environ['WORLD_SIZE'])
|
||||
in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)]
|
||||
out_q = torch.multiprocessing.Manager().Queue()
|
||||
initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)]
|
||||
context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False)
|
||||
all_initialized = False
|
||||
while not all_initialized:
|
||||
all_initialized = all(event.is_set() for event in initialized_events)
|
||||
if not all_initialized:
|
||||
time.sleep(0.1)
|
||||
print('Inference model is initialized', flush=True)
|
||||
self.in_q_list = in_q_list
|
||||
self.out_q = out_q
|
||||
self.inference_pids = context.pids()
|
||||
self.initialized_events = initialized_events
|
||||
|
||||
def transfer_data_to_cuda(self, data, device):
|
||||
if data is None:
|
||||
return None
|
||||
else:
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.to(device)
|
||||
elif isinstance(data, list):
|
||||
data = [self.transfer_data_to_cuda(subdata, device) for subdata in data]
|
||||
elif isinstance(data, dict):
|
||||
data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()}
|
||||
return data
|
||||
|
||||
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env):
|
||||
try:
|
||||
world_size = pmi_world_size * gpu_infer
|
||||
rank = pmi_rank * gpu_infer + gpu
|
||||
print("world_size", world_size, "rank", rank, flush=True)
|
||||
|
||||
torch.cuda.set_device(gpu)
|
||||
dist.init_process_group(
|
||||
backend='nccl',
|
||||
init_method='env://',
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
)
|
||||
|
||||
from xfuser.core.distributed import (initialize_model_parallel,
|
||||
init_distributed_environment)
|
||||
init_distributed_environment(
|
||||
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||
|
||||
initialize_model_parallel(
|
||||
sequence_parallel_degree=dist.get_world_size(),
|
||||
ring_degree=self.ring_size or 1,
|
||||
ulysses_degree=self.ulysses_size or 1
|
||||
)
|
||||
|
||||
num_train_timesteps = self.config.num_train_timesteps
|
||||
param_dtype = self.config.param_dtype
|
||||
shard_fn = partial(shard_model, device_id=gpu)
|
||||
text_encoder = T5EncoderModel(
|
||||
text_len=self.config.text_len,
|
||||
dtype=self.config.t5_dtype,
|
||||
device=torch.device('cpu'),
|
||||
checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint),
|
||||
tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer),
|
||||
shard_fn=shard_fn if True else None)
|
||||
text_encoder.model.to(gpu)
|
||||
vae_stride = self.config.vae_stride
|
||||
patch_size = self.config.patch_size
|
||||
vae = WanVAE(
|
||||
vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint),
|
||||
device=gpu)
|
||||
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
|
||||
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
|
||||
model.eval().requires_grad_(False)
|
||||
|
||||
if self.use_usp:
|
||||
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||
usp_dit_forward,
|
||||
usp_dit_forward_vace)
|
||||
for block in model.blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
for block in model.vace_blocks:
|
||||
block.self_attn.forward = types.MethodType(
|
||||
usp_attn_forward, block.self_attn)
|
||||
model.forward = types.MethodType(usp_dit_forward, model)
|
||||
model.forward_vace = types.MethodType(usp_dit_forward_vace, model)
|
||||
sp_size = get_sequence_parallel_world_size()
|
||||
else:
|
||||
sp_size = 1
|
||||
|
||||
dist.barrier()
|
||||
model = shard_fn(model)
|
||||
sample_neg_prompt = self.config.sample_neg_prompt
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
event = initialized_events[gpu]
|
||||
in_q = in_q_list[gpu]
|
||||
event.set()
|
||||
|
||||
while True:
|
||||
item = in_q.get()
|
||||
input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \
|
||||
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
|
||||
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
|
||||
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
|
||||
input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu)
|
||||
|
||||
if n_prompt == "":
|
||||
n_prompt = sample_neg_prompt
|
||||
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||
seed_g = torch.Generator(device=gpu)
|
||||
seed_g.manual_seed(seed)
|
||||
|
||||
context = text_encoder([input_prompt], gpu)
|
||||
context_null = text_encoder([n_prompt], gpu)
|
||||
|
||||
# vace context encode
|
||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae)
|
||||
m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride)
|
||||
z = self.vace_latent(z0, m0)
|
||||
|
||||
target_shape = list(z0[0].shape)
|
||||
target_shape[0] = int(target_shape[0] / 2)
|
||||
noise = [
|
||||
torch.randn(
|
||||
target_shape[0],
|
||||
target_shape[1],
|
||||
target_shape[2],
|
||||
target_shape[3],
|
||||
dtype=torch.float32,
|
||||
device=gpu,
|
||||
generator=seed_g)
|
||||
]
|
||||
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
||||
(patch_size[1] * patch_size[2]) *
|
||||
target_shape[1] / sp_size) * sp_size
|
||||
|
||||
@contextmanager
|
||||
def noop_no_sync():
|
||||
yield
|
||||
|
||||
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
||||
|
||||
# evaluation mode
|
||||
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
|
||||
|
||||
if sample_solver == 'unipc':
|
||||
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False)
|
||||
sample_scheduler.set_timesteps(
|
||||
sampling_steps, device=gpu, shift=shift)
|
||||
timesteps = sample_scheduler.timesteps
|
||||
elif sample_solver == 'dpm++':
|
||||
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
shift=1,
|
||||
use_dynamic_shifting=False)
|
||||
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
||||
timesteps, _ = retrieve_timesteps(
|
||||
sample_scheduler,
|
||||
device=gpu,
|
||||
sigmas=sampling_sigmas)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported solver.")
|
||||
|
||||
# sample videos
|
||||
latents = noise
|
||||
|
||||
arg_c = {'context': context, 'seq_len': seq_len}
|
||||
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||
|
||||
for _, t in enumerate(tqdm(timesteps)):
|
||||
latent_model_input = latents
|
||||
timestep = [t]
|
||||
|
||||
timestep = torch.stack(timestep)
|
||||
|
||||
model.to(gpu)
|
||||
noise_pred_cond = model(
|
||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[
|
||||
0]
|
||||
noise_pred_uncond = model(
|
||||
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,
|
||||
**arg_null)[0]
|
||||
|
||||
noise_pred = noise_pred_uncond + guide_scale * (
|
||||
noise_pred_cond - noise_pred_uncond)
|
||||
|
||||
temp_x0 = sample_scheduler.step(
|
||||
noise_pred.unsqueeze(0),
|
||||
t,
|
||||
latents[0].unsqueeze(0),
|
||||
return_dict=False,
|
||||
generator=seed_g)[0]
|
||||
latents = [temp_x0.squeeze(0)]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
x0 = latents
|
||||
if rank == 0:
|
||||
videos = self.decode_latent(x0, input_ref_images, vae=vae)
|
||||
|
||||
del noise, latents
|
||||
del sample_scheduler
|
||||
if offload_model:
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
if rank == 0:
|
||||
out_q.put(videos[0].cpu())
|
||||
|
||||
except Exception as e:
|
||||
trace_info = traceback.format_exc()
|
||||
print(trace_info, flush=True)
|
||||
print(e, flush=True)
|
||||
|
||||
|
||||
|
||||
def generate(self,
|
||||
input_prompt,
|
||||
input_frames,
|
||||
input_masks,
|
||||
input_ref_images,
|
||||
size=(1280, 720),
|
||||
frame_num=81,
|
||||
context_scale=1.0,
|
||||
shift=5.0,
|
||||
sample_solver='unipc',
|
||||
sampling_steps=50,
|
||||
guide_scale=5.0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
offload_model=True):
|
||||
|
||||
input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale,
|
||||
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model)
|
||||
for in_q in self.in_q_list:
|
||||
in_q.put(input_data)
|
||||
value_output = self.out_q.get()
|
||||
|
||||
return value_output
|
Loading…
Reference in New Issue
Block a user