mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	Merge c6c5675a06 into 7c81b2f27d
				
					
				
			This commit is contained in:
		
						commit
						9614f70259
					
				
							
								
								
									
										151
									
								
								I2V-FastAPI文档.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										151
									
								
								I2V-FastAPI文档.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,151 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					# 图像到视频生成服务API文档
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 一、功能概述
 | 
				
			||||||
 | 
					基于Wan2.1-I2V-14B-480P模型实现图像到视频生成,核心功能包括:
 | 
				
			||||||
 | 
					1. **异步任务队列**:支持多任务排队和并发控制(最大2个并行任务)
 | 
				
			||||||
 | 
					2. **智能分辨率适配**:
 | 
				
			||||||
 | 
					   - 支持自动计算最佳分辨率(保持原图比例)
 | 
				
			||||||
 | 
					   - 支持手动指定分辨率(480x832/832x480)
 | 
				
			||||||
 | 
					3. **资源管理**:
 | 
				
			||||||
 | 
					   - 显存优化(bfloat16精度)
 | 
				
			||||||
 | 
					   - 生成文件自动清理(默认1小时)
 | 
				
			||||||
 | 
					4. **安全认证**:基于API Key的Bearer Token验证
 | 
				
			||||||
 | 
					5. **任务控制**:支持任务提交/状态查询/取消操作
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					技术栈:
 | 
				
			||||||
 | 
					- FastAPI框架
 | 
				
			||||||
 | 
					- CUDA加速
 | 
				
			||||||
 | 
					- 异步任务处理
 | 
				
			||||||
 | 
					- Diffusers推理库
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 二、接口说明
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 1. 提交生成任务
 | 
				
			||||||
 | 
					**POST /video/submit**
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "model": "Wan2.1-I2V-14B-480P",
 | 
				
			||||||
 | 
					  "prompt": "A dancing cat in the style of Van Gogh",
 | 
				
			||||||
 | 
					  "image_url": "https://example.com/input.jpg",
 | 
				
			||||||
 | 
					  "image_size": "auto",
 | 
				
			||||||
 | 
					  "num_frames": 81,
 | 
				
			||||||
 | 
					  "guidance_scale": 3.0,
 | 
				
			||||||
 | 
					  "infer_steps": 30
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**响应示例**:
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2. 查询任务状态
 | 
				
			||||||
 | 
					**POST /video/status**
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**响应示例**:
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "status": "Succeed",
 | 
				
			||||||
 | 
					  "results": {
 | 
				
			||||||
 | 
					    "videos": [{"url": "http://localhost:8088/videos/abcd1234.mp4"}],
 | 
				
			||||||
 | 
					    "timings": {"inference": 90},
 | 
				
			||||||
 | 
					    "seed": 123456
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 3. 取消任务
 | 
				
			||||||
 | 
					**POST /video/cancel**
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**响应示例**:
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "status": "Succeed"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 三、Postman使用指南
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 1. 基础配置
 | 
				
			||||||
 | 
					- 服务器地址:`http://ip地址:8088`
 | 
				
			||||||
 | 
					- 认证方式:Bearer Token
 | 
				
			||||||
 | 
					- Token值:需替换为有效API Key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2. 提交任务
 | 
				
			||||||
 | 
					1. 选择POST方法,URL填写`/video/submit`
 | 
				
			||||||
 | 
					2. Headers添加:
 | 
				
			||||||
 | 
					   ```text
 | 
				
			||||||
 | 
					   Authorization: Bearer YOUR_API_KEY
 | 
				
			||||||
 | 
					   Content-Type: application/json
 | 
				
			||||||
 | 
					   ```
 | 
				
			||||||
 | 
					3. Body示例(图像生成视频):
 | 
				
			||||||
 | 
					   ```json
 | 
				
			||||||
 | 
					   {
 | 
				
			||||||
 | 
					     "prompt": "Sunset scene with mountains",
 | 
				
			||||||
 | 
					     "image_url": "https://example.com/mountain.jpg",
 | 
				
			||||||
 | 
					     "image_size": "auto",
 | 
				
			||||||
 | 
					     "num_frames": 50
 | 
				
			||||||
 | 
					   }
 | 
				
			||||||
 | 
					   ```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 3. 特殊处理
 | 
				
			||||||
 | 
					- **图像下载失败**:返回400错误,包含具体原因(如URL无效/超时)
 | 
				
			||||||
 | 
					- **显存不足**:返回500错误并提示降低分辨率
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 四、参数规范
 | 
				
			||||||
 | 
					| 参数名           | 允许值范围                     | 必填 | 说明                                      |
 | 
				
			||||||
 | 
					|------------------|-------------------------------|------|------------------------------------------|
 | 
				
			||||||
 | 
					| image_url        | 有效HTTP/HTTPS URL            | 是   | 输入图像地址                              |
 | 
				
			||||||
 | 
					| prompt           | 10-500字符                    | 是   | 视频内容描述                              |
 | 
				
			||||||
 | 
					| image_size       | "480x832", "832x480", "auto"  | 是   | auto模式自动适配原图比例                  |
 | 
				
			||||||
 | 
					| num_frames       | 24-120                        | 是   | 视频总帧数                                |
 | 
				
			||||||
 | 
					| guidance_scale   | 1.0-20.0                      | 是   | 文本引导强度                              |
 | 
				
			||||||
 | 
					| infer_steps      | 20-100                        | 是   | 推理步数                                  |
 | 
				
			||||||
 | 
					| seed             | 0-2147483647                  | 否   | 随机种子                                  |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 五、状态码说明
 | 
				
			||||||
 | 
					| 状态码 | 含义                               |
 | 
				
			||||||
 | 
					|--------|-----------------------------------|
 | 
				
			||||||
 | 
					| 202    | 任务已接受                         |
 | 
				
			||||||
 | 
					| 400    | 图像下载失败/参数错误              |
 | 
				
			||||||
 | 
					| 401    | 认证失败                           |
 | 
				
			||||||
 | 
					| 404    | 任务不存在                         |
 | 
				
			||||||
 | 
					| 422    | 参数校验失败                       |
 | 
				
			||||||
 | 
					| 500    | 服务端错误(显存不足/模型异常等)  |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 六、特殊功能说明
 | 
				
			||||||
 | 
					1. **智能分辨率适配**:
 | 
				
			||||||
 | 
					   - 当`image_size="auto"`时,自动计算符合模型要求的最优分辨率
 | 
				
			||||||
 | 
					   - 保持原始图像宽高比,最大像素面积不超过399,360(约640x624)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					2. **图像预处理**:
 | 
				
			||||||
 | 
					   - 自动转换为RGB模式
 | 
				
			||||||
 | 
					   - 根据目标分辨率进行等比缩放
 | 
				
			||||||
 | 
					   
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**重要提示**:输入图像URL需保证公开可访问,私有资源需提供有效鉴权
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**提示** :访问`http://服务器地址:8088/docs`可查看交互式API文档,支持在线测试所有接口
 | 
				
			||||||
							
								
								
									
										133
									
								
								T2V-FastAPI文档.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								T2V-FastAPI文档.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,133 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					# 视频生成服务API文档
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 一、功能概述
 | 
				
			||||||
 | 
					本服务基于Wan2.1-T2V-1.3B模型实现文本到视频生成,包含以下核心功能:
 | 
				
			||||||
 | 
					1. **异步任务队列**:支持多任务排队和并发控制(最大2个并行任务)
 | 
				
			||||||
 | 
					2. **资源管理**:
 | 
				
			||||||
 | 
					   - 显存优化(使用bfloat16精度)
 | 
				
			||||||
 | 
					   - 生成视频自动清理(默认1小时后删除)
 | 
				
			||||||
 | 
					3. **安全认证**:基于API Key的Bearer Token验证
 | 
				
			||||||
 | 
					4. **任务控制**:支持任务提交/状态查询/取消操作
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					技术栈:
 | 
				
			||||||
 | 
					- FastAPI框架
 | 
				
			||||||
 | 
					- CUDA加速
 | 
				
			||||||
 | 
					- 异步任务处理
 | 
				
			||||||
 | 
					- Diffusers推理库
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 二、接口说明
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 1. 提交生成任务
 | 
				
			||||||
 | 
					**POST /video/submit**
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "model": "Wan2.1-T2V-1.3B",
 | 
				
			||||||
 | 
					  "prompt": "A beautiful sunset over the mountains",
 | 
				
			||||||
 | 
					  "image_size": "480x832",
 | 
				
			||||||
 | 
					  "num_frames": 81,
 | 
				
			||||||
 | 
					  "guidance_scale": 5.0,
 | 
				
			||||||
 | 
					  "infer_steps": 50
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**响应示例**:
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2. 查询任务状态
 | 
				
			||||||
 | 
					**POST /video/status**
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**响应示例**:
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "status": "Succeed",
 | 
				
			||||||
 | 
					  "results": {
 | 
				
			||||||
 | 
					    "videos": [{"url": "http://localhost:8088/videos/abcd1234.mp4"}],
 | 
				
			||||||
 | 
					    "timings": {"inference": 120}
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 3. 取消任务
 | 
				
			||||||
 | 
					**POST /video/cancel**
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**响应示例**:
 | 
				
			||||||
 | 
					```json
 | 
				
			||||||
 | 
					{
 | 
				
			||||||
 | 
					  "status": "Succeed"
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 三、Postman使用指南
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 1. 基础配置
 | 
				
			||||||
 | 
					- 服务器地址:`http://ip地址:8088`
 | 
				
			||||||
 | 
					- 认证方式:Bearer Token
 | 
				
			||||||
 | 
					- Token值:需替换为有效API Key
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2. 提交任务
 | 
				
			||||||
 | 
					1. 选择POST方法,输入URL:`/video/submit`
 | 
				
			||||||
 | 
					2. Headers添加:
 | 
				
			||||||
 | 
					   ```text
 | 
				
			||||||
 | 
					   Authorization: Bearer YOUR_API_KEY
 | 
				
			||||||
 | 
					   Content-Type: application/json
 | 
				
			||||||
 | 
					   ```
 | 
				
			||||||
 | 
					3. Body选择raw/JSON格式,输入请求参数
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 3. 查询状态
 | 
				
			||||||
 | 
					1. 新建请求,URL填写`/video/status`
 | 
				
			||||||
 | 
					2. 使用相同认证头
 | 
				
			||||||
 | 
					3. Body中携带requestId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 4. 取消任务
 | 
				
			||||||
 | 
					1. 新建DELETE请求,URL填写`/video/cancel`
 | 
				
			||||||
 | 
					2. Body携带需要取消的requestId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 注意事项
 | 
				
			||||||
 | 
					1. 所有接口必须携带有效API Key
 | 
				
			||||||
 | 
					2. 视频生成耗时约2-5分钟(根据参数配置)
 | 
				
			||||||
 | 
					3. 生成视频默认保留1小时
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 四、参数规范
 | 
				
			||||||
 | 
					| 参数名           | 允许值范围                     | 必填 | 说明                     |
 | 
				
			||||||
 | 
					|------------------|-------------------------------|------|--------------------------|
 | 
				
			||||||
 | 
					| prompt           | 10-500字符                    | 是   | 视频内容描述             |
 | 
				
			||||||
 | 
					| image_size       | "480x832" 或 "832x480"        | 是   | 分辨率                   |
 | 
				
			||||||
 | 
					| num_frames       | 24-120                        | 是   | 视频总帧数               |
 | 
				
			||||||
 | 
					| guidance_scale   | 1.0-20.0                      | 是   | 文本引导强度             |
 | 
				
			||||||
 | 
					| infer_steps      | 20-100                        | 是   | 推理步数                 |
 | 
				
			||||||
 | 
					| seed             | 0-2147483647                  | 否   | 随机种子                 |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					---
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## 五、状态码说明
 | 
				
			||||||
 | 
					| 状态码 | 含义                     |
 | 
				
			||||||
 | 
					|--------|--------------------------|
 | 
				
			||||||
 | 
					| 202    | 任务已接受               |
 | 
				
			||||||
 | 
					| 401    | 认证失败                 |
 | 
				
			||||||
 | 
					| 404    | 任务不存在               |
 | 
				
			||||||
 | 
					| 422    | 参数校验失败             |
 | 
				
			||||||
 | 
					| 500    | 服务端错误(显存不足等) |
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					**提示**:建议使用Swagger文档进行接口测试,访问`http://服务器地址:8088/docs`可查看自动生成的API文档界面
 | 
				
			||||||
							
								
								
									
										526
									
								
								i2v_api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										526
									
								
								i2v_api.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,526 @@
 | 
				
			|||||||
 | 
					import os
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import uuid
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					from threading import Lock
 | 
				
			||||||
 | 
					from typing import Optional, Dict, List
 | 
				
			||||||
 | 
					from fastapi import FastAPI, HTTPException, status, Depends
 | 
				
			||||||
 | 
					from fastapi.staticfiles import StaticFiles
 | 
				
			||||||
 | 
					from pydantic import BaseModel, Field, field_validator, ValidationError
 | 
				
			||||||
 | 
					from diffusers.utils import export_to_video, load_image
 | 
				
			||||||
 | 
					from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
 | 
				
			||||||
 | 
					from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
 | 
				
			||||||
 | 
					from transformers import CLIPVisionModel
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
 | 
					from io import BytesIO
 | 
				
			||||||
 | 
					from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 | 
				
			||||||
 | 
					from contextlib import asynccontextmanager
 | 
				
			||||||
 | 
					from requests.exceptions import RequestException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 创建存储目录
 | 
				
			||||||
 | 
					os.makedirs("generated_videos", exist_ok=True)
 | 
				
			||||||
 | 
					os.makedirs("temp_images", exist_ok=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# 生命周期管理
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					@asynccontextmanager
 | 
				
			||||||
 | 
					async def lifespan(app: FastAPI):
 | 
				
			||||||
 | 
					    """资源管理器"""
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        # 初始化认证系统
 | 
				
			||||||
 | 
					        app.state.valid_api_keys = {
 | 
				
			||||||
 | 
					            "密钥"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 初始化模型
 | 
				
			||||||
 | 
					        model_id = "./Wan2.1-I2V-14B-480P-Diffusers"
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 加载图像编码器
 | 
				
			||||||
 | 
					        image_encoder = CLIPVisionModel.from_pretrained(
 | 
				
			||||||
 | 
					            model_id,
 | 
				
			||||||
 | 
					            subfolder="image_encoder",
 | 
				
			||||||
 | 
					            torch_dtype=torch.float32
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 加载VAE
 | 
				
			||||||
 | 
					        vae = AutoencoderKLWan.from_pretrained(
 | 
				
			||||||
 | 
					            model_id,
 | 
				
			||||||
 | 
					            subfolder="vae",
 | 
				
			||||||
 | 
					            torch_dtype=torch.float32
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 配置调度器
 | 
				
			||||||
 | 
					        scheduler = UniPCMultistepScheduler(
 | 
				
			||||||
 | 
					            prediction_type='flow_prediction',
 | 
				
			||||||
 | 
					            use_flow_sigmas=True,
 | 
				
			||||||
 | 
					            num_train_timesteps=1000,
 | 
				
			||||||
 | 
					            flow_shift=3.0
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 创建管道
 | 
				
			||||||
 | 
					        app.state.pipe = WanImageToVideoPipeline.from_pretrained(
 | 
				
			||||||
 | 
					            model_id,
 | 
				
			||||||
 | 
					            vae=vae,
 | 
				
			||||||
 | 
					            image_encoder=image_encoder,
 | 
				
			||||||
 | 
					            torch_dtype=torch.bfloat16
 | 
				
			||||||
 | 
					        ).to("cuda")
 | 
				
			||||||
 | 
					        app.state.pipe.scheduler = scheduler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 初始化任务系统
 | 
				
			||||||
 | 
					        app.state.tasks: Dict[str, dict] = {}
 | 
				
			||||||
 | 
					        app.state.pending_queue: List[str] = []
 | 
				
			||||||
 | 
					        app.state.model_lock = Lock()
 | 
				
			||||||
 | 
					        app.state.task_lock = Lock()
 | 
				
			||||||
 | 
					        app.state.base_url = "ip地址+端口"
 | 
				
			||||||
 | 
					        app.state.semaphore = asyncio.Semaphore(2)  # 并发限制
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 启动后台处理器
 | 
				
			||||||
 | 
					        asyncio.create_task(task_processor())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        print("✅ 系统初始化完成")
 | 
				
			||||||
 | 
					        yield
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    finally:
 | 
				
			||||||
 | 
					        # 资源清理
 | 
				
			||||||
 | 
					        if hasattr(app.state, 'pipe'):
 | 
				
			||||||
 | 
					            del app.state.pipe
 | 
				
			||||||
 | 
					            torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					            print("♻️ 资源已释放")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# FastAPI应用
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					app = FastAPI(lifespan=lifespan)
 | 
				
			||||||
 | 
					app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
 | 
				
			||||||
 | 
					# 认证模块
 | 
				
			||||||
 | 
					security = HTTPBearer(auto_error=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# 数据模型--查询参数模型
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					class VideoSubmitRequest(BaseModel):
 | 
				
			||||||
 | 
					    model: str = Field(
 | 
				
			||||||
 | 
					        default="Wan2.1-I2V-14B-480P",
 | 
				
			||||||
 | 
					        description="模型版本"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    prompt: str = Field(
 | 
				
			||||||
 | 
					        ...,
 | 
				
			||||||
 | 
					        min_length=10,
 | 
				
			||||||
 | 
					        max_length=500,
 | 
				
			||||||
 | 
					        description="视频描述提示词,10-500个字符"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    image_url: str = Field(
 | 
				
			||||||
 | 
					        ...,
 | 
				
			||||||
 | 
					        description="输入图像URL,需支持HTTP/HTTPS协议"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    image_size: str = Field(
 | 
				
			||||||
 | 
					        default="auto",
 | 
				
			||||||
 | 
					        description="输出分辨率,格式:宽x高 或 auto(自动计算)"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    negative_prompt: Optional[str] = Field(
 | 
				
			||||||
 | 
					        default=None,
 | 
				
			||||||
 | 
					        max_length=500,
 | 
				
			||||||
 | 
					        description="排除不需要的内容"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    seed: Optional[int] = Field(
 | 
				
			||||||
 | 
					        default=None,
 | 
				
			||||||
 | 
					        ge=0,
 | 
				
			||||||
 | 
					        le=2147483647,
 | 
				
			||||||
 | 
					        description="随机数种子,范围0-2147483647"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    num_frames: int = Field(
 | 
				
			||||||
 | 
					        default=81,
 | 
				
			||||||
 | 
					        ge=24,
 | 
				
			||||||
 | 
					        le=120,
 | 
				
			||||||
 | 
					        description="视频帧数,24-89帧"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    guidance_scale: float = Field(
 | 
				
			||||||
 | 
					        default=3.0,
 | 
				
			||||||
 | 
					        ge=1.0,
 | 
				
			||||||
 | 
					        le=20.0,
 | 
				
			||||||
 | 
					        description="引导系数,1.0-20.0"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    infer_steps: int = Field(
 | 
				
			||||||
 | 
					        default=30,
 | 
				
			||||||
 | 
					        ge=20,
 | 
				
			||||||
 | 
					        le=100,
 | 
				
			||||||
 | 
					        description="推理步数,20-100步"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @field_validator('image_size')
 | 
				
			||||||
 | 
					    def validate_image_size(cls, v):
 | 
				
			||||||
 | 
					        allowed_sizes = {"480x832", "832x480", "auto"}
 | 
				
			||||||
 | 
					        if v not in allowed_sizes:
 | 
				
			||||||
 | 
					            raise ValueError(f"支持的分辨率: {', '.join(allowed_sizes)}")
 | 
				
			||||||
 | 
					        return v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VideoStatusRequest(BaseModel):
 | 
				
			||||||
 | 
					    requestId: str = Field(
 | 
				
			||||||
 | 
					        ...,
 | 
				
			||||||
 | 
					        min_length=32,
 | 
				
			||||||
 | 
					        max_length=32,
 | 
				
			||||||
 | 
					        description="32位任务ID"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VideoStatusResponse(BaseModel):
 | 
				
			||||||
 | 
					    status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
 | 
				
			||||||
 | 
					    reason: Optional[str] = Field(None, description="失败原因")
 | 
				
			||||||
 | 
					    results: Optional[dict] = Field(None, description="生成结果")
 | 
				
			||||||
 | 
					    queue_position: Optional[int] = Field(None, description="队列位置")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VideoCancelRequest(BaseModel):
 | 
				
			||||||
 | 
					    requestId: str = Field(
 | 
				
			||||||
 | 
					        ...,
 | 
				
			||||||
 | 
					        min_length=32,
 | 
				
			||||||
 | 
					        max_length=32,
 | 
				
			||||||
 | 
					        description="32位任务ID"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# 核心逻辑
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
 | 
				
			||||||
 | 
					    """统一认证验证"""
 | 
				
			||||||
 | 
					    if not credentials:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=401,
 | 
				
			||||||
 | 
					            detail={"status": "Failed", "reason": "缺少认证头"},
 | 
				
			||||||
 | 
					            headers={"WWW-Authenticate": "Bearer"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    if credentials.scheme != "Bearer":
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=401,
 | 
				
			||||||
 | 
					            detail={"status": "Failed", "reason": "无效的认证方案"},
 | 
				
			||||||
 | 
					            headers={"WWW-Authenticate": "Bearer"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    if credentials.credentials not in app.state.valid_api_keys:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=401,
 | 
				
			||||||
 | 
					            detail={"status": "Failed", "reason": "无效的API密钥"},
 | 
				
			||||||
 | 
					            headers={"WWW-Authenticate": "Bearer"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def task_processor():
 | 
				
			||||||
 | 
					    """任务处理器"""
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        async with app.state.semaphore:
 | 
				
			||||||
 | 
					            task_id = await get_next_task()
 | 
				
			||||||
 | 
					            if task_id:
 | 
				
			||||||
 | 
					                await process_task(task_id)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                await asyncio.sleep(0.5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def get_next_task():
 | 
				
			||||||
 | 
					    """获取下一个任务"""
 | 
				
			||||||
 | 
					    with app.state.task_lock:
 | 
				
			||||||
 | 
					        return app.state.pending_queue.pop(0) if app.state.pending_queue else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def process_task(task_id: str):
 | 
				
			||||||
 | 
					    """处理单个任务"""
 | 
				
			||||||
 | 
					    task = app.state.tasks.get(task_id)
 | 
				
			||||||
 | 
					    if not task:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        # 更新任务状态
 | 
				
			||||||
 | 
					        task['status'] = 'InProgress'
 | 
				
			||||||
 | 
					        task['started_at'] = int(time.time())
 | 
				
			||||||
 | 
					        print(task['request'].image_url)
 | 
				
			||||||
 | 
					        # 下载输入图像
 | 
				
			||||||
 | 
					        image = await download_image(task['request'].image_url)
 | 
				
			||||||
 | 
					        image_path = f"temp_images/{task_id}.jpg"
 | 
				
			||||||
 | 
					        image.save(image_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 生成视频
 | 
				
			||||||
 | 
					        video_path = await generate_video(task['request'], task_id, image)
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 生成下载链接
 | 
				
			||||||
 | 
					        download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 更新任务状态
 | 
				
			||||||
 | 
					        task.update({
 | 
				
			||||||
 | 
					            'status': 'Succeed',
 | 
				
			||||||
 | 
					            'download_url': download_url,
 | 
				
			||||||
 | 
					            'completed_at': int(time.time())
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 安排清理
 | 
				
			||||||
 | 
					        asyncio.create_task(cleanup_files([image_path, video_path]))
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        handle_task_error(task, e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def handle_task_error(task: dict, error: Exception):
 | 
				
			||||||
 | 
					    """错误处理(包含详细错误信息)"""
 | 
				
			||||||
 | 
					    error_msg = str(error)
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					    # 1. 显存不足错误
 | 
				
			||||||
 | 
					    if isinstance(error, torch.cuda.OutOfMemoryError):
 | 
				
			||||||
 | 
					        error_msg = "显存不足,请降低分辨率"
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					    # 2. 网络请求相关错误
 | 
				
			||||||
 | 
					    elif isinstance(error, (RequestException, HTTPException)):
 | 
				
			||||||
 | 
					        # 从异常中提取具体信息
 | 
				
			||||||
 | 
					        if isinstance(error, HTTPException):
 | 
				
			||||||
 | 
					            # 如果是 HTTPException,获取其 detail 字段
 | 
				
			||||||
 | 
					            error_detail = getattr(error, "detail", "")
 | 
				
			||||||
 | 
					            error_msg = f"图像下载失败: {error_detail}"
 | 
				
			||||||
 | 
					          
 | 
				
			||||||
 | 
					        elif isinstance(error, Timeout):
 | 
				
			||||||
 | 
					            error_msg = "图像下载超时,请检查网络"
 | 
				
			||||||
 | 
					          
 | 
				
			||||||
 | 
					        elif isinstance(error, ConnectionError):
 | 
				
			||||||
 | 
					            error_msg = "无法连接到服务器,请检查 URL"
 | 
				
			||||||
 | 
					          
 | 
				
			||||||
 | 
					        elif isinstance(error, HTTPError):
 | 
				
			||||||
 | 
					            # requests 的 HTTPError(例如 4xx/5xx 状态码)
 | 
				
			||||||
 | 
					            status_code = error.response.status_code
 | 
				
			||||||
 | 
					            error_msg = f"服务器返回错误状态码: {status_code}"
 | 
				
			||||||
 | 
					          
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            # 其他 RequestException 错误
 | 
				
			||||||
 | 
					            error_msg = f"图像下载失败: {str(error)}"
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					    # 3. 其他未知错误
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        error_msg = f"未知错误: {str(error)}"
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					    # 更新任务状态
 | 
				
			||||||
 | 
					    task.update({
 | 
				
			||||||
 | 
					        'status': 'Failed',
 | 
				
			||||||
 | 
					        'reason': error_msg,
 | 
				
			||||||
 | 
					        'completed_at': int(time.time())
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# 视频生成逻辑
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					async def download_image(url: str) -> Image.Image:
 | 
				
			||||||
 | 
					    """异步下载图像(包含详细错误信息)"""
 | 
				
			||||||
 | 
					    loop = asyncio.get_event_loop()
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        response = await loop.run_in_executor(
 | 
				
			||||||
 | 
					            None, 
 | 
				
			||||||
 | 
					            lambda: requests.get(url)  # 将 timeout 传递给 requests.get
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 如果状态码非 200,主动抛出 HTTPException
 | 
				
			||||||
 | 
					        if response.status_code != 200:
 | 
				
			||||||
 | 
					            raise HTTPException(
 | 
				
			||||||
 | 
					                status_code=response.status_code,
 | 
				
			||||||
 | 
					                detail=f"服务器返回状态码 {response.status_code}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					          
 | 
				
			||||||
 | 
					        return Image.open(BytesIO(response.content)).convert("RGB")
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					    except RequestException as e:
 | 
				
			||||||
 | 
					        # 将原始 requests 错误信息抛出
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=500,
 | 
				
			||||||
 | 
					            detail=f"请求失败: {str(e)}"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					async def generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
 | 
				
			||||||
 | 
					    """异步生成入口"""
 | 
				
			||||||
 | 
					    loop = asyncio.get_event_loop()
 | 
				
			||||||
 | 
					    return await loop.run_in_executor(
 | 
				
			||||||
 | 
					        None,
 | 
				
			||||||
 | 
					        sync_generate_video,
 | 
				
			||||||
 | 
					        request,
 | 
				
			||||||
 | 
					        task_id,
 | 
				
			||||||
 | 
					        image
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def sync_generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
 | 
				
			||||||
 | 
					    """同步生成核心"""
 | 
				
			||||||
 | 
					    with app.state.model_lock:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            # 解析分辨率
 | 
				
			||||||
 | 
					            mod_value = 16  # 模型要求的模数
 | 
				
			||||||
 | 
					            print(request.image_size)
 | 
				
			||||||
 | 
					            print('--------------------------------')
 | 
				
			||||||
 | 
					            if request.image_size == "auto":
 | 
				
			||||||
 | 
					                # 原版自动计算逻辑
 | 
				
			||||||
 | 
					                aspect_ratio = image.height / image.width
 | 
				
			||||||
 | 
					                print(image.height,image.width)
 | 
				
			||||||
 | 
					                max_area = 399360  # 模型基础分辨率
 | 
				
			||||||
 | 
					              
 | 
				
			||||||
 | 
					                # 计算理想尺寸
 | 
				
			||||||
 | 
					                height = round(np.sqrt(max_area * aspect_ratio)) 
 | 
				
			||||||
 | 
					                width = round(np.sqrt(max_area / aspect_ratio))
 | 
				
			||||||
 | 
					              
 | 
				
			||||||
 | 
					                # 应用模数调整
 | 
				
			||||||
 | 
					                height = height // mod_value * mod_value
 | 
				
			||||||
 | 
					                width = width // mod_value * mod_value
 | 
				
			||||||
 | 
					                resized_image = image.resize((width, height))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                width_str, height_str = request.image_size.split('x')
 | 
				
			||||||
 | 
					                width = int(width_str)
 | 
				
			||||||
 | 
					                height = int(height_str)
 | 
				
			||||||
 | 
					                mod_value = 16          
 | 
				
			||||||
 | 
					                # 调整图像尺寸
 | 
				
			||||||
 | 
					                resized_image = image.resize((width, height))
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					              
 | 
				
			||||||
 | 
					            # 设置随机种子
 | 
				
			||||||
 | 
					            generator = None
 | 
				
			||||||
 | 
					            # 修改点1: 使用属性访问seed
 | 
				
			||||||
 | 
					            if request.seed is not None:
 | 
				
			||||||
 | 
					                generator = torch.Generator(device="cuda")
 | 
				
			||||||
 | 
					                generator.manual_seed(request.seed)  # 修改点2
 | 
				
			||||||
 | 
					                print(f"🔮 使用随机种子: {request.seed}")
 | 
				
			||||||
 | 
					            print(resized_image)
 | 
				
			||||||
 | 
					            print(height,width)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # 执行推理
 | 
				
			||||||
 | 
					            output = app.state.pipe(
 | 
				
			||||||
 | 
					                image=resized_image,
 | 
				
			||||||
 | 
					                prompt=request.prompt,
 | 
				
			||||||
 | 
					                negative_prompt=request.negative_prompt,
 | 
				
			||||||
 | 
					                height=height,
 | 
				
			||||||
 | 
					                width=width,
 | 
				
			||||||
 | 
					                num_frames=request.num_frames,
 | 
				
			||||||
 | 
					                guidance_scale=request.guidance_scale,
 | 
				
			||||||
 | 
					                num_inference_steps=request.infer_steps,
 | 
				
			||||||
 | 
					                generator=generator
 | 
				
			||||||
 | 
					            ).frames[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # 导出视频
 | 
				
			||||||
 | 
					            video_id = uuid.uuid4().hex
 | 
				
			||||||
 | 
					            output_path = f"generated_videos/{video_id}.mp4"
 | 
				
			||||||
 | 
					            export_to_video(output, output_path, fps=16)
 | 
				
			||||||
 | 
					            return output_path
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            raise RuntimeError(f"视频生成失败: {str(e)}") from e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# API端点
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					@app.post("/video/submit",
 | 
				
			||||||
 | 
					          response_model=dict,
 | 
				
			||||||
 | 
					          status_code=status.HTTP_202_ACCEPTED,
 | 
				
			||||||
 | 
					          tags=["视频生成"])
 | 
				
			||||||
 | 
					async def submit_task(
 | 
				
			||||||
 | 
					    request: VideoSubmitRequest,
 | 
				
			||||||
 | 
					    auth: bool = Depends(verify_auth)
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """提交生成任务"""
 | 
				
			||||||
 | 
					    # 参数验证
 | 
				
			||||||
 | 
					    if request.image_url is None:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=422,
 | 
				
			||||||
 | 
					            detail={"status": "Failed", "reason": "需要图像URL参数"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 创建任务记录
 | 
				
			||||||
 | 
					    task_id = uuid.uuid4().hex
 | 
				
			||||||
 | 
					    with app.state.task_lock:
 | 
				
			||||||
 | 
					        app.state.tasks[task_id] = {
 | 
				
			||||||
 | 
					            "request": request,
 | 
				
			||||||
 | 
					            "status": "InQueue",
 | 
				
			||||||
 | 
					            "created_at": int(time.time())
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        app.state.pending_queue.append(task_id)
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					    return {"requestId": task_id}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/video/status",
 | 
				
			||||||
 | 
					          response_model=VideoStatusResponse,
 | 
				
			||||||
 | 
					          tags=["视频生成"])
 | 
				
			||||||
 | 
					async def get_status(
 | 
				
			||||||
 | 
					    request: VideoStatusRequest,
 | 
				
			||||||
 | 
					    auth: bool = Depends(verify_auth)
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """查询任务状态"""
 | 
				
			||||||
 | 
					    task = app.state.tasks.get(request.requestId)
 | 
				
			||||||
 | 
					    if not task:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=404,
 | 
				
			||||||
 | 
					            detail={"status": "Failed", "reason": "无效的任务ID"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 计算队列位置(仅当在队列中时)
 | 
				
			||||||
 | 
					    queue_pos = 0
 | 
				
			||||||
 | 
					    if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
 | 
				
			||||||
 | 
					        queue_pos = app.state.pending_queue.index(request.requestId) + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response = {
 | 
				
			||||||
 | 
					        "status": task['status'],
 | 
				
			||||||
 | 
					        "reason": task.get('reason'),
 | 
				
			||||||
 | 
					        "queue_position": queue_pos if task['status'] == "InQueue" else None  # 非排队状态返回null
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 成功状态的特殊处理
 | 
				
			||||||
 | 
					    if task['status'] == "Succeed":
 | 
				
			||||||
 | 
					        response["results"] = {
 | 
				
			||||||
 | 
					            "videos": [{"url": task['download_url']}],
 | 
				
			||||||
 | 
					            "timings": {
 | 
				
			||||||
 | 
					                "inference": task['completed_at'] - task['started_at']
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            "seed": task['request'].seed
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    # 取消状态的补充信息
 | 
				
			||||||
 | 
					    elif task['status'] == "Cancelled":
 | 
				
			||||||
 | 
					        response["reason"] = task.get('reason', "用户主动取消")  # 确保原因字段存在
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/video/cancel",
 | 
				
			||||||
 | 
					         response_model=dict,
 | 
				
			||||||
 | 
					         tags=["视频生成"])
 | 
				
			||||||
 | 
					async def cancel_task(
 | 
				
			||||||
 | 
					    request: VideoCancelRequest,
 | 
				
			||||||
 | 
					    auth: bool = Depends(verify_auth)
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """取消排队中的生成任务"""
 | 
				
			||||||
 | 
					    task_id = request.requestId
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					    with app.state.task_lock:
 | 
				
			||||||
 | 
					        task = app.state.tasks.get(task_id)
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 检查任务是否存在
 | 
				
			||||||
 | 
					        if not task:
 | 
				
			||||||
 | 
					            raise HTTPException(
 | 
				
			||||||
 | 
					                status_code=404,
 | 
				
			||||||
 | 
					                detail={"status": "Failed", "reason": "无效的任务ID"}
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					          
 | 
				
			||||||
 | 
					        current_status = task['status']
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 仅允许取消排队中的任务
 | 
				
			||||||
 | 
					        if current_status != "InQueue":
 | 
				
			||||||
 | 
					            raise HTTPException(
 | 
				
			||||||
 | 
					                status_code=400,
 | 
				
			||||||
 | 
					                detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 从队列移除
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            app.state.pending_queue.remove(task_id)
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            pass  # 可能已被处理
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 更新任务状态
 | 
				
			||||||
 | 
					        task.update({
 | 
				
			||||||
 | 
					            "status": "Cancelled",
 | 
				
			||||||
 | 
					            "reason": "用户主动取消",
 | 
				
			||||||
 | 
					            "completed_at": int(time.time())
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					    return {"status": "Succeed"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def cleanup_files(paths: List[str], delay: int = 3600):
 | 
				
			||||||
 | 
					    """定时清理文件"""
 | 
				
			||||||
 | 
					    await asyncio.sleep(delay)
 | 
				
			||||||
 | 
					    for path in paths:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            if os.path.exists(path):
 | 
				
			||||||
 | 
					                os.remove(path)
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            print(f"清理失败 {path}: {str(e)}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    import uvicorn
 | 
				
			||||||
 | 
					    uvicorn.run(app, host="0.0.0.0", port=8088)
 | 
				
			||||||
							
								
								
									
										450
									
								
								t2v-api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										450
									
								
								t2v-api.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,450 @@
 | 
				
			|||||||
 | 
					import os
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import uuid
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					from enum import Enum
 | 
				
			||||||
 | 
					from threading import Lock
 | 
				
			||||||
 | 
					from typing import Optional, Dict, List
 | 
				
			||||||
 | 
					from contextlib import asynccontextmanager
 | 
				
			||||||
 | 
					from fastapi import FastAPI, HTTPException, status, Depends
 | 
				
			||||||
 | 
					from fastapi.staticfiles import StaticFiles
 | 
				
			||||||
 | 
					from pydantic import BaseModel, Field, field_validator, ValidationError
 | 
				
			||||||
 | 
					from diffusers.utils import export_to_video
 | 
				
			||||||
 | 
					from diffusers import AutoencoderKLWan, WanPipeline
 | 
				
			||||||
 | 
					from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
 | 
				
			||||||
 | 
					from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 | 
				
			||||||
 | 
					from fastapi.responses import JSONResponse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 创建视频存储目录
 | 
				
			||||||
 | 
					os.makedirs("generated_videos", exist_ok=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 生命周期管理器
 | 
				
			||||||
 | 
					@asynccontextmanager
 | 
				
			||||||
 | 
					async def lifespan(app: FastAPI):
 | 
				
			||||||
 | 
					    """管理应用生命周期"""
 | 
				
			||||||
 | 
					    # 初始化模型和资源
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        # 初始化认证密钥
 | 
				
			||||||
 | 
					        app.state.valid_api_keys = {
 | 
				
			||||||
 | 
					            "密钥"
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 初始化视频生成模型
 | 
				
			||||||
 | 
					        model_id = "./Wan2.1-T2V-1.3B-Diffusers"
 | 
				
			||||||
 | 
					        vae = AutoencoderKLWan.from_pretrained(
 | 
				
			||||||
 | 
					            model_id,
 | 
				
			||||||
 | 
					            subfolder="vae",
 | 
				
			||||||
 | 
					            torch_dtype=torch.float32
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        scheduler = UniPCMultistepScheduler(
 | 
				
			||||||
 | 
					            prediction_type='flow_prediction',
 | 
				
			||||||
 | 
					            use_flow_sigmas=True,
 | 
				
			||||||
 | 
					            num_train_timesteps=1000,
 | 
				
			||||||
 | 
					            flow_shift=3.0
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        app.state.pipe = WanPipeline.from_pretrained(
 | 
				
			||||||
 | 
					            model_id,
 | 
				
			||||||
 | 
					            vae=vae,
 | 
				
			||||||
 | 
					            torch_dtype=torch.bfloat16
 | 
				
			||||||
 | 
					        ).to("cuda")
 | 
				
			||||||
 | 
					        app.state.pipe.scheduler = scheduler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 初始化任务系统
 | 
				
			||||||
 | 
					        app.state.tasks: Dict[str, dict] = {}
 | 
				
			||||||
 | 
					        app.state.pending_queue: List[str] = []
 | 
				
			||||||
 | 
					        app.state.model_lock = Lock()
 | 
				
			||||||
 | 
					        app.state.task_lock = Lock()
 | 
				
			||||||
 | 
					        app.state.base_url = "ip地址+端口"
 | 
				
			||||||
 | 
					        app.state.max_concurrent = 2
 | 
				
			||||||
 | 
					        app.state.semaphore = asyncio.Semaphore(app.state.max_concurrent)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 启动后台任务处理器
 | 
				
			||||||
 | 
					        asyncio.create_task(task_processor())
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        print("✅ 应用初始化完成")
 | 
				
			||||||
 | 
					        yield
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					    finally:
 | 
				
			||||||
 | 
					        # 清理资源
 | 
				
			||||||
 | 
					        if hasattr(app.state, 'pipe'):
 | 
				
			||||||
 | 
					            del app.state.pipe
 | 
				
			||||||
 | 
					            torch.cuda.empty_cache()
 | 
				
			||||||
 | 
					            print("♻️ 已释放模型资源")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 创建FastAPI应用
 | 
				
			||||||
 | 
					app = FastAPI(lifespan=lifespan)
 | 
				
			||||||
 | 
					app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# 认证模块
 | 
				
			||||||
 | 
					security = HTTPBearer(auto_error=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# 数据模型--查询参数模型
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					class VideoSubmitRequest(BaseModel):
 | 
				
			||||||
 | 
					    model: str = Field(default="Wan2.1-T2V-1.3B",description="使用的模型版本")
 | 
				
			||||||
 | 
					    prompt: str = Field(
 | 
				
			||||||
 | 
					        ...,
 | 
				
			||||||
 | 
					        min_length=10,
 | 
				
			||||||
 | 
					        max_length=500,
 | 
				
			||||||
 | 
					        description="视频描述提示词,10-500个字符"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    image_size: str = Field(
 | 
				
			||||||
 | 
					        ...,
 | 
				
			||||||
 | 
					        description="视频分辨率,仅支持480x832或832x480"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    negative_prompt: Optional[str] = Field(
 | 
				
			||||||
 | 
					        default=None,
 | 
				
			||||||
 | 
					        max_length=500,
 | 
				
			||||||
 | 
					        description="排除不需要的内容"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    seed: Optional[int] = Field(
 | 
				
			||||||
 | 
					        default=None,
 | 
				
			||||||
 | 
					        ge=0,
 | 
				
			||||||
 | 
					        le=2147483647,
 | 
				
			||||||
 | 
					        description="随机数种子,范围0-2147483647"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    num_frames: int = Field(
 | 
				
			||||||
 | 
					        default=81,
 | 
				
			||||||
 | 
					        ge=24,
 | 
				
			||||||
 | 
					        le=120,
 | 
				
			||||||
 | 
					        description="视频帧数,24-120帧"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    guidance_scale: float = Field(
 | 
				
			||||||
 | 
					        default=5.0,
 | 
				
			||||||
 | 
					        ge=1.0,
 | 
				
			||||||
 | 
					        le=20.0,
 | 
				
			||||||
 | 
					        description="引导系数,1.0-20.0"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    infer_steps: int = Field(
 | 
				
			||||||
 | 
					        default=50,
 | 
				
			||||||
 | 
					        ge=20,
 | 
				
			||||||
 | 
					        le=100,
 | 
				
			||||||
 | 
					        description="推理步数,20-100步"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @field_validator('image_size', mode='before')
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def validate_image_size(cls, value):
 | 
				
			||||||
 | 
					        allowed_sizes = {"480x832", "832x480"}
 | 
				
			||||||
 | 
					        if value not in allowed_sizes:
 | 
				
			||||||
 | 
					            raise ValueError(f"仅支持以下分辨率: {', '.join(allowed_sizes)}")
 | 
				
			||||||
 | 
					        return value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VideoStatusRequest(BaseModel):
 | 
				
			||||||
 | 
					    requestId: str = Field(
 | 
				
			||||||
 | 
					        ...,
 | 
				
			||||||
 | 
					        min_length=32,
 | 
				
			||||||
 | 
					        max_length=32,
 | 
				
			||||||
 | 
					        description="32位任务ID"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VideoSubmitResponse(BaseModel):
 | 
				
			||||||
 | 
					    requestId: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VideoStatusResponse(BaseModel):
 | 
				
			||||||
 | 
					    status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
 | 
				
			||||||
 | 
					    reason: Optional[str] = Field(None, description="失败原因")
 | 
				
			||||||
 | 
					    results: Optional[dict] = Field(None, description="生成结果")
 | 
				
			||||||
 | 
					    queue_position: Optional[int] = Field(None, description="队列位置")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VideoCancelRequest(BaseModel):
 | 
				
			||||||
 | 
					    requestId: str = Field(
 | 
				
			||||||
 | 
					        ...,
 | 
				
			||||||
 | 
					        min_length=32,
 | 
				
			||||||
 | 
					        max_length=32,
 | 
				
			||||||
 | 
					        description="32位任务ID"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# # 自定义HTTP异常处理器
 | 
				
			||||||
 | 
					# @app.exception_handler(HTTPException)
 | 
				
			||||||
 | 
					# async def http_exception_handler(request, exc):
 | 
				
			||||||
 | 
					#     return JSONResponse(
 | 
				
			||||||
 | 
					#         status_code=exc.status_code,
 | 
				
			||||||
 | 
					#         content=exc.detail,  # 直接返回detail内容(不再包装在detail字段)
 | 
				
			||||||
 | 
					#         headers=exc.headers
 | 
				
			||||||
 | 
					#     )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# 后台任务处理
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					async def task_processor():
 | 
				
			||||||
 | 
					    """处理任务队列"""
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        async with app.state.semaphore:
 | 
				
			||||||
 | 
					            task_id = await get_next_task()
 | 
				
			||||||
 | 
					            if task_id:
 | 
				
			||||||
 | 
					                await process_task(task_id)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                await asyncio.sleep(0.5)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def get_next_task():
 | 
				
			||||||
 | 
					    """获取下一个待处理任务"""
 | 
				
			||||||
 | 
					    with app.state.task_lock:
 | 
				
			||||||
 | 
					        if app.state.pending_queue:
 | 
				
			||||||
 | 
					            return app.state.pending_queue.pop(0)
 | 
				
			||||||
 | 
					    return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def process_task(task_id: str):
 | 
				
			||||||
 | 
					    """处理单个任务"""
 | 
				
			||||||
 | 
					    task = app.state.tasks.get(task_id)
 | 
				
			||||||
 | 
					    if not task:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        # 更新任务状态
 | 
				
			||||||
 | 
					        task['status'] = 'InProgress'
 | 
				
			||||||
 | 
					        task['started_at'] = int(time.time())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 执行视频生成
 | 
				
			||||||
 | 
					        video_path = await generate_video(task['request'], task_id)
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 生成下载链接
 | 
				
			||||||
 | 
					        download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 更新任务状态
 | 
				
			||||||
 | 
					        task.update({
 | 
				
			||||||
 | 
					            'status': 'Succeed',
 | 
				
			||||||
 | 
					            'download_url': download_url,
 | 
				
			||||||
 | 
					            'completed_at': int(time.time())
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 安排自动清理
 | 
				
			||||||
 | 
					        asyncio.create_task(auto_cleanup(video_path))
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        handle_task_error(task, e)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def handle_task_error(task: dict, error: Exception):
 | 
				
			||||||
 | 
					    """统一处理任务错误"""
 | 
				
			||||||
 | 
					    error_msg = str(error)
 | 
				
			||||||
 | 
					    if isinstance(error, torch.cuda.OutOfMemoryError):
 | 
				
			||||||
 | 
					        error_msg = "显存不足,请降低分辨率或减少帧数"
 | 
				
			||||||
 | 
					    elif isinstance(error, ValidationError):
 | 
				
			||||||
 | 
					        error_msg = "参数校验失败: " + str(error)
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					    task.update({
 | 
				
			||||||
 | 
					        'status': 'Failed',
 | 
				
			||||||
 | 
					        'reason': error_msg,
 | 
				
			||||||
 | 
					        'completed_at': int(time.time())
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# 视频生成核心逻辑
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					async def generate_video(request: dict, task_id: str) -> str:
 | 
				
			||||||
 | 
					    """异步执行视频生成"""
 | 
				
			||||||
 | 
					    loop = asyncio.get_event_loop()
 | 
				
			||||||
 | 
					    return await loop.run_in_executor(
 | 
				
			||||||
 | 
					        None,
 | 
				
			||||||
 | 
					        sync_generate_video,
 | 
				
			||||||
 | 
					        request,
 | 
				
			||||||
 | 
					        task_id
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def sync_generate_video(request: dict, task_id: str) -> str:
 | 
				
			||||||
 | 
					    """同步生成视频"""
 | 
				
			||||||
 | 
					    with app.state.model_lock:
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            generator = None
 | 
				
			||||||
 | 
					            if request.get('seed') is not None:
 | 
				
			||||||
 | 
					                generator = torch.Generator(device="cuda")
 | 
				
			||||||
 | 
					                generator.manual_seed(request['seed'])
 | 
				
			||||||
 | 
					                print(f"🔮 使用随机种子: {request['seed']}")
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            # 执行模型推理
 | 
				
			||||||
 | 
					            result = app.state.pipe(
 | 
				
			||||||
 | 
					                prompt=request['prompt'],
 | 
				
			||||||
 | 
					                negative_prompt=request['negative_prompt'],
 | 
				
			||||||
 | 
					                height=request['height'],
 | 
				
			||||||
 | 
					                width=request['width'],
 | 
				
			||||||
 | 
					                num_frames=request['num_frames'],
 | 
				
			||||||
 | 
					                guidance_scale=request['guidance_scale'],
 | 
				
			||||||
 | 
					                num_inference_steps=request['infer_steps'],
 | 
				
			||||||
 | 
					                generator=generator
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					          
 | 
				
			||||||
 | 
					            # 导出视频文件
 | 
				
			||||||
 | 
					            video_id = uuid.uuid4().hex
 | 
				
			||||||
 | 
					            output_path = f"generated_videos/{video_id}.mp4"
 | 
				
			||||||
 | 
					            export_to_video(result.frames[0], output_path, fps=16)
 | 
				
			||||||
 | 
					          
 | 
				
			||||||
 | 
					            return output_path
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            raise RuntimeError(f"视频生成失败: {str(e)}") from e
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# API端点
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
 | 
				
			||||||
 | 
					    """认证验证"""
 | 
				
			||||||
 | 
					    if not credentials:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=401,
 | 
				
			||||||
 | 
					            detail={"status": "Failed", "reason": "缺少认证头"},
 | 
				
			||||||
 | 
					            headers={"WWW-Authenticate": "Bearer"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    if credentials.scheme != "Bearer":
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=401,
 | 
				
			||||||
 | 
					            detail={"status": "Failed", "reason": "无效的认证方案"},
 | 
				
			||||||
 | 
					            headers={"WWW-Authenticate": "Bearer"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    if credentials.credentials not in app.state.valid_api_keys:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=401,
 | 
				
			||||||
 | 
					            detail={"status": "Failed", "reason": "无效的API密钥"},
 | 
				
			||||||
 | 
					            headers={"WWW-Authenticate": "Bearer"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/video/submit",
 | 
				
			||||||
 | 
					          response_model=VideoSubmitResponse,
 | 
				
			||||||
 | 
					          status_code=status.HTTP_202_ACCEPTED,
 | 
				
			||||||
 | 
					          tags=["视频生成"],
 | 
				
			||||||
 | 
					          summary="提交视频生成请求")
 | 
				
			||||||
 | 
					async def submit_video_task(
 | 
				
			||||||
 | 
					    request: VideoSubmitRequest,
 | 
				
			||||||
 | 
					    auth: bool = Depends(verify_auth)
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """提交新的视频生成任务"""
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        # 解析分辨率参数
 | 
				
			||||||
 | 
					        width, height = map(int, request.image_size.split('x'))
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 创建任务记录
 | 
				
			||||||
 | 
					        task_id = uuid.uuid4().hex
 | 
				
			||||||
 | 
					        task_data = {
 | 
				
			||||||
 | 
					            'request': {
 | 
				
			||||||
 | 
					                'prompt': request.prompt,
 | 
				
			||||||
 | 
					                'negative_prompt': request.negative_prompt,
 | 
				
			||||||
 | 
					                'width': width,
 | 
				
			||||||
 | 
					                'height': height,
 | 
				
			||||||
 | 
					                'num_frames': request.num_frames,
 | 
				
			||||||
 | 
					                'guidance_scale': request.guidance_scale,
 | 
				
			||||||
 | 
					                'infer_steps': request.infer_steps,
 | 
				
			||||||
 | 
					                'seed': request.seed
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            'status': 'InQueue',
 | 
				
			||||||
 | 
					            'created_at': int(time.time())
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 加入任务队列
 | 
				
			||||||
 | 
					        with app.state.task_lock:
 | 
				
			||||||
 | 
					            app.state.tasks[task_id] = task_data
 | 
				
			||||||
 | 
					            app.state.pending_queue.append(task_id)
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        return {"requestId": task_id}
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					    except ValidationError as e:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=422,
 | 
				
			||||||
 | 
					            detail={"status": "Failed", "reason": str(e)}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/video/status",
 | 
				
			||||||
 | 
					          response_model=VideoStatusResponse,
 | 
				
			||||||
 | 
					          tags=["视频生成"],
 | 
				
			||||||
 | 
					          summary="查询任务状态")
 | 
				
			||||||
 | 
					async def get_video_status(
 | 
				
			||||||
 | 
					    request: VideoStatusRequest,
 | 
				
			||||||
 | 
					    auth: bool = Depends(verify_auth)
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """查询任务状态"""
 | 
				
			||||||
 | 
					    task = app.state.tasks.get(request.requestId)
 | 
				
			||||||
 | 
					    if not task:
 | 
				
			||||||
 | 
					        raise HTTPException(
 | 
				
			||||||
 | 
					            status_code=404,
 | 
				
			||||||
 | 
					            detail={"status": "Failed", "reason": "无效的任务ID"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 计算队列位置(仅当在队列中时)
 | 
				
			||||||
 | 
					    queue_pos = 0
 | 
				
			||||||
 | 
					    if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
 | 
				
			||||||
 | 
					        queue_pos = app.state.pending_queue.index(request.requestId) + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    response = {
 | 
				
			||||||
 | 
					        "status": task['status'],
 | 
				
			||||||
 | 
					        "reason": task.get('reason'),
 | 
				
			||||||
 | 
					        "queue_position": queue_pos if task['status'] == "InQueue" else None  # 非排队状态返回null
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # 成功状态的特殊处理
 | 
				
			||||||
 | 
					    if task['status'] == "Succeed":
 | 
				
			||||||
 | 
					        response["results"] = {
 | 
				
			||||||
 | 
					            "videos": [{"url": task['download_url']}],
 | 
				
			||||||
 | 
					            "timings": {
 | 
				
			||||||
 | 
					                "inference": task['completed_at'] - task['started_at']
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            "seed": task['request']['seed']
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    # 取消状态的补充信息
 | 
				
			||||||
 | 
					    elif task['status'] == "Cancelled":
 | 
				
			||||||
 | 
					        response["reason"] = task.get('reason', "用户主动取消")  # 确保原因字段存在
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return response
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/video/cancel",
 | 
				
			||||||
 | 
					         response_model=dict,
 | 
				
			||||||
 | 
					         tags=["视频生成"])
 | 
				
			||||||
 | 
					async def cancel_task(
 | 
				
			||||||
 | 
					    request: VideoCancelRequest,
 | 
				
			||||||
 | 
					    auth: bool = Depends(verify_auth)
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """取消排队中的生成任务"""
 | 
				
			||||||
 | 
					    task_id = request.requestId
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					    with app.state.task_lock:
 | 
				
			||||||
 | 
					        task = app.state.tasks.get(task_id)
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 检查任务是否存在
 | 
				
			||||||
 | 
					        if not task:
 | 
				
			||||||
 | 
					            raise HTTPException(
 | 
				
			||||||
 | 
					                status_code=404,
 | 
				
			||||||
 | 
					                detail={"status": "Failed", "reason": "无效的任务ID"}
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					          
 | 
				
			||||||
 | 
					        current_status = task['status']
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 仅允许取消排队中的任务
 | 
				
			||||||
 | 
					        if current_status != "InQueue":
 | 
				
			||||||
 | 
					            raise HTTPException(
 | 
				
			||||||
 | 
					                status_code=400,
 | 
				
			||||||
 | 
					                detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 从队列移除
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            app.state.pending_queue.remove(task_id)
 | 
				
			||||||
 | 
					        except ValueError:
 | 
				
			||||||
 | 
					            pass  # 可能已被处理
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					        # 更新任务状态
 | 
				
			||||||
 | 
					        task.update({
 | 
				
			||||||
 | 
					            "status": "Cancelled",
 | 
				
			||||||
 | 
					            "reason": "用户主动取消",
 | 
				
			||||||
 | 
					            "completed_at": int(time.time())
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					      
 | 
				
			||||||
 | 
					    return {"status": "Succeed"}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					# 工具函数
 | 
				
			||||||
 | 
					# ======================
 | 
				
			||||||
 | 
					async def auto_cleanup(file_path: str, delay: int = 3600):
 | 
				
			||||||
 | 
					    """自动清理生成的视频文件"""
 | 
				
			||||||
 | 
					    await asyncio.sleep(delay)
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        if os.path.exists(file_path):
 | 
				
			||||||
 | 
					            os.remove(file_path)
 | 
				
			||||||
 | 
					            print(f"已清理文件: {file_path}")
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        print(f"文件清理失败: {str(e)}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    import uvicorn
 | 
				
			||||||
 | 
					    uvicorn.run(app, host="0.0.0.0", port=8088)
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user