mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-20 14:12:04 +00:00
Compare commits
5 Commits
50973b0c3b
...
55fdfe155d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55fdfe155d | ||
|
|
c709fcf0e7 | ||
|
|
18d53feb7a | ||
|
|
c6c5675a06 | ||
|
|
0961b7b888 |
151
I2V-FastAPI文档.md
Normal file
151
I2V-FastAPI文档.md
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
|
||||||
|
# 图像到视频生成服务API文档
|
||||||
|
|
||||||
|
## 一、功能概述
|
||||||
|
基于Wan2.1-I2V-14B-480P模型实现图像到视频生成,核心功能包括:
|
||||||
|
1. **异步任务队列**:支持多任务排队和并发控制(最大2个并行任务)
|
||||||
|
2. **智能分辨率适配**:
|
||||||
|
- 支持自动计算最佳分辨率(保持原图比例)
|
||||||
|
- 支持手动指定分辨率(480x832/832x480)
|
||||||
|
3. **资源管理**:
|
||||||
|
- 显存优化(bfloat16精度)
|
||||||
|
- 生成文件自动清理(默认1小时)
|
||||||
|
4. **安全认证**:基于API Key的Bearer Token验证
|
||||||
|
5. **任务控制**:支持任务提交/状态查询/取消操作
|
||||||
|
|
||||||
|
技术栈:
|
||||||
|
- FastAPI框架
|
||||||
|
- CUDA加速
|
||||||
|
- 异步任务处理
|
||||||
|
- Diffusers推理库
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、接口说明
|
||||||
|
|
||||||
|
### 1. 提交生成任务
|
||||||
|
**POST /video/submit**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "Wan2.1-I2V-14B-480P",
|
||||||
|
"prompt": "A dancing cat in the style of Van Gogh",
|
||||||
|
"image_url": "https://example.com/input.jpg",
|
||||||
|
"image_size": "auto",
|
||||||
|
"num_frames": 81,
|
||||||
|
"guidance_scale": 3.0,
|
||||||
|
"infer_steps": 30
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 查询任务状态
|
||||||
|
**POST /video/status**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "Succeed",
|
||||||
|
"results": {
|
||||||
|
"videos": [{"url": "http://localhost:8088/videos/abcd1234.mp4"}],
|
||||||
|
"timings": {"inference": 90},
|
||||||
|
"seed": 123456
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 取消任务
|
||||||
|
**POST /video/cancel**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "Succeed"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、Postman使用指南
|
||||||
|
|
||||||
|
### 1. 基础配置
|
||||||
|
- 服务器地址:`http://ip地址:8088`
|
||||||
|
- 认证方式:Bearer Token
|
||||||
|
- Token值:需替换为有效API Key
|
||||||
|
|
||||||
|
### 2. 提交任务
|
||||||
|
1. 选择POST方法,URL填写`/video/submit`
|
||||||
|
2. Headers添加:
|
||||||
|
```text
|
||||||
|
Authorization: Bearer YOUR_API_KEY
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
3. Body示例(图像生成视频):
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"prompt": "Sunset scene with mountains",
|
||||||
|
"image_url": "https://example.com/mountain.jpg",
|
||||||
|
"image_size": "auto",
|
||||||
|
"num_frames": 50
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 特殊处理
|
||||||
|
- **图像下载失败**:返回400错误,包含具体原因(如URL无效/超时)
|
||||||
|
- **显存不足**:返回500错误并提示降低分辨率
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、参数规范
|
||||||
|
| 参数名 | 允许值范围 | 必填 | 说明 |
|
||||||
|
|------------------|-------------------------------|------|------------------------------------------|
|
||||||
|
| image_url | 有效HTTP/HTTPS URL | 是 | 输入图像地址 |
|
||||||
|
| prompt | 10-500字符 | 是 | 视频内容描述 |
|
||||||
|
| image_size | "480x832", "832x480", "auto" | 是 | auto模式自动适配原图比例 |
|
||||||
|
| num_frames | 24-120 | 是 | 视频总帧数 |
|
||||||
|
| guidance_scale | 1.0-20.0 | 是 | 文本引导强度 |
|
||||||
|
| infer_steps | 20-100 | 是 | 推理步数 |
|
||||||
|
| seed | 0-2147483647 | 否 | 随机种子 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、状态码说明
|
||||||
|
| 状态码 | 含义 |
|
||||||
|
|--------|-----------------------------------|
|
||||||
|
| 202 | 任务已接受 |
|
||||||
|
| 400 | 图像下载失败/参数错误 |
|
||||||
|
| 401 | 认证失败 |
|
||||||
|
| 404 | 任务不存在 |
|
||||||
|
| 422 | 参数校验失败 |
|
||||||
|
| 500 | 服务端错误(显存不足/模型异常等) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 六、特殊功能说明
|
||||||
|
1. **智能分辨率适配**:
|
||||||
|
- 当`image_size="auto"`时,自动计算符合模型要求的最优分辨率
|
||||||
|
- 保持原始图像宽高比,最大像素面积不超过399,360(约640x624)
|
||||||
|
|
||||||
|
2. **图像预处理**:
|
||||||
|
- 自动转换为RGB模式
|
||||||
|
- 根据目标分辨率进行等比缩放
|
||||||
|
|
||||||
|
|
||||||
|
**重要提示**:输入图像URL需保证公开可访问,私有资源需提供有效鉴权
|
||||||
|
|
||||||
|
**提示** :访问`http://服务器地址:8088/docs`可查看交互式API文档,支持在线测试所有接口
|
||||||
80
README.md
80
README.md
@ -27,6 +27,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|||||||
|
|
||||||
## 🔥 Latest News!!
|
## 🔥 Latest News!!
|
||||||
|
|
||||||
|
* May 14, 2025: 👋 We introduce **Wan2.1** [VACE](https://github.com/ali-vilab/VACE), an all-in-one model for video creation and editing, along with its [inference code](#run-vace), [weights](#model-download), and [technical report](https://arxiv.org/abs/2503.07598)!
|
||||||
* Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights!
|
* Apr 17, 2025: 👋 We introduce **Wan2.1** [FLF2V](#run-first-last-frame-to-video-generation) with its inference code and weights!
|
||||||
* Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback!
|
* Mar 21, 2025: 👋 We are excited to announce the release of the **Wan2.1** [technical report](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf). We welcome discussions and feedback!
|
||||||
* Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try!
|
* Mar 3, 2025: 👋 **Wan2.1**'s T2V and I2V have been integrated into Diffusers ([T2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanPipeline) | [I2V](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan#diffusers.WanImageToVideoPipeline)). Feel free to give it a try!
|
||||||
@ -64,7 +65,13 @@ If your work has improved **Wan2.1** and you would like more people to see it, p
|
|||||||
- [ ] ComfyUI integration
|
- [ ] ComfyUI integration
|
||||||
- [ ] Diffusers integration
|
- [ ] Diffusers integration
|
||||||
- [ ] Diffusers + Multi-GPU Inference
|
- [ ] Diffusers + Multi-GPU Inference
|
||||||
|
- Wan2.1 VACE
|
||||||
|
- [x] Multi-GPU Inference code of the 14B and 1.3B models
|
||||||
|
- [x] Checkpoints of the 14B and 1.3B models
|
||||||
|
- [x] Gradio demo
|
||||||
|
- [x] ComfyUI integration
|
||||||
|
- [ ] Diffusers integration
|
||||||
|
- [ ] Diffusers + Multi-GPU Inference
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
@ -85,12 +92,14 @@ pip install -r requirements.txt
|
|||||||
#### Model Download
|
#### Model Download
|
||||||
|
|
||||||
| Models | Download Link | Notes |
|
| Models | Download Link | Notes |
|
||||||
|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
|
|--------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------|
|
||||||
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
|
| T2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | Supports both 480P and 720P
|
||||||
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
|
| I2V-14B-720P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | Supports 720P
|
||||||
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
|
| I2V-14B-480P | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | Supports 480P
|
||||||
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
|
| T2V-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | Supports 480P
|
||||||
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
|
| FLF2V-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | Supports 720P
|
||||||
|
| VACE-1.3B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | Supports 480P
|
||||||
|
| VACE-14B | 🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) 🤖 [ModelScope](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | Supports both 480P and 720P
|
||||||
|
|
||||||
> 💡Note:
|
> 💡Note:
|
||||||
> * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
|
> * The 1.3B model is capable of generating videos at 720P resolution. However, due to limited training at this resolution, the results are generally less stable compared to 480P. For optimal performance, we recommend using 480P resolution.
|
||||||
@ -448,6 +457,73 @@ DASH_API_KEY=your_key python flf2v_14B_singleGPU.py --prompt_extend_method 'dash
|
|||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Run VACE
|
||||||
|
|
||||||
|
[VACE](https://github.com/ali-vilab/VACE) now supports two models (1.3B and 14B) and two main resolutions (480P and 720P).
|
||||||
|
The input supports any resolution, but to achieve optimal results, the video size should fall within a specific range.
|
||||||
|
The parameters and configurations for these models are as follows:
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th rowspan="2">Task</th>
|
||||||
|
<th colspan="2">Resolution</th>
|
||||||
|
<th rowspan="2">Model</th>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th>480P(~81x480x832)</th>
|
||||||
|
<th>720P(~81x720x1280)</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<td>VACE</td>
|
||||||
|
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
|
||||||
|
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
|
||||||
|
<td>Wan2.1-VACE-14B</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>VACE</td>
|
||||||
|
<td style="color: green; text-align: center; vertical-align: middle;">✔️</td>
|
||||||
|
<td style="color: red; text-align: center; vertical-align: middle;">❌</td>
|
||||||
|
<td>Wan2.1-VACE-1.3B</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
In VACE, users can input text prompt and optional video, mask, and image for video generation or editing. Detailed instructions for using VACE can be found in the [User Guide](https://github.com/ali-vilab/VACE/blob/main/UserGuide.md).
|
||||||
|
The execution process is as follows:
|
||||||
|
|
||||||
|
##### (1) Preprocessing
|
||||||
|
|
||||||
|
User-collected materials needs to be preprocessed into VACE-recognizable inputs, including `src_video`, `src_mask`, `src_ref_images`, and `prompt`.
|
||||||
|
For R2V (Reference-to-Video Generation), you may skip this preprocessing, but for V2V (Video-to-Video Editing) and MV2V (Masked Video-to-Video Editing) tasks, additional preprocessing is required to obtain video with conditions such as depth, pose or masked regions.
|
||||||
|
For more details, please refer to [vace_preproccess](https://github.com/ali-vilab/VACE/blob/main/vace/vace_preproccess.py).
|
||||||
|
|
||||||
|
##### (2) cli inference
|
||||||
|
|
||||||
|
- Single-GPU inference
|
||||||
|
```sh
|
||||||
|
python generate.py --task vace-1.3B --size 832*480 --ckpt_dir ./Wan2.1-VACE-1.3B --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||||
|
```
|
||||||
|
|
||||||
|
- Multi-GPU inference using FSDP + xDiT USP
|
||||||
|
|
||||||
|
```sh
|
||||||
|
torchrun --nproc_per_node=8 generate.py --task vace-14B --size 1280*720 --ckpt_dir ./Wan2.1-VACE-14B --dit_fsdp --t5_fsdp --ulysses_size 8 --src_ref_images examples/girl.png,examples/snake.png --prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||||
|
```
|
||||||
|
|
||||||
|
##### (3) Running local gradio
|
||||||
|
- Single-GPU inference
|
||||||
|
```sh
|
||||||
|
python gradio/vace.py --ckpt_dir ./Wan2.1-VACE-1.3B
|
||||||
|
```
|
||||||
|
|
||||||
|
- Multi-GPU inference using FSDP + xDiT USP
|
||||||
|
```sh
|
||||||
|
python gradio/vace.py --mp --ulysses_size 8 --ckpt_dir ./Wan2.1-VACE-14B/
|
||||||
|
```
|
||||||
|
|
||||||
#### Run Text-to-Image Generation
|
#### Run Text-to-Image Generation
|
||||||
|
|
||||||
Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:
|
Wan2.1 is a unified model for both image and video generation. Since it was trained on both types of data, it can also generate images. The command for generating images is similar to video generation, as follows:
|
||||||
|
|||||||
133
T2V-FastAPI文档.md
Normal file
133
T2V-FastAPI文档.md
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
|
||||||
|
# 视频生成服务API文档
|
||||||
|
|
||||||
|
## 一、功能概述
|
||||||
|
本服务基于Wan2.1-T2V-1.3B模型实现文本到视频生成,包含以下核心功能:
|
||||||
|
1. **异步任务队列**:支持多任务排队和并发控制(最大2个并行任务)
|
||||||
|
2. **资源管理**:
|
||||||
|
- 显存优化(使用bfloat16精度)
|
||||||
|
- 生成视频自动清理(默认1小时后删除)
|
||||||
|
3. **安全认证**:基于API Key的Bearer Token验证
|
||||||
|
4. **任务控制**:支持任务提交/状态查询/取消操作
|
||||||
|
|
||||||
|
技术栈:
|
||||||
|
- FastAPI框架
|
||||||
|
- CUDA加速
|
||||||
|
- 异步任务处理
|
||||||
|
- Diffusers推理库
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 二、接口说明
|
||||||
|
|
||||||
|
### 1. 提交生成任务
|
||||||
|
**POST /video/submit**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "Wan2.1-T2V-1.3B",
|
||||||
|
"prompt": "A beautiful sunset over the mountains",
|
||||||
|
"image_size": "480x832",
|
||||||
|
"num_frames": 81,
|
||||||
|
"guidance_scale": 5.0,
|
||||||
|
"infer_steps": 50
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 查询任务状态
|
||||||
|
**POST /video/status**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "Succeed",
|
||||||
|
"results": {
|
||||||
|
"videos": [{"url": "http://localhost:8088/videos/abcd1234.mp4"}],
|
||||||
|
"timings": {"inference": 120}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 取消任务
|
||||||
|
**POST /video/cancel**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应示例**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "Succeed"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 三、Postman使用指南
|
||||||
|
|
||||||
|
### 1. 基础配置
|
||||||
|
- 服务器地址:`http://ip地址:8088`
|
||||||
|
- 认证方式:Bearer Token
|
||||||
|
- Token值:需替换为有效API Key
|
||||||
|
|
||||||
|
### 2. 提交任务
|
||||||
|
1. 选择POST方法,输入URL:`/video/submit`
|
||||||
|
2. Headers添加:
|
||||||
|
```text
|
||||||
|
Authorization: Bearer YOUR_API_KEY
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
3. Body选择raw/JSON格式,输入请求参数
|
||||||
|
|
||||||
|
### 3. 查询状态
|
||||||
|
1. 新建请求,URL填写`/video/status`
|
||||||
|
2. 使用相同认证头
|
||||||
|
3. Body中携带requestId
|
||||||
|
|
||||||
|
### 4. 取消任务
|
||||||
|
1. 新建DELETE请求,URL填写`/video/cancel`
|
||||||
|
2. Body携带需要取消的requestId
|
||||||
|
|
||||||
|
### 注意事项
|
||||||
|
1. 所有接口必须携带有效API Key
|
||||||
|
2. 视频生成耗时约2-5分钟(根据参数配置)
|
||||||
|
3. 生成视频默认保留1小时
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 四、参数规范
|
||||||
|
| 参数名 | 允许值范围 | 必填 | 说明 |
|
||||||
|
|------------------|-------------------------------|------|--------------------------|
|
||||||
|
| prompt | 10-500字符 | 是 | 视频内容描述 |
|
||||||
|
| image_size | "480x832" 或 "832x480" | 是 | 分辨率 |
|
||||||
|
| num_frames | 24-120 | 是 | 视频总帧数 |
|
||||||
|
| guidance_scale | 1.0-20.0 | 是 | 文本引导强度 |
|
||||||
|
| infer_steps | 20-100 | 是 | 推理步数 |
|
||||||
|
| seed | 0-2147483647 | 否 | 随机种子 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 五、状态码说明
|
||||||
|
| 状态码 | 含义 |
|
||||||
|
|--------|--------------------------|
|
||||||
|
| 202 | 任务已接受 |
|
||||||
|
| 401 | 认证失败 |
|
||||||
|
| 404 | 任务不存在 |
|
||||||
|
| 422 | 参数校验失败 |
|
||||||
|
| 500 | 服务端错误(显存不足等) |
|
||||||
|
|
||||||
|
|
||||||
|
**提示**:建议使用Swagger文档进行接口测试,访问`http://服务器地址:8088/docs`可查看自动生成的API文档界面
|
||||||
BIN
examples/girl.png
Normal file
BIN
examples/girl.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 817 KiB |
BIN
examples/snake.png
Normal file
BIN
examples/snake.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 435 KiB |
87
generate.py
87
generate.py
@ -41,6 +41,14 @@ EXAMPLE_PROMPT = {
|
|||||||
"last_frame":
|
"last_frame":
|
||||||
"examples/flf2v_input_last_frame.png",
|
"examples/flf2v_input_last_frame.png",
|
||||||
},
|
},
|
||||||
|
"vace-1.3B": {
|
||||||
|
"src_ref_images": 'examples/girl.png,examples/snake.png',
|
||||||
|
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||||
|
},
|
||||||
|
"vace-14B": {
|
||||||
|
"src_ref_images": 'examples/girl.png,examples/snake.png',
|
||||||
|
"prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -52,15 +60,19 @@ def _validate_args(args):
|
|||||||
|
|
||||||
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
|
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
|
||||||
if args.sample_steps is None:
|
if args.sample_steps is None:
|
||||||
args.sample_steps = 40 if "i2v" in args.task else 50
|
args.sample_steps = 50
|
||||||
|
if "i2v" in args.task:
|
||||||
|
args.sample_steps = 40
|
||||||
|
|
||||||
|
|
||||||
if args.sample_shift is None:
|
if args.sample_shift is None:
|
||||||
args.sample_shift = 5.0
|
args.sample_shift = 5.0
|
||||||
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
if "i2v" in args.task and args.size in ["832*480", "480*832"]:
|
||||||
args.sample_shift = 3.0
|
args.sample_shift = 3.0
|
||||||
if "flf2v" in args.task:
|
elif "flf2v" in args.task or "vace" in args.task:
|
||||||
args.sample_shift = 16
|
args.sample_shift = 16
|
||||||
|
|
||||||
|
|
||||||
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
|
||||||
if args.frame_num is None:
|
if args.frame_num is None:
|
||||||
args.frame_num = 1 if "t2i" in args.task else 81
|
args.frame_num = 1 if "t2i" in args.task else 81
|
||||||
@ -141,6 +153,21 @@ def _parse_args():
|
|||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The file to save the generated image or video to.")
|
help="The file to save the generated image or video to.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--src_video",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The file of the source video. Default None.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--src_mask",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The file of the source mask. Default None.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--src_ref_images",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The file list of the source reference images. Separated by ','. Default None.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
type=str,
|
type=str,
|
||||||
@ -397,7 +424,7 @@ def generate(args):
|
|||||||
guide_scale=args.sample_guide_scale,
|
guide_scale=args.sample_guide_scale,
|
||||||
seed=args.base_seed,
|
seed=args.base_seed,
|
||||||
offload_model=args.offload_model)
|
offload_model=args.offload_model)
|
||||||
else:
|
elif "flf2v" in args.task:
|
||||||
if args.prompt is None:
|
if args.prompt is None:
|
||||||
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||||
if args.first_frame is None or args.last_frame is None:
|
if args.first_frame is None or args.last_frame is None:
|
||||||
@ -457,6 +484,60 @@ def generate(args):
|
|||||||
seed=args.base_seed,
|
seed=args.base_seed,
|
||||||
offload_model=args.offload_model
|
offload_model=args.offload_model
|
||||||
)
|
)
|
||||||
|
elif "vace" in args.task:
|
||||||
|
if args.prompt is None:
|
||||||
|
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
||||||
|
args.src_video = EXAMPLE_PROMPT[args.task].get("src_video", None)
|
||||||
|
args.src_mask = EXAMPLE_PROMPT[args.task].get("src_mask", None)
|
||||||
|
args.src_ref_images = EXAMPLE_PROMPT[args.task].get("src_ref_images", None)
|
||||||
|
|
||||||
|
logging.info(f"Input prompt: {args.prompt}")
|
||||||
|
if args.use_prompt_extend and args.use_prompt_extend != 'plain':
|
||||||
|
logging.info("Extending prompt ...")
|
||||||
|
if rank == 0:
|
||||||
|
prompt = prompt_expander.forward(args.prompt)
|
||||||
|
logging.info(f"Prompt extended from '{args.prompt}' to '{prompt}'")
|
||||||
|
input_prompt = [prompt]
|
||||||
|
else:
|
||||||
|
input_prompt = [None]
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.broadcast_object_list(input_prompt, src=0)
|
||||||
|
args.prompt = input_prompt[0]
|
||||||
|
logging.info(f"Extended prompt: {args.prompt}")
|
||||||
|
|
||||||
|
logging.info("Creating VACE pipeline.")
|
||||||
|
wan_vace = wan.WanVace(
|
||||||
|
config=cfg,
|
||||||
|
checkpoint_dir=args.ckpt_dir,
|
||||||
|
device_id=device,
|
||||||
|
rank=rank,
|
||||||
|
t5_fsdp=args.t5_fsdp,
|
||||||
|
dit_fsdp=args.dit_fsdp,
|
||||||
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
|
t5_cpu=args.t5_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
src_video, src_mask, src_ref_images = wan_vace.prepare_source([args.src_video],
|
||||||
|
[args.src_mask],
|
||||||
|
[None if args.src_ref_images is None else args.src_ref_images.split(',')],
|
||||||
|
args.frame_num, SIZE_CONFIGS[args.size], device)
|
||||||
|
|
||||||
|
logging.info(f"Generating video...")
|
||||||
|
video = wan_vace.generate(
|
||||||
|
args.prompt,
|
||||||
|
src_video,
|
||||||
|
src_mask,
|
||||||
|
src_ref_images,
|
||||||
|
size=SIZE_CONFIGS[args.size],
|
||||||
|
frame_num=args.frame_num,
|
||||||
|
shift=args.sample_shift,
|
||||||
|
sample_solver=args.sample_solver,
|
||||||
|
sampling_steps=args.sample_steps,
|
||||||
|
guide_scale=args.sample_guide_scale,
|
||||||
|
seed=args.base_seed,
|
||||||
|
offload_model=args.offload_model)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unkown task type: {args.task}")
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if args.save_file is None:
|
if args.save_file is None:
|
||||||
|
|||||||
295
gradio/vace.py
Normal file
295
gradio/vace.py
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import datetime
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
|
||||||
|
import wan
|
||||||
|
from wan import WanVace, WanVaceMP
|
||||||
|
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS
|
||||||
|
|
||||||
|
|
||||||
|
class FixedSizeQueue:
|
||||||
|
def __init__(self, max_size):
|
||||||
|
self.max_size = max_size
|
||||||
|
self.queue = []
|
||||||
|
def add(self, item):
|
||||||
|
self.queue.insert(0, item)
|
||||||
|
if len(self.queue) > self.max_size:
|
||||||
|
self.queue.pop()
|
||||||
|
def get(self):
|
||||||
|
return self.queue
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.queue)
|
||||||
|
|
||||||
|
|
||||||
|
class VACEInference:
|
||||||
|
def __init__(self, cfg, skip_load=False, gallery_share=True, gallery_share_limit=5):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.save_dir = cfg.save_dir
|
||||||
|
self.gallery_share = gallery_share
|
||||||
|
self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
|
||||||
|
if not skip_load:
|
||||||
|
if not args.mp:
|
||||||
|
self.pipe = WanVace(
|
||||||
|
config=WAN_CONFIGS[cfg.model_name],
|
||||||
|
checkpoint_dir=cfg.ckpt_dir,
|
||||||
|
device_id=0,
|
||||||
|
rank=0,
|
||||||
|
t5_fsdp=False,
|
||||||
|
dit_fsdp=False,
|
||||||
|
use_usp=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.pipe = WanVaceMP(
|
||||||
|
config=WAN_CONFIGS[cfg.model_name],
|
||||||
|
checkpoint_dir=cfg.ckpt_dir,
|
||||||
|
use_usp=True,
|
||||||
|
ulysses_size=cfg.ulysses_size,
|
||||||
|
ring_size=cfg.ring_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_ui(self, *args, **kwargs):
|
||||||
|
gr.Markdown("""
|
||||||
|
<div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
|
||||||
|
<a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
|
||||||
|
</div>
|
||||||
|
""")
|
||||||
|
with gr.Row(variant='panel', equal_height=True):
|
||||||
|
with gr.Column(scale=1, min_width=0):
|
||||||
|
self.src_video = gr.Video(
|
||||||
|
label="src_video",
|
||||||
|
sources=['upload'],
|
||||||
|
value=None,
|
||||||
|
interactive=True)
|
||||||
|
with gr.Column(scale=1, min_width=0):
|
||||||
|
self.src_mask = gr.Video(
|
||||||
|
label="src_mask",
|
||||||
|
sources=['upload'],
|
||||||
|
value=None,
|
||||||
|
interactive=True)
|
||||||
|
#
|
||||||
|
with gr.Row(variant='panel', equal_height=True):
|
||||||
|
with gr.Column(scale=1, min_width=0):
|
||||||
|
with gr.Row(equal_height=True):
|
||||||
|
self.src_ref_image_1 = gr.Image(label='src_ref_image_1',
|
||||||
|
height=200,
|
||||||
|
interactive=True,
|
||||||
|
type='filepath',
|
||||||
|
image_mode='RGB',
|
||||||
|
sources=['upload'],
|
||||||
|
elem_id="src_ref_image_1",
|
||||||
|
format='png')
|
||||||
|
self.src_ref_image_2 = gr.Image(label='src_ref_image_2',
|
||||||
|
height=200,
|
||||||
|
interactive=True,
|
||||||
|
type='filepath',
|
||||||
|
image_mode='RGB',
|
||||||
|
sources=['upload'],
|
||||||
|
elem_id="src_ref_image_2",
|
||||||
|
format='png')
|
||||||
|
self.src_ref_image_3 = gr.Image(label='src_ref_image_3',
|
||||||
|
height=200,
|
||||||
|
interactive=True,
|
||||||
|
type='filepath',
|
||||||
|
image_mode='RGB',
|
||||||
|
sources=['upload'],
|
||||||
|
elem_id="src_ref_image_3",
|
||||||
|
format='png')
|
||||||
|
with gr.Row(variant='panel', equal_height=True):
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
self.prompt = gr.Textbox(
|
||||||
|
show_label=False,
|
||||||
|
placeholder="positive_prompt_input",
|
||||||
|
elem_id='positive_prompt',
|
||||||
|
container=True,
|
||||||
|
autofocus=True,
|
||||||
|
elem_classes='type_row',
|
||||||
|
visible=True,
|
||||||
|
lines=2)
|
||||||
|
self.negative_prompt = gr.Textbox(
|
||||||
|
show_label=False,
|
||||||
|
value=self.pipe.config.sample_neg_prompt,
|
||||||
|
placeholder="negative_prompt_input",
|
||||||
|
elem_id='negative_prompt',
|
||||||
|
container=True,
|
||||||
|
autofocus=False,
|
||||||
|
elem_classes='type_row',
|
||||||
|
visible=True,
|
||||||
|
interactive=True,
|
||||||
|
lines=1)
|
||||||
|
#
|
||||||
|
with gr.Row(variant='panel', equal_height=True):
|
||||||
|
with gr.Column(scale=1, min_width=0):
|
||||||
|
with gr.Row(equal_height=True):
|
||||||
|
self.shift_scale = gr.Slider(
|
||||||
|
label='shift_scale',
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=100.0,
|
||||||
|
step=1.0,
|
||||||
|
value=16.0,
|
||||||
|
interactive=True)
|
||||||
|
self.sample_steps = gr.Slider(
|
||||||
|
label='sample_steps',
|
||||||
|
minimum=1,
|
||||||
|
maximum=100,
|
||||||
|
step=1,
|
||||||
|
value=25,
|
||||||
|
interactive=True)
|
||||||
|
self.context_scale = gr.Slider(
|
||||||
|
label='context_scale',
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=2.0,
|
||||||
|
step=0.1,
|
||||||
|
value=1.0,
|
||||||
|
interactive=True)
|
||||||
|
self.guide_scale = gr.Slider(
|
||||||
|
label='guide_scale',
|
||||||
|
minimum=1,
|
||||||
|
maximum=10,
|
||||||
|
step=0.5,
|
||||||
|
value=5.0,
|
||||||
|
interactive=True)
|
||||||
|
self.infer_seed = gr.Slider(minimum=-1,
|
||||||
|
maximum=10000000,
|
||||||
|
value=2025,
|
||||||
|
label="Seed")
|
||||||
|
#
|
||||||
|
with gr.Accordion(label="Usable without source video", open=False):
|
||||||
|
with gr.Row(equal_height=True):
|
||||||
|
self.output_height = gr.Textbox(
|
||||||
|
label='resolutions_height',
|
||||||
|
# value=480,
|
||||||
|
value=720,
|
||||||
|
interactive=True)
|
||||||
|
self.output_width = gr.Textbox(
|
||||||
|
label='resolutions_width',
|
||||||
|
# value=832,
|
||||||
|
value=1280,
|
||||||
|
interactive=True)
|
||||||
|
self.frame_rate = gr.Textbox(
|
||||||
|
label='frame_rate',
|
||||||
|
value=16,
|
||||||
|
interactive=True)
|
||||||
|
self.num_frames = gr.Textbox(
|
||||||
|
label='num_frames',
|
||||||
|
value=81,
|
||||||
|
interactive=True)
|
||||||
|
#
|
||||||
|
with gr.Row(equal_height=True):
|
||||||
|
with gr.Column(scale=5):
|
||||||
|
self.generate_button = gr.Button(
|
||||||
|
value='Run',
|
||||||
|
elem_classes='type_row',
|
||||||
|
elem_id='generate_button',
|
||||||
|
visible=True)
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
self.refresh_button = gr.Button(value='\U0001f504') # 🔄
|
||||||
|
#
|
||||||
|
self.output_gallery = gr.Gallery(
|
||||||
|
label="output_gallery",
|
||||||
|
value=[],
|
||||||
|
interactive=False,
|
||||||
|
allow_preview=True,
|
||||||
|
preview=True)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, src_ref_image_2, src_ref_image_3, prompt, negative_prompt, shift_scale, sample_steps, context_scale, guide_scale, infer_seed, output_height, output_width, frame_rate, num_frames):
|
||||||
|
output_height, output_width, frame_rate, num_frames = int(output_height), int(output_width), int(frame_rate), int(num_frames)
|
||||||
|
src_ref_images = [x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] if
|
||||||
|
x is not None]
|
||||||
|
src_video, src_mask, src_ref_images = self.pipe.prepare_source([src_video],
|
||||||
|
[src_mask],
|
||||||
|
[src_ref_images],
|
||||||
|
num_frames=num_frames,
|
||||||
|
image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
|
||||||
|
device=self.pipe.device)
|
||||||
|
video = self.pipe.generate(
|
||||||
|
prompt,
|
||||||
|
src_video,
|
||||||
|
src_mask,
|
||||||
|
src_ref_images,
|
||||||
|
size=(output_width, output_height),
|
||||||
|
context_scale=context_scale,
|
||||||
|
shift=shift_scale,
|
||||||
|
sampling_steps=sample_steps,
|
||||||
|
guide_scale=guide_scale,
|
||||||
|
n_prompt=negative_prompt,
|
||||||
|
seed=infer_seed,
|
||||||
|
offload_model=True)
|
||||||
|
|
||||||
|
name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
|
||||||
|
video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
|
||||||
|
video_frames = (torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 255).cpu().numpy().astype(np.uint8)
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer = imageio.get_writer(video_path, fps=frame_rate, codec='libx264', quality=8, macro_block_size=1)
|
||||||
|
for frame in video_frames:
|
||||||
|
writer.append_data(frame)
|
||||||
|
writer.close()
|
||||||
|
print(video_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise gr.Error(f"Video save error: {e}")
|
||||||
|
|
||||||
|
if self.gallery_share:
|
||||||
|
self.gallery_share_data.add(video_path)
|
||||||
|
return self.gallery_share_data.get()
|
||||||
|
else:
|
||||||
|
return [video_path]
|
||||||
|
|
||||||
|
def set_callbacks(self, **kwargs):
|
||||||
|
self.gen_inputs = [self.output_gallery, self.src_video, self.src_mask, self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, self.prompt, self.negative_prompt, self.shift_scale, self.sample_steps, self.context_scale, self.guide_scale, self.infer_seed, self.output_height, self.output_width, self.frame_rate, self.num_frames]
|
||||||
|
self.gen_outputs = [self.output_gallery]
|
||||||
|
self.generate_button.click(self.generate,
|
||||||
|
inputs=self.gen_inputs,
|
||||||
|
outputs=self.gen_outputs,
|
||||||
|
queue=True)
|
||||||
|
self.refresh_button.click(lambda x: self.gallery_share_data.get() if self.gallery_share else x, inputs=[self.output_gallery], outputs=[self.output_gallery])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Argparser for VACE-WAN Demo:\n')
|
||||||
|
parser.add_argument('--server_port', dest='server_port', help='', type=int, default=7860)
|
||||||
|
parser.add_argument('--server_name', dest='server_name', help='', default='0.0.0.0')
|
||||||
|
parser.add_argument('--root_path', dest='root_path', help='', default=None)
|
||||||
|
parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
|
||||||
|
parser.add_argument("--mp", action="store_true", help="Use Multi-GPUs",)
|
||||||
|
parser.add_argument("--model_name", type=str, default="vace-14B", choices=list(WAN_CONFIGS.keys()), help="The model name to run.")
|
||||||
|
parser.add_argument("--ulysses_size", type=int, default=1, help="The size of the ulysses parallelism in DiT.")
|
||||||
|
parser.add_argument("--ring_size", type=int, default=1, help="The size of the ring attention parallelism in DiT.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt_dir",
|
||||||
|
type=str,
|
||||||
|
# default='models/VACE-Wan2.1-1.3B-Preview',
|
||||||
|
default='models/Wan2.1-VACE-14B/',
|
||||||
|
help="The path to the checkpoint directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_to_cpu",
|
||||||
|
action="store_true",
|
||||||
|
help="Offloading unnecessary computations to CPU.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.save_dir):
|
||||||
|
os.makedirs(args.save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with gr.Blocks() as demo:
|
||||||
|
infer_gr = VACEInference(args, skip_load=False, gallery_share=True, gallery_share_limit=5)
|
||||||
|
infer_gr.create_ui()
|
||||||
|
infer_gr.set_callbacks()
|
||||||
|
allowed_paths = [args.save_dir]
|
||||||
|
demo.queue(status_update_rate=1).launch(server_name=args.server_name,
|
||||||
|
server_port=args.server_port,
|
||||||
|
root_path=args.root_path,
|
||||||
|
allowed_paths=allowed_paths,
|
||||||
|
show_error=True, debug=True)
|
||||||
526
i2v_api.py
Normal file
526
i2v_api.py
Normal file
@ -0,0 +1,526 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import numpy as np
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Optional, Dict, List
|
||||||
|
from fastapi import FastAPI, HTTPException, status, Depends
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from pydantic import BaseModel, Field, field_validator, ValidationError
|
||||||
|
from diffusers.utils import export_to_video, load_image
|
||||||
|
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||||
|
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||||
|
from transformers import CLIPVisionModel
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
from io import BytesIO
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
|
# 创建存储目录
|
||||||
|
os.makedirs("generated_videos", exist_ok=True)
|
||||||
|
os.makedirs("temp_images", exist_ok=True)
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 生命周期管理
|
||||||
|
# ======================
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""资源管理器"""
|
||||||
|
try:
|
||||||
|
# 初始化认证系统
|
||||||
|
app.state.valid_api_keys = {
|
||||||
|
"密钥"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
model_id = "./Wan2.1-I2V-14B-480P-Diffusers"
|
||||||
|
|
||||||
|
# 加载图像编码器
|
||||||
|
image_encoder = CLIPVisionModel.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
subfolder="image_encoder",
|
||||||
|
torch_dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载VAE
|
||||||
|
vae = AutoencoderKLWan.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
subfolder="vae",
|
||||||
|
torch_dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置调度器
|
||||||
|
scheduler = UniPCMultistepScheduler(
|
||||||
|
prediction_type='flow_prediction',
|
||||||
|
use_flow_sigmas=True,
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
flow_shift=3.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建管道
|
||||||
|
app.state.pipe = WanImageToVideoPipeline.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
vae=vae,
|
||||||
|
image_encoder=image_encoder,
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
).to("cuda")
|
||||||
|
app.state.pipe.scheduler = scheduler
|
||||||
|
|
||||||
|
# 初始化任务系统
|
||||||
|
app.state.tasks: Dict[str, dict] = {}
|
||||||
|
app.state.pending_queue: List[str] = []
|
||||||
|
app.state.model_lock = Lock()
|
||||||
|
app.state.task_lock = Lock()
|
||||||
|
app.state.base_url = "ip地址+端口"
|
||||||
|
app.state.semaphore = asyncio.Semaphore(2) # 并发限制
|
||||||
|
|
||||||
|
# 启动后台处理器
|
||||||
|
asyncio.create_task(task_processor())
|
||||||
|
|
||||||
|
print("✅ 系统初始化完成")
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 资源清理
|
||||||
|
if hasattr(app.state, 'pipe'):
|
||||||
|
del app.state.pipe
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("♻️ 资源已释放")
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# FastAPI应用
|
||||||
|
# ======================
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
|
||||||
|
# 认证模块
|
||||||
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 数据模型--查询参数模型
|
||||||
|
# ======================
|
||||||
|
class VideoSubmitRequest(BaseModel):
|
||||||
|
model: str = Field(
|
||||||
|
default="Wan2.1-I2V-14B-480P",
|
||||||
|
description="模型版本"
|
||||||
|
)
|
||||||
|
prompt: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=10,
|
||||||
|
max_length=500,
|
||||||
|
description="视频描述提示词,10-500个字符"
|
||||||
|
)
|
||||||
|
image_url: str = Field(
|
||||||
|
...,
|
||||||
|
description="输入图像URL,需支持HTTP/HTTPS协议"
|
||||||
|
)
|
||||||
|
image_size: str = Field(
|
||||||
|
default="auto",
|
||||||
|
description="输出分辨率,格式:宽x高 或 auto(自动计算)"
|
||||||
|
)
|
||||||
|
negative_prompt: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
max_length=500,
|
||||||
|
description="排除不需要的内容"
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
ge=0,
|
||||||
|
le=2147483647,
|
||||||
|
description="随机数种子,范围0-2147483647"
|
||||||
|
)
|
||||||
|
num_frames: int = Field(
|
||||||
|
default=81,
|
||||||
|
ge=24,
|
||||||
|
le=120,
|
||||||
|
description="视频帧数,24-89帧"
|
||||||
|
)
|
||||||
|
guidance_scale: float = Field(
|
||||||
|
default=3.0,
|
||||||
|
ge=1.0,
|
||||||
|
le=20.0,
|
||||||
|
description="引导系数,1.0-20.0"
|
||||||
|
)
|
||||||
|
infer_steps: int = Field(
|
||||||
|
default=30,
|
||||||
|
ge=20,
|
||||||
|
le=100,
|
||||||
|
description="推理步数,20-100步"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator('image_size')
|
||||||
|
def validate_image_size(cls, v):
|
||||||
|
allowed_sizes = {"480x832", "832x480", "auto"}
|
||||||
|
if v not in allowed_sizes:
|
||||||
|
raise ValueError(f"支持的分辨率: {', '.join(allowed_sizes)}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class VideoStatusRequest(BaseModel):
|
||||||
|
requestId: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=32,
|
||||||
|
max_length=32,
|
||||||
|
description="32位任务ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
class VideoStatusResponse(BaseModel):
|
||||||
|
status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
|
||||||
|
reason: Optional[str] = Field(None, description="失败原因")
|
||||||
|
results: Optional[dict] = Field(None, description="生成结果")
|
||||||
|
queue_position: Optional[int] = Field(None, description="队列位置")
|
||||||
|
|
||||||
|
class VideoCancelRequest(BaseModel):
|
||||||
|
requestId: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=32,
|
||||||
|
max_length=32,
|
||||||
|
description="32位任务ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 核心逻辑
|
||||||
|
# ======================
|
||||||
|
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||||
|
"""统一认证验证"""
|
||||||
|
if not credentials:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "缺少认证头"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
if credentials.scheme != "Bearer":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "无效的认证方案"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
if credentials.credentials not in app.state.valid_api_keys:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "无效的API密钥"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def task_processor():
|
||||||
|
"""任务处理器"""
|
||||||
|
while True:
|
||||||
|
async with app.state.semaphore:
|
||||||
|
task_id = await get_next_task()
|
||||||
|
if task_id:
|
||||||
|
await process_task(task_id)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
async def get_next_task():
|
||||||
|
"""获取下一个任务"""
|
||||||
|
with app.state.task_lock:
|
||||||
|
return app.state.pending_queue.pop(0) if app.state.pending_queue else None
|
||||||
|
|
||||||
|
async def process_task(task_id: str):
|
||||||
|
"""处理单个任务"""
|
||||||
|
task = app.state.tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 更新任务状态
|
||||||
|
task['status'] = 'InProgress'
|
||||||
|
task['started_at'] = int(time.time())
|
||||||
|
print(task['request'].image_url)
|
||||||
|
# 下载输入图像
|
||||||
|
image = await download_image(task['request'].image_url)
|
||||||
|
image_path = f"temp_images/{task_id}.jpg"
|
||||||
|
image.save(image_path)
|
||||||
|
|
||||||
|
# 生成视频
|
||||||
|
video_path = await generate_video(task['request'], task_id, image)
|
||||||
|
|
||||||
|
# 生成下载链接
|
||||||
|
download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task.update({
|
||||||
|
'status': 'Succeed',
|
||||||
|
'download_url': download_url,
|
||||||
|
'completed_at': int(time.time())
|
||||||
|
})
|
||||||
|
|
||||||
|
# 安排清理
|
||||||
|
asyncio.create_task(cleanup_files([image_path, video_path]))
|
||||||
|
except Exception as e:
|
||||||
|
handle_task_error(task, e)
|
||||||
|
|
||||||
|
def handle_task_error(task: dict, error: Exception):
|
||||||
|
"""错误处理(包含详细错误信息)"""
|
||||||
|
error_msg = str(error)
|
||||||
|
|
||||||
|
# 1. 显存不足错误
|
||||||
|
if isinstance(error, torch.cuda.OutOfMemoryError):
|
||||||
|
error_msg = "显存不足,请降低分辨率"
|
||||||
|
|
||||||
|
# 2. 网络请求相关错误
|
||||||
|
elif isinstance(error, (RequestException, HTTPException)):
|
||||||
|
# 从异常中提取具体信息
|
||||||
|
if isinstance(error, HTTPException):
|
||||||
|
# 如果是 HTTPException,获取其 detail 字段
|
||||||
|
error_detail = getattr(error, "detail", "")
|
||||||
|
error_msg = f"图像下载失败: {error_detail}"
|
||||||
|
|
||||||
|
elif isinstance(error, Timeout):
|
||||||
|
error_msg = "图像下载超时,请检查网络"
|
||||||
|
|
||||||
|
elif isinstance(error, ConnectionError):
|
||||||
|
error_msg = "无法连接到服务器,请检查 URL"
|
||||||
|
|
||||||
|
elif isinstance(error, HTTPError):
|
||||||
|
# requests 的 HTTPError(例如 4xx/5xx 状态码)
|
||||||
|
status_code = error.response.status_code
|
||||||
|
error_msg = f"服务器返回错误状态码: {status_code}"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 其他 RequestException 错误
|
||||||
|
error_msg = f"图像下载失败: {str(error)}"
|
||||||
|
|
||||||
|
# 3. 其他未知错误
|
||||||
|
else:
|
||||||
|
error_msg = f"未知错误: {str(error)}"
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task.update({
|
||||||
|
'status': 'Failed',
|
||||||
|
'reason': error_msg,
|
||||||
|
'completed_at': int(time.time())
|
||||||
|
})
|
||||||
|
# ======================
|
||||||
|
# 视频生成逻辑
|
||||||
|
# ======================
|
||||||
|
async def download_image(url: str) -> Image.Image:
|
||||||
|
"""异步下载图像(包含详细错误信息)"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
try:
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: requests.get(url) # 将 timeout 传递给 requests.get
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果状态码非 200,主动抛出 HTTPException
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status_code,
|
||||||
|
detail=f"服务器返回状态码 {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return Image.open(BytesIO(response.content)).convert("RGB")
|
||||||
|
|
||||||
|
except RequestException as e:
|
||||||
|
# 将原始 requests 错误信息抛出
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"请求失败: {str(e)}"
|
||||||
|
)
|
||||||
|
async def generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
|
||||||
|
"""异步生成入口"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
sync_generate_video,
|
||||||
|
request,
|
||||||
|
task_id,
|
||||||
|
image
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
|
||||||
|
"""同步生成核心"""
|
||||||
|
with app.state.model_lock:
|
||||||
|
try:
|
||||||
|
# 解析分辨率
|
||||||
|
mod_value = 16 # 模型要求的模数
|
||||||
|
print(request.image_size)
|
||||||
|
print('--------------------------------')
|
||||||
|
if request.image_size == "auto":
|
||||||
|
# 原版自动计算逻辑
|
||||||
|
aspect_ratio = image.height / image.width
|
||||||
|
print(image.height,image.width)
|
||||||
|
max_area = 399360 # 模型基础分辨率
|
||||||
|
|
||||||
|
# 计算理想尺寸
|
||||||
|
height = round(np.sqrt(max_area * aspect_ratio))
|
||||||
|
width = round(np.sqrt(max_area / aspect_ratio))
|
||||||
|
|
||||||
|
# 应用模数调整
|
||||||
|
height = height // mod_value * mod_value
|
||||||
|
width = width // mod_value * mod_value
|
||||||
|
resized_image = image.resize((width, height))
|
||||||
|
else:
|
||||||
|
width_str, height_str = request.image_size.split('x')
|
||||||
|
width = int(width_str)
|
||||||
|
height = int(height_str)
|
||||||
|
mod_value = 16
|
||||||
|
# 调整图像尺寸
|
||||||
|
resized_image = image.resize((width, height))
|
||||||
|
|
||||||
|
|
||||||
|
# 设置随机种子
|
||||||
|
generator = None
|
||||||
|
# 修改点1: 使用属性访问seed
|
||||||
|
if request.seed is not None:
|
||||||
|
generator = torch.Generator(device="cuda")
|
||||||
|
generator.manual_seed(request.seed) # 修改点2
|
||||||
|
print(f"🔮 使用随机种子: {request.seed}")
|
||||||
|
print(resized_image)
|
||||||
|
print(height,width)
|
||||||
|
|
||||||
|
# 执行推理
|
||||||
|
output = app.state.pipe(
|
||||||
|
image=resized_image,
|
||||||
|
prompt=request.prompt,
|
||||||
|
negative_prompt=request.negative_prompt,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=request.num_frames,
|
||||||
|
guidance_scale=request.guidance_scale,
|
||||||
|
num_inference_steps=request.infer_steps,
|
||||||
|
generator=generator
|
||||||
|
).frames[0]
|
||||||
|
|
||||||
|
# 导出视频
|
||||||
|
video_id = uuid.uuid4().hex
|
||||||
|
output_path = f"generated_videos/{video_id}.mp4"
|
||||||
|
export_to_video(output, output_path, fps=16)
|
||||||
|
return output_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"视频生成失败: {str(e)}") from e
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# API端点
|
||||||
|
# ======================
|
||||||
|
@app.post("/video/submit",
|
||||||
|
response_model=dict,
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
tags=["视频生成"])
|
||||||
|
async def submit_task(
|
||||||
|
request: VideoSubmitRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""提交生成任务"""
|
||||||
|
# 参数验证
|
||||||
|
if request.image_url is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail={"status": "Failed", "reason": "需要图像URL参数"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建任务记录
|
||||||
|
task_id = uuid.uuid4().hex
|
||||||
|
with app.state.task_lock:
|
||||||
|
app.state.tasks[task_id] = {
|
||||||
|
"request": request,
|
||||||
|
"status": "InQueue",
|
||||||
|
"created_at": int(time.time())
|
||||||
|
}
|
||||||
|
app.state.pending_queue.append(task_id)
|
||||||
|
|
||||||
|
return {"requestId": task_id}
|
||||||
|
|
||||||
|
@app.post("/video/status",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
tags=["视频生成"])
|
||||||
|
async def get_status(
|
||||||
|
request: VideoStatusRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""查询任务状态"""
|
||||||
|
task = app.state.tasks.get(request.requestId)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={"status": "Failed", "reason": "无效的任务ID"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算队列位置(仅当在队列中时)
|
||||||
|
queue_pos = 0
|
||||||
|
if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
|
||||||
|
queue_pos = app.state.pending_queue.index(request.requestId) + 1
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"status": task['status'],
|
||||||
|
"reason": task.get('reason'),
|
||||||
|
"queue_position": queue_pos if task['status'] == "InQueue" else None # 非排队状态返回null
|
||||||
|
}
|
||||||
|
|
||||||
|
# 成功状态的特殊处理
|
||||||
|
if task['status'] == "Succeed":
|
||||||
|
response["results"] = {
|
||||||
|
"videos": [{"url": task['download_url']}],
|
||||||
|
"timings": {
|
||||||
|
"inference": task['completed_at'] - task['started_at']
|
||||||
|
},
|
||||||
|
"seed": task['request'].seed
|
||||||
|
}
|
||||||
|
# 取消状态的补充信息
|
||||||
|
elif task['status'] == "Cancelled":
|
||||||
|
response["reason"] = task.get('reason', "用户主动取消") # 确保原因字段存在
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@app.post("/video/cancel",
|
||||||
|
response_model=dict,
|
||||||
|
tags=["视频生成"])
|
||||||
|
async def cancel_task(
|
||||||
|
request: VideoCancelRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""取消排队中的生成任务"""
|
||||||
|
task_id = request.requestId
|
||||||
|
|
||||||
|
with app.state.task_lock:
|
||||||
|
task = app.state.tasks.get(task_id)
|
||||||
|
|
||||||
|
# 检查任务是否存在
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={"status": "Failed", "reason": "无效的任务ID"}
|
||||||
|
)
|
||||||
|
|
||||||
|
current_status = task['status']
|
||||||
|
|
||||||
|
# 仅允许取消排队中的任务
|
||||||
|
if current_status != "InQueue":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从队列移除
|
||||||
|
try:
|
||||||
|
app.state.pending_queue.remove(task_id)
|
||||||
|
except ValueError:
|
||||||
|
pass # 可能已被处理
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task.update({
|
||||||
|
"status": "Cancelled",
|
||||||
|
"reason": "用户主动取消",
|
||||||
|
"completed_at": int(time.time())
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"status": "Succeed"}
|
||||||
|
|
||||||
|
async def cleanup_files(paths: List[str], delay: int = 3600):
|
||||||
|
"""定时清理文件"""
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
for path in paths:
|
||||||
|
try:
|
||||||
|
if os.path.exists(path):
|
||||||
|
os.remove(path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"清理失败 {path}: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8088)
|
||||||
450
t2v-api.py
Normal file
450
t2v-api.py
Normal file
@ -0,0 +1,450 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from enum import Enum
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Optional, Dict, List
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI, HTTPException, status, Depends
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from pydantic import BaseModel, Field, field_validator, ValidationError
|
||||||
|
from diffusers.utils import export_to_video
|
||||||
|
from diffusers import AutoencoderKLWan, WanPipeline
|
||||||
|
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||||
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
# 创建视频存储目录
|
||||||
|
os.makedirs("generated_videos", exist_ok=True)
|
||||||
|
|
||||||
|
# 生命周期管理器
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""管理应用生命周期"""
|
||||||
|
# 初始化模型和资源
|
||||||
|
try:
|
||||||
|
# 初始化认证密钥
|
||||||
|
app.state.valid_api_keys = {
|
||||||
|
"密钥"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 初始化视频生成模型
|
||||||
|
model_id = "./Wan2.1-T2V-1.3B-Diffusers"
|
||||||
|
vae = AutoencoderKLWan.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
subfolder="vae",
|
||||||
|
torch_dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = UniPCMultistepScheduler(
|
||||||
|
prediction_type='flow_prediction',
|
||||||
|
use_flow_sigmas=True,
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
flow_shift=3.0
|
||||||
|
)
|
||||||
|
|
||||||
|
app.state.pipe = WanPipeline.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
vae=vae,
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
).to("cuda")
|
||||||
|
app.state.pipe.scheduler = scheduler
|
||||||
|
|
||||||
|
# 初始化任务系统
|
||||||
|
app.state.tasks: Dict[str, dict] = {}
|
||||||
|
app.state.pending_queue: List[str] = []
|
||||||
|
app.state.model_lock = Lock()
|
||||||
|
app.state.task_lock = Lock()
|
||||||
|
app.state.base_url = "ip地址+端口"
|
||||||
|
app.state.max_concurrent = 2
|
||||||
|
app.state.semaphore = asyncio.Semaphore(app.state.max_concurrent)
|
||||||
|
|
||||||
|
# 启动后台任务处理器
|
||||||
|
asyncio.create_task(task_processor())
|
||||||
|
|
||||||
|
print("✅ 应用初始化完成")
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理资源
|
||||||
|
if hasattr(app.state, 'pipe'):
|
||||||
|
del app.state.pipe
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("♻️ 已释放模型资源")
|
||||||
|
|
||||||
|
# 创建FastAPI应用
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
|
||||||
|
|
||||||
|
# 认证模块
|
||||||
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 数据模型--查询参数模型
|
||||||
|
# ======================
|
||||||
|
class VideoSubmitRequest(BaseModel):
|
||||||
|
model: str = Field(default="Wan2.1-T2V-1.3B",description="使用的模型版本")
|
||||||
|
prompt: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=10,
|
||||||
|
max_length=500,
|
||||||
|
description="视频描述提示词,10-500个字符"
|
||||||
|
)
|
||||||
|
image_size: str = Field(
|
||||||
|
...,
|
||||||
|
description="视频分辨率,仅支持480x832或832x480"
|
||||||
|
)
|
||||||
|
negative_prompt: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
max_length=500,
|
||||||
|
description="排除不需要的内容"
|
||||||
|
)
|
||||||
|
seed: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
ge=0,
|
||||||
|
le=2147483647,
|
||||||
|
description="随机数种子,范围0-2147483647"
|
||||||
|
)
|
||||||
|
num_frames: int = Field(
|
||||||
|
default=81,
|
||||||
|
ge=24,
|
||||||
|
le=120,
|
||||||
|
description="视频帧数,24-120帧"
|
||||||
|
)
|
||||||
|
guidance_scale: float = Field(
|
||||||
|
default=5.0,
|
||||||
|
ge=1.0,
|
||||||
|
le=20.0,
|
||||||
|
description="引导系数,1.0-20.0"
|
||||||
|
)
|
||||||
|
infer_steps: int = Field(
|
||||||
|
default=50,
|
||||||
|
ge=20,
|
||||||
|
le=100,
|
||||||
|
description="推理步数,20-100步"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator('image_size', mode='before')
|
||||||
|
@classmethod
|
||||||
|
def validate_image_size(cls, value):
|
||||||
|
allowed_sizes = {"480x832", "832x480"}
|
||||||
|
if value not in allowed_sizes:
|
||||||
|
raise ValueError(f"仅支持以下分辨率: {', '.join(allowed_sizes)}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
class VideoStatusRequest(BaseModel):
|
||||||
|
requestId: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=32,
|
||||||
|
max_length=32,
|
||||||
|
description="32位任务ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
class VideoSubmitResponse(BaseModel):
|
||||||
|
requestId: str
|
||||||
|
|
||||||
|
class VideoStatusResponse(BaseModel):
|
||||||
|
status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
|
||||||
|
reason: Optional[str] = Field(None, description="失败原因")
|
||||||
|
results: Optional[dict] = Field(None, description="生成结果")
|
||||||
|
queue_position: Optional[int] = Field(None, description="队列位置")
|
||||||
|
|
||||||
|
class VideoCancelRequest(BaseModel):
|
||||||
|
requestId: str = Field(
|
||||||
|
...,
|
||||||
|
min_length=32,
|
||||||
|
max_length=32,
|
||||||
|
description="32位任务ID"
|
||||||
|
)
|
||||||
|
|
||||||
|
# # 自定义HTTP异常处理器
|
||||||
|
# @app.exception_handler(HTTPException)
|
||||||
|
# async def http_exception_handler(request, exc):
|
||||||
|
# return JSONResponse(
|
||||||
|
# status_code=exc.status_code,
|
||||||
|
# content=exc.detail, # 直接返回detail内容(不再包装在detail字段)
|
||||||
|
# headers=exc.headers
|
||||||
|
# )
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 后台任务处理
|
||||||
|
# ======================
|
||||||
|
async def task_processor():
|
||||||
|
"""处理任务队列"""
|
||||||
|
while True:
|
||||||
|
async with app.state.semaphore:
|
||||||
|
task_id = await get_next_task()
|
||||||
|
if task_id:
|
||||||
|
await process_task(task_id)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
async def get_next_task():
|
||||||
|
"""获取下一个待处理任务"""
|
||||||
|
with app.state.task_lock:
|
||||||
|
if app.state.pending_queue:
|
||||||
|
return app.state.pending_queue.pop(0)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def process_task(task_id: str):
|
||||||
|
"""处理单个任务"""
|
||||||
|
task = app.state.tasks.get(task_id)
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 更新任务状态
|
||||||
|
task['status'] = 'InProgress'
|
||||||
|
task['started_at'] = int(time.time())
|
||||||
|
|
||||||
|
# 执行视频生成
|
||||||
|
video_path = await generate_video(task['request'], task_id)
|
||||||
|
|
||||||
|
# 生成下载链接
|
||||||
|
download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task.update({
|
||||||
|
'status': 'Succeed',
|
||||||
|
'download_url': download_url,
|
||||||
|
'completed_at': int(time.time())
|
||||||
|
})
|
||||||
|
|
||||||
|
# 安排自动清理
|
||||||
|
asyncio.create_task(auto_cleanup(video_path))
|
||||||
|
except Exception as e:
|
||||||
|
handle_task_error(task, e)
|
||||||
|
|
||||||
|
def handle_task_error(task: dict, error: Exception):
|
||||||
|
"""统一处理任务错误"""
|
||||||
|
error_msg = str(error)
|
||||||
|
if isinstance(error, torch.cuda.OutOfMemoryError):
|
||||||
|
error_msg = "显存不足,请降低分辨率或减少帧数"
|
||||||
|
elif isinstance(error, ValidationError):
|
||||||
|
error_msg = "参数校验失败: " + str(error)
|
||||||
|
|
||||||
|
task.update({
|
||||||
|
'status': 'Failed',
|
||||||
|
'reason': error_msg,
|
||||||
|
'completed_at': int(time.time())
|
||||||
|
})
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 视频生成核心逻辑
|
||||||
|
# ======================
|
||||||
|
async def generate_video(request: dict, task_id: str) -> str:
|
||||||
|
"""异步执行视频生成"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
sync_generate_video,
|
||||||
|
request,
|
||||||
|
task_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_generate_video(request: dict, task_id: str) -> str:
|
||||||
|
"""同步生成视频"""
|
||||||
|
with app.state.model_lock:
|
||||||
|
try:
|
||||||
|
generator = None
|
||||||
|
if request.get('seed') is not None:
|
||||||
|
generator = torch.Generator(device="cuda")
|
||||||
|
generator.manual_seed(request['seed'])
|
||||||
|
print(f"🔮 使用随机种子: {request['seed']}")
|
||||||
|
|
||||||
|
# 执行模型推理
|
||||||
|
result = app.state.pipe(
|
||||||
|
prompt=request['prompt'],
|
||||||
|
negative_prompt=request['negative_prompt'],
|
||||||
|
height=request['height'],
|
||||||
|
width=request['width'],
|
||||||
|
num_frames=request['num_frames'],
|
||||||
|
guidance_scale=request['guidance_scale'],
|
||||||
|
num_inference_steps=request['infer_steps'],
|
||||||
|
generator=generator
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导出视频文件
|
||||||
|
video_id = uuid.uuid4().hex
|
||||||
|
output_path = f"generated_videos/{video_id}.mp4"
|
||||||
|
export_to_video(result.frames[0], output_path, fps=16)
|
||||||
|
|
||||||
|
return output_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"视频生成失败: {str(e)}") from e
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# API端点
|
||||||
|
# ======================
|
||||||
|
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||||
|
"""认证验证"""
|
||||||
|
if not credentials:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "缺少认证头"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
if credentials.scheme != "Bearer":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "无效的认证方案"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
if credentials.credentials not in app.state.valid_api_keys:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={"status": "Failed", "reason": "无效的API密钥"},
|
||||||
|
headers={"WWW-Authenticate": "Bearer"}
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
@app.post("/video/submit",
|
||||||
|
response_model=VideoSubmitResponse,
|
||||||
|
status_code=status.HTTP_202_ACCEPTED,
|
||||||
|
tags=["视频生成"],
|
||||||
|
summary="提交视频生成请求")
|
||||||
|
async def submit_video_task(
|
||||||
|
request: VideoSubmitRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""提交新的视频生成任务"""
|
||||||
|
try:
|
||||||
|
# 解析分辨率参数
|
||||||
|
width, height = map(int, request.image_size.split('x'))
|
||||||
|
|
||||||
|
# 创建任务记录
|
||||||
|
task_id = uuid.uuid4().hex
|
||||||
|
task_data = {
|
||||||
|
'request': {
|
||||||
|
'prompt': request.prompt,
|
||||||
|
'negative_prompt': request.negative_prompt,
|
||||||
|
'width': width,
|
||||||
|
'height': height,
|
||||||
|
'num_frames': request.num_frames,
|
||||||
|
'guidance_scale': request.guidance_scale,
|
||||||
|
'infer_steps': request.infer_steps,
|
||||||
|
'seed': request.seed
|
||||||
|
},
|
||||||
|
'status': 'InQueue',
|
||||||
|
'created_at': int(time.time())
|
||||||
|
}
|
||||||
|
|
||||||
|
# 加入任务队列
|
||||||
|
with app.state.task_lock:
|
||||||
|
app.state.tasks[task_id] = task_data
|
||||||
|
app.state.pending_queue.append(task_id)
|
||||||
|
|
||||||
|
return {"requestId": task_id}
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail={"status": "Failed", "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.post("/video/status",
|
||||||
|
response_model=VideoStatusResponse,
|
||||||
|
tags=["视频生成"],
|
||||||
|
summary="查询任务状态")
|
||||||
|
async def get_video_status(
|
||||||
|
request: VideoStatusRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""查询任务状态"""
|
||||||
|
task = app.state.tasks.get(request.requestId)
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={"status": "Failed", "reason": "无效的任务ID"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算队列位置(仅当在队列中时)
|
||||||
|
queue_pos = 0
|
||||||
|
if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
|
||||||
|
queue_pos = app.state.pending_queue.index(request.requestId) + 1
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"status": task['status'],
|
||||||
|
"reason": task.get('reason'),
|
||||||
|
"queue_position": queue_pos if task['status'] == "InQueue" else None # 非排队状态返回null
|
||||||
|
}
|
||||||
|
|
||||||
|
# 成功状态的特殊处理
|
||||||
|
if task['status'] == "Succeed":
|
||||||
|
response["results"] = {
|
||||||
|
"videos": [{"url": task['download_url']}],
|
||||||
|
"timings": {
|
||||||
|
"inference": task['completed_at'] - task['started_at']
|
||||||
|
},
|
||||||
|
"seed": task['request']['seed']
|
||||||
|
}
|
||||||
|
# 取消状态的补充信息
|
||||||
|
elif task['status'] == "Cancelled":
|
||||||
|
response["reason"] = task.get('reason', "用户主动取消") # 确保原因字段存在
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/video/cancel",
|
||||||
|
response_model=dict,
|
||||||
|
tags=["视频生成"])
|
||||||
|
async def cancel_task(
|
||||||
|
request: VideoCancelRequest,
|
||||||
|
auth: bool = Depends(verify_auth)
|
||||||
|
):
|
||||||
|
"""取消排队中的生成任务"""
|
||||||
|
task_id = request.requestId
|
||||||
|
|
||||||
|
with app.state.task_lock:
|
||||||
|
task = app.state.tasks.get(task_id)
|
||||||
|
|
||||||
|
# 检查任务是否存在
|
||||||
|
if not task:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={"status": "Failed", "reason": "无效的任务ID"}
|
||||||
|
)
|
||||||
|
|
||||||
|
current_status = task['status']
|
||||||
|
|
||||||
|
# 仅允许取消排队中的任务
|
||||||
|
if current_status != "InQueue":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从队列移除
|
||||||
|
try:
|
||||||
|
app.state.pending_queue.remove(task_id)
|
||||||
|
except ValueError:
|
||||||
|
pass # 可能已被处理
|
||||||
|
|
||||||
|
# 更新任务状态
|
||||||
|
task.update({
|
||||||
|
"status": "Cancelled",
|
||||||
|
"reason": "用户主动取消",
|
||||||
|
"completed_at": int(time.time())
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"status": "Succeed"}
|
||||||
|
|
||||||
|
|
||||||
|
# ======================
|
||||||
|
# 工具函数
|
||||||
|
# ======================
|
||||||
|
async def auto_cleanup(file_path: str, delay: int = 3600):
|
||||||
|
"""自动清理生成的视频文件"""
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
print(f"已清理文件: {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"文件清理失败: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8088)
|
||||||
@ -105,9 +105,16 @@ function i2v_14B_720p() {
|
|||||||
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function vace_1_3B() {
|
||||||
|
VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/"
|
||||||
|
torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
t2i_14B
|
t2i_14B
|
||||||
t2v_1_3B
|
t2v_1_3B
|
||||||
t2v_14B
|
t2v_14B
|
||||||
i2v_14B_480p
|
i2v_14B_480p
|
||||||
i2v_14B_720p
|
i2v_14B_720p
|
||||||
|
vace_1_3B
|
||||||
|
|||||||
@ -2,3 +2,4 @@ from . import configs, distributed, modules
|
|||||||
from .image2video import WanI2V
|
from .image2video import WanI2V
|
||||||
from .text2video import WanT2V
|
from .text2video import WanT2V
|
||||||
from .first_last_frame2video import WanFLF2V
|
from .first_last_frame2video import WanFLF2V
|
||||||
|
from .vace import WanVace, WanVaceMP
|
||||||
|
|||||||
@ -22,7 +22,9 @@ WAN_CONFIGS = {
|
|||||||
't2v-1.3B': t2v_1_3B,
|
't2v-1.3B': t2v_1_3B,
|
||||||
'i2v-14B': i2v_14B,
|
'i2v-14B': i2v_14B,
|
||||||
't2i-14B': t2i_14B,
|
't2i-14B': t2i_14B,
|
||||||
'flf2v-14B': flf2v_14B
|
'flf2v-14B': flf2v_14B,
|
||||||
|
'vace-1.3B': t2v_1_3B,
|
||||||
|
'vace-14B': t2v_14B,
|
||||||
}
|
}
|
||||||
|
|
||||||
SIZE_CONFIGS = {
|
SIZE_CONFIGS = {
|
||||||
@ -46,4 +48,6 @@ SUPPORTED_SIZES = {
|
|||||||
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||||
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
||||||
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
||||||
|
'vace-1.3B': ('480*832', '832*480'),
|
||||||
|
'vace-14B': ('720*1280', '1280*720', '480*832', '832*480')
|
||||||
}
|
}
|
||||||
|
|||||||
@ -63,12 +63,45 @@ def rope_apply(x, grid_sizes, freqs):
|
|||||||
return torch.stack(output).float()
|
return torch.stack(output).float()
|
||||||
|
|
||||||
|
|
||||||
|
def usp_dit_forward_vace(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
vace_context,
|
||||||
|
seq_len,
|
||||||
|
kwargs
|
||||||
|
):
|
||||||
|
# embeddings
|
||||||
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
c = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in c
|
||||||
|
])
|
||||||
|
|
||||||
|
# arguments
|
||||||
|
new_kwargs = dict(x=x)
|
||||||
|
new_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
# Context Parallel
|
||||||
|
c = torch.chunk(
|
||||||
|
c, get_sequence_parallel_world_size(),
|
||||||
|
dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
hints = []
|
||||||
|
for block in self.vace_blocks:
|
||||||
|
c, c_skip = block(c, **new_kwargs)
|
||||||
|
hints.append(c_skip)
|
||||||
|
return hints
|
||||||
|
|
||||||
|
|
||||||
def usp_dit_forward(
|
def usp_dit_forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
t,
|
t,
|
||||||
context,
|
context,
|
||||||
seq_len,
|
seq_len,
|
||||||
|
vace_context=None,
|
||||||
|
vace_context_scale=1.0,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
y=None,
|
y=None,
|
||||||
):
|
):
|
||||||
@ -84,7 +117,7 @@ def usp_dit_forward(
|
|||||||
if self.freqs.device != device:
|
if self.freqs.device != device:
|
||||||
self.freqs = self.freqs.to(device)
|
self.freqs = self.freqs.to(device)
|
||||||
|
|
||||||
if y is not None:
|
if self.model_type != 'vace' and y is not None:
|
||||||
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
@ -114,7 +147,7 @@ def usp_dit_forward(
|
|||||||
for u in context
|
for u in context
|
||||||
]))
|
]))
|
||||||
|
|
||||||
if clip_fea is not None:
|
if self.model_type != 'vace' and clip_fea is not None:
|
||||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
context = torch.concat([context_clip, context], dim=1)
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
|
||||||
@ -132,6 +165,11 @@ def usp_dit_forward(
|
|||||||
x, get_sequence_parallel_world_size(),
|
x, get_sequence_parallel_world_size(),
|
||||||
dim=1)[get_sequence_parallel_rank()]
|
dim=1)[get_sequence_parallel_rank()]
|
||||||
|
|
||||||
|
if self.model_type == 'vace':
|
||||||
|
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
|
||||||
|
kwargs['hints'] = hints
|
||||||
|
kwargs['context_scale'] = vace_context_scale
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, **kwargs)
|
x = block(x, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@ -2,11 +2,13 @@ from .attention import flash_attention
|
|||||||
from .model import WanModel
|
from .model import WanModel
|
||||||
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
||||||
from .tokenizers import HuggingfaceTokenizer
|
from .tokenizers import HuggingfaceTokenizer
|
||||||
|
from .vace_model import VaceWanModel
|
||||||
from .vae import WanVAE
|
from .vae import WanVAE
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'WanVAE',
|
'WanVAE',
|
||||||
'WanModel',
|
'WanModel',
|
||||||
|
'VaceWanModel',
|
||||||
'T5Model',
|
'T5Model',
|
||||||
'T5Encoder',
|
'T5Encoder',
|
||||||
'T5Decoder',
|
'T5Decoder',
|
||||||
|
|||||||
@ -400,7 +400,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type (`str`, *optional*, defaults to 't2v'):
|
model_type (`str`, *optional*, defaults to 't2v'):
|
||||||
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video)
|
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
|
||||||
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
||||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
||||||
text_len (`int`, *optional*, defaults to 512):
|
text_len (`int`, *optional*, defaults to 512):
|
||||||
@ -433,7 +433,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
assert model_type in ['t2v', 'i2v', 'flf2v']
|
assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
|
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
|||||||
233
wan/modules/vace_model.py
Normal file
233
wan/modules/vace_model.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import torch
|
||||||
|
import torch.cuda.amp as amp
|
||||||
|
import torch.nn as nn
|
||||||
|
from diffusers.configuration_utils import register_to_config
|
||||||
|
from .model import WanModel, WanAttentionBlock, sinusoidal_embedding_1d
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cross_attn_type,
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=False,
|
||||||
|
eps=1e-6,
|
||||||
|
block_id=0
|
||||||
|
):
|
||||||
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
||||||
|
self.block_id = block_id
|
||||||
|
if block_id == 0:
|
||||||
|
self.before_proj = nn.Linear(self.dim, self.dim)
|
||||||
|
nn.init.zeros_(self.before_proj.weight)
|
||||||
|
nn.init.zeros_(self.before_proj.bias)
|
||||||
|
self.after_proj = nn.Linear(self.dim, self.dim)
|
||||||
|
nn.init.zeros_(self.after_proj.weight)
|
||||||
|
nn.init.zeros_(self.after_proj.bias)
|
||||||
|
|
||||||
|
def forward(self, c, x, **kwargs):
|
||||||
|
if self.block_id == 0:
|
||||||
|
c = self.before_proj(c) + x
|
||||||
|
|
||||||
|
c = super().forward(c, **kwargs)
|
||||||
|
c_skip = self.after_proj(c)
|
||||||
|
return c, c_skip
|
||||||
|
|
||||||
|
|
||||||
|
class BaseWanAttentionBlock(WanAttentionBlock):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cross_attn_type,
|
||||||
|
dim,
|
||||||
|
ffn_dim,
|
||||||
|
num_heads,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=False,
|
||||||
|
eps=1e-6,
|
||||||
|
block_id=None
|
||||||
|
):
|
||||||
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
||||||
|
self.block_id = block_id
|
||||||
|
|
||||||
|
def forward(self, x, hints, context_scale=1.0, **kwargs):
|
||||||
|
x = super().forward(x, **kwargs)
|
||||||
|
if self.block_id is not None:
|
||||||
|
x = x + hints[self.block_id] * context_scale
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VaceWanModel(WanModel):
|
||||||
|
@register_to_config
|
||||||
|
def __init__(self,
|
||||||
|
vace_layers=None,
|
||||||
|
vace_in_dim=None,
|
||||||
|
model_type='vace',
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
text_len=512,
|
||||||
|
in_dim=16,
|
||||||
|
dim=2048,
|
||||||
|
ffn_dim=8192,
|
||||||
|
freq_dim=256,
|
||||||
|
text_dim=4096,
|
||||||
|
out_dim=16,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=32,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6):
|
||||||
|
super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
|
||||||
|
num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
|
||||||
|
|
||||||
|
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
|
||||||
|
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
|
||||||
|
|
||||||
|
assert 0 in self.vace_layers
|
||||||
|
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
||||||
|
|
||||||
|
# blocks
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
||||||
|
self.cross_attn_norm, self.eps,
|
||||||
|
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
|
||||||
|
for i in range(self.num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
# vace blocks
|
||||||
|
self.vace_blocks = nn.ModuleList([
|
||||||
|
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
||||||
|
self.cross_attn_norm, self.eps, block_id=i)
|
||||||
|
for i in self.vace_layers
|
||||||
|
])
|
||||||
|
|
||||||
|
# vace patch embeddings
|
||||||
|
self.vace_patch_embedding = nn.Conv3d(
|
||||||
|
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_vace(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
vace_context,
|
||||||
|
seq_len,
|
||||||
|
kwargs
|
||||||
|
):
|
||||||
|
# embeddings
|
||||||
|
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||||
|
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||||
|
c = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in c
|
||||||
|
])
|
||||||
|
|
||||||
|
# arguments
|
||||||
|
new_kwargs = dict(x=x)
|
||||||
|
new_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
hints = []
|
||||||
|
for block in self.vace_blocks:
|
||||||
|
c, c_skip = block(c, **new_kwargs)
|
||||||
|
hints.append(c_skip)
|
||||||
|
return hints
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
vace_context,
|
||||||
|
context,
|
||||||
|
seq_len,
|
||||||
|
vace_context_scale=1.0,
|
||||||
|
clip_fea=None,
|
||||||
|
y=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Forward pass through the diffusion model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (List[Tensor]):
|
||||||
|
List of input video tensors, each with shape [C_in, F, H, W]
|
||||||
|
t (Tensor):
|
||||||
|
Diffusion timesteps tensor of shape [B]
|
||||||
|
context (List[Tensor]):
|
||||||
|
List of text embeddings each with shape [L, C]
|
||||||
|
seq_len (`int`):
|
||||||
|
Maximum sequence length for positional encoding
|
||||||
|
clip_fea (Tensor, *optional*):
|
||||||
|
CLIP image features for image-to-video mode
|
||||||
|
y (List[Tensor], *optional*):
|
||||||
|
Conditional video inputs for image-to-video mode, same shape as x
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tensor]:
|
||||||
|
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
||||||
|
"""
|
||||||
|
# if self.model_type == 'i2v':
|
||||||
|
# assert clip_fea is not None and y is not None
|
||||||
|
# params
|
||||||
|
device = self.patch_embedding.weight.device
|
||||||
|
if self.freqs.device != device:
|
||||||
|
self.freqs = self.freqs.to(device)
|
||||||
|
|
||||||
|
# if y is not None:
|
||||||
|
# x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
||||||
|
|
||||||
|
# embeddings
|
||||||
|
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
||||||
|
grid_sizes = torch.stack(
|
||||||
|
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
||||||
|
x = [u.flatten(2).transpose(1, 2) for u in x]
|
||||||
|
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
||||||
|
assert seq_lens.max() <= seq_len
|
||||||
|
x = torch.cat([
|
||||||
|
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
||||||
|
dim=1) for u in x
|
||||||
|
])
|
||||||
|
|
||||||
|
# time embeddings
|
||||||
|
with amp.autocast(dtype=torch.float32):
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
||||||
|
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
||||||
|
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
||||||
|
|
||||||
|
# context
|
||||||
|
context_lens = None
|
||||||
|
context = self.text_embedding(
|
||||||
|
torch.stack([
|
||||||
|
torch.cat(
|
||||||
|
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||||
|
for u in context
|
||||||
|
]))
|
||||||
|
|
||||||
|
# if clip_fea is not None:
|
||||||
|
# context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||||
|
# context = torch.concat([context_clip, context], dim=1)
|
||||||
|
|
||||||
|
# arguments
|
||||||
|
kwargs = dict(
|
||||||
|
e=e0,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
grid_sizes=grid_sizes,
|
||||||
|
freqs=self.freqs,
|
||||||
|
context=context,
|
||||||
|
context_lens=context_lens)
|
||||||
|
|
||||||
|
hints = self.forward_vace(x, vace_context, seq_len, kwargs)
|
||||||
|
kwargs['hints'] = hints
|
||||||
|
kwargs['context_scale'] = vace_context_scale
|
||||||
|
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, **kwargs)
|
||||||
|
|
||||||
|
# head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return [u.float() for u in x]
|
||||||
@ -1,8 +1,10 @@
|
|||||||
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
|
from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas,
|
||||||
retrieve_timesteps)
|
retrieve_timesteps)
|
||||||
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||||
|
from .vace_processor import VaceVideoProcessor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
||||||
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
|
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
|
||||||
|
'VaceVideoProcessor'
|
||||||
]
|
]
|
||||||
|
|||||||
270
wan/utils/vace_processor.py
Normal file
270
wan/utils/vace_processor.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
|
||||||
|
|
||||||
|
class VaceImageProcessor(object):
|
||||||
|
def __init__(self, downsample=None, seq_len=None):
|
||||||
|
self.downsample = downsample
|
||||||
|
self.seq_len = seq_len
|
||||||
|
|
||||||
|
def _pillow_convert(self, image, cvt_type='RGB'):
|
||||||
|
if image.mode != cvt_type:
|
||||||
|
if image.mode == 'P':
|
||||||
|
image = image.convert(f'{cvt_type}A')
|
||||||
|
if image.mode == f'{cvt_type}A':
|
||||||
|
bg = Image.new(cvt_type,
|
||||||
|
size=(image.width, image.height),
|
||||||
|
color=(255, 255, 255))
|
||||||
|
bg.paste(image, (0, 0), mask=image)
|
||||||
|
image = bg
|
||||||
|
else:
|
||||||
|
image = image.convert(cvt_type)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _load_image(self, img_path):
|
||||||
|
if img_path is None or img_path == '':
|
||||||
|
return None
|
||||||
|
img = Image.open(img_path)
|
||||||
|
img = self._pillow_convert(img)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def _resize_crop(self, img, oh, ow, normalize=True):
|
||||||
|
"""
|
||||||
|
Resize, center crop, convert to tensor, and normalize.
|
||||||
|
"""
|
||||||
|
# resize and crop
|
||||||
|
iw, ih = img.size
|
||||||
|
if iw != ow or ih != oh:
|
||||||
|
# resize
|
||||||
|
scale = max(ow / iw, oh / ih)
|
||||||
|
img = img.resize(
|
||||||
|
(round(scale * iw), round(scale * ih)),
|
||||||
|
resample=Image.Resampling.LANCZOS
|
||||||
|
)
|
||||||
|
assert img.width >= ow and img.height >= oh
|
||||||
|
|
||||||
|
# center crop
|
||||||
|
x1 = (img.width - ow) // 2
|
||||||
|
y1 = (img.height - oh) // 2
|
||||||
|
img = img.crop((x1, y1, x1 + ow, y1 + oh))
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
if normalize:
|
||||||
|
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
|
||||||
|
return self._resize_crop(img, oh, ow, normalize)
|
||||||
|
|
||||||
|
def load_image(self, data_key, **kwargs):
|
||||||
|
return self.load_image_batch(data_key, **kwargs)
|
||||||
|
|
||||||
|
def load_image_pair(self, data_key, data_key2, **kwargs):
|
||||||
|
return self.load_image_batch(data_key, data_key2, **kwargs)
|
||||||
|
|
||||||
|
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
|
||||||
|
seq_len = self.seq_len if seq_len is None else seq_len
|
||||||
|
imgs = []
|
||||||
|
for data_key in data_key_batch:
|
||||||
|
img = self._load_image(data_key)
|
||||||
|
imgs.append(img)
|
||||||
|
w, h = imgs[0].size
|
||||||
|
dh, dw = self.downsample[1:]
|
||||||
|
|
||||||
|
# compute output size
|
||||||
|
scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
|
||||||
|
oh = int(h * scale) // dh * dh
|
||||||
|
ow = int(w * scale) // dw * dw
|
||||||
|
assert (oh // dh) * (ow // dw) <= seq_len
|
||||||
|
imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
|
||||||
|
return *imgs, (oh, ow)
|
||||||
|
|
||||||
|
|
||||||
|
class VaceVideoProcessor(object):
|
||||||
|
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
|
||||||
|
self.downsample = downsample
|
||||||
|
self.min_area = min_area
|
||||||
|
self.max_area = max_area
|
||||||
|
self.min_fps = min_fps
|
||||||
|
self.max_fps = max_fps
|
||||||
|
self.zero_start = zero_start
|
||||||
|
self.keep_last = keep_last
|
||||||
|
self.seq_len = seq_len
|
||||||
|
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
|
||||||
|
|
||||||
|
def set_area(self, area):
|
||||||
|
self.min_area = area
|
||||||
|
self.max_area = area
|
||||||
|
|
||||||
|
def set_seq_len(self, seq_len):
|
||||||
|
self.seq_len = seq_len
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resize_crop(video: torch.Tensor, oh: int, ow: int):
|
||||||
|
"""
|
||||||
|
Resize, center crop and normalize for decord loaded video (torch.Tensor type)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
|
||||||
|
oh - target height (int)
|
||||||
|
ow - target width (int)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
"""
|
||||||
|
# permute ([t, h, w, c] -> [t, c, h, w])
|
||||||
|
video = video.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
# resize and crop
|
||||||
|
ih, iw = video.shape[2:]
|
||||||
|
if ih != oh or iw != ow:
|
||||||
|
# resize
|
||||||
|
scale = max(ow / iw, oh / ih)
|
||||||
|
video = F.interpolate(
|
||||||
|
video,
|
||||||
|
size=(round(scale * ih), round(scale * iw)),
|
||||||
|
mode='bicubic',
|
||||||
|
antialias=True
|
||||||
|
)
|
||||||
|
assert video.size(3) >= ow and video.size(2) >= oh
|
||||||
|
|
||||||
|
# center crop
|
||||||
|
x1 = (video.size(3) - ow) // 2
|
||||||
|
y1 = (video.size(2) - oh) // 2
|
||||||
|
video = video[:, :, y1:y1 + oh, x1:x1 + ow]
|
||||||
|
|
||||||
|
# permute ([t, c, h, w] -> [c, t, h, w]) and normalize
|
||||||
|
video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
|
||||||
|
return video
|
||||||
|
|
||||||
|
def _video_preprocess(self, video, oh, ow):
|
||||||
|
return self.resize_crop(video, oh, ow)
|
||||||
|
|
||||||
|
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||||
|
target_fps = min(fps, self.max_fps)
|
||||||
|
duration = frame_timestamps[-1].mean()
|
||||||
|
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||||
|
h, w = y2 - y1, x2 - x1
|
||||||
|
ratio = h / w
|
||||||
|
df, dh, dw = self.downsample
|
||||||
|
|
||||||
|
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
||||||
|
of = min(
|
||||||
|
(int(duration * target_fps) - 1) // df + 1,
|
||||||
|
int(self.seq_len / area_z)
|
||||||
|
)
|
||||||
|
|
||||||
|
# deduce target shape of the [latent video]
|
||||||
|
target_area_z = min(area_z, int(self.seq_len / of))
|
||||||
|
oh = round(np.sqrt(target_area_z * ratio))
|
||||||
|
ow = int(target_area_z / oh)
|
||||||
|
of = (of - 1) * df + 1
|
||||||
|
oh *= dh
|
||||||
|
ow *= dw
|
||||||
|
|
||||||
|
# sample frame ids
|
||||||
|
target_duration = of / target_fps
|
||||||
|
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
|
||||||
|
timestamps = np.linspace(begin, begin + target_duration, of)
|
||||||
|
frame_ids = np.argmax(np.logical_and(
|
||||||
|
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
||||||
|
timestamps[:, None] < frame_timestamps[None, :, 1]
|
||||||
|
), axis=1).tolist()
|
||||||
|
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||||
|
|
||||||
|
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||||
|
duration = frame_timestamps[-1].mean()
|
||||||
|
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
||||||
|
h, w = y2 - y1, x2 - x1
|
||||||
|
ratio = h / w
|
||||||
|
df, dh, dw = self.downsample
|
||||||
|
|
||||||
|
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
||||||
|
of = min(
|
||||||
|
(len(frame_timestamps) - 1) // df + 1,
|
||||||
|
int(self.seq_len / area_z)
|
||||||
|
)
|
||||||
|
|
||||||
|
# deduce target shape of the [latent video]
|
||||||
|
target_area_z = min(area_z, int(self.seq_len / of))
|
||||||
|
oh = round(np.sqrt(target_area_z * ratio))
|
||||||
|
ow = int(target_area_z / oh)
|
||||||
|
of = (of - 1) * df + 1
|
||||||
|
oh *= dh
|
||||||
|
ow *= dw
|
||||||
|
|
||||||
|
# sample frame ids
|
||||||
|
target_duration = duration
|
||||||
|
target_fps = of / target_duration
|
||||||
|
timestamps = np.linspace(0., target_duration, of)
|
||||||
|
frame_ids = np.argmax(np.logical_and(
|
||||||
|
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
||||||
|
timestamps[:, None] <= frame_timestamps[None, :, 1]
|
||||||
|
), axis=1).tolist()
|
||||||
|
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
|
||||||
|
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
||||||
|
|
||||||
|
|
||||||
|
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
|
||||||
|
if self.keep_last:
|
||||||
|
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng)
|
||||||
|
else:
|
||||||
|
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng)
|
||||||
|
|
||||||
|
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
||||||
|
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
|
||||||
|
|
||||||
|
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
|
||||||
|
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
|
||||||
|
|
||||||
|
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs):
|
||||||
|
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
|
||||||
|
# read video
|
||||||
|
import decord
|
||||||
|
decord.bridge.set_bridge('torch')
|
||||||
|
readers = []
|
||||||
|
for data_k in data_key_batch:
|
||||||
|
reader = decord.VideoReader(data_k)
|
||||||
|
readers.append(reader)
|
||||||
|
|
||||||
|
fps = readers[0].get_avg_fps()
|
||||||
|
length = min([len(r) for r in readers])
|
||||||
|
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
|
||||||
|
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
||||||
|
h, w = readers[0].next().shape[:2]
|
||||||
|
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng)
|
||||||
|
|
||||||
|
# preprocess video
|
||||||
|
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
||||||
|
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
||||||
|
return *videos, frame_ids, (oh, ow), fps
|
||||||
|
# return videos if len(videos) > 1 else videos[0]
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
||||||
|
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||||
|
if sub_src_video is None and sub_src_mask is None:
|
||||||
|
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
||||||
|
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
|
||||||
|
for i, ref_images in enumerate(src_ref_images):
|
||||||
|
if ref_images is not None:
|
||||||
|
for j, ref_img in enumerate(ref_images):
|
||||||
|
if ref_img is not None and ref_img.shape[-2:] != image_size:
|
||||||
|
canvas_height, canvas_width = image_size
|
||||||
|
ref_height, ref_width = ref_img.shape[-2:]
|
||||||
|
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
||||||
|
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||||
|
new_height = int(ref_height * scale)
|
||||||
|
new_width = int(ref_width * scale)
|
||||||
|
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
||||||
|
top = (canvas_height - new_height) // 2
|
||||||
|
left = (canvas_width - new_width) // 2
|
||||||
|
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
||||||
|
src_ref_images[i][j] = white_canvas
|
||||||
|
return src_video, src_mask, src_ref_images
|
||||||
718
wan/vace.py
Normal file
718
wan/vace.py
Normal file
@ -0,0 +1,718 @@
|
|||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import gc
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import types
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.cuda.amp as amp
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .text2video import (WanT2V, T5EncoderModel, WanVAE, shard_model, FlowDPMSolverMultistepScheduler,
|
||||||
|
get_sampling_sigmas, retrieve_timesteps, FlowUniPCMultistepScheduler)
|
||||||
|
from .modules.vace_model import VaceWanModel
|
||||||
|
from .utils.vace_processor import VaceVideoProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class WanVace(WanT2V):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
checkpoint_dir,
|
||||||
|
device_id=0,
|
||||||
|
rank=0,
|
||||||
|
t5_fsdp=False,
|
||||||
|
dit_fsdp=False,
|
||||||
|
use_usp=False,
|
||||||
|
t5_cpu=False,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Initializes the Wan text-to-video generation model components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (EasyDict):
|
||||||
|
Object containing model parameters initialized from config.py
|
||||||
|
checkpoint_dir (`str`):
|
||||||
|
Path to directory containing model checkpoints
|
||||||
|
device_id (`int`, *optional*, defaults to 0):
|
||||||
|
Id of target GPU device
|
||||||
|
rank (`int`, *optional*, defaults to 0):
|
||||||
|
Process rank for distributed training
|
||||||
|
t5_fsdp (`bool`, *optional*, defaults to False):
|
||||||
|
Enable FSDP sharding for T5 model
|
||||||
|
dit_fsdp (`bool`, *optional*, defaults to False):
|
||||||
|
Enable FSDP sharding for DiT model
|
||||||
|
use_usp (`bool`, *optional*, defaults to False):
|
||||||
|
Enable distribution strategy of USP.
|
||||||
|
t5_cpu (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
||||||
|
"""
|
||||||
|
self.device = torch.device(f"cuda:{device_id}")
|
||||||
|
self.config = config
|
||||||
|
self.rank = rank
|
||||||
|
self.t5_cpu = t5_cpu
|
||||||
|
|
||||||
|
self.num_train_timesteps = config.num_train_timesteps
|
||||||
|
self.param_dtype = config.param_dtype
|
||||||
|
|
||||||
|
shard_fn = partial(shard_model, device_id=device_id)
|
||||||
|
self.text_encoder = T5EncoderModel(
|
||||||
|
text_len=config.text_len,
|
||||||
|
dtype=config.t5_dtype,
|
||||||
|
device=torch.device('cpu'),
|
||||||
|
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
||||||
|
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
||||||
|
shard_fn=shard_fn if t5_fsdp else None)
|
||||||
|
|
||||||
|
self.vae_stride = config.vae_stride
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
self.vae = WanVAE(
|
||||||
|
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
|
||||||
|
self.model = VaceWanModel.from_pretrained(checkpoint_dir)
|
||||||
|
self.model.eval().requires_grad_(False)
|
||||||
|
|
||||||
|
if use_usp:
|
||||||
|
from xfuser.core.distributed import \
|
||||||
|
get_sequence_parallel_world_size
|
||||||
|
|
||||||
|
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||||
|
usp_dit_forward,
|
||||||
|
usp_dit_forward_vace)
|
||||||
|
for block in self.model.blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(
|
||||||
|
usp_attn_forward, block.self_attn)
|
||||||
|
for block in self.model.vace_blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(
|
||||||
|
usp_attn_forward, block.self_attn)
|
||||||
|
self.model.forward = types.MethodType(usp_dit_forward, self.model)
|
||||||
|
self.model.forward_vace = types.MethodType(usp_dit_forward_vace, self.model)
|
||||||
|
self.sp_size = get_sequence_parallel_world_size()
|
||||||
|
else:
|
||||||
|
self.sp_size = 1
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
if dit_fsdp:
|
||||||
|
self.model = shard_fn(self.model)
|
||||||
|
else:
|
||||||
|
self.model.to(self.device)
|
||||||
|
|
||||||
|
self.sample_neg_prompt = config.sample_neg_prompt
|
||||||
|
|
||||||
|
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
||||||
|
min_area=720*1280,
|
||||||
|
max_area=720*1280,
|
||||||
|
min_fps=config.sample_fps,
|
||||||
|
max_fps=config.sample_fps,
|
||||||
|
zero_start=True,
|
||||||
|
seq_len=75600,
|
||||||
|
keep_last=True)
|
||||||
|
|
||||||
|
def vace_encode_frames(self, frames, ref_images, masks=None, vae=None):
|
||||||
|
vae = self.vae if vae is None else vae
|
||||||
|
if ref_images is None:
|
||||||
|
ref_images = [None] * len(frames)
|
||||||
|
else:
|
||||||
|
assert len(frames) == len(ref_images)
|
||||||
|
|
||||||
|
if masks is None:
|
||||||
|
latents = vae.encode(frames)
|
||||||
|
else:
|
||||||
|
masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
|
||||||
|
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
|
||||||
|
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
||||||
|
inactive = vae.encode(inactive)
|
||||||
|
reactive = vae.encode(reactive)
|
||||||
|
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
||||||
|
|
||||||
|
cat_latents = []
|
||||||
|
for latent, refs in zip(latents, ref_images):
|
||||||
|
if refs is not None:
|
||||||
|
if masks is None:
|
||||||
|
ref_latent = vae.encode(refs)
|
||||||
|
else:
|
||||||
|
ref_latent = vae.encode(refs)
|
||||||
|
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
|
||||||
|
assert all([x.shape[1] == 1 for x in ref_latent])
|
||||||
|
latent = torch.cat([*ref_latent, latent], dim=1)
|
||||||
|
cat_latents.append(latent)
|
||||||
|
return cat_latents
|
||||||
|
|
||||||
|
def vace_encode_masks(self, masks, ref_images=None, vae_stride=None):
|
||||||
|
vae_stride = self.vae_stride if vae_stride is None else vae_stride
|
||||||
|
if ref_images is None:
|
||||||
|
ref_images = [None] * len(masks)
|
||||||
|
else:
|
||||||
|
assert len(masks) == len(ref_images)
|
||||||
|
|
||||||
|
result_masks = []
|
||||||
|
for mask, refs in zip(masks, ref_images):
|
||||||
|
c, depth, height, width = mask.shape
|
||||||
|
new_depth = int((depth + 3) // vae_stride[0])
|
||||||
|
height = 2 * (int(height) // (vae_stride[1] * 2))
|
||||||
|
width = 2 * (int(width) // (vae_stride[2] * 2))
|
||||||
|
|
||||||
|
# reshape
|
||||||
|
mask = mask[0, :, :, :]
|
||||||
|
mask = mask.view(
|
||||||
|
depth, height, vae_stride[1], width, vae_stride[1]
|
||||||
|
) # depth, height, 8, width, 8
|
||||||
|
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
|
||||||
|
mask = mask.reshape(
|
||||||
|
vae_stride[1] * vae_stride[2], depth, height, width
|
||||||
|
) # 8*8, depth, height, width
|
||||||
|
|
||||||
|
# interpolation
|
||||||
|
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
|
||||||
|
|
||||||
|
if refs is not None:
|
||||||
|
length = len(refs)
|
||||||
|
mask_pad = torch.zeros_like(mask[:, :length, :, :])
|
||||||
|
mask = torch.cat((mask_pad, mask), dim=1)
|
||||||
|
result_masks.append(mask)
|
||||||
|
return result_masks
|
||||||
|
|
||||||
|
def vace_latent(self, z, m):
|
||||||
|
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
||||||
|
|
||||||
|
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
||||||
|
area = image_size[0] * image_size[1]
|
||||||
|
self.vid_proc.set_area(area)
|
||||||
|
if area == 720*1280:
|
||||||
|
self.vid_proc.set_seq_len(75600)
|
||||||
|
elif area == 480*832:
|
||||||
|
self.vid_proc.set_seq_len(32760)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'image_size {image_size} is not supported')
|
||||||
|
|
||||||
|
image_size = (image_size[1], image_size[0])
|
||||||
|
image_sizes = []
|
||||||
|
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
||||||
|
if sub_src_mask is not None and sub_src_video is not None:
|
||||||
|
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
|
||||||
|
src_video[i] = src_video[i].to(device)
|
||||||
|
src_mask[i] = src_mask[i].to(device)
|
||||||
|
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
||||||
|
image_sizes.append(src_video[i].shape[2:])
|
||||||
|
elif sub_src_video is None:
|
||||||
|
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
||||||
|
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||||
|
image_sizes.append(image_size)
|
||||||
|
else:
|
||||||
|
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
|
||||||
|
src_video[i] = src_video[i].to(device)
|
||||||
|
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
||||||
|
image_sizes.append(src_video[i].shape[2:])
|
||||||
|
|
||||||
|
for i, ref_images in enumerate(src_ref_images):
|
||||||
|
if ref_images is not None:
|
||||||
|
image_size = image_sizes[i]
|
||||||
|
for j, ref_img in enumerate(ref_images):
|
||||||
|
if ref_img is not None:
|
||||||
|
ref_img = Image.open(ref_img).convert("RGB")
|
||||||
|
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
||||||
|
if ref_img.shape[-2:] != image_size:
|
||||||
|
canvas_height, canvas_width = image_size
|
||||||
|
ref_height, ref_width = ref_img.shape[-2:]
|
||||||
|
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
||||||
|
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
||||||
|
new_height = int(ref_height * scale)
|
||||||
|
new_width = int(ref_width * scale)
|
||||||
|
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
||||||
|
top = (canvas_height - new_height) // 2
|
||||||
|
left = (canvas_width - new_width) // 2
|
||||||
|
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
||||||
|
ref_img = white_canvas
|
||||||
|
src_ref_images[i][j] = ref_img.to(device)
|
||||||
|
return src_video, src_mask, src_ref_images
|
||||||
|
|
||||||
|
def decode_latent(self, zs, ref_images=None, vae=None):
|
||||||
|
vae = self.vae if vae is None else vae
|
||||||
|
if ref_images is None:
|
||||||
|
ref_images = [None] * len(zs)
|
||||||
|
else:
|
||||||
|
assert len(zs) == len(ref_images)
|
||||||
|
|
||||||
|
trimed_zs = []
|
||||||
|
for z, refs in zip(zs, ref_images):
|
||||||
|
if refs is not None:
|
||||||
|
z = z[:, len(refs):, :, :]
|
||||||
|
trimed_zs.append(z)
|
||||||
|
|
||||||
|
return vae.decode(trimed_zs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
input_prompt,
|
||||||
|
input_frames,
|
||||||
|
input_masks,
|
||||||
|
input_ref_images,
|
||||||
|
size=(1280, 720),
|
||||||
|
frame_num=81,
|
||||||
|
context_scale=1.0,
|
||||||
|
shift=5.0,
|
||||||
|
sample_solver='unipc',
|
||||||
|
sampling_steps=50,
|
||||||
|
guide_scale=5.0,
|
||||||
|
n_prompt="",
|
||||||
|
seed=-1,
|
||||||
|
offload_model=True):
|
||||||
|
r"""
|
||||||
|
Generates video frames from text prompt using diffusion process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_prompt (`str`):
|
||||||
|
Text prompt for content generation
|
||||||
|
size (tupele[`int`], *optional*, defaults to (1280,720)):
|
||||||
|
Controls video resolution, (width,height).
|
||||||
|
frame_num (`int`, *optional*, defaults to 81):
|
||||||
|
How many frames to sample from a video. The number should be 4n+1
|
||||||
|
shift (`float`, *optional*, defaults to 5.0):
|
||||||
|
Noise schedule shift parameter. Affects temporal dynamics
|
||||||
|
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
||||||
|
Solver used to sample the video.
|
||||||
|
sampling_steps (`int`, *optional*, defaults to 40):
|
||||||
|
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
||||||
|
guide_scale (`float`, *optional*, defaults 5.0):
|
||||||
|
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
||||||
|
n_prompt (`str`, *optional*, defaults to ""):
|
||||||
|
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
||||||
|
seed (`int`, *optional*, defaults to -1):
|
||||||
|
Random seed for noise generation. If -1, use random seed.
|
||||||
|
offload_model (`bool`, *optional*, defaults to True):
|
||||||
|
If True, offloads models to CPU during generation to save VRAM
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor:
|
||||||
|
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
||||||
|
- C: Color channels (3 for RGB)
|
||||||
|
- N: Number of frames (81)
|
||||||
|
- H: Frame height (from size)
|
||||||
|
- W: Frame width from size)
|
||||||
|
"""
|
||||||
|
# preprocess
|
||||||
|
# F = frame_num
|
||||||
|
# target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
||||||
|
# size[1] // self.vae_stride[1],
|
||||||
|
# size[0] // self.vae_stride[2])
|
||||||
|
#
|
||||||
|
# seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
||||||
|
# (self.patch_size[1] * self.patch_size[2]) *
|
||||||
|
# target_shape[1] / self.sp_size) * self.sp_size
|
||||||
|
|
||||||
|
if n_prompt == "":
|
||||||
|
n_prompt = self.sample_neg_prompt
|
||||||
|
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||||
|
seed_g = torch.Generator(device=self.device)
|
||||||
|
seed_g.manual_seed(seed)
|
||||||
|
|
||||||
|
if not self.t5_cpu:
|
||||||
|
self.text_encoder.model.to(self.device)
|
||||||
|
context = self.text_encoder([input_prompt], self.device)
|
||||||
|
context_null = self.text_encoder([n_prompt], self.device)
|
||||||
|
if offload_model:
|
||||||
|
self.text_encoder.model.cpu()
|
||||||
|
else:
|
||||||
|
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
||||||
|
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
||||||
|
context = [t.to(self.device) for t in context]
|
||||||
|
context_null = [t.to(self.device) for t in context_null]
|
||||||
|
|
||||||
|
# vace context encode
|
||||||
|
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks)
|
||||||
|
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||||
|
z = self.vace_latent(z0, m0)
|
||||||
|
|
||||||
|
target_shape = list(z0[0].shape)
|
||||||
|
target_shape[0] = int(target_shape[0] / 2)
|
||||||
|
noise = [
|
||||||
|
torch.randn(
|
||||||
|
target_shape[0],
|
||||||
|
target_shape[1],
|
||||||
|
target_shape[2],
|
||||||
|
target_shape[3],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
generator=seed_g)
|
||||||
|
]
|
||||||
|
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
||||||
|
(self.patch_size[1] * self.patch_size[2]) *
|
||||||
|
target_shape[1] / self.sp_size) * self.sp_size
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def noop_no_sync():
|
||||||
|
yield
|
||||||
|
|
||||||
|
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
||||||
|
|
||||||
|
# evaluation mode
|
||||||
|
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
||||||
|
|
||||||
|
if sample_solver == 'unipc':
|
||||||
|
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||||
|
num_train_timesteps=self.num_train_timesteps,
|
||||||
|
shift=1,
|
||||||
|
use_dynamic_shifting=False)
|
||||||
|
sample_scheduler.set_timesteps(
|
||||||
|
sampling_steps, device=self.device, shift=shift)
|
||||||
|
timesteps = sample_scheduler.timesteps
|
||||||
|
elif sample_solver == 'dpm++':
|
||||||
|
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
||||||
|
num_train_timesteps=self.num_train_timesteps,
|
||||||
|
shift=1,
|
||||||
|
use_dynamic_shifting=False)
|
||||||
|
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
||||||
|
timesteps, _ = retrieve_timesteps(
|
||||||
|
sample_scheduler,
|
||||||
|
device=self.device,
|
||||||
|
sigmas=sampling_sigmas)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unsupported solver.")
|
||||||
|
|
||||||
|
# sample videos
|
||||||
|
latents = noise
|
||||||
|
|
||||||
|
arg_c = {'context': context, 'seq_len': seq_len}
|
||||||
|
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||||
|
|
||||||
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
|
latent_model_input = latents
|
||||||
|
timestep = [t]
|
||||||
|
|
||||||
|
timestep = torch.stack(timestep)
|
||||||
|
|
||||||
|
self.model.to(self.device)
|
||||||
|
noise_pred_cond = self.model(
|
||||||
|
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[0]
|
||||||
|
noise_pred_uncond = self.model(
|
||||||
|
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,**arg_null)[0]
|
||||||
|
|
||||||
|
noise_pred = noise_pred_uncond + guide_scale * (
|
||||||
|
noise_pred_cond - noise_pred_uncond)
|
||||||
|
|
||||||
|
temp_x0 = sample_scheduler.step(
|
||||||
|
noise_pred.unsqueeze(0),
|
||||||
|
t,
|
||||||
|
latents[0].unsqueeze(0),
|
||||||
|
return_dict=False,
|
||||||
|
generator=seed_g)[0]
|
||||||
|
latents = [temp_x0.squeeze(0)]
|
||||||
|
|
||||||
|
x0 = latents
|
||||||
|
if offload_model:
|
||||||
|
self.model.cpu()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if self.rank == 0:
|
||||||
|
videos = self.decode_latent(x0, input_ref_images)
|
||||||
|
|
||||||
|
del noise, latents
|
||||||
|
del sample_scheduler
|
||||||
|
if offload_model:
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
return videos[0] if self.rank == 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
class WanVaceMP(WanVace):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config,
|
||||||
|
checkpoint_dir,
|
||||||
|
use_usp=False,
|
||||||
|
ulysses_size=None,
|
||||||
|
ring_size=None
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
self.use_usp = use_usp
|
||||||
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||||||
|
os.environ['MASTER_PORT'] = '12345'
|
||||||
|
os.environ['RANK'] = '0'
|
||||||
|
os.environ['WORLD_SIZE'] = '1'
|
||||||
|
self.in_q_list = None
|
||||||
|
self.out_q = None
|
||||||
|
self.inference_pids = None
|
||||||
|
self.ulysses_size = ulysses_size
|
||||||
|
self.ring_size = ring_size
|
||||||
|
self.dynamic_load()
|
||||||
|
|
||||||
|
self.device = 'cpu' if torch.cuda.is_available() else 'cpu'
|
||||||
|
self.vid_proc = VaceVideoProcessor(
|
||||||
|
downsample=tuple([x * y for x, y in zip(config.vae_stride, config.patch_size)]),
|
||||||
|
min_area=480 * 832,
|
||||||
|
max_area=480 * 832,
|
||||||
|
min_fps=self.config.sample_fps,
|
||||||
|
max_fps=self.config.sample_fps,
|
||||||
|
zero_start=True,
|
||||||
|
seq_len=32760,
|
||||||
|
keep_last=True)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_load(self):
|
||||||
|
if hasattr(self, 'inference_pids') and self.inference_pids is not None:
|
||||||
|
return
|
||||||
|
gpu_infer = os.environ.get('LOCAL_WORLD_SIZE') or torch.cuda.device_count()
|
||||||
|
pmi_rank = int(os.environ['RANK'])
|
||||||
|
pmi_world_size = int(os.environ['WORLD_SIZE'])
|
||||||
|
in_q_list = [torch.multiprocessing.Manager().Queue() for _ in range(gpu_infer)]
|
||||||
|
out_q = torch.multiprocessing.Manager().Queue()
|
||||||
|
initialized_events = [torch.multiprocessing.Manager().Event() for _ in range(gpu_infer)]
|
||||||
|
context = mp.spawn(self.mp_worker, nprocs=gpu_infer, args=(gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, self), join=False)
|
||||||
|
all_initialized = False
|
||||||
|
while not all_initialized:
|
||||||
|
all_initialized = all(event.is_set() for event in initialized_events)
|
||||||
|
if not all_initialized:
|
||||||
|
time.sleep(0.1)
|
||||||
|
print('Inference model is initialized', flush=True)
|
||||||
|
self.in_q_list = in_q_list
|
||||||
|
self.out_q = out_q
|
||||||
|
self.inference_pids = context.pids()
|
||||||
|
self.initialized_events = initialized_events
|
||||||
|
|
||||||
|
def transfer_data_to_cuda(self, data, device):
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
data = data.to(device)
|
||||||
|
elif isinstance(data, list):
|
||||||
|
data = [self.transfer_data_to_cuda(subdata, device) for subdata in data]
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
data = {key: self.transfer_data_to_cuda(val, device) for key, val in data.items()}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, out_q, initialized_events, work_env):
|
||||||
|
try:
|
||||||
|
world_size = pmi_world_size * gpu_infer
|
||||||
|
rank = pmi_rank * gpu_infer + gpu
|
||||||
|
print("world_size", world_size, "rank", rank, flush=True)
|
||||||
|
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
dist.init_process_group(
|
||||||
|
backend='nccl',
|
||||||
|
init_method='env://',
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size
|
||||||
|
)
|
||||||
|
|
||||||
|
from xfuser.core.distributed import (initialize_model_parallel,
|
||||||
|
init_distributed_environment)
|
||||||
|
init_distributed_environment(
|
||||||
|
rank=dist.get_rank(), world_size=dist.get_world_size())
|
||||||
|
|
||||||
|
initialize_model_parallel(
|
||||||
|
sequence_parallel_degree=dist.get_world_size(),
|
||||||
|
ring_degree=self.ring_size or 1,
|
||||||
|
ulysses_degree=self.ulysses_size or 1
|
||||||
|
)
|
||||||
|
|
||||||
|
num_train_timesteps = self.config.num_train_timesteps
|
||||||
|
param_dtype = self.config.param_dtype
|
||||||
|
shard_fn = partial(shard_model, device_id=gpu)
|
||||||
|
text_encoder = T5EncoderModel(
|
||||||
|
text_len=self.config.text_len,
|
||||||
|
dtype=self.config.t5_dtype,
|
||||||
|
device=torch.device('cpu'),
|
||||||
|
checkpoint_path=os.path.join(self.checkpoint_dir, self.config.t5_checkpoint),
|
||||||
|
tokenizer_path=os.path.join(self.checkpoint_dir, self.config.t5_tokenizer),
|
||||||
|
shard_fn=shard_fn if True else None)
|
||||||
|
text_encoder.model.to(gpu)
|
||||||
|
vae_stride = self.config.vae_stride
|
||||||
|
patch_size = self.config.patch_size
|
||||||
|
vae = WanVAE(
|
||||||
|
vae_pth=os.path.join(self.checkpoint_dir, self.config.vae_checkpoint),
|
||||||
|
device=gpu)
|
||||||
|
logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
|
||||||
|
model = VaceWanModel.from_pretrained(self.checkpoint_dir)
|
||||||
|
model.eval().requires_grad_(False)
|
||||||
|
|
||||||
|
if self.use_usp:
|
||||||
|
from xfuser.core.distributed import get_sequence_parallel_world_size
|
||||||
|
from .distributed.xdit_context_parallel import (usp_attn_forward,
|
||||||
|
usp_dit_forward,
|
||||||
|
usp_dit_forward_vace)
|
||||||
|
for block in model.blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(
|
||||||
|
usp_attn_forward, block.self_attn)
|
||||||
|
for block in model.vace_blocks:
|
||||||
|
block.self_attn.forward = types.MethodType(
|
||||||
|
usp_attn_forward, block.self_attn)
|
||||||
|
model.forward = types.MethodType(usp_dit_forward, model)
|
||||||
|
model.forward_vace = types.MethodType(usp_dit_forward_vace, model)
|
||||||
|
sp_size = get_sequence_parallel_world_size()
|
||||||
|
else:
|
||||||
|
sp_size = 1
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
model = shard_fn(model)
|
||||||
|
sample_neg_prompt = self.config.sample_neg_prompt
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
event = initialized_events[gpu]
|
||||||
|
in_q = in_q_list[gpu]
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = in_q.get()
|
||||||
|
input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale, \
|
||||||
|
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model = item
|
||||||
|
input_frames = self.transfer_data_to_cuda(input_frames, gpu)
|
||||||
|
input_masks = self.transfer_data_to_cuda(input_masks, gpu)
|
||||||
|
input_ref_images = self.transfer_data_to_cuda(input_ref_images, gpu)
|
||||||
|
|
||||||
|
if n_prompt == "":
|
||||||
|
n_prompt = sample_neg_prompt
|
||||||
|
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
||||||
|
seed_g = torch.Generator(device=gpu)
|
||||||
|
seed_g.manual_seed(seed)
|
||||||
|
|
||||||
|
context = text_encoder([input_prompt], gpu)
|
||||||
|
context_null = text_encoder([n_prompt], gpu)
|
||||||
|
|
||||||
|
# vace context encode
|
||||||
|
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, vae=vae)
|
||||||
|
m0 = self.vace_encode_masks(input_masks, input_ref_images, vae_stride=vae_stride)
|
||||||
|
z = self.vace_latent(z0, m0)
|
||||||
|
|
||||||
|
target_shape = list(z0[0].shape)
|
||||||
|
target_shape[0] = int(target_shape[0] / 2)
|
||||||
|
noise = [
|
||||||
|
torch.randn(
|
||||||
|
target_shape[0],
|
||||||
|
target_shape[1],
|
||||||
|
target_shape[2],
|
||||||
|
target_shape[3],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=gpu,
|
||||||
|
generator=seed_g)
|
||||||
|
]
|
||||||
|
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
||||||
|
(patch_size[1] * patch_size[2]) *
|
||||||
|
target_shape[1] / sp_size) * sp_size
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def noop_no_sync():
|
||||||
|
yield
|
||||||
|
|
||||||
|
no_sync = getattr(model, 'no_sync', noop_no_sync)
|
||||||
|
|
||||||
|
# evaluation mode
|
||||||
|
with amp.autocast(dtype=param_dtype), torch.no_grad(), no_sync():
|
||||||
|
|
||||||
|
if sample_solver == 'unipc':
|
||||||
|
sample_scheduler = FlowUniPCMultistepScheduler(
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
shift=1,
|
||||||
|
use_dynamic_shifting=False)
|
||||||
|
sample_scheduler.set_timesteps(
|
||||||
|
sampling_steps, device=gpu, shift=shift)
|
||||||
|
timesteps = sample_scheduler.timesteps
|
||||||
|
elif sample_solver == 'dpm++':
|
||||||
|
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
||||||
|
num_train_timesteps=num_train_timesteps,
|
||||||
|
shift=1,
|
||||||
|
use_dynamic_shifting=False)
|
||||||
|
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
||||||
|
timesteps, _ = retrieve_timesteps(
|
||||||
|
sample_scheduler,
|
||||||
|
device=gpu,
|
||||||
|
sigmas=sampling_sigmas)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Unsupported solver.")
|
||||||
|
|
||||||
|
# sample videos
|
||||||
|
latents = noise
|
||||||
|
|
||||||
|
arg_c = {'context': context, 'seq_len': seq_len}
|
||||||
|
arg_null = {'context': context_null, 'seq_len': seq_len}
|
||||||
|
|
||||||
|
for _, t in enumerate(tqdm(timesteps)):
|
||||||
|
latent_model_input = latents
|
||||||
|
timestep = [t]
|
||||||
|
|
||||||
|
timestep = torch.stack(timestep)
|
||||||
|
|
||||||
|
model.to(gpu)
|
||||||
|
noise_pred_cond = model(
|
||||||
|
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale, **arg_c)[
|
||||||
|
0]
|
||||||
|
noise_pred_uncond = model(
|
||||||
|
latent_model_input, t=timestep, vace_context=z, vace_context_scale=context_scale,
|
||||||
|
**arg_null)[0]
|
||||||
|
|
||||||
|
noise_pred = noise_pred_uncond + guide_scale * (
|
||||||
|
noise_pred_cond - noise_pred_uncond)
|
||||||
|
|
||||||
|
temp_x0 = sample_scheduler.step(
|
||||||
|
noise_pred.unsqueeze(0),
|
||||||
|
t,
|
||||||
|
latents[0].unsqueeze(0),
|
||||||
|
return_dict=False,
|
||||||
|
generator=seed_g)[0]
|
||||||
|
latents = [temp_x0.squeeze(0)]
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
x0 = latents
|
||||||
|
if rank == 0:
|
||||||
|
videos = self.decode_latent(x0, input_ref_images, vae=vae)
|
||||||
|
|
||||||
|
del noise, latents
|
||||||
|
del sample_scheduler
|
||||||
|
if offload_model:
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
out_q.put(videos[0].cpu())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
trace_info = traceback.format_exc()
|
||||||
|
print(trace_info, flush=True)
|
||||||
|
print(e, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
input_prompt,
|
||||||
|
input_frames,
|
||||||
|
input_masks,
|
||||||
|
input_ref_images,
|
||||||
|
size=(1280, 720),
|
||||||
|
frame_num=81,
|
||||||
|
context_scale=1.0,
|
||||||
|
shift=5.0,
|
||||||
|
sample_solver='unipc',
|
||||||
|
sampling_steps=50,
|
||||||
|
guide_scale=5.0,
|
||||||
|
n_prompt="",
|
||||||
|
seed=-1,
|
||||||
|
offload_model=True):
|
||||||
|
|
||||||
|
input_data = (input_prompt, input_frames, input_masks, input_ref_images, size, frame_num, context_scale,
|
||||||
|
shift, sample_solver, sampling_steps, guide_scale, n_prompt, seed, offload_model)
|
||||||
|
for in_q in self.in_q_list:
|
||||||
|
in_q.put(input_data)
|
||||||
|
value_output = self.out_q.get()
|
||||||
|
|
||||||
|
return value_output
|
||||||
Loading…
Reference in New Issue
Block a user