Added the service deployment code of FastAPI, including key authentication, task submission, task details viewing, and task canceling

This commit is contained in:
knoka812 2025-04-09 15:37:10 +08:00
parent 679ccc6c68
commit 0961b7b888
4 changed files with 1260 additions and 0 deletions

151
I2V-FastAPI文档.md Normal file
View File

@ -0,0 +1,151 @@
# 图像到视频生成服务API文档
## 一、功能概述
基于Wan2.1-I2V-14B-480P模型实现图像到视频生成核心功能包括
1. **异步任务队列**支持多任务排队和并发控制最大2个并行任务
2. **智能分辨率适配**
- 支持自动计算最佳分辨率(保持原图比例)
- 支持手动指定分辨率480x832/832x480
3. **资源管理**
- 显存优化bfloat16精度
- 生成文件自动清理默认1小时
4. **安全认证**基于API Key的Bearer Token验证
5. **任务控制**:支持任务提交/状态查询/取消操作
技术栈:
- FastAPI框架
- CUDA加速
- 异步任务处理
- Diffusers推理库
---
## 二、接口说明
### 1. 提交生成任务
**POST /video/submit**
```json
{
"model": "Wan2.1-I2V-14B-480P",
"prompt": "A dancing cat in the style of Van Gogh",
"image_url": "https://example.com/input.jpg",
"image_size": "auto",
"num_frames": 81,
"guidance_scale": 3.0,
"infer_steps": 30
}
```
**响应示例**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
### 2. 查询任务状态
**POST /video/status**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
**响应示例**
```json
{
"status": "Succeed",
"results": {
"videos": [{"url": "http://localhost:8088/videos/abcd1234.mp4"}],
"timings": {"inference": 90},
"seed": 123456
}
}
```
### 3. 取消任务
**POST /video/cancel**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
**响应示例**
```json
{
"status": "Succeed"
}
```
---
## 三、Postman使用指南
### 1. 基础配置
- 服务器地址:`http://ip地址:8088`
- 认证方式Bearer Token
- Token值需替换为有效API Key
### 2. 提交任务
1. 选择POST方法URL填写`/video/submit`
2. Headers添加
```text
Authorization: Bearer YOUR_API_KEY
Content-Type: application/json
```
3. Body示例图像生成视频
```json
{
"prompt": "Sunset scene with mountains",
"image_url": "https://example.com/mountain.jpg",
"image_size": "auto",
"num_frames": 50
}
```
### 3. 特殊处理
- **图像下载失败**返回400错误包含具体原因如URL无效/超时)
- **显存不足**返回500错误并提示降低分辨率
---
## 四、参数规范
| 参数名 | 允许值范围 | 必填 | 说明 |
|------------------|-------------------------------|------|------------------------------------------|
| image_url | 有效HTTP/HTTPS URL | 是 | 输入图像地址 |
| prompt | 10-500字符 | 是 | 视频内容描述 |
| image_size | "480x832", "832x480", "auto" | 是 | auto模式自动适配原图比例 |
| num_frames | 24-120 | 是 | 视频总帧数 |
| guidance_scale | 1.0-20.0 | 是 | 文本引导强度 |
| infer_steps | 20-100 | 是 | 推理步数 |
| seed | 0-2147483647 | 否 | 随机种子 |
---
## 五、状态码说明
| 状态码 | 含义 |
|--------|-----------------------------------|
| 202 | 任务已接受 |
| 400 | 图像下载失败/参数错误 |
| 401 | 认证失败 |
| 404 | 任务不存在 |
| 422 | 参数校验失败 |
| 500 | 服务端错误(显存不足/模型异常等) |
---
## 六、特殊功能说明
1. **智能分辨率适配**
- 当`image_size="auto"`时,自动计算符合模型要求的最优分辨率
- 保持原始图像宽高比最大像素面积不超过399,360约640x624
2. **图像预处理**
- 自动转换为RGB模式
- 根据目标分辨率进行等比缩放
**重要提示**输入图像URL需保证公开可访问私有资源需提供有效鉴权
**提示** :访问`http://服务器地址:8088/docs`可查看交互式API文档支持在线测试所有接口

133
T2V-FastAPI文档.md Normal file
View File

@ -0,0 +1,133 @@
# 视频生成服务API文档
## 一、功能概述
本服务基于Wan2.1-T2V-1.3B模型实现文本到视频生成,包含以下核心功能:
1. **异步任务队列**支持多任务排队和并发控制最大2个并行任务
2. **资源管理**
- 显存优化使用bfloat16精度
- 生成视频自动清理默认1小时后删除
3. **安全认证**基于API Key的Bearer Token验证
4. **任务控制**:支持任务提交/状态查询/取消操作
技术栈:
- FastAPI框架
- CUDA加速
- 异步任务处理
- Diffusers推理库
---
## 二、接口说明
### 1. 提交生成任务
**POST /video/submit**
```json
{
"model": "Wan2.1-T2V-1.3B",
"prompt": "A beautiful sunset over the mountains",
"image_size": "480x832",
"num_frames": 81,
"guidance_scale": 5.0,
"infer_steps": 50
}
```
**响应示例**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
### 2. 查询任务状态
**POST /video/status**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
**响应示例**
```json
{
"status": "Succeed",
"results": {
"videos": [{"url": "http://localhost:8088/videos/abcd1234.mp4"}],
"timings": {"inference": 120}
}
}
```
### 3. 取消任务
**POST /video/cancel**
```json
{
"requestId": "a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5p6"
}
```
**响应示例**
```json
{
"status": "Succeed"
}
```
---
## 三、Postman使用指南
### 1. 基础配置
- 服务器地址:`http://ip地址:8088`
- 认证方式Bearer Token
- Token值需替换为有效API Key
### 2. 提交任务
1. 选择POST方法输入URL`/video/submit`
2. Headers添加
```text
Authorization: Bearer YOUR_API_KEY
Content-Type: application/json
```
3. Body选择raw/JSON格式输入请求参数
### 3. 查询状态
1. 新建请求URL填写`/video/status`
2. 使用相同认证头
3. Body中携带requestId
### 4. 取消任务
1. 新建DELETE请求URL填写`/video/cancel`
2. Body携带需要取消的requestId
### 注意事项
1. 所有接口必须携带有效API Key
2. 视频生成耗时约2-5分钟根据参数配置
3. 生成视频默认保留1小时
---
## 四、参数规范
| 参数名 | 允许值范围 | 必填 | 说明 |
|------------------|-------------------------------|------|--------------------------|
| prompt | 10-500字符 | 是 | 视频内容描述 |
| image_size | "480x832" 或 "832x480" | 是 | 分辨率 |
| num_frames | 24-120 | 是 | 视频总帧数 |
| guidance_scale | 1.0-20.0 | 是 | 文本引导强度 |
| infer_steps | 20-100 | 是 | 推理步数 |
| seed | 0-2147483647 | 否 | 随机种子 |
---
## 五、状态码说明
| 状态码 | 含义 |
|--------|--------------------------|
| 202 | 任务已接受 |
| 401 | 认证失败 |
| 404 | 任务不存在 |
| 422 | 参数校验失败 |
| 500 | 服务端错误(显存不足等) |
**提示**建议使用Swagger文档进行接口测试访问`http://服务器地址:8088/docs`可查看自动生成的API文档界面

526
i2v_api.py Normal file
View File

@ -0,0 +1,526 @@
import os
import torch
import uuid
import time
import asyncio
import numpy as np
from threading import Lock
from typing import Optional, Dict, List
from fastapi import FastAPI, HTTPException, status, Depends
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, field_validator, ValidationError
from diffusers.utils import export_to_video, load_image
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from transformers import CLIPVisionModel
from PIL import Image
import requests
from io import BytesIO
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from contextlib import asynccontextmanager
from requests.exceptions import RequestException
# 创建存储目录
os.makedirs("generated_videos", exist_ok=True)
os.makedirs("temp_images", exist_ok=True)
# ======================
# 生命周期管理
# ======================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""资源管理器"""
try:
# 初始化认证系统
app.state.valid_api_keys = {
"密钥"
}
# 初始化模型
model_id = "./Wan2.1-I2V-14B-480P-Diffusers"
# 加载图像编码器
image_encoder = CLIPVisionModel.from_pretrained(
model_id,
subfolder="image_encoder",
torch_dtype=torch.float32
)
# 加载VAE
vae = AutoencoderKLWan.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=torch.float32
)
# 配置调度器
scheduler = UniPCMultistepScheduler(
prediction_type='flow_prediction',
use_flow_sigmas=True,
num_train_timesteps=1000,
flow_shift=3.0
)
# 创建管道
app.state.pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
).to("cuda")
app.state.pipe.scheduler = scheduler
# 初始化任务系统
app.state.tasks: Dict[str, dict] = {}
app.state.pending_queue: List[str] = []
app.state.model_lock = Lock()
app.state.task_lock = Lock()
app.state.base_url = "ip地址+端口"
app.state.semaphore = asyncio.Semaphore(2) # 并发限制
# 启动后台处理器
asyncio.create_task(task_processor())
print("✅ 系统初始化完成")
yield
finally:
# 资源清理
if hasattr(app.state, 'pipe'):
del app.state.pipe
torch.cuda.empty_cache()
print("♻️ 资源已释放")
# ======================
# FastAPI应用
# ======================
app = FastAPI(lifespan=lifespan)
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
# 认证模块
security = HTTPBearer(auto_error=False)
# ======================
# 数据模型--查询参数模型
# ======================
class VideoSubmitRequest(BaseModel):
model: str = Field(
default="Wan2.1-I2V-14B-480P",
description="模型版本"
)
prompt: str = Field(
...,
min_length=10,
max_length=500,
description="视频描述提示词10-500个字符"
)
image_url: str = Field(
...,
description="输入图像URL需支持HTTP/HTTPS协议"
)
image_size: str = Field(
default="auto",
description="输出分辨率格式宽x高 或 auto自动计算"
)
negative_prompt: Optional[str] = Field(
default=None,
max_length=500,
description="排除不需要的内容"
)
seed: Optional[int] = Field(
default=None,
ge=0,
le=2147483647,
description="随机数种子范围0-2147483647"
)
num_frames: int = Field(
default=81,
ge=24,
le=120,
description="视频帧数24-89帧"
)
guidance_scale: float = Field(
default=3.0,
ge=1.0,
le=20.0,
description="引导系数1.0-20.0"
)
infer_steps: int = Field(
default=30,
ge=20,
le=100,
description="推理步数20-100步"
)
@field_validator('image_size')
def validate_image_size(cls, v):
allowed_sizes = {"480x832", "832x480", "auto"}
if v not in allowed_sizes:
raise ValueError(f"支持的分辨率: {', '.join(allowed_sizes)}")
return v
class VideoStatusRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
class VideoStatusResponse(BaseModel):
status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
reason: Optional[str] = Field(None, description="失败原因")
results: Optional[dict] = Field(None, description="生成结果")
queue_position: Optional[int] = Field(None, description="队列位置")
class VideoCancelRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
# ======================
# 核心逻辑
# ======================
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""统一认证验证"""
if not credentials:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "缺少认证头"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的认证方案"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.credentials not in app.state.valid_api_keys:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的API密钥"},
headers={"WWW-Authenticate": "Bearer"}
)
return True
async def task_processor():
"""任务处理器"""
while True:
async with app.state.semaphore:
task_id = await get_next_task()
if task_id:
await process_task(task_id)
else:
await asyncio.sleep(0.5)
async def get_next_task():
"""获取下一个任务"""
with app.state.task_lock:
return app.state.pending_queue.pop(0) if app.state.pending_queue else None
async def process_task(task_id: str):
"""处理单个任务"""
task = app.state.tasks.get(task_id)
if not task:
return
try:
# 更新任务状态
task['status'] = 'InProgress'
task['started_at'] = int(time.time())
print(task['request'].image_url)
# 下载输入图像
image = await download_image(task['request'].image_url)
image_path = f"temp_images/{task_id}.jpg"
image.save(image_path)
# 生成视频
video_path = await generate_video(task['request'], task_id, image)
# 生成下载链接
download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
# 更新任务状态
task.update({
'status': 'Succeed',
'download_url': download_url,
'completed_at': int(time.time())
})
# 安排清理
asyncio.create_task(cleanup_files([image_path, video_path]))
except Exception as e:
handle_task_error(task, e)
def handle_task_error(task: dict, error: Exception):
"""错误处理(包含详细错误信息)"""
error_msg = str(error)
# 1. 显存不足错误
if isinstance(error, torch.cuda.OutOfMemoryError):
error_msg = "显存不足,请降低分辨率"
# 2. 网络请求相关错误
elif isinstance(error, (RequestException, HTTPException)):
# 从异常中提取具体信息
if isinstance(error, HTTPException):
# 如果是 HTTPException获取其 detail 字段
error_detail = getattr(error, "detail", "")
error_msg = f"图像下载失败: {error_detail}"
elif isinstance(error, Timeout):
error_msg = "图像下载超时,请检查网络"
elif isinstance(error, ConnectionError):
error_msg = "无法连接到服务器,请检查 URL"
elif isinstance(error, HTTPError):
# requests 的 HTTPError例如 4xx/5xx 状态码)
status_code = error.response.status_code
error_msg = f"服务器返回错误状态码: {status_code}"
else:
# 其他 RequestException 错误
error_msg = f"图像下载失败: {str(error)}"
# 3. 其他未知错误
else:
error_msg = f"未知错误: {str(error)}"
# 更新任务状态
task.update({
'status': 'Failed',
'reason': error_msg,
'completed_at': int(time.time())
})
# ======================
# 视频生成逻辑
# ======================
async def download_image(url: str) -> Image.Image:
"""异步下载图像(包含详细错误信息)"""
loop = asyncio.get_event_loop()
try:
response = await loop.run_in_executor(
None,
lambda: requests.get(url) # 将 timeout 传递给 requests.get
)
# 如果状态码非 200主动抛出 HTTPException
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"服务器返回状态码 {response.status_code}"
)
return Image.open(BytesIO(response.content)).convert("RGB")
except RequestException as e:
# 将原始 requests 错误信息抛出
raise HTTPException(
status_code=500,
detail=f"请求失败: {str(e)}"
)
async def generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
"""异步生成入口"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
sync_generate_video,
request,
task_id,
image
)
def sync_generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
"""同步生成核心"""
with app.state.model_lock:
try:
# 解析分辨率
mod_value = 16 # 模型要求的模数
print(request.image_size)
print('--------------------------------')
if request.image_size == "auto":
# 原版自动计算逻辑
aspect_ratio = image.height / image.width
print(image.height,image.width)
max_area = 399360 # 模型基础分辨率
# 计算理想尺寸
height = round(np.sqrt(max_area * aspect_ratio))
width = round(np.sqrt(max_area / aspect_ratio))
# 应用模数调整
height = height // mod_value * mod_value
width = width // mod_value * mod_value
resized_image = image.resize((width, height))
else:
width_str, height_str = request.image_size.split('x')
width = int(width_str)
height = int(height_str)
mod_value = 16
# 调整图像尺寸
resized_image = image.resize((width, height))
# 设置随机种子
generator = None
# 修改点1: 使用属性访问seed
if request.seed is not None:
generator = torch.Generator(device="cuda")
generator.manual_seed(request.seed) # 修改点2
print(f"🔮 使用随机种子: {request.seed}")
print(resized_image)
print(height,width)
# 执行推理
output = app.state.pipe(
image=resized_image,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
height=height,
width=width,
num_frames=request.num_frames,
guidance_scale=request.guidance_scale,
num_inference_steps=request.infer_steps,
generator=generator
).frames[0]
# 导出视频
video_id = uuid.uuid4().hex
output_path = f"generated_videos/{video_id}.mp4"
export_to_video(output, output_path, fps=16)
return output_path
except Exception as e:
raise RuntimeError(f"视频生成失败: {str(e)}") from e
# ======================
# API端点
# ======================
@app.post("/video/submit",
response_model=dict,
status_code=status.HTTP_202_ACCEPTED,
tags=["视频生成"])
async def submit_task(
request: VideoSubmitRequest,
auth: bool = Depends(verify_auth)
):
"""提交生成任务"""
# 参数验证
if request.image_url is None:
raise HTTPException(
status_code=422,
detail={"status": "Failed", "reason": "需要图像URL参数"}
)
# 创建任务记录
task_id = uuid.uuid4().hex
with app.state.task_lock:
app.state.tasks[task_id] = {
"request": request,
"status": "InQueue",
"created_at": int(time.time())
}
app.state.pending_queue.append(task_id)
return {"requestId": task_id}
@app.post("/video/status",
response_model=VideoStatusResponse,
tags=["视频生成"])
async def get_status(
request: VideoStatusRequest,
auth: bool = Depends(verify_auth)
):
"""查询任务状态"""
task = app.state.tasks.get(request.requestId)
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
# 计算队列位置(仅当在队列中时)
queue_pos = 0
if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
queue_pos = app.state.pending_queue.index(request.requestId) + 1
response = {
"status": task['status'],
"reason": task.get('reason'),
"queue_position": queue_pos if task['status'] == "InQueue" else None # 非排队状态返回null
}
# 成功状态的特殊处理
if task['status'] == "Succeed":
response["results"] = {
"videos": [{"url": task['download_url']}],
"timings": {
"inference": task['completed_at'] - task['started_at']
},
"seed": task['request'].seed
}
# 取消状态的补充信息
elif task['status'] == "Cancelled":
response["reason"] = task.get('reason', "用户主动取消") # 确保原因字段存在
return response
@app.post("/video/cancel",
response_model=dict,
tags=["视频生成"])
async def cancel_task(
request: VideoCancelRequest,
auth: bool = Depends(verify_auth)
):
"""取消排队中的生成任务"""
task_id = request.requestId
with app.state.task_lock:
task = app.state.tasks.get(task_id)
# 检查任务是否存在
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
current_status = task['status']
# 仅允许取消排队中的任务
if current_status != "InQueue":
raise HTTPException(
status_code=400,
detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
)
# 从队列移除
try:
app.state.pending_queue.remove(task_id)
except ValueError:
pass # 可能已被处理
# 更新任务状态
task.update({
"status": "Cancelled",
"reason": "用户主动取消",
"completed_at": int(time.time())
})
return {"status": "Succeed"}
async def cleanup_files(paths: List[str], delay: int = 3600):
"""定时清理文件"""
await asyncio.sleep(delay)
for path in paths:
try:
if os.path.exists(path):
os.remove(path)
except Exception as e:
print(f"清理失败 {path}: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8088)

450
t2v-api.py Normal file
View File

@ -0,0 +1,450 @@
import os
import torch
import uuid
import time
import asyncio
from enum import Enum
from threading import Lock
from typing import Optional, Dict, List
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, status, Depends
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, field_validator, ValidationError
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import JSONResponse
# 创建视频存储目录
os.makedirs("generated_videos", exist_ok=True)
# 生命周期管理器
@asynccontextmanager
async def lifespan(app: FastAPI):
"""管理应用生命周期"""
# 初始化模型和资源
try:
# 初始化认证密钥
app.state.valid_api_keys = {
"密钥"
}
# 初始化视频生成模型
model_id = "./Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=torch.float32
)
scheduler = UniPCMultistepScheduler(
prediction_type='flow_prediction',
use_flow_sigmas=True,
num_train_timesteps=1000,
flow_shift=3.0
)
app.state.pipe = WanPipeline.from_pretrained(
model_id,
vae=vae,
torch_dtype=torch.bfloat16
).to("cuda")
app.state.pipe.scheduler = scheduler
# 初始化任务系统
app.state.tasks: Dict[str, dict] = {}
app.state.pending_queue: List[str] = []
app.state.model_lock = Lock()
app.state.task_lock = Lock()
app.state.base_url = "ip地址+端口"
app.state.max_concurrent = 2
app.state.semaphore = asyncio.Semaphore(app.state.max_concurrent)
# 启动后台任务处理器
asyncio.create_task(task_processor())
print("✅ 应用初始化完成")
yield
finally:
# 清理资源
if hasattr(app.state, 'pipe'):
del app.state.pipe
torch.cuda.empty_cache()
print("♻️ 已释放模型资源")
# 创建FastAPI应用
app = FastAPI(lifespan=lifespan)
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
# 认证模块
security = HTTPBearer(auto_error=False)
# ======================
# 数据模型--查询参数模型
# ======================
class VideoSubmitRequest(BaseModel):
model: str = Field(default="Wan2.1-T2V-1.3B",description="使用的模型版本")
prompt: str = Field(
...,
min_length=10,
max_length=500,
description="视频描述提示词10-500个字符"
)
image_size: str = Field(
...,
description="视频分辨率仅支持480x832或832x480"
)
negative_prompt: Optional[str] = Field(
default=None,
max_length=500,
description="排除不需要的内容"
)
seed: Optional[int] = Field(
default=None,
ge=0,
le=2147483647,
description="随机数种子范围0-2147483647"
)
num_frames: int = Field(
default=81,
ge=24,
le=120,
description="视频帧数24-120帧"
)
guidance_scale: float = Field(
default=5.0,
ge=1.0,
le=20.0,
description="引导系数1.0-20.0"
)
infer_steps: int = Field(
default=50,
ge=20,
le=100,
description="推理步数20-100步"
)
@field_validator('image_size', mode='before')
@classmethod
def validate_image_size(cls, value):
allowed_sizes = {"480x832", "832x480"}
if value not in allowed_sizes:
raise ValueError(f"仅支持以下分辨率: {', '.join(allowed_sizes)}")
return value
class VideoStatusRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
class VideoSubmitResponse(BaseModel):
requestId: str
class VideoStatusResponse(BaseModel):
status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
reason: Optional[str] = Field(None, description="失败原因")
results: Optional[dict] = Field(None, description="生成结果")
queue_position: Optional[int] = Field(None, description="队列位置")
class VideoCancelRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
# # 自定义HTTP异常处理器
# @app.exception_handler(HTTPException)
# async def http_exception_handler(request, exc):
# return JSONResponse(
# status_code=exc.status_code,
# content=exc.detail, # 直接返回detail内容不再包装在detail字段
# headers=exc.headers
# )
# ======================
# 后台任务处理
# ======================
async def task_processor():
"""处理任务队列"""
while True:
async with app.state.semaphore:
task_id = await get_next_task()
if task_id:
await process_task(task_id)
else:
await asyncio.sleep(0.5)
async def get_next_task():
"""获取下一个待处理任务"""
with app.state.task_lock:
if app.state.pending_queue:
return app.state.pending_queue.pop(0)
return None
async def process_task(task_id: str):
"""处理单个任务"""
task = app.state.tasks.get(task_id)
if not task:
return
try:
# 更新任务状态
task['status'] = 'InProgress'
task['started_at'] = int(time.time())
# 执行视频生成
video_path = await generate_video(task['request'], task_id)
# 生成下载链接
download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
# 更新任务状态
task.update({
'status': 'Succeed',
'download_url': download_url,
'completed_at': int(time.time())
})
# 安排自动清理
asyncio.create_task(auto_cleanup(video_path))
except Exception as e:
handle_task_error(task, e)
def handle_task_error(task: dict, error: Exception):
"""统一处理任务错误"""
error_msg = str(error)
if isinstance(error, torch.cuda.OutOfMemoryError):
error_msg = "显存不足,请降低分辨率或减少帧数"
elif isinstance(error, ValidationError):
error_msg = "参数校验失败: " + str(error)
task.update({
'status': 'Failed',
'reason': error_msg,
'completed_at': int(time.time())
})
# ======================
# 视频生成核心逻辑
# ======================
async def generate_video(request: dict, task_id: str) -> str:
"""异步执行视频生成"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
sync_generate_video,
request,
task_id
)
def sync_generate_video(request: dict, task_id: str) -> str:
"""同步生成视频"""
with app.state.model_lock:
try:
generator = None
if request.get('seed') is not None:
generator = torch.Generator(device="cuda")
generator.manual_seed(request['seed'])
print(f"🔮 使用随机种子: {request['seed']}")
# 执行模型推理
result = app.state.pipe(
prompt=request['prompt'],
negative_prompt=request['negative_prompt'],
height=request['height'],
width=request['width'],
num_frames=request['num_frames'],
guidance_scale=request['guidance_scale'],
num_inference_steps=request['infer_steps'],
generator=generator
)
# 导出视频文件
video_id = uuid.uuid4().hex
output_path = f"generated_videos/{video_id}.mp4"
export_to_video(result.frames[0], output_path, fps=16)
return output_path
except Exception as e:
raise RuntimeError(f"视频生成失败: {str(e)}") from e
# ======================
# API端点
# ======================
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""认证验证"""
if not credentials:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "缺少认证头"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的认证方案"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.credentials not in app.state.valid_api_keys:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的API密钥"},
headers={"WWW-Authenticate": "Bearer"}
)
return True
@app.post("/video/submit",
response_model=VideoSubmitResponse,
status_code=status.HTTP_202_ACCEPTED,
tags=["视频生成"],
summary="提交视频生成请求")
async def submit_video_task(
request: VideoSubmitRequest,
auth: bool = Depends(verify_auth)
):
"""提交新的视频生成任务"""
try:
# 解析分辨率参数
width, height = map(int, request.image_size.split('x'))
# 创建任务记录
task_id = uuid.uuid4().hex
task_data = {
'request': {
'prompt': request.prompt,
'negative_prompt': request.negative_prompt,
'width': width,
'height': height,
'num_frames': request.num_frames,
'guidance_scale': request.guidance_scale,
'infer_steps': request.infer_steps,
'seed': request.seed
},
'status': 'InQueue',
'created_at': int(time.time())
}
# 加入任务队列
with app.state.task_lock:
app.state.tasks[task_id] = task_data
app.state.pending_queue.append(task_id)
return {"requestId": task_id}
except ValidationError as e:
raise HTTPException(
status_code=422,
detail={"status": "Failed", "reason": str(e)}
)
@app.post("/video/status",
response_model=VideoStatusResponse,
tags=["视频生成"],
summary="查询任务状态")
async def get_video_status(
request: VideoStatusRequest,
auth: bool = Depends(verify_auth)
):
"""查询任务状态"""
task = app.state.tasks.get(request.requestId)
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
# 计算队列位置(仅当在队列中时)
queue_pos = 0
if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
queue_pos = app.state.pending_queue.index(request.requestId) + 1
response = {
"status": task['status'],
"reason": task.get('reason'),
"queue_position": queue_pos if task['status'] == "InQueue" else None # 非排队状态返回null
}
# 成功状态的特殊处理
if task['status'] == "Succeed":
response["results"] = {
"videos": [{"url": task['download_url']}],
"timings": {
"inference": task['completed_at'] - task['started_at']
},
"seed": task['request']['seed']
}
# 取消状态的补充信息
elif task['status'] == "Cancelled":
response["reason"] = task.get('reason', "用户主动取消") # 确保原因字段存在
return response
@app.post("/video/cancel",
response_model=dict,
tags=["视频生成"])
async def cancel_task(
request: VideoCancelRequest,
auth: bool = Depends(verify_auth)
):
"""取消排队中的生成任务"""
task_id = request.requestId
with app.state.task_lock:
task = app.state.tasks.get(task_id)
# 检查任务是否存在
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
current_status = task['status']
# 仅允许取消排队中的任务
if current_status != "InQueue":
raise HTTPException(
status_code=400,
detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
)
# 从队列移除
try:
app.state.pending_queue.remove(task_id)
except ValueError:
pass # 可能已被处理
# 更新任务状态
task.update({
"status": "Cancelled",
"reason": "用户主动取消",
"completed_at": int(time.time())
})
return {"status": "Succeed"}
# ======================
# 工具函数
# ======================
async def auto_cleanup(file_path: str, delay: int = 3600):
"""自动清理生成的视频文件"""
await asyncio.sleep(delay)
try:
if os.path.exists(file_path):
os.remove(file_path)
print(f"已清理文件: {file_path}")
except Exception as e:
print(f"文件清理失败: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8088)