mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-07-14 11:40:10 +00:00
265 lines
8.8 KiB
Python
265 lines
8.8 KiB
Python
import gc
|
|
import logging
|
|
import math
|
|
import os
|
|
import random
|
|
import sys
|
|
import types
|
|
from contextlib import contextmanager
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.cuda.amp as amp
|
|
import torch.distributed as dist
|
|
import torchvision.transforms.functional as TF
|
|
from tqdm import tqdm
|
|
|
|
from wan.distributed.fsdp import shard_model
|
|
from wan.modules.clip import CLIPModel
|
|
from wan.modules.model import WanModel
|
|
from wan.modules.t5 import T5EncoderModel
|
|
from wan.modules.vae import WanVAE
|
|
from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
|
get_sampling_sigmas, retrieve_timesteps)
|
|
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
|
from wan import WanI2V
|
|
|
|
def wan_i2v_generate(self:WanI2V,
|
|
input_prompt,
|
|
img,
|
|
max_area=720 * 1280,
|
|
frame_num=81,
|
|
shift=5.0,
|
|
sample_solver='unipc',
|
|
sampling_steps=40,
|
|
guide_scale=5.0,
|
|
n_prompt="",
|
|
seed=-1,
|
|
offload_model=True):
|
|
r"""
|
|
Generates video frames from input image and text prompt using diffusion process.
|
|
|
|
Args:
|
|
input_prompt (`str`):
|
|
Text prompt for content generation.
|
|
img (PIL.Image.Image):
|
|
Input image tensor. Shape: [3, H, W]
|
|
max_area (`int`, *optional*, defaults to 720*1280):
|
|
Maximum pixel area for latent space calculation. Controls video resolution scaling
|
|
frame_num (`int`, *optional*, defaults to 81):
|
|
How many frames to sample from a video. The number should be 4n+1
|
|
shift (`float`, *optional*, defaults to 5.0):
|
|
Noise schedule shift parameter. Affects temporal dynamics
|
|
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
|
|
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
|
Solver used to sample the video.
|
|
sampling_steps (`int`, *optional*, defaults to 40):
|
|
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
|
guide_scale (`float`, *optional*, defaults 5.0):
|
|
Classifier-free guidance scale. Controls prompt adherence vs. creativity
|
|
n_prompt (`str`, *optional*, defaults to ""):
|
|
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
|
seed (`int`, *optional*, defaults to -1):
|
|
Random seed for noise generation. If -1, use random seed
|
|
offload_model (`bool`, *optional*, defaults to True):
|
|
If True, offloads models to CPU during generation to save VRAM
|
|
|
|
Returns:
|
|
torch.Tensor:
|
|
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
|
- C: Color channels (3 for RGB)
|
|
- N: Number of frames (81)
|
|
- H: Frame height (from max_area)
|
|
- W: Frame width from max_area)
|
|
"""
|
|
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
|
|
|
F = frame_num
|
|
h, w = img.shape[1:]
|
|
aspect_ratio = h / w
|
|
lat_h = round(
|
|
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
|
|
self.patch_size[1] * self.patch_size[1])
|
|
lat_w = round(
|
|
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
|
|
self.patch_size[2] * self.patch_size[2])
|
|
h = lat_h * self.vae_stride[1]
|
|
w = lat_w * self.vae_stride[2]
|
|
|
|
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
|
|
self.patch_size[1] * self.patch_size[2])
|
|
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
|
|
|
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
|
seed_g = torch.Generator(device=self.device)
|
|
seed_g.manual_seed(seed)
|
|
noise = torch.randn(
|
|
16,
|
|
21,
|
|
lat_h,
|
|
lat_w,
|
|
dtype=torch.float32,
|
|
generator=seed_g,
|
|
device=self.device)
|
|
|
|
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
|
msk[:, 1:] = 0
|
|
msk = torch.concat([
|
|
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
|
],
|
|
dim=1)
|
|
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
|
msk = msk.transpose(1, 2)[0]
|
|
|
|
if n_prompt == "":
|
|
n_prompt = self.sample_neg_prompt
|
|
|
|
# preprocess
|
|
if not self.t5_cpu:
|
|
self.text_encoder.model.to(self.device)
|
|
context = self.text_encoder([input_prompt], self.device)
|
|
context_null = self.text_encoder([n_prompt], self.device)
|
|
if offload_model:
|
|
self.text_encoder.model.cpu()
|
|
else:
|
|
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
|
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
|
context = [t.to(self.device) for t in context]
|
|
context_null = [t.to(self.device) for t in context_null]
|
|
|
|
self.clip.model.to(self.device)
|
|
clip_context = self.clip.visual([img[:, None, :, :]])
|
|
if offload_model:
|
|
self.clip.model.cpu()
|
|
|
|
y = self.vae.encode([
|
|
torch.concat([
|
|
torch.nn.functional.interpolate(
|
|
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
|
|
0, 1),
|
|
torch.zeros(3, 80, h, w)
|
|
],
|
|
dim=1).to(self.device)
|
|
])[0]
|
|
y = torch.concat([msk, y])
|
|
|
|
@contextmanager
|
|
def noop_no_sync():
|
|
yield
|
|
|
|
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
|
|
|
# evaluation mode
|
|
with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
|
|
|
|
if sample_solver == 'unipc':
|
|
sample_scheduler = FlowUniPCMultistepScheduler(
|
|
num_train_timesteps=self.num_train_timesteps,
|
|
shift=1,
|
|
use_dynamic_shifting=False)
|
|
sample_scheduler.set_timesteps(
|
|
sampling_steps, device=self.device, shift=shift)
|
|
timesteps = sample_scheduler.timesteps
|
|
elif sample_solver == 'dpm++':
|
|
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
|
num_train_timesteps=self.num_train_timesteps,
|
|
shift=1,
|
|
use_dynamic_shifting=False)
|
|
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
|
timesteps, _ = retrieve_timesteps(
|
|
sample_scheduler,
|
|
device=self.device,
|
|
sigmas=sampling_sigmas)
|
|
else:
|
|
raise NotImplementedError("Unsupported solver.")
|
|
|
|
# sample videos
|
|
latent = noise
|
|
|
|
arg_c = {
|
|
'context': [context[0]],
|
|
'clip_fea': clip_context,
|
|
'seq_len': max_seq_len,
|
|
'y': [y],
|
|
}
|
|
|
|
arg_null = {
|
|
'context': context_null,
|
|
'clip_fea': clip_context,
|
|
'seq_len': max_seq_len,
|
|
'y': [y],
|
|
}
|
|
|
|
if offload_model:
|
|
torch.cuda.empty_cache()
|
|
|
|
self.model.to(self.device)
|
|
|
|
for i, t in enumerate(tqdm(timesteps)):
|
|
|
|
current_step = i
|
|
|
|
latent_model_input = [latent.to(self.device)]
|
|
timestep = [t]
|
|
|
|
timestep = torch.stack(timestep).to(self.device)
|
|
|
|
current_stream = 'cond_stream'
|
|
|
|
noise_pred_cond = self.model(
|
|
latent_model_input, t=timestep,
|
|
current_step = current_step,
|
|
current_stream = current_stream,
|
|
**arg_c)[0].to(
|
|
torch.device('cpu') if offload_model else self.device)
|
|
|
|
if offload_model:
|
|
torch.cuda.empty_cache()
|
|
|
|
current_stream = 'uncond_stream'
|
|
|
|
noise_pred_uncond = self.model(
|
|
latent_model_input, t=timestep,
|
|
current_step = current_step,
|
|
current_stream = current_stream,
|
|
**arg_null)[0].to(
|
|
torch.device('cpu') if offload_model else self.device)
|
|
|
|
if offload_model:
|
|
torch.cuda.empty_cache()
|
|
|
|
noise_pred = noise_pred_uncond + guide_scale * (
|
|
noise_pred_cond - noise_pred_uncond)
|
|
|
|
latent = latent.to(
|
|
torch.device('cpu') if offload_model else self.device)
|
|
|
|
temp_x0 = sample_scheduler.step(
|
|
noise_pred.unsqueeze(0),
|
|
t,
|
|
latent.unsqueeze(0),
|
|
return_dict=False,
|
|
generator=seed_g)[0]
|
|
latent = temp_x0.squeeze(0)
|
|
|
|
x0 = [latent.to(self.device)]
|
|
del latent_model_input, timestep
|
|
|
|
if offload_model:
|
|
self.model.cpu()
|
|
torch.cuda.empty_cache()
|
|
|
|
if self.rank == 0:
|
|
videos = self.vae.decode(x0)
|
|
|
|
del noise, latent
|
|
del sample_scheduler
|
|
if offload_model:
|
|
gc.collect()
|
|
torch.cuda.synchronize()
|
|
if dist.is_initialized():
|
|
dist.barrier()
|
|
|
|
return videos[0] if self.rank == 0 else None
|