Wan2.1/i2v_api.py

526 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import uuid
import time
import asyncio
import numpy as np
from threading import Lock
from typing import Optional, Dict, List
from fastapi import FastAPI, HTTPException, status, Depends
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, field_validator, ValidationError
from diffusers.utils import export_to_video, load_image
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from transformers import CLIPVisionModel
from PIL import Image
import requests
from io import BytesIO
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from contextlib import asynccontextmanager
from requests.exceptions import RequestException
# 创建存储目录
os.makedirs("generated_videos", exist_ok=True)
os.makedirs("temp_images", exist_ok=True)
# ======================
# 生命周期管理
# ======================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""资源管理器"""
try:
# 初始化认证系统
app.state.valid_api_keys = {
"密钥"
}
# 初始化模型
model_id = "./Wan2.1-I2V-14B-480P-Diffusers"
# 加载图像编码器
image_encoder = CLIPVisionModel.from_pretrained(
model_id,
subfolder="image_encoder",
torch_dtype=torch.float32
)
# 加载VAE
vae = AutoencoderKLWan.from_pretrained(
model_id,
subfolder="vae",
torch_dtype=torch.float32
)
# 配置调度器
scheduler = UniPCMultistepScheduler(
prediction_type='flow_prediction',
use_flow_sigmas=True,
num_train_timesteps=1000,
flow_shift=3.0
)
# 创建管道
app.state.pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
).to("cuda")
app.state.pipe.scheduler = scheduler
# 初始化任务系统
app.state.tasks: Dict[str, dict] = {}
app.state.pending_queue: List[str] = []
app.state.model_lock = Lock()
app.state.task_lock = Lock()
app.state.base_url = "ip地址+端口"
app.state.semaphore = asyncio.Semaphore(2) # 并发限制
# 启动后台处理器
asyncio.create_task(task_processor())
print("✅ 系统初始化完成")
yield
finally:
# 资源清理
if hasattr(app.state, 'pipe'):
del app.state.pipe
torch.cuda.empty_cache()
print("♻️ 资源已释放")
# ======================
# FastAPI应用
# ======================
app = FastAPI(lifespan=lifespan)
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
# 认证模块
security = HTTPBearer(auto_error=False)
# ======================
# 数据模型--查询参数模型
# ======================
class VideoSubmitRequest(BaseModel):
model: str = Field(
default="Wan2.1-I2V-14B-480P",
description="模型版本"
)
prompt: str = Field(
...,
min_length=10,
max_length=500,
description="视频描述提示词10-500个字符"
)
image_url: str = Field(
...,
description="输入图像URL需支持HTTP/HTTPS协议"
)
image_size: str = Field(
default="auto",
description="输出分辨率格式宽x高 或 auto自动计算"
)
negative_prompt: Optional[str] = Field(
default=None,
max_length=500,
description="排除不需要的内容"
)
seed: Optional[int] = Field(
default=None,
ge=0,
le=2147483647,
description="随机数种子范围0-2147483647"
)
num_frames: int = Field(
default=81,
ge=24,
le=120,
description="视频帧数24-89帧"
)
guidance_scale: float = Field(
default=3.0,
ge=1.0,
le=20.0,
description="引导系数1.0-20.0"
)
infer_steps: int = Field(
default=30,
ge=20,
le=100,
description="推理步数20-100步"
)
@field_validator('image_size')
def validate_image_size(cls, v):
allowed_sizes = {"480x832", "832x480", "auto"}
if v not in allowed_sizes:
raise ValueError(f"支持的分辨率: {', '.join(allowed_sizes)}")
return v
class VideoStatusRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
class VideoStatusResponse(BaseModel):
status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
reason: Optional[str] = Field(None, description="失败原因")
results: Optional[dict] = Field(None, description="生成结果")
queue_position: Optional[int] = Field(None, description="队列位置")
class VideoCancelRequest(BaseModel):
requestId: str = Field(
...,
min_length=32,
max_length=32,
description="32位任务ID"
)
# ======================
# 核心逻辑
# ======================
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""统一认证验证"""
if not credentials:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "缺少认证头"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.scheme != "Bearer":
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的认证方案"},
headers={"WWW-Authenticate": "Bearer"}
)
if credentials.credentials not in app.state.valid_api_keys:
raise HTTPException(
status_code=401,
detail={"status": "Failed", "reason": "无效的API密钥"},
headers={"WWW-Authenticate": "Bearer"}
)
return True
async def task_processor():
"""任务处理器"""
while True:
async with app.state.semaphore:
task_id = await get_next_task()
if task_id:
await process_task(task_id)
else:
await asyncio.sleep(0.5)
async def get_next_task():
"""获取下一个任务"""
with app.state.task_lock:
return app.state.pending_queue.pop(0) if app.state.pending_queue else None
async def process_task(task_id: str):
"""处理单个任务"""
task = app.state.tasks.get(task_id)
if not task:
return
try:
# 更新任务状态
task['status'] = 'InProgress'
task['started_at'] = int(time.time())
print(task['request'].image_url)
# 下载输入图像
image = await download_image(task['request'].image_url)
image_path = f"temp_images/{task_id}.jpg"
image.save(image_path)
# 生成视频
video_path = await generate_video(task['request'], task_id, image)
# 生成下载链接
download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
# 更新任务状态
task.update({
'status': 'Succeed',
'download_url': download_url,
'completed_at': int(time.time())
})
# 安排清理
asyncio.create_task(cleanup_files([image_path, video_path]))
except Exception as e:
handle_task_error(task, e)
def handle_task_error(task: dict, error: Exception):
"""错误处理(包含详细错误信息)"""
error_msg = str(error)
# 1. 显存不足错误
if isinstance(error, torch.cuda.OutOfMemoryError):
error_msg = "显存不足,请降低分辨率"
# 2. 网络请求相关错误
elif isinstance(error, (RequestException, HTTPException)):
# 从异常中提取具体信息
if isinstance(error, HTTPException):
# 如果是 HTTPException获取其 detail 字段
error_detail = getattr(error, "detail", "")
error_msg = f"图像下载失败: {error_detail}"
elif isinstance(error, Timeout):
error_msg = "图像下载超时,请检查网络"
elif isinstance(error, ConnectionError):
error_msg = "无法连接到服务器,请检查 URL"
elif isinstance(error, HTTPError):
# requests 的 HTTPError例如 4xx/5xx 状态码)
status_code = error.response.status_code
error_msg = f"服务器返回错误状态码: {status_code}"
else:
# 其他 RequestException 错误
error_msg = f"图像下载失败: {str(error)}"
# 3. 其他未知错误
else:
error_msg = f"未知错误: {str(error)}"
# 更新任务状态
task.update({
'status': 'Failed',
'reason': error_msg,
'completed_at': int(time.time())
})
# ======================
# 视频生成逻辑
# ======================
async def download_image(url: str) -> Image.Image:
"""异步下载图像(包含详细错误信息)"""
loop = asyncio.get_event_loop()
try:
response = await loop.run_in_executor(
None,
lambda: requests.get(url) # 将 timeout 传递给 requests.get
)
# 如果状态码非 200主动抛出 HTTPException
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"服务器返回状态码 {response.status_code}"
)
return Image.open(BytesIO(response.content)).convert("RGB")
except RequestException as e:
# 将原始 requests 错误信息抛出
raise HTTPException(
status_code=500,
detail=f"请求失败: {str(e)}"
)
async def generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
"""异步生成入口"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
sync_generate_video,
request,
task_id,
image
)
def sync_generate_video(request: VideoSubmitRequest, task_id: str, image: Image.Image):
"""同步生成核心"""
with app.state.model_lock:
try:
# 解析分辨率
mod_value = 16 # 模型要求的模数
print(request.image_size)
print('--------------------------------')
if request.image_size == "auto":
# 原版自动计算逻辑
aspect_ratio = image.height / image.width
print(image.height,image.width)
max_area = 399360 # 模型基础分辨率
# 计算理想尺寸
height = round(np.sqrt(max_area * aspect_ratio))
width = round(np.sqrt(max_area / aspect_ratio))
# 应用模数调整
height = height // mod_value * mod_value
width = width // mod_value * mod_value
resized_image = image.resize((width, height))
else:
width_str, height_str = request.image_size.split('x')
width = int(width_str)
height = int(height_str)
mod_value = 16
# 调整图像尺寸
resized_image = image.resize((width, height))
# 设置随机种子
generator = None
# 修改点1: 使用属性访问seed
if request.seed is not None:
generator = torch.Generator(device="cuda")
generator.manual_seed(request.seed) # 修改点2
print(f"🔮 使用随机种子: {request.seed}")
print(resized_image)
print(height,width)
# 执行推理
output = app.state.pipe(
image=resized_image,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
height=height,
width=width,
num_frames=request.num_frames,
guidance_scale=request.guidance_scale,
num_inference_steps=request.infer_steps,
generator=generator
).frames[0]
# 导出视频
video_id = uuid.uuid4().hex
output_path = f"generated_videos/{video_id}.mp4"
export_to_video(output, output_path, fps=16)
return output_path
except Exception as e:
raise RuntimeError(f"视频生成失败: {str(e)}") from e
# ======================
# API端点
# ======================
@app.post("/video/submit",
response_model=dict,
status_code=status.HTTP_202_ACCEPTED,
tags=["视频生成"])
async def submit_task(
request: VideoSubmitRequest,
auth: bool = Depends(verify_auth)
):
"""提交生成任务"""
# 参数验证
if request.image_url is None:
raise HTTPException(
status_code=422,
detail={"status": "Failed", "reason": "需要图像URL参数"}
)
# 创建任务记录
task_id = uuid.uuid4().hex
with app.state.task_lock:
app.state.tasks[task_id] = {
"request": request,
"status": "InQueue",
"created_at": int(time.time())
}
app.state.pending_queue.append(task_id)
return {"requestId": task_id}
@app.post("/video/status",
response_model=VideoStatusResponse,
tags=["视频生成"])
async def get_status(
request: VideoStatusRequest,
auth: bool = Depends(verify_auth)
):
"""查询任务状态"""
task = app.state.tasks.get(request.requestId)
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
# 计算队列位置(仅当在队列中时)
queue_pos = 0
if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
queue_pos = app.state.pending_queue.index(request.requestId) + 1
response = {
"status": task['status'],
"reason": task.get('reason'),
"queue_position": queue_pos if task['status'] == "InQueue" else None # 非排队状态返回null
}
# 成功状态的特殊处理
if task['status'] == "Succeed":
response["results"] = {
"videos": [{"url": task['download_url']}],
"timings": {
"inference": task['completed_at'] - task['started_at']
},
"seed": task['request'].seed
}
# 取消状态的补充信息
elif task['status'] == "Cancelled":
response["reason"] = task.get('reason', "用户主动取消") # 确保原因字段存在
return response
@app.post("/video/cancel",
response_model=dict,
tags=["视频生成"])
async def cancel_task(
request: VideoCancelRequest,
auth: bool = Depends(verify_auth)
):
"""取消排队中的生成任务"""
task_id = request.requestId
with app.state.task_lock:
task = app.state.tasks.get(task_id)
# 检查任务是否存在
if not task:
raise HTTPException(
status_code=404,
detail={"status": "Failed", "reason": "无效的任务ID"}
)
current_status = task['status']
# 仅允许取消排队中的任务
if current_status != "InQueue":
raise HTTPException(
status_code=400,
detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
)
# 从队列移除
try:
app.state.pending_queue.remove(task_id)
except ValueError:
pass # 可能已被处理
# 更新任务状态
task.update({
"status": "Cancelled",
"reason": "用户主动取消",
"completed_at": int(time.time())
})
return {"status": "Succeed"}
async def cleanup_files(paths: List[str], delay: int = 3600):
"""定时清理文件"""
await asyncio.sleep(delay)
for path in paths:
try:
if os.path.exists(path):
os.remove(path)
except Exception as e:
print(f"清理失败 {path}: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8088)