Wan2.1/wan/t2v_pipeline.py
2025-10-21 15:02:12 +08:00

218 lines
7.5 KiB
Python

# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
from contextlib import contextmanager
import torch
import torch.cuda.amp as amp
from tqdm import tqdm
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class WanT2V:
def __init__(
self,
model_hyperparam,
checkpoint_dir,
t5_cpu=False,
):
r"""
Initializes the Wan text-to-video generation model components (single-GPU only).
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU during encoding.
"""
self.device = torch.device("cuda:0")
self.config = model_hyperparam
self.t5_cpu = t5_cpu
self.num_train_timesteps = model_hyperparam.num_train_timesteps
self.param_dtype = model_hyperparam.param_dtype
# Load T5 text encoder
self.text_encoder = T5EncoderModel(
text_len=model_hyperparam.text_len,
dtype=model_hyperparam.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, model_hyperparam.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, model_hyperparam.t5_tokenizer),
shard_fn=None # No FSDP
)
# Load VAE
self.vae_stride = model_hyperparam.vae_stride
self.patch_size = model_hyperparam.patch_size
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, model_hyperparam.vae_checkpoint),
device=self.device
)
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False)
self.model.to(self.device)
self.sample_neg_prompt = model_hyperparam.sample_neg_prompt
def generate(
self,
input_prompt,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True
):
r"""
Generates video frames from text prompt using diffusion process (single-GPU).
Args:
input_prompt (`str`): Text prompt for content generation
size (tuple[`int`], *optional*, defaults to (1280,720)): (width, height)
frame_num (`int`, *optional*, defaults to 81): Number of frames (must be 4n+1)
shift (`float`, *optional*, defaults to 5.0): Noise schedule shift
sample_solver (`str`, *optional*, defaults to 'unipc'): Solver name
sampling_steps (`int`, *optional*, defaults to 50): Diffusion steps
guide_scale (`float`, *optional*, defaults to 5.0): CFG scale
n_prompt (`str`, *optional*, defaults to ""): Negative prompt
seed (`int`, *optional*, defaults to -1): Random seed
offload_model (`bool`, *optional*, defaults to True): Offload to CPU to save VRAM
Returns:
torch.Tensor: Video tensor of shape (C, N, H, W)
"""
F = frame_num
target_shape = (
self.vae.model.z_dim,
(F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2]
)
seq_len = math.ceil(
(target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1]
)
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
# Encode prompts
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g
)
]
@contextmanager
def noop():
yield
# Evaluation
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), noop():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False
)
sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False
)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas
)
else:
raise NotImplementedError("Unsupported solver.")
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
for t in tqdm(timesteps):
latent_model_input = latents
timestep = torch.tensor([t], device=self.device)
self.model.to(self.device)
noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(latent_model_input, t=timestep, **arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g
)[0]
latents = [temp_x0.squeeze(0)]
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
videos = self.vae.decode(x0)
del noise, latents, sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
return videos[0]