mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			450 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			450 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						||
import torch
 | 
						||
import uuid
 | 
						||
import time
 | 
						||
import asyncio
 | 
						||
from enum import Enum
 | 
						||
from threading import Lock
 | 
						||
from typing import Optional, Dict, List
 | 
						||
from contextlib import asynccontextmanager
 | 
						||
from fastapi import FastAPI, HTTPException, status, Depends
 | 
						||
from fastapi.staticfiles import StaticFiles
 | 
						||
from pydantic import BaseModel, Field, field_validator, ValidationError
 | 
						||
from diffusers.utils import export_to_video
 | 
						||
from diffusers import AutoencoderKLWan, WanPipeline
 | 
						||
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
 | 
						||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 | 
						||
from fastapi.responses import JSONResponse
 | 
						||
 | 
						||
# 创建视频存储目录
 | 
						||
os.makedirs("generated_videos", exist_ok=True)
 | 
						||
 | 
						||
# 生命周期管理器
 | 
						||
@asynccontextmanager
 | 
						||
async def lifespan(app: FastAPI):
 | 
						||
    """管理应用生命周期"""
 | 
						||
    # 初始化模型和资源
 | 
						||
    try:
 | 
						||
        # 初始化认证密钥
 | 
						||
        app.state.valid_api_keys = {
 | 
						||
            "密钥"
 | 
						||
        }
 | 
						||
 | 
						||
        # 初始化视频生成模型
 | 
						||
        model_id = "./Wan2.1-T2V-1.3B-Diffusers"
 | 
						||
        vae = AutoencoderKLWan.from_pretrained(
 | 
						||
            model_id,
 | 
						||
            subfolder="vae",
 | 
						||
            torch_dtype=torch.float32
 | 
						||
        )
 | 
						||
      
 | 
						||
        scheduler = UniPCMultistepScheduler(
 | 
						||
            prediction_type='flow_prediction',
 | 
						||
            use_flow_sigmas=True,
 | 
						||
            num_train_timesteps=1000,
 | 
						||
            flow_shift=3.0
 | 
						||
        )
 | 
						||
      
 | 
						||
        app.state.pipe = WanPipeline.from_pretrained(
 | 
						||
            model_id,
 | 
						||
            vae=vae,
 | 
						||
            torch_dtype=torch.bfloat16
 | 
						||
        ).to("cuda")
 | 
						||
        app.state.pipe.scheduler = scheduler
 | 
						||
 | 
						||
        # 初始化任务系统
 | 
						||
        app.state.tasks: Dict[str, dict] = {}
 | 
						||
        app.state.pending_queue: List[str] = []
 | 
						||
        app.state.model_lock = Lock()
 | 
						||
        app.state.task_lock = Lock()
 | 
						||
        app.state.base_url = "ip地址+端口"
 | 
						||
        app.state.max_concurrent = 2
 | 
						||
        app.state.semaphore = asyncio.Semaphore(app.state.max_concurrent)
 | 
						||
 | 
						||
        # 启动后台任务处理器
 | 
						||
        asyncio.create_task(task_processor())
 | 
						||
      
 | 
						||
        print("✅ 应用初始化完成")
 | 
						||
        yield
 | 
						||
      
 | 
						||
    finally:
 | 
						||
        # 清理资源
 | 
						||
        if hasattr(app.state, 'pipe'):
 | 
						||
            del app.state.pipe
 | 
						||
            torch.cuda.empty_cache()
 | 
						||
            print("♻️ 已释放模型资源")
 | 
						||
 | 
						||
# 创建FastAPI应用
 | 
						||
app = FastAPI(lifespan=lifespan)
 | 
						||
app.mount("/videos", StaticFiles(directory="generated_videos"), name="videos")
 | 
						||
 | 
						||
# 认证模块
 | 
						||
security = HTTPBearer(auto_error=False)
 | 
						||
 | 
						||
# ======================
 | 
						||
# 数据模型--查询参数模型
 | 
						||
# ======================
 | 
						||
class VideoSubmitRequest(BaseModel):
 | 
						||
    model: str = Field(default="Wan2.1-T2V-1.3B",description="使用的模型版本")
 | 
						||
    prompt: str = Field(
 | 
						||
        ...,
 | 
						||
        min_length=10,
 | 
						||
        max_length=500,
 | 
						||
        description="视频描述提示词,10-500个字符"
 | 
						||
    )
 | 
						||
    image_size: str = Field(
 | 
						||
        ...,
 | 
						||
        description="视频分辨率,仅支持480x832或832x480"
 | 
						||
    )
 | 
						||
    negative_prompt: Optional[str] = Field(
 | 
						||
        default=None,
 | 
						||
        max_length=500,
 | 
						||
        description="排除不需要的内容"
 | 
						||
    )
 | 
						||
    seed: Optional[int] = Field(
 | 
						||
        default=None,
 | 
						||
        ge=0,
 | 
						||
        le=2147483647,
 | 
						||
        description="随机数种子,范围0-2147483647"
 | 
						||
    )
 | 
						||
    num_frames: int = Field(
 | 
						||
        default=81,
 | 
						||
        ge=24,
 | 
						||
        le=120,
 | 
						||
        description="视频帧数,24-120帧"
 | 
						||
    )
 | 
						||
    guidance_scale: float = Field(
 | 
						||
        default=5.0,
 | 
						||
        ge=1.0,
 | 
						||
        le=20.0,
 | 
						||
        description="引导系数,1.0-20.0"
 | 
						||
    )
 | 
						||
    infer_steps: int = Field(
 | 
						||
        default=50,
 | 
						||
        ge=20,
 | 
						||
        le=100,
 | 
						||
        description="推理步数,20-100步"
 | 
						||
    )
 | 
						||
 | 
						||
    @field_validator('image_size', mode='before')
 | 
						||
    @classmethod
 | 
						||
    def validate_image_size(cls, value):
 | 
						||
        allowed_sizes = {"480x832", "832x480"}
 | 
						||
        if value not in allowed_sizes:
 | 
						||
            raise ValueError(f"仅支持以下分辨率: {', '.join(allowed_sizes)}")
 | 
						||
        return value
 | 
						||
 | 
						||
class VideoStatusRequest(BaseModel):
 | 
						||
    requestId: str = Field(
 | 
						||
        ...,
 | 
						||
        min_length=32,
 | 
						||
        max_length=32,
 | 
						||
        description="32位任务ID"
 | 
						||
    )
 | 
						||
 | 
						||
class VideoSubmitResponse(BaseModel):
 | 
						||
    requestId: str
 | 
						||
 | 
						||
class VideoStatusResponse(BaseModel):
 | 
						||
    status: str = Field(..., description="任务状态: Succeed, InQueue, InProgress, Failed,Cancelled")
 | 
						||
    reason: Optional[str] = Field(None, description="失败原因")
 | 
						||
    results: Optional[dict] = Field(None, description="生成结果")
 | 
						||
    queue_position: Optional[int] = Field(None, description="队列位置")
 | 
						||
 | 
						||
class VideoCancelRequest(BaseModel):
 | 
						||
    requestId: str = Field(
 | 
						||
        ...,
 | 
						||
        min_length=32,
 | 
						||
        max_length=32,
 | 
						||
        description="32位任务ID"
 | 
						||
    )
 | 
						||
 | 
						||
# # 自定义HTTP异常处理器
 | 
						||
# @app.exception_handler(HTTPException)
 | 
						||
# async def http_exception_handler(request, exc):
 | 
						||
#     return JSONResponse(
 | 
						||
#         status_code=exc.status_code,
 | 
						||
#         content=exc.detail,  # 直接返回detail内容(不再包装在detail字段)
 | 
						||
#         headers=exc.headers
 | 
						||
#     )
 | 
						||
 | 
						||
# ======================
 | 
						||
# 后台任务处理
 | 
						||
# ======================
 | 
						||
async def task_processor():
 | 
						||
    """处理任务队列"""
 | 
						||
    while True:
 | 
						||
        async with app.state.semaphore:
 | 
						||
            task_id = await get_next_task()
 | 
						||
            if task_id:
 | 
						||
                await process_task(task_id)
 | 
						||
            else:
 | 
						||
                await asyncio.sleep(0.5)
 | 
						||
 | 
						||
async def get_next_task():
 | 
						||
    """获取下一个待处理任务"""
 | 
						||
    with app.state.task_lock:
 | 
						||
        if app.state.pending_queue:
 | 
						||
            return app.state.pending_queue.pop(0)
 | 
						||
    return None
 | 
						||
 | 
						||
async def process_task(task_id: str):
 | 
						||
    """处理单个任务"""
 | 
						||
    task = app.state.tasks.get(task_id)
 | 
						||
    if not task:
 | 
						||
        return
 | 
						||
 | 
						||
    try:
 | 
						||
        # 更新任务状态
 | 
						||
        task['status'] = 'InProgress'
 | 
						||
        task['started_at'] = int(time.time())
 | 
						||
 | 
						||
        # 执行视频生成
 | 
						||
        video_path = await generate_video(task['request'], task_id)
 | 
						||
      
 | 
						||
        # 生成下载链接
 | 
						||
        download_url = f"{app.state.base_url}/videos/{os.path.basename(video_path)}"
 | 
						||
      
 | 
						||
        # 更新任务状态
 | 
						||
        task.update({
 | 
						||
            'status': 'Succeed',
 | 
						||
            'download_url': download_url,
 | 
						||
            'completed_at': int(time.time())
 | 
						||
        })
 | 
						||
      
 | 
						||
        # 安排自动清理
 | 
						||
        asyncio.create_task(auto_cleanup(video_path))
 | 
						||
    except Exception as e:
 | 
						||
        handle_task_error(task, e)
 | 
						||
 | 
						||
def handle_task_error(task: dict, error: Exception):
 | 
						||
    """统一处理任务错误"""
 | 
						||
    error_msg = str(error)
 | 
						||
    if isinstance(error, torch.cuda.OutOfMemoryError):
 | 
						||
        error_msg = "显存不足,请降低分辨率或减少帧数"
 | 
						||
    elif isinstance(error, ValidationError):
 | 
						||
        error_msg = "参数校验失败: " + str(error)
 | 
						||
  
 | 
						||
    task.update({
 | 
						||
        'status': 'Failed',
 | 
						||
        'reason': error_msg,
 | 
						||
        'completed_at': int(time.time())
 | 
						||
    })
 | 
						||
 | 
						||
# ======================
 | 
						||
# 视频生成核心逻辑
 | 
						||
# ======================
 | 
						||
async def generate_video(request: dict, task_id: str) -> str:
 | 
						||
    """异步执行视频生成"""
 | 
						||
    loop = asyncio.get_event_loop()
 | 
						||
    return await loop.run_in_executor(
 | 
						||
        None,
 | 
						||
        sync_generate_video,
 | 
						||
        request,
 | 
						||
        task_id
 | 
						||
    )
 | 
						||
 | 
						||
def sync_generate_video(request: dict, task_id: str) -> str:
 | 
						||
    """同步生成视频"""
 | 
						||
    with app.state.model_lock:
 | 
						||
        try:
 | 
						||
            generator = None
 | 
						||
            if request.get('seed') is not None:
 | 
						||
                generator = torch.Generator(device="cuda")
 | 
						||
                generator.manual_seed(request['seed'])
 | 
						||
                print(f"🔮 使用随机种子: {request['seed']}")
 | 
						||
            
 | 
						||
            # 执行模型推理
 | 
						||
            result = app.state.pipe(
 | 
						||
                prompt=request['prompt'],
 | 
						||
                negative_prompt=request['negative_prompt'],
 | 
						||
                height=request['height'],
 | 
						||
                width=request['width'],
 | 
						||
                num_frames=request['num_frames'],
 | 
						||
                guidance_scale=request['guidance_scale'],
 | 
						||
                num_inference_steps=request['infer_steps'],
 | 
						||
                generator=generator
 | 
						||
            )
 | 
						||
          
 | 
						||
            # 导出视频文件
 | 
						||
            video_id = uuid.uuid4().hex
 | 
						||
            output_path = f"generated_videos/{video_id}.mp4"
 | 
						||
            export_to_video(result.frames[0], output_path, fps=16)
 | 
						||
          
 | 
						||
            return output_path
 | 
						||
        except Exception as e:
 | 
						||
            raise RuntimeError(f"视频生成失败: {str(e)}") from e
 | 
						||
 | 
						||
# ======================
 | 
						||
# API端点
 | 
						||
# ======================
 | 
						||
async def verify_auth(credentials: HTTPAuthorizationCredentials = Depends(security)):
 | 
						||
    """认证验证"""
 | 
						||
    if not credentials:
 | 
						||
        raise HTTPException(
 | 
						||
            status_code=401,
 | 
						||
            detail={"status": "Failed", "reason": "缺少认证头"},
 | 
						||
            headers={"WWW-Authenticate": "Bearer"}
 | 
						||
        )
 | 
						||
    if credentials.scheme != "Bearer":
 | 
						||
        raise HTTPException(
 | 
						||
            status_code=401,
 | 
						||
            detail={"status": "Failed", "reason": "无效的认证方案"},
 | 
						||
            headers={"WWW-Authenticate": "Bearer"}
 | 
						||
        )
 | 
						||
    if credentials.credentials not in app.state.valid_api_keys:
 | 
						||
        raise HTTPException(
 | 
						||
            status_code=401,
 | 
						||
            detail={"status": "Failed", "reason": "无效的API密钥"},
 | 
						||
            headers={"WWW-Authenticate": "Bearer"}
 | 
						||
        )
 | 
						||
    return True
 | 
						||
 | 
						||
@app.post("/video/submit",
 | 
						||
          response_model=VideoSubmitResponse,
 | 
						||
          status_code=status.HTTP_202_ACCEPTED,
 | 
						||
          tags=["视频生成"],
 | 
						||
          summary="提交视频生成请求")
 | 
						||
async def submit_video_task(
 | 
						||
    request: VideoSubmitRequest,
 | 
						||
    auth: bool = Depends(verify_auth)
 | 
						||
):
 | 
						||
    """提交新的视频生成任务"""
 | 
						||
    try:
 | 
						||
        # 解析分辨率参数
 | 
						||
        width, height = map(int, request.image_size.split('x'))
 | 
						||
      
 | 
						||
        # 创建任务记录
 | 
						||
        task_id = uuid.uuid4().hex
 | 
						||
        task_data = {
 | 
						||
            'request': {
 | 
						||
                'prompt': request.prompt,
 | 
						||
                'negative_prompt': request.negative_prompt,
 | 
						||
                'width': width,
 | 
						||
                'height': height,
 | 
						||
                'num_frames': request.num_frames,
 | 
						||
                'guidance_scale': request.guidance_scale,
 | 
						||
                'infer_steps': request.infer_steps,
 | 
						||
                'seed': request.seed
 | 
						||
            },
 | 
						||
            'status': 'InQueue',
 | 
						||
            'created_at': int(time.time())
 | 
						||
        }
 | 
						||
      
 | 
						||
        # 加入任务队列
 | 
						||
        with app.state.task_lock:
 | 
						||
            app.state.tasks[task_id] = task_data
 | 
						||
            app.state.pending_queue.append(task_id)
 | 
						||
      
 | 
						||
        return {"requestId": task_id}
 | 
						||
  
 | 
						||
    except ValidationError as e:
 | 
						||
        raise HTTPException(
 | 
						||
            status_code=422,
 | 
						||
            detail={"status": "Failed", "reason": str(e)}
 | 
						||
        )
 | 
						||
 | 
						||
@app.post("/video/status",
 | 
						||
          response_model=VideoStatusResponse,
 | 
						||
          tags=["视频生成"],
 | 
						||
          summary="查询任务状态")
 | 
						||
async def get_video_status(
 | 
						||
    request: VideoStatusRequest,
 | 
						||
    auth: bool = Depends(verify_auth)
 | 
						||
):
 | 
						||
    """查询任务状态"""
 | 
						||
    task = app.state.tasks.get(request.requestId)
 | 
						||
    if not task:
 | 
						||
        raise HTTPException(
 | 
						||
            status_code=404,
 | 
						||
            detail={"status": "Failed", "reason": "无效的任务ID"}
 | 
						||
        )
 | 
						||
 | 
						||
    # 计算队列位置(仅当在队列中时)
 | 
						||
    queue_pos = 0
 | 
						||
    if task['status'] == "InQueue" and request.requestId in app.state.pending_queue:
 | 
						||
        queue_pos = app.state.pending_queue.index(request.requestId) + 1
 | 
						||
 | 
						||
    response = {
 | 
						||
        "status": task['status'],
 | 
						||
        "reason": task.get('reason'),
 | 
						||
        "queue_position": queue_pos if task['status'] == "InQueue" else None  # 非排队状态返回null
 | 
						||
    }
 | 
						||
 | 
						||
    # 成功状态的特殊处理
 | 
						||
    if task['status'] == "Succeed":
 | 
						||
        response["results"] = {
 | 
						||
            "videos": [{"url": task['download_url']}],
 | 
						||
            "timings": {
 | 
						||
                "inference": task['completed_at'] - task['started_at']
 | 
						||
            },
 | 
						||
            "seed": task['request']['seed']
 | 
						||
        }
 | 
						||
    # 取消状态的补充信息
 | 
						||
    elif task['status'] == "Cancelled":
 | 
						||
        response["reason"] = task.get('reason', "用户主动取消")  # 确保原因字段存在
 | 
						||
 | 
						||
    return response
 | 
						||
 | 
						||
 | 
						||
@app.post("/video/cancel",
 | 
						||
         response_model=dict,
 | 
						||
         tags=["视频生成"])
 | 
						||
async def cancel_task(
 | 
						||
    request: VideoCancelRequest,
 | 
						||
    auth: bool = Depends(verify_auth)
 | 
						||
):
 | 
						||
    """取消排队中的生成任务"""
 | 
						||
    task_id = request.requestId
 | 
						||
  
 | 
						||
    with app.state.task_lock:
 | 
						||
        task = app.state.tasks.get(task_id)
 | 
						||
      
 | 
						||
        # 检查任务是否存在
 | 
						||
        if not task:
 | 
						||
            raise HTTPException(
 | 
						||
                status_code=404,
 | 
						||
                detail={"status": "Failed", "reason": "无效的任务ID"}
 | 
						||
            )
 | 
						||
          
 | 
						||
        current_status = task['status']
 | 
						||
      
 | 
						||
        # 仅允许取消排队中的任务
 | 
						||
        if current_status != "InQueue":
 | 
						||
            raise HTTPException(
 | 
						||
                status_code=400,
 | 
						||
                detail={"status": "Failed", "reason": f"仅允许取消排队任务,当前状态: {current_status}"}
 | 
						||
            )
 | 
						||
      
 | 
						||
        # 从队列移除
 | 
						||
        try:
 | 
						||
            app.state.pending_queue.remove(task_id)
 | 
						||
        except ValueError:
 | 
						||
            pass  # 可能已被处理
 | 
						||
      
 | 
						||
        # 更新任务状态
 | 
						||
        task.update({
 | 
						||
            "status": "Cancelled",
 | 
						||
            "reason": "用户主动取消",
 | 
						||
            "completed_at": int(time.time())
 | 
						||
        })
 | 
						||
      
 | 
						||
    return {"status": "Succeed"}
 | 
						||
 | 
						||
 | 
						||
# ======================
 | 
						||
# 工具函数
 | 
						||
# ======================
 | 
						||
async def auto_cleanup(file_path: str, delay: int = 3600):
 | 
						||
    """自动清理生成的视频文件"""
 | 
						||
    await asyncio.sleep(delay)
 | 
						||
    try:
 | 
						||
        if os.path.exists(file_path):
 | 
						||
            os.remove(file_path)
 | 
						||
            print(f"已清理文件: {file_path}")
 | 
						||
    except Exception as e:
 | 
						||
        print(f"文件清理失败: {str(e)}")
 | 
						||
 | 
						||
if __name__ == "__main__":
 | 
						||
    import uvicorn
 | 
						||
    uvicorn.run(app, host="0.0.0.0", port=8088) |