Compare commits

...

5 Commits

Author SHA1 Message Date
Knoka
55fdfe155d
Merge c6c5675a06 into c709fcf0e7 2025-05-16 11:24:39 +08:00
Zhen Han
c709fcf0e7
fix vace size (#397) 2025-05-14 22:01:45 +08:00
Ang Wang
18d53feb7a
[feature] Add VACE (#389)
* Add VACE

* Support training with multiple gpus

* Update default args for vace task

* vace block update

* Add vace exmaple jpg

* Fix dist vace fwd hook error

* Update vace exmample

* Update vace args

* Update pipeline name for vace

* vace gradio and Readme

* Update vace snake png

---------

Co-authored-by: hanzhn <han.feng.jason@gmail.com>
2025-05-14 20:44:25 +08:00
Knoka
c6c5675a06
Add FastAPI
Added the service deployment code of FastAPI, including key authentication, task submission, task details viewing, and task canceling
2025-04-09 15:52:55 +08:00
knoka812
0961b7b888 Added the service deployment code of FastAPI, including key authentication, task submission, task details viewing, and task canceling 2025-04-09 15:37:10 +08:00
19 changed files with 3004 additions and 17 deletions

151
I2V-FastAPI文档.md Normal file
View File

@ -0,0 +1,151 @@
# 图像到视频生成服务API文档
## 一、功能概述
基于Wan2.1-I2V-14B-480P模型实现图像到视频生成核心功能包括
1. **异步任务队列**支持多任务排队和并发控制最大2个并行任务
2. **智能分辨率适配**
- 支持自动计算最佳分辨率(保持原图比例)
- 支持手动指定分辨率480x832/832x480
3. **资源管理**
- 显存优化bfloat16精度
- 生成文件自动清理默认1小时
4. **安全认证**基于API Key的Bearer Token验证
5. **任务控制**:支持任务提交/状态查询/取消操作
技术栈:
- FastAPI框架
- CUDA加速
- 异步任务处理
- Diffusers推理库
---
## 二、接口说明
### 1. 提交生成任务
**POST /video/submit**
```json
{
"model": "Wan2.1-I2V-14B-480P",
"prompt": "A dancing cat in the style of Van Gogh",
"image_url": "https://example.com/input.jpg",
"image_size": "auto",
"num_frames": 81,
"guidance_scale": 3.0,
"infer_steps": 30
}
```
**响应示例**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
### 2. 查询任务状态
**POST /video/status**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
**响应示例**
```json
{
"status": "Succeed",
"results": {
"videos": [{"url": "http://localhost:8088/videos/abcd1234.mp4"}],
"timings": {"inference": 90},
"seed": 123456
}
}
```
### 3. 取消任务
**POST /video/cancel**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
**响应示例**
```json
{
"status": "Succeed"
}
```
---
## 三、Postman使用指南
### 1. 基础配置
- 服务器地址:`http://ip地址:8088`
- 认证方式Bearer Token
- Token值需替换为有效API Key
### 2. 提交任务
1. 选择POST方法URL填写`/video/submit`
2. Headers添加
```text
Authorization: Bearer YOUR_API_KEY
Content-Type: application/json
```
3. Body示例图像生成视频
```json
{
"prompt": "Sunset scene with mountains",
"image_url": "https://example.com/mountain.jpg",
"image_size": "auto",
"num_frames": 50
}
```
### 3. 特殊处理
- **图像下载失败**返回400错误包含具体原因如URL无效/超时)
- **显存不足**返回500错误并提示降低分辨率
---
## 四、参数规范
| 参数名 | 允许值范围 | 必填 | 说明 |
|------------------|-------------------------------|------|------------------------------------------|
| image_url | 有效HTTP/HTTPS URL | 是 | 输入图像地址 |
| prompt | 10-500字符 | 是 | 视频内容描述 |
| image_size | "480x832", "832x480", "auto" | 是 | auto模式自动适配原图比例 |
| num_frames | 24-120 | 是 | 视频总帧数 |
| guidance_scale | 1.0-20.0 | 是 | 文本引导强度 |
| infer_steps | 20-100 | 是 | 推理步数 |
| seed | 0-2147483647 | 否 | 随机种子 |
---
## 五、状态码说明
| 状态码 | 含义 |
|--------|-----------------------------------|
| 202 | 任务已接受 |
| 400 | 图像下载失败/参数错误 |
| 401 | 认证失败 |
| 404 | 任务不存在 |
| 422 | 参数校验失败 |
| 500 | 服务端错误(显存不足/模型异常等) |
---
## 六、特殊功能说明
1. **智能分辨率适配**
- 当`image_size="auto"`时,自动计算符合模型要求的最优分辨率
- 保持原始图像宽高比最大像素面积不超过399,360约640x624
2. **图像预处理**
- 自动转换为RGB模式
- 根据目标分辨率进行等比缩放
**重要提示**输入图像URL需保证公开可访问私有资源需提供有效鉴权
**提示** :访问`http://服务器地址:8088/docs`可查看交互式API文档支持在线测试所有接口

View File

@ -27,6 +27,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
## 🔥 Latest News!!
* May 14, 2025: 👋 We introduce **Wan2.1** [VACE](https://github.com/ali-vilab/VACE), an all-in-one model for video creation and editing, along with its [inference code](#run-vace), [weights](#model-download), and [technical report](https://arxiv.org/abs/2503.07598)!
* Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights!
* Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback!
* Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try!
@ -64,7 +65,13 @@ If your work has improved **Wan2.1** and you would like more people to see it, p
- [ ] ComfyUI integration
- [ ] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
- Wan2.1 VACE
- [x] Multi-GPU Inference code of the 14B and 1.3B models
- [x] Checkpoints of the 14B and 1.3B models
- [x] Gradio demo
- [x] ComfyUI integration
- [ ] Diffusers integration
- [ ] Diffusers + Multi-GPU Inference
## Quickstart
@ -84,13 +91,15 @@ pip install -r requirements.txt
#### Model Download
| Models | Download Link | Notes |
|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
| Models | Download Link | Notes |
|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
| VACE-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | Supports 480P
| VACE-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | Supports both 480P and 720P
> 💡Note:
> * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
@ -448,6 +457,73 @@ DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dash
```
#### Run VACE
[VACE](https://github.com/ali-vilab/VACE) now supports two models (1.3B and 14B) and two main resolutions (480P and 720P).
The input supports any resolution, but to achieve optimal results, the video size should fall within a specific range.
The parameters and configurations for these models are as follows:
<table>
<thead>
<tr>
<th rowspan="2">Task</th>
<th colspan="2">Resolution</th>
<th rowspan="2">Model</th>
</tr>
<tr>
<th>480P(~81x480x832)</th>
<th>720P(~81x720x1280)</th>
</tr>
</thead>
<tbody>
<tr>
<td>VACE</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td>Wan2.1-VACE-14B</td>
</tr>
<tr>
<td>VACE</td>
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
<td style="color: red; text-align: center; vertical-align: middle;"></td>
<td>Wan2.1-VACE-1.3B</td>
</tr>
</tbody>
</table>
In VACE, users can input text prompt and optional video, mask, and image for video generation or editing. Detailed instructions for using VACE can be found in the [User Guide](https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).
The execution process is as follows:
##### (1) Preprocessing
User-collected materials needs to be preprocessed into VACE-recognizable inputs, including `src_video`, `src_mask`, `src_ref_images`, and `prompt`.
For R2V (Reference-to-Video Generation), you may skip this preprocessing, but for V2V (Video-to-Video Editing) and MV2V (Masked Video-to-Video Editing) tasks, additional preprocessing is required to obtain video with conditions such as depth, pose or masked regions.
For more details, please refer to [vace_preproccess](https://github.com/ali-vilab/VACE/blob/main/vace/vace_preproccess.py).
##### (2) cli inference
- Single-GPU inference
```sh
python generate.py --task vace-1.3B --size 832*480 --ckpt_dir ./Wan2.1-VACE-1.3B --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
```
- Multi-GPU inference using FSDP + xDiT USP
```sh
torchrun --nproc_per_node=8 generate.py --task vace-14B --size 1280*720 --ckpt_dir ./Wan2.1-VACE-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
```
##### (3) Running local gradio
- Single-GPU inference
```sh
python gradio/vace.py --ckpt_dir ./Wan2.1-VACE-1.3B
```
- Multi-GPU inference using FSDP + xDiT USP
```sh
python gradio/vace.py --mp --ulysses_size 8 --ckpt_dir ./Wan2.1-VACE-14B/
```
#### Run Text-to-Image Generation
Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:

133
T2V-FastAPI文档.md Normal file
View File

@ -0,0 +1,133 @@
# 视频生成服务API文档
## 一、功能概述
本服务基于Wan2.1-T2V-1.3B模型实现文本到视频生成,包含以下核心功能:
1. **异步任务队列**支持多任务排队和并发控制最大2个并行任务
2. **资源管理**
- 显存优化使用bfloat16精度
- 生成视频自动清理默认1小时后删除
3. **安全认证**基于API Key的Bearer Token验证
4. **任务控制**:支持任务提交/状态查询/取消操作
技术栈:
- FastAPI框架
- CUDA加速
- 异步任务处理
- Diffusers推理库
---
## 二、接口说明
### 1. 提交生成任务
**POST /video/submit**
```json
{
"model": "Wan2.1-T2V-1.3B",
"prompt": "A beautiful sunset over the mountains",
"image_size": "480x832",
"num_frames": 81,
"guidance_scale": 5.0,
"infer_steps": 50
}
```
**响应示例**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
### 2. 查询任务状态
**POST /video/status**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
**响应示例**
```json
{
"status": "Succeed",
"results": {
"videos": [{"url": "http://localhost:8088/videos/abcd1234.mp4"}],
"timings": {"inference": 120}
}
}
```
### 3. 取消任务
**POST /video/cancel**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
**响应示例**
```json
{
"status": "Succeed"
}
```
---
## 三、Postman使用指南
### 1. 基础配置
- 服务器地址:`http://ip地址:8088`
- 认证方式Bearer Token
- Token值需替换为有效API Key
### 2. 提交任务
1. 选择POST方法输入URL`/video/submit`
2. Headers添加
```text
Authorization: Bearer YOUR_API_KEY
Content-Type: application/json
```
3. Body选择raw/JSON格式输入请求参数
### 3. 查询状态
1. 新建请求URL填写`/video/status`
2. 使用相同认证头
3. Body中携带requestId
### 4. 取消任务
1. 新建DELETE请求URL填写`/video/cancel`
2. Body携带需要取消的requestId
### 注意事项
1. 所有接口必须携带有效API Key
2. 视频生成耗时约2-5分钟根据参数配置
3. 生成视频默认保留1小时
---
## 四、参数规范
| 参数名 | 允许值范围 | 必填 | 说明 |
|------------------|-------------------------------|------|--------------------------|
| prompt | 10-500字符 | 是 | 视频内容描述 |
| image_size | "480x832" 或 "832x480" | 是 | 分辨率 |
| num_frames | 24-120 | 是 | 视频总帧数 |
| guidance_scale | 1.0-20.0 | 是 | 文本引导强度 |
| infer_steps | 20-100 | 是 | 推理步数 |
| seed | 0-2147483647 | 否 | 随机种子 |
---
## 五、状态码说明
| 状态码 | 含义 |
|--------|--------------------------|
| 202 | 任务已接受 |
| 401 | 认证失败 |
| 404 | 任务不存在 |
| 422 | 参数校验失败 |
| 500 | 服务端错误(显存不足等) |
**提示**建议使用Swagger文档进行接口测试访问`http://服务器地址:8088/docs`可查看自动生成的API文档界面

BIN
examples/girl.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 817 KiB

BIN
examples/snake.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 435 KiB

View File

@ -41,6 +41,14 @@ EXAMPLE_PROMPT = {
"last_frame":
"examples/flf2v_input_last_frame.png",
},
"vace-1.3B": {
"src_ref_images": 'examples/girl.png,examples/snake.png',
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
},
"vace-14B": {
"src_ref_images": 'examples/girl.png,examples/snake.png',
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
}
}
@ -52,15 +60,19 @@ def _validate_args(args):
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
if args.sample_steps is None:
args.sample_steps = 40 if "i2v" in args.task else 50
args.sample_steps = 50
if "i2v" in args.task:
args.sample_steps = 40
if args.sample_shift is None:
args.sample_shift = 5.0
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
args.sample_shift = 3.0
if "flf2v" in args.task:
elif "flf2v" in args.task or "vace" in args.task:
args.sample_shift = 16
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
args.frame_num = 1 if "t2i" in args.task else 81
@ -141,6 +153,21 @@ def _parse_args():
type=str,
default=None,
help="The file to save the generated image or video to.")
parser.add_argument(
"--src_video",
type=str,
default=None,
help="The file of the source video. Default None.")
parser.add_argument(
"--src_mask",
type=str,
default=None,
help="The file of the source mask. Default None.")
parser.add_argument(
"--src_ref_images",
type=str,
default=None,
help="The file list of the source reference images. Separated by ','. Default None.")
parser.add_argument(
"--prompt",
type=str,
@ -397,7 +424,7 @@ def generate(args):
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
elif "flf2v" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
if args.first_frame is None or args.last_frame is None:
@ -457,6 +484,60 @@ def generate(args):
seed=args.base_seed,
offload_model=args.offload_model
)
elif "vace" in args.task:
if args.prompt is None:
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None)
logging.info(f"Input prompt: {args.prompt}")
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
logging.info("Extending prompt ...")
if rank == 0:
prompt = prompt_expander.forward(args.prompt)
logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'")
input_prompt = [prompt]
else:
input_prompt = [None]
if dist.is_initialized():
dist.broadcast_object_list(input_prompt, src=0)
args.prompt = input_prompt[0]
logging.info(f"Extended prompt: {args.prompt}")
logging.info("Creating VACE pipeline.")
wan_vace = wan.WanVace(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video],
[args.src_mask],
[None if args.src_ref_images is None else args.src_ref_images.split(',')],
args.frame_num, SIZE_CONFIGS[args.size], device)
logging.info(f"Generating video...")
video = wan_vace.generate(
args.prompt,
src_video,
src_mask,
src_ref_images,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
else:
raise ValueError(f"Unkown task type: {args.task}")
if rank == 0:
if args.save_file is None:

295
gradio/vace.py Normal file
View File

@ -0,0 +1,295 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
import sys
import datetime
import imageio
import numpy as np
import torch
import gradio as gr
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
import wan
from wan import WanVace, WanVaceMP
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
class FixedSizeQueue:
def __init__(self, max_size):
self.max_size = max_size
self.queue = []
def add(self, item):
self.queue.insert(0, item)
if len(self.queue) > self.max_size:
self.queue.pop()
def get(self):
return self.queue
def __repr__(self):
return str(self.queue)
class VACEInference:
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
self.cfg = cfg
self.save_dir = cfg.save_dir
self.gallery_share = gallery_share
self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
if not skip_load:
if not args.mp:
self.pipe = WanVace(
config=WAN_CONFIGS[cfg.model_name],
checkpoint_dir=cfg.ckpt_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
)
else:
self.pipe = WanVaceMP(
config=WAN_CONFIGS[cfg.model_name],
checkpoint_dir=cfg.ckpt_dir,
use_usp=True,
ulysses_size=cfg.ulysses_size,
ring_size=cfg.ring_size
)
def create_ui(self, *args, **kwargs):
gr.Markdown("""
<div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
<a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
</div>
""")
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
self.src_video = gr.Video(
label="src_video",
sources=['upload'],
value=None,
interactive=True)
with gr.Column(scale=1, min_width=0):
self.src_mask = gr.Video(
label="src_mask",
sources=['upload'],
value=None,
interactive=True)
#
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_1",
format='png')
self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_2",
format='png')
self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
height=200,
interactive=True,
type='filepath',
image_mode='RGB',
sources=['upload'],
elem_id="src_ref_image_3",
format='png')
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1):
self.prompt = gr.Textbox(
show_label=False,
placeholder="positive_prompt_input",
elem_id='positive_prompt',
container=True,
autofocus=True,
elem_classes='type_row',
visible=True,
lines=2)
self.negative_prompt = gr.Textbox(
show_label=False,
value=self.pipe.config.sample_neg_prompt,
placeholder="negative_prompt_input",
elem_id='negative_prompt',
container=True,
autofocus=False,
elem_classes='type_row',
visible=True,
interactive=True,
lines=1)
#
with gr.Row(variant='panel', equal_height=True):
with gr.Column(scale=1, min_width=0):
with gr.Row(equal_height=True):
self.shift_scale = gr.Slider(
label='shift_scale',
minimum=0.0,
maximum=100.0,
step=1.0,
value=16.0,
interactive=True)
self.sample_steps = gr.Slider(
label='sample_steps',
minimum=1,
maximum=100,
step=1,
value=25,
interactive=True)
self.context_scale = gr.Slider(
label='context_scale',
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
interactive=True)
self.guide_scale = gr.Slider(
label='guide_scale',
minimum=1,
maximum=10,
step=0.5,
value=5.0,
interactive=True)
self.infer_seed = gr.Slider(minimum=-1,
maximum=10000000,
value=2025,
label="Seed")
#
with gr.Accordion(label="Usable without source video", open=False):
with gr.Row(equal_height=True):
self.output_height = gr.Textbox(
label='resolutions_height',
# value=480,
value=720,
interactive=True)
self.output_width = gr.Textbox(
label='resolutions_width',
# value=832,
value=1280,
interactive=True)
self.frame_rate = gr.Textbox(
label='frame_rate',
value=16,
interactive=True)
self.num_frames = gr.Textbox(
label='num_frames',
value=81,
interactive=True)
#
with gr.Row(equal_height=True):
with gr.Column(scale=5):
self.generate_button = gr.Button(
value='Run',
elem_classes='type_row',
elem_id='generate_button',
visible=True)
with gr.Column(scale=1):
self.refresh_button = gr.Button(value='\U0001f504') # 🔄
#
self.output_gallery = gr.Gallery(
label="output_gallery",
value=[],
interactive=False,
allow_preview=True,
preview=True)
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
x is not None]
src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
[src_mask],
[src_ref_images],
num_frames=num_frames,
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
device=self.pipe.device)
video = self.pipe.generate(
prompt,
src_video,
src_mask,
src_ref_images,
size=(output_width, output_height),
context_scale=context_scale,
shift=shift_scale,
sampling_steps=sample_steps,
guide_scale=guide_scale,
n_prompt=negative_prompt,
seed=infer_seed,
offload_model=True)
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
try:
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
for frame in video_frames:
writer.append_data(frame)
writer.close()
print(video_path)
except Exception as e:
raise gr.Error(f"Video save error: {e}")
if self.gallery_share:
self.gallery_share_data.add(video_path)
return self.gallery_share_data.get()
else:
return [video_path]
def set_callbacks(self, **kwargs):
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
self.gen_outputs = [self.output_gallery]
self.generate_button.click(self.generate,
inputs=self.gen_inputs,
outputs=self.gen_outputs,
queue=True)
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n')
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
parser.add_argument('--root_path', dest='root_path', help='', default=None)
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",)
parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.")
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--ckpt_dir",
type=str,
# default='models/VACE-Wan2.1-1.3B-Preview',
default='models/Wan2.1-VACE-14B/',
help="The path to the checkpoint directory.",
)
parser.add_argument(
"--offload_to_cpu",
action="store_true",
help="Offloading unnecessary computations to CPU.",
)
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir, exist_ok=True)
with gr.Blocks() as demo:
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
infer_gr.create_ui()
infer_gr.set_callbacks()
allowed_paths = [args.save_dir]
demo.queue(status_update_rate=1).launch(server_name=args.server_name,
server_port=args.server_port,
root_path=args.root_path,
allowed_paths=allowed_paths,
show_error=True, debug=True)

526
i2v_api.py Normal file
View File

@ -0,0 +1,526 @@
import os
import torch
import uuid
import time
import asyncio
import numpy as np
from threading import Lock
from typing import Optional, Dict, List
from fastapi import FastAPI, HTTPException, status, Depends
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, field_validator, ValidationError
from diffusers.utils import export_to_video, load_image
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from transformers import CLIPVisionModel
from PIL import Image
import requests
from io import BytesIO
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from contextlib import asynccontextmanager
from requests.exceptions import RequestException
# 创建存储目录
os.makedirs("generated_videos", exist_ok=True)
os.makedirs("temp_images", exist_ok=True)
# ======================
# 生命周期管理
# ======================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""资源管理器"""
try:
# 初始化认证系统
app.state.valid_api_keys = {
"密钥"
}
# 初始化模型
model_id = "./Wan2.1-I2V-14B-480P-Diffusers"
# 加载图像编码器
image_encoder = CLIPVisionModel.from_pretrained(
model_id,
subfolder="image_encoder",
torch_dtype=torch.float32
)
# 加载VAE
vae = AutoencoderKLWan.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=torch.float32
)
# 配置调度器
scheduler = UniPCMultistepScheduler(
prediction_type='flow_prediction',
use_flow_sigmas=True,
num_train_timesteps=1000,
flow_shift=3.0
)
# 创建管道
app.state.pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
).to("cuda")
app.state.pipe.scheduler = scheduler
# 初始化任务系统
app.state.tasks: Dict[str, dict] = {}
app.state.pending_queue: List[str] = []
app.state.model_lock = Lock()
app.state.task_lock = Lock()
app.state.base_url = "ip地址+端口"
app.state.semaphore = asyncio.Semaphore(2) # 并发限制
# 启动后台处理器
asyncio.create_task(task_processor())
print("✅ 系统初始化完成")
yield
finally:
# 资源清理
if hasattr(app.state, 'pipe'):
del app.state.pipe
torch.cuda.empty_cache()
print("♻️ 资源已释放")
# ======================
# FastAPI应用
# ======================
app = FastAPI(lifespan=lifespan)
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
# 认证模块
security = HTTPBearer(auto_error=False)
# ======================
# 数据模型--查询参数模型
# ======================
class VideoSubmitRequest(BaseModel):
model: str = Field(
default="Wan2.1-I2V-14B-480P",
description="模型版本"
)
prompt: str = Field(
...,
min_length=10,
max_length=500,
description="视频描述提示词10-500个字符"
)
image_url: str = Field(
...,
description="输入图像URL需支持HTTP/HTTPS协议"
)
image_size: str = Field(
default="auto",
description="输出分辨率格式宽x高 或 auto自动计算"
)
negative_prompt: Optional[str] = Field(
default=None,
max_length=500,
description="排除不需要的内容"
)
seed: Optional[int] = Field(
default=None,
ge=0,
le=2147483647,
description="随机数种子范围0-2147483647"
)
num_frames: int = Field(
default=81,
ge=24,
le=120,
description="视频帧数24-89帧"
)
guidance_scale: float = Field(
default=3.0,
ge=1.0,
le=20.0,
description="引导系数1.0-20.0"
)
infer_steps: int = Field(
default=30,
ge=20,
le=100,
description="推理步数20-100步"
)
@field_validator('image_size')
def validate_image_size(cls, v):
allowed_sizes = {"480x832", "832x480", "auto"}
if v not in allowed_sizes:
raise ValueError(f"支持的分辨率: {', '.join(allowed_sizes)}")
return v
class VideoStatusRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
class VideoStatusResponse(BaseModel):
status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
reason: Optional[str] = Field(None, description="失败原因")
results: Optional[dict] = Field(None, description="生成结果")
queue_position: Optional[int] = Field(None, description="队列位置")
class VideoCancelRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
# ======================
# 核心逻辑
# ======================
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""统一认证验证"""
if not credentials:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "缺少认证头"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的认证方案"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.credentials not in app.state.valid_api_keys:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的API密钥"},
headers={"WWW-Authenticate": "Bearer"}
)
return True
async def task_processor():
"""任务处理器"""
while True:
async with app.state.semaphore:
task_id = await get_next_task()
if task_id:
await process_task(task_id)
else:
await asyncio.sleep(0.5)
async def get_next_task():
"""获取下一个任务"""
with app.state.task_lock:
return app.state.pending_queue.pop(0) if app.state.pending_queue else None
async def process_task(task_id: str):
"""处理单个任务"""
task = app.state.tasks.get(task_id)
if not task:
return
try:
# 更新任务状态
task['status'] = 'InProgress'
task['started_at'] = int(time.time())
print(task['request'].image_url)
# 下载输入图像
image = await download_image(task['request'].image_url)
image_path = f"temp_images/{task_id}.jpg"
image.save(image_path)
# 生成视频
video_path = await generate_video(task['request'], task_id, image)
# 生成下载链接
download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
# 更新任务状态
task.update({
'status': 'Succeed',
'download_url': download_url,
'completed_at': int(time.time())
})
# 安排清理
asyncio.create_task(cleanup_files([image_path, video_path]))
except Exception as e:
handle_task_error(task, e)
def handle_task_error(task: dict, error: Exception):
"""错误处理(包含详细错误信息)"""
error_msg = str(error)
# 1. 显存不足错误
if isinstance(error, torch.cuda.OutOfMemoryError):
error_msg = "显存不足,请降低分辨率"
# 2. 网络请求相关错误
elif isinstance(error, (RequestException, HTTPException)):
# 从异常中提取具体信息
if isinstance(error, HTTPException):
# 如果是 HTTPException获取其 detail 字段
error_detail = getattr(error, "detail", "")
error_msg = f"图像下载失败: {error_detail}"
elif isinstance(error, Timeout):
error_msg = "图像下载超时,请检查网络"
elif isinstance(error, ConnectionError):
error_msg = "无法连接到服务器,请检查 URL"
elif isinstance(error, HTTPError):
# requests 的 HTTPError例如 4xx/5xx 状态码)
status_code = error.response.status_code
error_msg = f"服务器返回错误状态码: {status_code}"
else:
# 其他 RequestException 错误
error_msg = f"图像下载失败: {str(error)}"
# 3. 其他未知错误
else:
error_msg = f"未知错误: {str(error)}"
# 更新任务状态
task.update({
'status': 'Failed',
'reason': error_msg,
'completed_at': int(time.time())
})
# ======================
# 视频生成逻辑
# ======================
async def download_image(url: str) -> Image.Image:
"""异步下载图像(包含详细错误信息)"""
loop = asyncio.get_event_loop()
try:
response = await loop.run_in_executor(
None,
lambda: requests.get(url) # 将 timeout 传递给 requests.get
)
# 如果状态码非 200主动抛出 HTTPException
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"服务器返回状态码 {response.status_code}"
)
return Image.open(BytesIO(response.content)).convert("RGB")
except RequestException as e:
# 将原始 requests 错误信息抛出
raise HTTPException(
status_code=500,
detail=f"请求失败: {str(e)}"
)
async def generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
"""异步生成入口"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
sync_generate_video,
request,
task_id,
image
)
def sync_generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
"""同步生成核心"""
with app.state.model_lock:
try:
# 解析分辨率
mod_value = 16 # 模型要求的模数
print(request.image_size)
print('--------------------------------')
if request.image_size == "auto":
# 原版自动计算逻辑
aspect_ratio = image.height / image.width
print(image.height,image.width)
max_area = 399360 # 模型基础分辨率
# 计算理想尺寸
height = round(np.sqrt(max_area * aspect_ratio))
width = round(np.sqrt(max_area / aspect_ratio))
# 应用模数调整
height = height // mod_value * mod_value
width = width // mod_value * mod_value
resized_image = image.resize((width, height))
else:
width_str, height_str = request.image_size.split('x')
width = int(width_str)
height = int(height_str)
mod_value = 16
# 调整图像尺寸
resized_image = image.resize((width, height))
# 设置随机种子
generator = None
# 修改点1: 使用属性访问seed
if request.seed is not None:
generator = torch.Generator(device="cuda")
generator.manual_seed(request.seed) # 修改点2
print(f"🔮 使用随机种子: {request.seed}")
print(resized_image)
print(height,width)
# 执行推理
output = app.state.pipe(
image=resized_image,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
height=height,
width=width,
num_frames=request.num_frames,
guidance_scale=request.guidance_scale,
num_inference_steps=request.infer_steps,
generator=generator
).frames[0]
# 导出视频
video_id = uuid.uuid4().hex
output_path = f"generated_videos/{video_id}.mp4"
export_to_video(output, output_path, fps=16)
return output_path
except Exception as e:
raise RuntimeError(f"视频生成失败: {str(e)}") from e
# ======================
# API端点
# ======================
@app.post("/video/submit",
response_model=dict,
status_code=status.HTTP_202_ACCEPTED,
tags=["视频生成"])
async def submit_task(
request: VideoSubmitRequest,
auth: bool = Depends(verify_auth)
):
"""提交生成任务"""
# 参数验证
if request.image_url is None:
raise HTTPException(
status_code=422,
detail={"status": "Failed", "reason": "需要图像URL参数"}
)
# 创建任务记录
task_id = uuid.uuid4().hex
with app.state.task_lock:
app.state.tasks[task_id] = {
"request": request,
"status": "InQueue",
"created_at": int(time.time())
}
app.state.pending_queue.append(task_id)
return {"requestId": task_id}
@app.post("/video/status",
response_model=VideoStatusResponse,
tags=["视频生成"])
async def get_status(
request: VideoStatusRequest,
auth: bool = Depends(verify_auth)
):
"""查询任务状态"""
task = app.state.tasks.get(request.requestId)
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
# 计算队列位置(仅当在队列中时)
queue_pos = 0
if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
queue_pos = app.state.pending_queue.index(request.requestId) + 1
response = {
"status": task['status'],
"reason": task.get('reason'),
"queue_position": queue_pos if task['status'] == "InQueue" else None # 非排队状态返回null
}
# 成功状态的特殊处理
if task['status'] == "Succeed":
response["results"] = {
"videos": [{"url": task['download_url']}],
"timings": {
"inference": task['completed_at'] - task['started_at']
},
"seed": task['request'].seed
}
# 取消状态的补充信息
elif task['status'] == "Cancelled":
response["reason"] = task.get('reason', "用户主动取消") # 确保原因字段存在
return response
@app.post("/video/cancel",
response_model=dict,
tags=["视频生成"])
async def cancel_task(
request: VideoCancelRequest,
auth: bool = Depends(verify_auth)
):
"""取消排队中的生成任务"""
task_id = request.requestId
with app.state.task_lock:
task = app.state.tasks.get(task_id)
# 检查任务是否存在
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
current_status = task['status']
# 仅允许取消排队中的任务
if current_status != "InQueue":
raise HTTPException(
status_code=400,
detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
)
# 从队列移除
try:
app.state.pending_queue.remove(task_id)
except ValueError:
pass # 可能已被处理
# 更新任务状态
task.update({
"status": "Cancelled",
"reason": "用户主动取消",
"completed_at": int(time.time())
})
return {"status": "Succeed"}
async def cleanup_files(paths: List[str], delay: int = 3600):
"""定时清理文件"""
await asyncio.sleep(delay)
for path in paths:
try:
if os.path.exists(path):
os.remove(path)
except Exception as e:
print(f"清理失败 {path}: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8088)

450
t2v-api.py Normal file
View File

@ -0,0 +1,450 @@
import os
import torch
import uuid
import time
import asyncio
from enum import Enum
from threading import Lock
from typing import Optional, Dict, List
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, status, Depends
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, field_validator, ValidationError
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import JSONResponse
# 创建视频存储目录
os.makedirs("generated_videos", exist_ok=True)
# 生命周期管理器
@asynccontextmanager
async def lifespan(app: FastAPI):
"""管理应用生命周期"""
# 初始化模型和资源
try:
# 初始化认证密钥
app.state.valid_api_keys = {
"密钥"
}
# 初始化视频生成模型
model_id = "./Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=torch.float32
)
scheduler = UniPCMultistepScheduler(
prediction_type='flow_prediction',
use_flow_sigmas=True,
num_train_timesteps=1000,
flow_shift=3.0
)
app.state.pipe = WanPipeline.from_pretrained(
model_id,
vae=vae,
torch_dtype=torch.bfloat16
).to("cuda")
app.state.pipe.scheduler = scheduler
# 初始化任务系统
app.state.tasks: Dict[str, dict] = {}
app.state.pending_queue: List[str] = []
app.state.model_lock = Lock()
app.state.task_lock = Lock()
app.state.base_url = "ip地址+端口"
app.state.max_concurrent = 2
app.state.semaphore = asyncio.Semaphore(app.state.max_concurrent)
# 启动后台任务处理器
asyncio.create_task(task_processor())
print("✅ 应用初始化完成")
yield
finally:
# 清理资源
if hasattr(app.state, 'pipe'):
del app.state.pipe
torch.cuda.empty_cache()
print("♻️ 已释放模型资源")
# 创建FastAPI应用
app = FastAPI(lifespan=lifespan)
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
# 认证模块
security = HTTPBearer(auto_error=False)
# ======================
# 数据模型--查询参数模型
# ======================
class VideoSubmitRequest(BaseModel):
model: str = Field(default="Wan2.1-T2V-1.3B",description="使用的模型版本")
prompt: str = Field(
...,
min_length=10,
max_length=500,
description="视频描述提示词10-500个字符"
)
image_size: str = Field(
...,
description="视频分辨率仅支持480x832或832x480"
)
negative_prompt: Optional[str] = Field(
default=None,
max_length=500,
description="排除不需要的内容"
)
seed: Optional[int] = Field(
default=None,
ge=0,
le=2147483647,
description="随机数种子范围0-2147483647"
)
num_frames: int = Field(
default=81,
ge=24,
le=120,
description="视频帧数24-120帧"
)
guidance_scale: float = Field(
default=5.0,
ge=1.0,
le=20.0,
description="引导系数1.0-20.0"
)
infer_steps: int = Field(
default=50,
ge=20,
le=100,
description="推理步数20-100步"
)
@field_validator('image_size', mode='before')
@classmethod
def validate_image_size(cls, value):
allowed_sizes = {"480x832", "832x480"}
if value not in allowed_sizes:
raise ValueError(f"仅支持以下分辨率: {', '.join(allowed_sizes)}")
return value
class VideoStatusRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
class VideoSubmitResponse(BaseModel):
requestId: str
class VideoStatusResponse(BaseModel):
status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
reason: Optional[str] = Field(None, description="失败原因")
results: Optional[dict] = Field(None, description="生成结果")
queue_position: Optional[int] = Field(None, description="队列位置")
class VideoCancelRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
# # 自定义HTTP异常处理器
# @app.exception_handler(HTTPException)
# async def http_exception_handler(request, exc):
# return JSONResponse(
# status_code=exc.status_code,
# content=exc.detail, # 直接返回detail内容不再包装在detail字段
# headers=exc.headers
# )
# ======================
# 后台任务处理
# ======================
async def task_processor():
"""处理任务队列"""
while True:
async with app.state.semaphore:
task_id = await get_next_task()
if task_id:
await process_task(task_id)
else:
await asyncio.sleep(0.5)
async def get_next_task():
"""获取下一个待处理任务"""
with app.state.task_lock:
if app.state.pending_queue:
return app.state.pending_queue.pop(0)
return None
async def process_task(task_id: str):
"""处理单个任务"""
task = app.state.tasks.get(task_id)
if not task:
return
try:
# 更新任务状态
task['status'] = 'InProgress'
task['started_at'] = int(time.time())
# 执行视频生成
video_path = await generate_video(task['request'], task_id)
# 生成下载链接
download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
# 更新任务状态
task.update({
'status': 'Succeed',
'download_url': download_url,
'completed_at': int(time.time())
})
# 安排自动清理
asyncio.create_task(auto_cleanup(video_path))
except Exception as e:
handle_task_error(task, e)
def handle_task_error(task: dict, error: Exception):
"""统一处理任务错误"""
error_msg = str(error)
if isinstance(error, torch.cuda.OutOfMemoryError):
error_msg = "显存不足,请降低分辨率或减少帧数"
elif isinstance(error, ValidationError):
error_msg = "参数校验失败: " + str(error)
task.update({
'status': 'Failed',
'reason': error_msg,
'completed_at': int(time.time())
})
# ======================
# 视频生成核心逻辑
# ======================
async def generate_video(request: dict, task_id: str) -> str:
"""异步执行视频生成"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
sync_generate_video,
request,
task_id
)
def sync_generate_video(request: dict, task_id: str) -> str:
"""同步生成视频"""
with app.state.model_lock:
try:
generator = None
if request.get('seed') is not None:
generator = torch.Generator(device="cuda")
generator.manual_seed(request['seed'])
print(f"🔮 使用随机种子: {request['seed']}")
# 执行模型推理
result = app.state.pipe(
prompt=request['prompt'],
negative_prompt=request['negative_prompt'],
height=request['height'],
width=request['width'],
num_frames=request['num_frames'],
guidance_scale=request['guidance_scale'],
num_inference_steps=request['infer_steps'],
generator=generator
)
# 导出视频文件
video_id = uuid.uuid4().hex
output_path = f"generated_videos/{video_id}.mp4"
export_to_video(result.frames[0], output_path, fps=16)
return output_path
except Exception as e:
raise RuntimeError(f"视频生成失败: {str(e)}") from e
# ======================
# API端点
# ======================
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""认证验证"""
if not credentials:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "缺少认证头"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的认证方案"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.credentials not in app.state.valid_api_keys:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的API密钥"},
headers={"WWW-Authenticate": "Bearer"}
)
return True
@app.post("/video/submit",
response_model=VideoSubmitResponse,
status_code=status.HTTP_202_ACCEPTED,
tags=["视频生成"],
summary="提交视频生成请求")
async def submit_video_task(
request: VideoSubmitRequest,
auth: bool = Depends(verify_auth)
):
"""提交新的视频生成任务"""
try:
# 解析分辨率参数
width, height = map(int, request.image_size.split('x'))
# 创建任务记录
task_id = uuid.uuid4().hex
task_data = {
'request': {
'prompt': request.prompt,
'negative_prompt': request.negative_prompt,
'width': width,
'height': height,
'num_frames': request.num_frames,
'guidance_scale': request.guidance_scale,
'infer_steps': request.infer_steps,
'seed': request.seed
},
'status': 'InQueue',
'created_at': int(time.time())
}
# 加入任务队列
with app.state.task_lock:
app.state.tasks[task_id] = task_data
app.state.pending_queue.append(task_id)
return {"requestId": task_id}
except ValidationError as e:
raise HTTPException(
status_code=422,
detail={"status": "Failed", "reason": str(e)}
)
@app.post("/video/status",
response_model=VideoStatusResponse,
tags=["视频生成"],
summary="查询任务状态")
async def get_video_status(
request: VideoStatusRequest,
auth: bool = Depends(verify_auth)
):
"""查询任务状态"""
task = app.state.tasks.get(request.requestId)
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
# 计算队列位置(仅当在队列中时)
queue_pos = 0
if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
queue_pos = app.state.pending_queue.index(request.requestId) + 1
response = {
"status": task['status'],
"reason": task.get('reason'),
"queue_position": queue_pos if task['status'] == "InQueue" else None # 非排队状态返回null
}
# 成功状态的特殊处理
if task['status'] == "Succeed":
response["results"] = {
"videos": [{"url": task['download_url']}],
"timings": {
"inference": task['completed_at'] - task['started_at']
},
"seed": task['request']['seed']
}
# 取消状态的补充信息
elif task['status'] == "Cancelled":
response["reason"] = task.get('reason', "用户主动取消") # 确保原因字段存在
return response
@app.post("/video/cancel",
response_model=dict,
tags=["视频生成"])
async def cancel_task(
request: VideoCancelRequest,
auth: bool = Depends(verify_auth)
):
"""取消排队中的生成任务"""
task_id = request.requestId
with app.state.task_lock:
task = app.state.tasks.get(task_id)
# 检查任务是否存在
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
current_status = task['status']
# 仅允许取消排队中的任务
if current_status != "InQueue":
raise HTTPException(
status_code=400,
detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
)
# 从队列移除
try:
app.state.pending_queue.remove(task_id)
except ValueError:
pass # 可能已被处理
# 更新任务状态
task.update({
"status": "Cancelled",
"reason": "用户主动取消",
"completed_at": int(time.time())
})
return {"status": "Succeed"}
# ======================
# 工具函数
# ======================
async def auto_cleanup(file_path: str, delay: int = 3600):
"""自动清理生成的视频文件"""
await asyncio.sleep(delay)
try:
if os.path.exists(file_path):
os.remove(file_path)
print(f"已清理文件: {file_path}")
except Exception as e:
print(f"文件清理失败: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8088)

View File

@ -105,9 +105,16 @@ function i2v_14B_720p() {
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
}
function vace_1_3B() {
VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/"
torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR
}
t2i_14B
t2v_1_3B
t2v_14B
i2v_14B_480p
i2v_14B_720p
vace_1_3B

View File

@ -2,3 +2,4 @@ from . import configs, distributed, modules
from .image2video import WanI2V
from .text2video import WanT2V
from .first_last_frame2video import WanFLF2V
from .vace import WanVace, WanVaceMP

View File

@ -22,7 +22,9 @@ WAN_CONFIGS = {
't2v-1.3B': t2v_1_3B,
'i2v-14B': i2v_14B,
't2i-14B': t2i_14B,
'flf2v-14B': flf2v_14B
'flf2v-14B': flf2v_14B,
'vace-1.3B': t2v_1_3B,
'vace-14B': t2v_14B,
}
SIZE_CONFIGS = {
@ -46,4 +48,6 @@ SUPPORTED_SIZES = {
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
't2i-14B': tuple(SIZE_CONFIGS.keys()),
'vace-1.3B': ('480*832', '832*480'),
'vace-14B': ('720*1280', '1280*720', '480*832', '832*480')
}

View File

@ -63,12 +63,45 @@ def rope_apply(x, grid_sizes, freqs):
return torch.stack(output).float()
def usp_dit_forward_vace(
self,
x,
vace_context,
seq_len,
kwargs
):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
# Context Parallel
c = torch.chunk(
c, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
hints = []
for block in self.vace_blocks:
c, c_skip = block(c, **new_kwargs)
hints.append(c_skip)
return hints
def usp_dit_forward(
self,
x,
t,
context,
seq_len,
vace_context=None,
vace_context_scale=1.0,
clip_fea=None,
y=None,
):
@ -84,7 +117,7 @@ def usp_dit_forward(
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
if self.model_type != 'vace' and y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
@ -114,7 +147,7 @@ def usp_dit_forward(
for u in context
]))
if clip_fea is not None:
if self.model_type != 'vace' and clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
@ -132,6 +165,11 @@ def usp_dit_forward(
x, get_sequence_parallel_world_size(),
dim=1)[get_sequence_parallel_rank()]
if self.model_type == 'vace':
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
kwargs['hints'] = hints
kwargs['context_scale'] = vace_context_scale
for block in self.blocks:
x = block(x, **kwargs)

View File

@ -2,11 +2,13 @@ from .attention import flash_attention
from .model import WanModel
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
from .tokenizers import HuggingfaceTokenizer
from .vace_model import VaceWanModel
from .vae import WanVAE
__all__ = [
'WanVAE',
'WanModel',
'VaceWanModel',
'T5Model',
'T5Encoder',
'T5Decoder',

View File

@ -400,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin):
Args:
model_type (`str`, *optional*, defaults to 't2v'):
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
text_len (`int`, *optional*, defaults to 512):
@ -433,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin):
super().__init__()
assert model_type in ['t2v', 'i2v', 'flf2v']
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
self.model_type = model_type
self.patch_size = patch_size

233
wan/modules/vace_model.py Normal file
View File

@ -0,0 +1,233 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.cuda.amp as amp
import torch.nn as nn
from diffusers.configuration_utils import register_to_config
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
class VaceWanAttentionBlock(WanAttentionBlock):
def __init__(
self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=0
):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.before_proj.weight)
nn.init.zeros_(self.before_proj.bias)
self.after_proj = nn.Linear(self.dim, self.dim)
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
return c, c_skip
class BaseWanAttentionBlock(WanAttentionBlock):
def __init__(
self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6,
block_id=None
):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
self.block_id = block_id
def forward(self, x, hints, context_scale=1.0, **kwargs):
x = super().forward(x, **kwargs)
if self.block_id is not None:
x = x + hints[self.block_id] * context_scale
return x
class VaceWanModel(WanModel):
@register_to_config
def __init__(self,
vace_layers=None,
vace_in_dim=None,
model_type='vace',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6):
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
assert 0 in self.vace_layers
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
# blocks
self.blocks = nn.ModuleList([
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps,
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
for i in range(self.num_layers)
])
# vace blocks
self.vace_blocks = nn.ModuleList([
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
self.cross_attn_norm, self.eps, block_id=i)
for i in self.vace_layers
])
# vace patch embeddings
self.vace_patch_embedding = nn.Conv3d(
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
)
def forward_vace(
self,
x,
vace_context,
seq_len,
kwargs
):
# embeddings
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in c
])
# arguments
new_kwargs = dict(x=x)
new_kwargs.update(kwargs)
hints = []
for block in self.vace_blocks:
c, c_skip = block(c, **new_kwargs)
hints.append(c_skip)
return hints
def forward(
self,
x,
t,
vace_context,
context,
seq_len,
vace_context_scale=1.0,
clip_fea=None,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
# if self.model_type == 'i2v':
# assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
# if y is not None:
# x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
# if clip_fea is not None:
# context_clip = self.img_emb(clip_fea) # bs x 257 x dim
# context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
kwargs['hints'] = hints
kwargs['context_scale'] = vace_context_scale
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]

View File

@ -1,8 +1,10 @@
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
retrieve_timesteps)
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
from .vace_processor import VaceVideoProcessor
__all__ = [
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
'VaceVideoProcessor'
]

270
wan/utils/vace_processor.py Normal file
View File

@ -0,0 +1,270 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
class VaceImageProcessor(object):
def __init__(self, downsample=None, seq_len=None):
self.downsample = downsample
self.seq_len = seq_len
def _pillow_convert(self, image, cvt_type='RGB'):
if image.mode != cvt_type:
if image.mode == 'P':
image = image.convert(f'{cvt_type}A')
if image.mode == f'{cvt_type}A':
bg = Image.new(cvt_type,
size=(image.width, image.height),
color=(255, 255, 255))
bg.paste(image, (0, 0), mask=image)
image = bg
else:
image = image.convert(cvt_type)
return image
def _load_image(self, img_path):
if img_path is None or img_path == '':
return None
img = Image.open(img_path)
img = self._pillow_convert(img)
return img
def _resize_crop(self, img, oh, ow, normalize=True):
"""
Resize, center crop, convert to tensor, and normalize.
"""
# resize and crop
iw, ih = img.size
if iw != ow or ih != oh:
# resize
scale = max(ow / iw, oh / ih)
img = img.resize(
(round(scale * iw), round(scale * ih)),
resample=Image.Resampling.LANCZOS
)
assert img.width >= ow and img.height >= oh
# center crop
x1 = (img.width - ow) // 2
y1 = (img.height - oh) // 2
img = img.crop((x1, y1, x1 + ow, y1 + oh))
# normalize
if normalize:
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
return img
def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
return self._resize_crop(img, oh, ow, normalize)
def load_image(self, data_key, **kwargs):
return self.load_image_batch(data_key, **kwargs)
def load_image_pair(self, data_key, data_key2, **kwargs):
return self.load_image_batch(data_key, data_key2, **kwargs)
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
seq_len = self.seq_len if seq_len is None else seq_len
imgs = []
for data_key in data_key_batch:
img = self._load_image(data_key)
imgs.append(img)
w, h = imgs[0].size
dh, dw = self.downsample[1:]
# compute output size
scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
oh = int(h * scale) // dh * dh
ow = int(w * scale) // dw * dw
assert (oh // dh) * (ow // dw) <= seq_len
imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
return *imgs, (oh, ow)
class VaceVideoProcessor(object):
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
self.downsample = downsample
self.min_area = min_area
self.max_area = max_area
self.min_fps = min_fps
self.max_fps = max_fps
self.zero_start = zero_start
self.keep_last = keep_last
self.seq_len = seq_len
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
def set_area(self, area):
self.min_area = area
self.max_area = area
def set_seq_len(self, seq_len):
self.seq_len = seq_len
@staticmethod
def resize_crop(video: torch.Tensor, oh: int, ow: int):
"""
Resize, center crop and normalize for decord loaded video (torch.Tensor type)
Parameters:
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
oh - target height (int)
ow - target width (int)
Returns:
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
Raises:
"""
# permute ([t, h, w, c] -> [t, c, h, w])
video = video.permute(0, 3, 1, 2)
# resize and crop
ih, iw = video.shape[2:]
if ih != oh or iw != ow:
# resize
scale = max(ow / iw, oh / ih)
video = F.interpolate(
video,
size=(round(scale * ih), round(scale * iw)),
mode='bicubic',
antialias=True
)
assert video.size(3) >= ow and video.size(2) >= oh
# center crop
x1 = (video.size(3) - ow) // 2
y1 = (video.size(2) - oh) // 2
video = video[:, :, y1:y1 + oh, x1:x1 + ow]
# permute ([t, c, h, w] -> [c, t, h, w]) and normalize
video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
return video
def _video_preprocess(self, video, oh, ow):
return self.resize_crop(video, oh, ow)
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
target_fps = min(fps, self.max_fps)
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
of = min(
(int(duration * target_fps) - 1) // df + 1,
int(self.seq_len / area_z)
)
# deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of))
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
# sample frame ids
target_duration = of / target_fps
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
timestamps = np.linspace(begin, begin + target_duration, of)
frame_ids = np.argmax(np.logical_and(
timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] < frame_timestamps[None, :, 1]
), axis=1).tolist()
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng):
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
of = min(
(len(frame_timestamps) - 1) // df + 1,
int(self.seq_len / area_z)
)
# deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of))
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
# sample frame ids
target_duration = duration
target_fps = of / target_duration
timestamps = np.linspace(0., target_duration, of)
frame_ids = np.argmax(np.logical_and(
timestamps[:, None] >= frame_timestamps[None, :, 0],
timestamps[:, None] <= frame_timestamps[None, :, 1]
), axis=1).tolist()
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng)
else:
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng)
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
decord.bridge.set_bridge('torch')
readers = []
for data_k in data_key_batch:
reader = decord.VideoReader(data_k)
readers.append(reader)
fps = readers[0].get_avg_fps()
length = min([len(r) for r in readers])
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng)
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
videos = [self._video_preprocess(video, oh, ow) for video in videos]
return *videos, frame_ids, (oh, ow), fps
# return videos if len(videos) > 1 else videos[0]
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_video is None and sub_src_mask is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
for j, ref_img in enumerate(ref_images):
if ref_img is not None and ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
src_ref_images[i][j] = white_canvas
return src_video, src_mask, src_ref_images

718
wan/vace.py Normal file
View File

@ -0,0 +1,718 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import os
import sys
import gc
import math
import time
import random
import types
import logging
import traceback
from contextlib import contextmanager
from functools import partial
from PIL import Image
import torchvision.transforms.functional as TF
import torch
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.distributed as dist
import torch.multiprocessing as mp
from tqdm import tqdm
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
from .modules.vace_model import VaceWanModel
from .utils.vace_processor import VaceVideoProcessor
class WanVace(WanT2V):
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
):
r"""
Initializes the Wan text-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_usp (`bool`, *optional*, defaults to False):
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=self.device)
logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
self.model = VaceWanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False)
if use_usp:
from xfuser.core.distributed import \
get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
for block in self.model.vace_blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
self.model = shard_fn(self.model)
else:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
min_area=720*1280,
max_area=720*1280,
min_fps=config.sample_fps,
max_fps=config.sample_fps,
zero_start=True,
seq_len=75600,
keep_last=True)
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
vae = self.vae if vae is None else vae
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
if masks is None:
latents = vae.encode(frames)
else:
masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = vae.encode(inactive)
reactive = vae.encode(reactive)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = vae.encode(refs)
else:
ref_latent = vae.encode(refs)
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
return cat_latents
def vace_encode_masks(self, masks, ref_images=None, vae_stride=None):
vae_stride = self.vae_stride if vae_stride is None else vae_stride
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
new_depth = int((depth + 3) // vae_stride[0])
height = 2 * (int(height) // (vae_stride[1] * 2))
width = 2 * (int(width) // (vae_stride[2] * 2))
# reshape
mask = mask[0, :, :, :]
mask = mask.view(
depth, height, vae_stride[1], width, vae_stride[1]
) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(
vae_stride[1] * vae_stride[2], depth, height, width
) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
if refs is not None:
length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :])
mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask)
return result_masks
def vace_latent(self, z, m):
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device):
area = image_size[0] * image_size[1]
self.vid_proc.set_area(area)
if area == 720*1280:
self.vid_proc.set_seq_len(75600)
elif area == 480*832:
self.vid_proc.set_seq_len(32760)
else:
raise NotImplementedError(f'image_size {image_size} is not supported')
image_size = (image_size[1], image_size[0])
image_sizes = []
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(src_video[i].shape[2:])
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
ref_img = Image.open(ref_img).convert("RGB")
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images
def decode_latent(self, zs, ref_images=None, vae=None):
vae = self.vae if vae is None else vae
if ref_images is None:
ref_images = [None] * len(zs)
else:
assert len(zs) == len(ref_images)
trimed_zs = []
for z, refs in zip(zs, ref_images):
if refs is not None:
z = z[:, len(refs):, :, :]
trimed_zs.append(z)
return vae.decode(trimed_zs)
def generate(self,
input_prompt,
input_frames,
input_masks,
input_ref_images,
size=(1280, 720),
frame_num=81,
context_scale=1.0,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tupele[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
# F = frame_num
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
# size[1] // self.vae_stride[1],
# size[0] // self.vae_stride[2])
#
# seq_len = math.ceil((target_shape[2] * target_shape[3]) /
# (self.patch_size[1] * self.patch_size[2]) *
# target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
# vace context encode
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks)
m0 = self.vace_encode_masks(input_masks, input_ref_images)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0]
noise_pred_uncond = self.model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.decode_latent(x0, input_ref_images)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
class WanVaceMP(WanVace):
def __init__(
self,
config,
checkpoint_dir,
use_usp=False,
ulysses_size=None,
ring_size=None
):
self.config = config
self.checkpoint_dir = checkpoint_dir
self.use_usp = use_usp
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
self.in_q_list = None
self.out_q = None
self.inference_pids = None
self.ulysses_size = ulysses_size
self.ring_size = ring_size
self.dynamic_load()
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
self.vid_proc = VaceVideoProcessor(
downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]),
min_area=480 * 832,
max_area=480 * 832,
min_fps=self.config.sample_fps,
max_fps=self.config.sample_fps,
zero_start=True,
seq_len=32760,
keep_last=True)
def dynamic_load(self):
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
return
gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count()
pmi_rank = int(os.environ['RANK'])
pmi_world_size = int(os.environ['WORLD_SIZE'])
in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)]
out_q = torch.multiprocessing.Manager().Queue()
initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)]
context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False)
all_initialized = False
while not all_initialized:
all_initialized = all(event.is_set() for event in initialized_events)
if not all_initialized:
time.sleep(0.1)
print('Inference model is initialized', flush=True)
self.in_q_list = in_q_list
self.out_q = out_q
self.inference_pids = context.pids()
self.initialized_events = initialized_events
def transfer_data_to_cuda(self, data, device):
if data is None:
return None
else:
if isinstance(data, torch.Tensor):
data = data.to(device)
elif isinstance(data, list):
data = [self.transfer_data_to_cuda(subdata, device) for subdata in data]
elif isinstance(data, dict):
data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()}
return data
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env):
try:
world_size = pmi_world_size * gpu_infer
rank = pmi_rank * gpu_infer + gpu
print("world_size", world_size, "rank", rank, flush=True)
torch.cuda.set_device(gpu)
dist.init_process_group(
backend='nccl',
init_method='env://',
rank=rank,
world_size=world_size
)
from xfuser.core.distributed import (initialize_model_parallel,
init_distributed_environment)
init_distributed_environment(
rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
ring_degree=self.ring_size or 1,
ulysses_degree=self.ulysses_size or 1
)
num_train_timesteps = self.config.num_train_timesteps
param_dtype = self.config.param_dtype
shard_fn = partial(shard_model, device_id=gpu)
text_encoder = T5EncoderModel(
text_len=self.config.text_len,
dtype=self.config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint),
tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer),
shard_fn=shard_fn if True else None)
text_encoder.model.to(gpu)
vae_stride = self.config.vae_stride
patch_size = self.config.patch_size
vae = WanVAE(
vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint),
device=gpu)
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
model.eval().requires_grad_(False)
if self.use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (usp_attn_forward,
usp_dit_forward,
usp_dit_forward_vace)
for block in model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
for block in model.vace_blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
model.forward = types.MethodType(usp_dit_forward, model)
model.forward_vace = types.MethodType(usp_dit_forward_vace, model)
sp_size = get_sequence_parallel_world_size()
else:
sp_size = 1
dist.barrier()
model = shard_fn(model)
sample_neg_prompt = self.config.sample_neg_prompt
torch.cuda.empty_cache()
event = initialized_events[gpu]
in_q = in_q_list[gpu]
event.set()
while True:
item = in_q.get()
input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu)
if n_prompt == "":
n_prompt = sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=gpu)
seed_g.manual_seed(seed)
context = text_encoder([input_prompt], gpu)
context_null = text_encoder([n_prompt], gpu)
# vace context encode
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae)
m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride)
z = self.vace_latent(z0, m0)
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=gpu,
generator=seed_g)
]
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(patch_size[1] * patch_size[2]) *
target_shape[1] / sp_size) * sp_size
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=gpu, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=gpu,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
model.to(gpu)
noise_pred_cond = model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[
0]
noise_pred_uncond = model(
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,
**arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
torch.cuda.empty_cache()
x0 = latents
if rank == 0:
videos = self.decode_latent(x0, input_ref_images, vae=vae)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
if rank == 0:
out_q.put(videos[0].cpu())
except Exception as e:
trace_info = traceback.format_exc()
print(trace_info, flush=True)
print(e, flush=True)
def generate(self,
input_prompt,
input_frames,
input_masks,
input_ref_images,
size=(1280, 720),
frame_num=81,
context_scale=1.0,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale,
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model)
for in_q in self.in_q_list:
in_q.put(input_data)
value_output = self.out_q.get()
return value_output