diff --git a/generate.py b/generate.py index c841c19..ce19aae 100644 --- a/generate.py +++ b/generate.py @@ -4,21 +4,32 @@ import logging import os import sys import warnings +from tqdm import tqdm from datetime import datetime warnings.filterwarnings('ignore') import random - import torch import torch.distributed as dist from PIL import Image +import torchvision.transforms.functional as TF +import torch.cuda.amp as amp +import numpy as np +import math import wan from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_image, cache_video, str2bool +import gc +from contextlib import contextmanager +from wan.modules.model import sinusoidal_embedding_1d +from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) +from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + EXAMPLE_PROMPT = { "t2v-1.3B": { @@ -60,6 +71,537 @@ EXAMPLE_PROMPT = { } } +def t2v_generate(self, + input_prompt, + size=(1280, 720), + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation + size (tupele[`int`], *optional*, defaults to (1280,720)): + Controls video resolution, (width,height). + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + # preprocess + F = frame_num + target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2]) + + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + noise = [ + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ] + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noise + + arg_c = {'context': context, 'seq_len': seq_len} + arg_null = {'context': context_null, 'seq_len': seq_len} + + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = latents + timestep = [t] + + timestep = torch.stack(timestep) + + self.model.to(self.device) + noise_pred_cond = self.model( + latent_model_input, t=timestep, **arg_c)[0] + noise_pred_uncond = self.model( + latent_model_input, t=timestep, **arg_null)[0] + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latents[0].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents = [temp_x0.squeeze(0)] + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.vae.decode(x0) + + del noise, latents + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None + +def i2v_generate(self, + input_prompt, + img, + max_area=720 * 1280, + frame_num=81, + shift=5.0, + sample_solver='unipc', + sampling_steps=40, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from input image and text prompt using diffusion process. + + Args: + input_prompt (`str`): + Text prompt for content generation. + img (PIL.Image.Image): + Input image tensor. Shape: [3, H, W] + max_area (`int`, *optional*, defaults to 720*1280): + Maximum pixel area for latent space calculation. Controls video resolution scaling + frame_num (`int`, *optional*, defaults to 81): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from max_area) + - W: Frame width from max_area) + """ + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) + + F = frame_num + h, w = img.shape[1:] + aspect_ratio = h / w + lat_h = round( + np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // + self.patch_size[1] * self.patch_size[1]) + lat_w = round( + np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // + self.patch_size[2] * self.patch_size[2]) + h = lat_h * self.vae_stride[1] + w = lat_w * self.vae_stride[2] + + max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( + self.patch_size[1] * self.patch_size[2]) + max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + noise = torch.randn( + self.vae.model.z_dim, + (F - 1) // self.vae_stride[0] + 1, + lat_h, + lat_w, + dtype=torch.float32, + generator=seed_g, + device=self.device) + + msk = torch.ones(1, F, lat_h, lat_w, device=self.device) + msk[:, 1:] = 0 + msk = torch.concat([ + torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] + ], + dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + + # preprocess + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + context_null = self.text_encoder([n_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context_null = self.text_encoder([n_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + context_null = [t.to(self.device) for t in context_null] + + self.clip.model.to(self.device) + clip_context = self.clip.visual([img[:, None, :, :]]) + if offload_model: + self.clip.model.cpu() + + y = self.vae.encode([ + torch.concat([ + torch.nn.functional.interpolate( + img[None].cpu(), size=(h, w), mode='bicubic').transpose( + 0, 1), + torch.zeros(3, F-1, h, w) + ], + dim=1).to(self.device) + ])[0] + y = torch.concat([msk, y]) + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latent = noise + + arg_c = { + 'context': [context[0]], + 'clip_fea': clip_context, + 'seq_len': max_seq_len, + 'y': [y], + # 'cond_flag': True, + } + + arg_null = { + 'context': context_null, + 'clip_fea': clip_context, + 'seq_len': max_seq_len, + 'y': [y], + # 'cond_flag': False, + } + + if offload_model: + torch.cuda.empty_cache() + + self.model.to(self.device) + for _, t in enumerate(tqdm(timesteps)): + latent_model_input = [latent.to(self.device)] + timestep = [t] + + timestep = torch.stack(timestep).to(self.device) + + noise_pred_cond = self.model( + latent_model_input, t=timestep, **arg_c)[0].to( + torch.device('cpu') if offload_model else self.device) + if offload_model: + torch.cuda.empty_cache() + noise_pred_uncond = self.model( + latent_model_input, t=timestep, **arg_null)[0].to( + torch.device('cpu') if offload_model else self.device) + if offload_model: + torch.cuda.empty_cache() + + noise_pred = noise_pred_uncond + guide_scale * ( + noise_pred_cond - noise_pred_uncond) + + latent = latent.to( + torch.device('cpu') if offload_model else self.device) + + temp_x0 = sample_scheduler.step( + noise_pred.unsqueeze(0), + t, + latent.unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latent = temp_x0.squeeze(0) + + x0 = [latent.to(self.device)] + del latent_model_input, timestep + + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + + if self.rank == 0: + videos = self.vae.decode(x0) + + del noise, latent + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos[0] if self.rank == 0 else None + +def teacache_forward( + self, + x, + t, + context, + seq_len, + clip_fea=None, + y=None, +): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + + logging.info("via teacache forward process") + if self.model_type == 'i2v': + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack( + [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([ + torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], + dim=1) for u in x + ]) + + # time embeddings + with amp.autocast(dtype=torch.float32): + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).float()) + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding( + torch.stack([ + torch.cat( + [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) + for u in context + ])) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = dict( + e=e0, + seq_lens=seq_lens, + grid_sizes=grid_sizes, + freqs=self.freqs, + context=context, + context_lens=context_lens) + + if self.enable_teacache: + modulated_inp = e0 if self.use_ref_steps else e + # teacache + if self.cnt%2==0: # even -> conditon + self.is_even = True + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_even = True + self.accumulated_rel_l1_distance_even = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_even < self.teacache_thresh: + should_calc_even = False + else: + should_calc_even = True + self.accumulated_rel_l1_distance_even = 0 + self.previous_e0_even = modulated_inp.clone() + + else: # odd -> unconditon + self.is_even = False + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_odd = True + self.accumulated_rel_l1_distance_odd = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp-self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: + should_calc_odd = False + else: + should_calc_odd = True + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_odd = modulated_inp.clone() + + if self.enable_teacache: + if self.is_even: + if not should_calc_even: + logging.info("use residual estimation for this difusion step") + x += self.previous_residual_even + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + self.previous_residual_even = x - ori_x + else: + if not should_calc_odd: + logging.info("use residual estimation for thi8s difusion step") + x += self.previous_residual_odd + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + self.previous_residual_odd = x - ori_x + + else: + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + self.cnt += 1 + if self.cnt >= self.num_steps: + self.cnt = 0 + return [u.float() for u in x] def _validate_args(args): # Basic check @@ -243,6 +785,25 @@ def _parse_args(): type=float, default=5.0, help="Classifier free guidance scale.") + + parser.add_argument( + "--use_ret_steps", + action="store_true", + default=False, + help=" use ret_steps or not") + + parser.add_argument( + "--enable_teacache", + action="store_true", + default=False, + help=" use ret_steps or not") + + #teacache_thresh + parser.add_argument( + "--teacache_thresh", + type=float, + default= 0.2, + help="tea_cache threshold") args = parser.parse_args() @@ -367,6 +928,35 @@ def generate(args): use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) + + if args.enable_teacache: + wan_t2v.__class__.generate = t2v_generate + wan_t2v.model.__class__.enable_teacache = True + wan_t2v.model.__class__.forward = teacache_forward + wan_t2v.model.__class__.cnt = 0 + wan_t2v.model.__class__.num_steps = args.sample_steps*2 + wan_t2v.model.__class__.teacache_thresh = args.teacache_thresh + wan_t2v.model.__class__.accumulated_rel_l1_distance_even = 0 + wan_t2v.model.__class__.accumulated_rel_l1_distance_odd = 0 + wan_t2v.model.__class__.previous_e0_even = None + wan_t2v.model.__class__.previous_e0_odd = None + wan_t2v.model.__class__.previous_residual_even = None + wan_t2v.model.__class__.previous_residual_odd = None + wan_t2v.model.__class__.use_ref_steps = args.use_ret_steps + if args.use_ret_steps: + if '1.3B' in args.ckpt_dir: + wan_t2v.model.__class__.coefficients = [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02] + if '14B' in args.ckpt_dir: + wan_t2v.model.__class__.coefficients = [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01] + wan_t2v.model.__class__.ret_steps = 5*2 + wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2 + else: + if '1.3B' in args.ckpt_dir: + wan_t2v.model.__class__.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01] + if '14B' in args.ckpt_dir: + wan_t2v.model.__class__.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] + wan_t2v.model.__class__.ret_steps = 1*2 + wan_t2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2 logging.info( f"Generating {'image' if 't2i' in args.task else 'video'} ...") @@ -424,7 +1014,36 @@ def generate(args): use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) - + + if args.enable_teacache: + wan_i2v.__class__.generate = i2v_generate + wan_i2v.model.__class__.enable_teacache = True + wan_i2v.model.__class__.forward = teacache_forward + wan_i2v.model.__class__.cnt = 0 + wan_i2v.model.__class__.num_steps = args.sample_steps*2 + wan_i2v.model.__class__.teacache_thresh = args.teacache_thresh + wan_i2v.model.__class__.accumulated_rel_l1_distance_even = 0 + wan_i2v.model.__class__.accumulated_rel_l1_distance_odd = 0 + wan_i2v.model.__class__.previous_e0_even = None + wan_i2v.model.__class__.previous_e0_odd = None + wan_i2v.model.__class__.previous_residual_even = None + wan_i2v.model.__class__.previous_residual_odd = None + wan_i2v.model.__class__.use_ref_steps = args.use_ret_steps + if args.use_ret_steps: + if '480P' in args.ckpt_dir: + wan_i2v.model.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01] + if '720P' in args.ckpt_dir: + wan_i2v.model.__class__.coefficients = [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02] + wan_i2v.model.__class__.ret_steps = 5*2 + wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 + else: + if '480P' in args.ckpt_dir: + wan_i2v.model.__class__.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01] + if '720P' in args.ckpt_dir: + wan_i2v.model.__class__.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + wan_i2v.model.__class__.ret_steps = 1*2 + wan_i2v.model.__class__.cutoff_steps = args.sample_steps*2 - 2 + logging.info("Generating video ...") video = wan_i2v.generate( args.prompt,