Wan2.1/wan/text2video.py
Stan Campbell fdbc5f0588 feat: add --vae_cpu flag for improved VRAM optimization
Add --vae_cpu argument to enable VAE offloading for consumer GPUs with
limited VRAM. When enabled, VAE initializes on CPU and moves to GPU only
when needed for encoding/decoding operations.

Key changes:
- Add --vae_cpu argument to generate.py (mirrors --t5_cpu pattern)
- Update all 4 pipelines (T2V, I2V, FLF2V, VACE) with conditional VAE offloading
- Fix DiT offloading to free VRAM before T5 loading when offload_model=True
- Handle VAE scale tensors (mean/std) during device transfers

Benefits:
- Saves ~100-200MB VRAM without performance degradation
- Enables T2V-1.3B on more consumer GPUs (tested on 11.49GB GPU)
- Backward compatible (default=False)
- Consistent with existing --t5_cpu flag

Test results on 11.49 GiB VRAM GPU:
- Baseline: OOM (needed 80MB, only 85MB free)
- With --vae_cpu: Success
- With --t5_cpu: Success
- With both flags: Success (maximum VRAM savings)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-17 03:14:28 -07:00

298 lines
11 KiB
Python

# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .utils.fm_solvers import (
FlowDPMSolverMultistepScheduler,
get_sampling_sigmas,
retrieve_timesteps,
)
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
class WanT2V:
def __init__(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
vae_cpu=False,
):
r"""
Initializes the Wan text-to-video generation model components.
Args:
config (EasyDict):
Object containing model parameters initialized from config.py
checkpoint_dir (`str`):
Path to directory containing model checkpoints
device_id (`int`, *optional*, defaults to 0):
Id of target GPU device
rank (`int`, *optional*, defaults to 0):
Process rank for distributed training
t5_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for T5 model
dit_fsdp (`bool`, *optional*, defaults to False):
Enable FSDP sharding for DiT model
use_usp (`bool`, *optional*, defaults to False):
Enable distribution strategy of USP.
t5_cpu (`bool`, *optional*, defaults to False):
Whether to place T5 model on CPU. Only works without t5_fsdp.
vae_cpu (`bool`, *optional*, defaults to False):
Whether to place VAE model on CPU to save VRAM. VAE will be moved to GPU only when needed.
"""
self.device = torch.device(f"cuda:{device_id}")
self.config = config
self.rank = rank
self.t5_cpu = t5_cpu
self.vae_cpu = vae_cpu
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
shard_fn = partial(shard_model, device_id=device_id)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
shard_fn=shard_fn if t5_fsdp else None)
self.vae_stride = config.vae_stride
self.patch_size = config.patch_size
# Initialize VAE on CPU if vae_cpu=True to save VRAM during pipeline initialization and diffusion loop
# VAE is only needed at the end for decoding latents to pixels
vae_device = torch.device('cpu') if vae_cpu else self.device
self.vae = WanVAE(
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
device=vae_device)
logging.info(f"Creating WanModel from {checkpoint_dir}")
self.model = WanModel.from_pretrained(checkpoint_dir)
self.model.eval().requires_grad_(False)
if use_usp:
from xfuser.core.distributed import get_sequence_parallel_world_size
from .distributed.xdit_context_parallel import (
usp_attn_forward,
usp_dit_forward,
)
for block in self.model.blocks:
block.self_attn.forward = types.MethodType(
usp_attn_forward, block.self_attn)
self.model.forward = types.MethodType(usp_dit_forward, self.model)
self.sp_size = get_sequence_parallel_world_size()
else:
self.sp_size = 1
if dist.is_initialized():
dist.barrier()
if dit_fsdp:
self.model = shard_fn(self.model)
else:
self.model.to(self.device)
self.sample_neg_prompt = config.sample_neg_prompt
def generate(self,
input_prompt,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tupele[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if not self.t5_cpu:
# Offload DiT to CPU first if needed to make room for T5
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
torch.cuda.empty_cache()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
# Move VAE to GPU for decoding if it's on CPU
if self.vae_cpu:
self.vae.model.to(self.device)
# Also move scale tensors to GPU
self.vae.mean = self.vae.mean.to(self.device)
self.vae.std = self.vae.std.to(self.device)
self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
videos = self.vae.decode(x0)
# Offload VAE back to CPU after decoding to free VRAM
if self.vae_cpu and offload_model:
self.vae.model.cpu()
self.vae.mean = self.vae.mean.cpu()
self.vae.std = self.vae.std.cpu()
self.vae.scale = [self.vae.mean, 1.0 / self.vae.std]
torch.cuda.empty_cache()
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None