add teacache implement for Wan2.1

This commit is contained in:
kousakawang 2025-05-23 19:15:52 +08:00
parent e5a741309d
commit a11333634c

View File

@ -4,21 +4,32 @@ import logging
import os
import sys
import warnings
from tqdm import tqdm
from datetime import datetime
warnings.filterwarnings('ignore')
import random
import torch
import torch.distributed as dist
from PIL import Image
import torchvision.transforms.functional as TF
import torch.cuda.amp as amp
import numpy as np
import math
import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image, cache_video, str2bool
import gc
from contextlib import contextmanager
from wan.modules.model import sinusoidal_embedding_1d
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
EXAMPLE_PROMPT = {
"t2v-1.3B": {
@ -60,6 +71,537 @@ EXAMPLE_PROMPT = {
}
}
def t2v_generate(self,
input_prompt,
size=(1280, 720),
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation
size (tupele[`int`], *optional*, defaults to (1280,720)):
Controls video resolution, (width,height).
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed.
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from size)
- W: Frame width from size)
"""
# preprocess
F = frame_num
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
size[1] // self.vae_stride[1],
size[0] // self.vae_stride[2])
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
(self.patch_size[1] * self.patch_size[2]) *
target_shape[1] / self.sp_size) * self.sp_size
if n_prompt == "":
n_prompt = self.sample_neg_prompt
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
noise = [
torch.randn(
target_shape[0],
target_shape[1],
target_shape[2],
target_shape[3],
dtype=torch.float32,
device=self.device,
generator=seed_g)
]
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latents = noise
arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = latents
timestep = [t]
timestep = torch.stack(timestep)
self.model.to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0]
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latents[0].unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latents = [temp_x0.squeeze(0)]
x0 = latents
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latents
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
def i2v_generate(self,
input_prompt,
img,
max_area=720 * 1280,
frame_num=81,
shift=5.0,
sample_solver='unipc',
sampling_steps=40,
guide_scale=5.0,
n_prompt="",
seed=-1,
offload_model=True):
r"""
Generates video frames from input image and text prompt using diffusion process.
Args:
input_prompt (`str`):
Text prompt for content generation.
img (PIL.Image.Image):
Input image tensor. Shape: [3, H, W]
max_area (`int`, *optional*, defaults to 720*1280):
Maximum pixel area for latent space calculation. Controls video resolution scaling
frame_num (`int`, *optional*, defaults to 81):
How many frames to sample from a video. The number should be 4n+1
shift (`float`, *optional*, defaults to 5.0):
Noise schedule shift parameter. Affects temporal dynamics
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
sample_solver (`str`, *optional*, defaults to 'unipc'):
Solver used to sample the video.
sampling_steps (`int`, *optional*, defaults to 40):
Number of diffusion sampling steps. Higher values improve quality but slow generation
guide_scale (`float`, *optional*, defaults 5.0):
Classifier-free guidance scale. Controls prompt adherence vs. creativity
n_prompt (`str`, *optional*, defaults to ""):
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
seed (`int`, *optional*, defaults to -1):
Random seed for noise generation. If -1, use random seed
offload_model (`bool`, *optional*, defaults to True):
If True, offloads models to CPU during generation to save VRAM
Returns:
torch.Tensor:
Generated video frames tensor. Dimensions: (C, N H, W) where:
- C: Color channels (3 for RGB)
- N: Number of frames (81)
- H: Frame height (from max_area)
- W: Frame width from max_area)
"""
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
F = frame_num
h, w = img.shape[1:]
aspect_ratio = h / w
lat_h = round(
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
self.patch_size[1] * self.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
self.patch_size[2] * self.patch_size[2])
h = lat_h * self.vae_stride[1]
w = lat_w * self.vae_stride[2]
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
self.patch_size[1] * self.patch_size[2])
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
noise = torch.randn(
self.vae.model.z_dim,
(F - 1) // self.vae_stride[0] + 1,
lat_h,
lat_w,
dtype=torch.float32,
generator=seed_g,
device=self.device)
msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
],
dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if n_prompt == "":
n_prompt = self.sample_neg_prompt
# preprocess
if not self.t5_cpu:
self.text_encoder.model.to(self.device)
context = self.text_encoder([input_prompt], self.device)
context_null = self.text_encoder([n_prompt], self.device)
if offload_model:
self.text_encoder.model.cpu()
else:
context = self.text_encoder([input_prompt], torch.device('cpu'))
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
context = [t.to(self.device) for t in context]
context_null = [t.to(self.device) for t in context_null]
self.clip.model.to(self.device)
clip_context = self.clip.visual([img[:, None, :, :]])
if offload_model:
self.clip.model.cpu()
y = self.vae.encode([
torch.concat([
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, F-1, h, w)
],
dim=1).to(self.device)
])[0]
y = torch.concat([msk, y])
@contextmanager
def noop_no_sync():
yield
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
# evaluation mode
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
if sample_solver == 'unipc':
sample_scheduler = FlowUniPCMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sample_scheduler.set_timesteps(
sampling_steps, device=self.device, shift=shift)
timesteps = sample_scheduler.timesteps
elif sample_solver == 'dpm++':
sample_scheduler = FlowDPMSolverMultistepScheduler(
num_train_timesteps=self.num_train_timesteps,
shift=1,
use_dynamic_shifting=False)
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
timesteps, _ = retrieve_timesteps(
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
else:
raise NotImplementedError("Unsupported solver.")
# sample videos
latent = noise
arg_c = {
'context': [context[0]],
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
# 'cond_flag': True,
}
arg_null = {
'context': context_null,
'clip_fea': clip_context,
'seq_len': max_seq_len,
'y': [y],
# 'cond_flag': False,
}
if offload_model:
torch.cuda.empty_cache()
self.model.to(self.device)
for _, t in enumerate(tqdm(timesteps)):
latent_model_input = [latent.to(self.device)]
timestep = [t]
timestep = torch.stack(timestep).to(self.device)
noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0].to(
torch.device('cpu') if offload_model else self.device)
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
noise_pred_cond - noise_pred_uncond)
latent = latent.to(
torch.device('cpu') if offload_model else self.device)
temp_x0 = sample_scheduler.step(
noise_pred.unsqueeze(0),
t,
latent.unsqueeze(0),
return_dict=False,
generator=seed_g)[0]
latent = temp_x0.squeeze(0)
x0 = [latent.to(self.device)]
del latent_model_input, timestep
if offload_model:
self.model.cpu()
torch.cuda.empty_cache()
if self.rank == 0:
videos = self.vae.decode(x0)
del noise, latent
del sample_scheduler
if offload_model:
gc.collect()
torch.cuda.synchronize()
if dist.is_initialized():
dist.barrier()
return videos[0] if self.rank == 0 else None
def teacache_forward(
self,
x,
t,
context,
seq_len,
clip_fea=None,
y=None,
):
r"""
Forward pass through the diffusion model
Args:
x (List[Tensor]):
List of input video tensors, each with shape [C_in, F, H, W]
t (Tensor):
Diffusion timesteps tensor of shape [B]
context (List[Tensor]):
List of text embeddings each with shape [L, C]
seq_len (`int`):
Maximum sequence length for positional encoding
clip_fea (Tensor, *optional*):
CLIP image features for image-to-video mode
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
Returns:
List[Tensor]:
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
"""
logging.info("via teacache forward process")
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
dim=1) for u in x
])
# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens)
if self.enable_teacache:
modulated_inp = e0 if self.use_ref_steps else e
# teacache
if self.cnt%2==0: # even -> conditon
self.is_even = True
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc_even = False
else:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = modulated_inp.clone()
else: # odd -> unconditon
self.is_even = False
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc_odd = False
else:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = modulated_inp.clone()
if self.enable_teacache:
if self.is_even:
if not should_calc_even:
logging.info("use residual estimation for this difusion step")
x += self.previous_residual_even
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
self.previous_residual_even = x - ori_x
else:
if not should_calc_odd:
logging.info("use residual estimation for thi8s difusion step")
x += self.previous_residual_odd
else:
ori_x = x.clone()
for block in self.blocks:
x = block(x, **kwargs)
self.previous_residual_odd = x - ori_x
else:
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
return [u.float() for u in x]
def _validate_args(args):
# Basic check
@ -243,6 +785,25 @@ def _parse_args():
type=float,
default=5.0,
help="Classifier free guidance scale.")
parser.add_argument(
"--use_ret_steps",
action="store_true",
default=False,
help=" use ret_steps or not")
parser.add_argument(
"--enable_teacache",
action="store_true",
default=False,
help=" use ret_steps or not")
#teacache_thresh
parser.add_argument(
"--teacache_thresh",
type=float,
default= 0.2,
help="tea_cache threshold")
args = parser.parse_args()
@ -367,6 +928,35 @@ def generate(args):
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
if args.enable_teacache:
wan_t2v.__class__.generate = t2v_generate
wan_t2v.model.__class__.enable_teacache = True
wan_t2v.model.__class__.forward = teacache_forward
wan_t2v.model.__class__.cnt = 0
wan_t2v.model.__class__.num_steps = args.sample_steps*2
wan_t2v.model.__class__.teacache_thresh = args.teacache_thresh
wan_t2v.model.__class__.accumulated_rel_l1_distance_even = 0
wan_t2v.model.__class__.accumulated_rel_l1_distance_odd = 0
wan_t2v.model.__class__.previous_e0_even = None
wan_t2v.model.__class__.previous_e0_odd = None
wan_t2v.model.__class__.previous_residual_even = None
wan_t2v.model.__class__.previous_residual_odd = None
wan_t2v.model.__class__.use_ref_steps = args.use_ret_steps
if args.use_ret_steps:
if '1.3B' in args.ckpt_dir:
wan_t2v.model.__class__.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
if '14B' in args.ckpt_dir:
wan_t2v.model.__class__.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
wan_t2v.model.__class__.ret_steps = 5*2
wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2
else:
if '1.3B' in args.ckpt_dir:
wan_t2v.model.__class__.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
if '14B' in args.ckpt_dir:
wan_t2v.model.__class__.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
wan_t2v.model.__class__.ret_steps = 1*2
wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2
logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
@ -424,7 +1014,36 @@ def generate(args):
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
if args.enable_teacache:
wan_i2v.__class__.generate = i2v_generate
wan_i2v.model.__class__.enable_teacache = True
wan_i2v.model.__class__.forward = teacache_forward
wan_i2v.model.__class__.cnt = 0
wan_i2v.model.__class__.num_steps = args.sample_steps*2
wan_i2v.model.__class__.teacache_thresh = args.teacache_thresh
wan_i2v.model.__class__.accumulated_rel_l1_distance_even = 0
wan_i2v.model.__class__.accumulated_rel_l1_distance_odd = 0
wan_i2v.model.__class__.previous_e0_even = None
wan_i2v.model.__class__.previous_e0_odd = None
wan_i2v.model.__class__.previous_residual_even = None
wan_i2v.model.__class__.previous_residual_odd = None
wan_i2v.model.__class__.use_ref_steps = args.use_ret_steps
if args.use_ret_steps:
if '480P' in args.ckpt_dir:
wan_i2v.model.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
if '720P' in args.ckpt_dir:
wan_i2v.model.__class__.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
wan_i2v.model.__class__.ret_steps = 5*2
wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2
else:
if '480P' in args.ckpt_dir:
wan_i2v.model.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
if '720P' in args.ckpt_dir:
wan_i2v.model.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
wan_i2v.model.__class__.ret_steps = 1*2
wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2
logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,