import gc import logging import math import os import random import sys import types from contextlib import contextmanager from functools import partial import numpy as np import torch import torch.cuda.amp as amp import torch.distributed as dist import torchvision.transforms.functional as TF from tqdm import tqdm from wan.distributed.fsdp import shard_model from wan.modules.clip import CLIPModel from wan.modules.model import WanModel from wan.modules.t5 import T5EncoderModel from wan.modules.vae import WanVAE from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps) from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from wan import WanI2V def wan_i2v_generate(self:WanI2V, input_prompt, img, max_area=720 * 1280, frame_num=81, shift=5.0, sample_solver='unipc', sampling_steps=40, guide_scale=5.0, n_prompt="", seed=-1, offload_model=True): r""" Generates video frames from input image and text prompt using diffusion process. Args: input_prompt (`str`): Text prompt for content generation. img (PIL.Image.Image): Input image tensor. Shape: [3, H, W] max_area (`int`, *optional*, defaults to 720*1280): Maximum pixel area for latent space calculation. Controls video resolution scaling frame_num (`int`, *optional*, defaults to 81): How many frames to sample from a video. The number should be 4n+1 shift (`float`, *optional*, defaults to 5.0): Noise schedule shift parameter. Affects temporal dynamics [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. sample_solver (`str`, *optional*, defaults to 'unipc'): Solver used to sample the video. sampling_steps (`int`, *optional*, defaults to 40): Number of diffusion sampling steps. Higher values improve quality but slow generation guide_scale (`float`, *optional*, defaults 5.0): Classifier-free guidance scale. Controls prompt adherence vs. creativity n_prompt (`str`, *optional*, defaults to ""): Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` seed (`int`, *optional*, defaults to -1): Random seed for noise generation. If -1, use random seed offload_model (`bool`, *optional*, defaults to True): If True, offloads models to CPU during generation to save VRAM Returns: torch.Tensor: Generated video frames tensor. Dimensions: (C, N H, W) where: - C: Color channels (3 for RGB) - N: Number of frames (81) - H: Frame height (from max_area) - W: Frame width from max_area) """ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) F = frame_num h, w = img.shape[1:] aspect_ratio = h / w lat_h = round( np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // self.patch_size[1] * self.patch_size[1]) lat_w = round( np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // self.patch_size[2] * self.patch_size[2]) h = lat_h * self.vae_stride[1] w = lat_w * self.vae_stride[2] max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2]) max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size seed = seed if seed >= 0 else random.randint(0, sys.maxsize) seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) noise = torch.randn( 16, 21, lat_h, lat_w, dtype=torch.float32, generator=seed_g, device=self.device) msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) msk[:, 1:] = 0 msk = torch.concat([ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] ], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.transpose(1, 2)[0] if n_prompt == "": n_prompt = self.sample_neg_prompt # preprocess if not self.t5_cpu: self.text_encoder.model.to(self.device) context = self.text_encoder([input_prompt], self.device) context_null = self.text_encoder([n_prompt], self.device) if offload_model: self.text_encoder.model.cpu() else: context = self.text_encoder([input_prompt], torch.device('cpu')) context_null = self.text_encoder([n_prompt], torch.device('cpu')) context = [t.to(self.device) for t in context] context_null = [t.to(self.device) for t in context_null] self.clip.model.to(self.device) clip_context = self.clip.visual([img[:, None, :, :]]) if offload_model: self.clip.model.cpu() y = self.vae.encode([ torch.concat([ torch.nn.functional.interpolate( img[None].cpu(), size=(h, w), mode='bicubic').transpose( 0, 1), torch.zeros(3, 80, h, w) ], dim=1).to(self.device) ])[0] y = torch.concat([msk, y]) @contextmanager def noop_no_sync(): yield no_sync = getattr(self.model, 'no_sync', noop_no_sync) # evaluation mode with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): if sample_solver == 'unipc': sample_scheduler = FlowUniPCMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False) sample_scheduler.set_timesteps( sampling_steps, device=self.device, shift=shift) timesteps = sample_scheduler.timesteps elif sample_solver == 'dpm++': sample_scheduler = FlowDPMSolverMultistepScheduler( num_train_timesteps=self.num_train_timesteps, shift=1, use_dynamic_shifting=False) sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) timesteps, _ = retrieve_timesteps( sample_scheduler, device=self.device, sigmas=sampling_sigmas) else: raise NotImplementedError("Unsupported solver.") # sample videos latent = noise arg_c = { 'context': [context[0]], 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': [y], } arg_null = { 'context': context_null, 'clip_fea': clip_context, 'seq_len': max_seq_len, 'y': [y], } if offload_model: torch.cuda.empty_cache() self.model.to(self.device) for i, t in enumerate(tqdm(timesteps)): current_step = i latent_model_input = [latent.to(self.device)] timestep = [t] timestep = torch.stack(timestep).to(self.device) current_stream = 'cond_stream' noise_pred_cond = self.model( latent_model_input, t=timestep, current_step = current_step, current_stream = current_stream, **arg_c)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: torch.cuda.empty_cache() current_stream = 'uncond_stream' noise_pred_uncond = self.model( latent_model_input, t=timestep, current_step = current_step, current_stream = current_stream, **arg_null)[0].to( torch.device('cpu') if offload_model else self.device) if offload_model: torch.cuda.empty_cache() noise_pred = noise_pred_uncond + guide_scale * ( noise_pred_cond - noise_pred_uncond) latent = latent.to( torch.device('cpu') if offload_model else self.device) temp_x0 = sample_scheduler.step( noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False, generator=seed_g)[0] latent = temp_x0.squeeze(0) x0 = [latent.to(self.device)] del latent_model_input, timestep if offload_model: self.model.cpu() torch.cuda.empty_cache() if self.rank == 0: videos = self.vae.decode(x0) del noise, latent del sample_scheduler if offload_model: gc.collect() torch.cuda.synchronize() if dist.is_initialized(): dist.barrier() return videos[0] if self.rank == 0 else None