Wan2.1/t2v-api.py

450 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import uuid
import time
import asyncio
from enum import Enum
from threading import Lock
from typing import Optional, Dict, List
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, status, Depends
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, field_validator, ValidationError
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.responses import JSONResponse
# 创建视频存储目录
os.makedirs("generated_videos", exist_ok=True)
# 生命周期管理器
@asynccontextmanager
async def lifespan(app: FastAPI):
"""管理应用生命周期"""
# 初始化模型和资源
try:
# 初始化认证密钥
app.state.valid_api_keys = {
"密钥"
}
# 初始化视频生成模型
model_id = "./Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=torch.float32
)
scheduler = UniPCMultistepScheduler(
prediction_type='flow_prediction',
use_flow_sigmas=True,
num_train_timesteps=1000,
flow_shift=3.0
)
app.state.pipe = WanPipeline.from_pretrained(
model_id,
vae=vae,
torch_dtype=torch.bfloat16
).to("cuda")
app.state.pipe.scheduler = scheduler
# 初始化任务系统
app.state.tasks: Dict[str, dict] = {}
app.state.pending_queue: List[str] = []
app.state.model_lock = Lock()
app.state.task_lock = Lock()
app.state.base_url = "ip地址+端口"
app.state.max_concurrent = 2
app.state.semaphore = asyncio.Semaphore(app.state.max_concurrent)
# 启动后台任务处理器
asyncio.create_task(task_processor())
print("✅ 应用初始化完成")
yield
finally:
# 清理资源
if hasattr(app.state, 'pipe'):
del app.state.pipe
torch.cuda.empty_cache()
print("♻️ 已释放模型资源")
# 创建FastAPI应用
app = FastAPI(lifespan=lifespan)
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
# 认证模块
security = HTTPBearer(auto_error=False)
# ======================
# 数据模型--查询参数模型
# ======================
class VideoSubmitRequest(BaseModel):
model: str = Field(default="Wan2.1-T2V-1.3B",description="使用的模型版本")
prompt: str = Field(
...,
min_length=10,
max_length=500,
description="视频描述提示词10-500个字符"
)
image_size: str = Field(
...,
description="视频分辨率仅支持480x832或832x480"
)
negative_prompt: Optional[str] = Field(
default=None,
max_length=500,
description="排除不需要的内容"
)
seed: Optional[int] = Field(
default=None,
ge=0,
le=2147483647,
description="随机数种子范围0-2147483647"
)
num_frames: int = Field(
default=81,
ge=24,
le=120,
description="视频帧数24-120帧"
)
guidance_scale: float = Field(
default=5.0,
ge=1.0,
le=20.0,
description="引导系数1.0-20.0"
)
infer_steps: int = Field(
default=50,
ge=20,
le=100,
description="推理步数20-100步"
)
@field_validator('image_size', mode='before')
@classmethod
def validate_image_size(cls, value):
allowed_sizes = {"480x832", "832x480"}
if value not in allowed_sizes:
raise ValueError(f"仅支持以下分辨率: {', '.join(allowed_sizes)}")
return value
class VideoStatusRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
class VideoSubmitResponse(BaseModel):
requestId: str
class VideoStatusResponse(BaseModel):
status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
reason: Optional[str] = Field(None, description="失败原因")
results: Optional[dict] = Field(None, description="生成结果")
queue_position: Optional[int] = Field(None, description="队列位置")
class VideoCancelRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
# # 自定义HTTP异常处理器
# @app.exception_handler(HTTPException)
# async def http_exception_handler(request, exc):
# return JSONResponse(
# status_code=exc.status_code,
# content=exc.detail, # 直接返回detail内容不再包装在detail字段
# headers=exc.headers
# )
# ======================
# 后台任务处理
# ======================
async def task_processor():
"""处理任务队列"""
while True:
async with app.state.semaphore:
task_id = await get_next_task()
if task_id:
await process_task(task_id)
else:
await asyncio.sleep(0.5)
async def get_next_task():
"""获取下一个待处理任务"""
with app.state.task_lock:
if app.state.pending_queue:
return app.state.pending_queue.pop(0)
return None
async def process_task(task_id: str):
"""处理单个任务"""
task = app.state.tasks.get(task_id)
if not task:
return
try:
# 更新任务状态
task['status'] = 'InProgress'
task['started_at'] = int(time.time())
# 执行视频生成
video_path = await generate_video(task['request'], task_id)
# 生成下载链接
download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
# 更新任务状态
task.update({
'status': 'Succeed',
'download_url': download_url,
'completed_at': int(time.time())
})
# 安排自动清理
asyncio.create_task(auto_cleanup(video_path))
except Exception as e:
handle_task_error(task, e)
def handle_task_error(task: dict, error: Exception):
"""统一处理任务错误"""
error_msg = str(error)
if isinstance(error, torch.cuda.OutOfMemoryError):
error_msg = "显存不足,请降低分辨率或减少帧数"
elif isinstance(error, ValidationError):
error_msg = "参数校验失败: " + str(error)
task.update({
'status': 'Failed',
'reason': error_msg,
'completed_at': int(time.time())
})
# ======================
# 视频生成核心逻辑
# ======================
async def generate_video(request: dict, task_id: str) -> str:
"""异步执行视频生成"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
sync_generate_video,
request,
task_id
)
def sync_generate_video(request: dict, task_id: str) -> str:
"""同步生成视频"""
with app.state.model_lock:
try:
generator = None
if request.get('seed') is not None:
generator = torch.Generator(device="cuda")
generator.manual_seed(request['seed'])
print(f"🔮 使用随机种子: {request['seed']}")
# 执行模型推理
result = app.state.pipe(
prompt=request['prompt'],
negative_prompt=request['negative_prompt'],
height=request['height'],
width=request['width'],
num_frames=request['num_frames'],
guidance_scale=request['guidance_scale'],
num_inference_steps=request['infer_steps'],
generator=generator
)
# 导出视频文件
video_id = uuid.uuid4().hex
output_path = f"generated_videos/{video_id}.mp4"
export_to_video(result.frames[0], output_path, fps=16)
return output_path
except Exception as e:
raise RuntimeError(f"视频生成失败: {str(e)}") from e
# ======================
# API端点
# ======================
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""认证验证"""
if not credentials:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "缺少认证头"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的认证方案"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.credentials not in app.state.valid_api_keys:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的API密钥"},
headers={"WWW-Authenticate": "Bearer"}
)
return True
@app.post("/video/submit",
response_model=VideoSubmitResponse,
status_code=status.HTTP_202_ACCEPTED,
tags=["视频生成"],
summary="提交视频生成请求")
async def submit_video_task(
request: VideoSubmitRequest,
auth: bool = Depends(verify_auth)
):
"""提交新的视频生成任务"""
try:
# 解析分辨率参数
width, height = map(int, request.image_size.split('x'))
# 创建任务记录
task_id = uuid.uuid4().hex
task_data = {
'request': {
'prompt': request.prompt,
'negative_prompt': request.negative_prompt,
'width': width,
'height': height,
'num_frames': request.num_frames,
'guidance_scale': request.guidance_scale,
'infer_steps': request.infer_steps,
'seed': request.seed
},
'status': 'InQueue',
'created_at': int(time.time())
}
# 加入任务队列
with app.state.task_lock:
app.state.tasks[task_id] = task_data
app.state.pending_queue.append(task_id)
return {"requestId": task_id}
except ValidationError as e:
raise HTTPException(
status_code=422,
detail={"status": "Failed", "reason": str(e)}
)
@app.post("/video/status",
response_model=VideoStatusResponse,
tags=["视频生成"],
summary="查询任务状态")
async def get_video_status(
request: VideoStatusRequest,
auth: bool = Depends(verify_auth)
):
"""查询任务状态"""
task = app.state.tasks.get(request.requestId)
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
# 计算队列位置(仅当在队列中时)
queue_pos = 0
if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
queue_pos = app.state.pending_queue.index(request.requestId) + 1
response = {
"status": task['status'],
"reason": task.get('reason'),
"queue_position": queue_pos if task['status'] == "InQueue" else None # 非排队状态返回null
}
# 成功状态的特殊处理
if task['status'] == "Succeed":
response["results"] = {
"videos": [{"url": task['download_url']}],
"timings": {
"inference": task['completed_at'] - task['started_at']
},
"seed": task['request']['seed']
}
# 取消状态的补充信息
elif task['status'] == "Cancelled":
response["reason"] = task.get('reason', "用户主动取消") # 确保原因字段存在
return response
@app.post("/video/cancel",
response_model=dict,
tags=["视频生成"])
async def cancel_task(
request: VideoCancelRequest,
auth: bool = Depends(verify_auth)
):
"""取消排队中的生成任务"""
task_id = request.requestId
with app.state.task_lock:
task = app.state.tasks.get(task_id)
# 检查任务是否存在
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
current_status = task['status']
# 仅允许取消排队中的任务
if current_status != "InQueue":
raise HTTPException(
status_code=400,
detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
)
# 从队列移除
try:
app.state.pending_queue.remove(task_id)
except ValueError:
pass # 可能已被处理
# 更新任务状态
task.update({
"status": "Cancelled",
"reason": "用户主动取消",
"completed_at": int(time.time())
})
return {"status": "Succeed"}
# ======================
# 工具函数
# ======================
async def auto_cleanup(file_path: str, delay: int = 3600):
"""自动清理生成的视频文件"""
await asyncio.sleep(delay)
try:
if os.path.exists(file_path):
os.remove(file_path)
print(f"已清理文件: {file_path}")
except Exception as e:
print(f"文件清理失败: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8088)