mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2026-01-11 16:53:34 +00:00
Merge 58a0244c0f into ae487cc653
This commit is contained in:
commit
98f01b83a7
344
wan/wan_flf2v_diffusers.py
Normal file
344
wan/wan_flf2v_diffusers.py
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
import inspect
|
||||||
|
from typing import List, Optional, Union, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import PIL.Image
|
||||||
|
from diffusers import DiffusionPipeline, UniPCMultistepScheduler
|
||||||
|
from diffusers.utils import logging
|
||||||
|
from diffusers.utils.torch_utils import randn_tensor
|
||||||
|
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
|
||||||
|
|
||||||
|
# Import Wan modules
|
||||||
|
# Using absolute imports ensures this works when installed as a package
|
||||||
|
from wan.modules.model import WanModel
|
||||||
|
from wan.modules.t5 import T5EncoderModel
|
||||||
|
from wan.modules.vae import WanVAE
|
||||||
|
from wan.modules.clip import CLIPModel
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
class WanFLF2VPipeline(DiffusionPipeline):
|
||||||
|
r"""
|
||||||
|
Pipeline for First-Last-Frame-to-Video generation using Wan2.1.
|
||||||
|
"""
|
||||||
|
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vae: WanVAE,
|
||||||
|
text_encoder: T5EncoderModel,
|
||||||
|
image_encoder: CLIPModel,
|
||||||
|
transformer: WanModel,
|
||||||
|
scheduler: UniPCMultistepScheduler,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
image_encoder=image_encoder,
|
||||||
|
transformer=transformer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
self.vae_stride = [4, 8, 8] # hardcoded based on config
|
||||||
|
self.patch_size = [1, 2, 2] # hardcoded based on config
|
||||||
|
|
||||||
|
def check_inputs(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
first_frame,
|
||||||
|
last_frame,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
callback_steps,
|
||||||
|
):
|
||||||
|
if height % 16 != 0 or width % 16 != 0:
|
||||||
|
raise ValueError(f"`height` and `width` must be divisible by 16 but are {height} and {width}.")
|
||||||
|
|
||||||
|
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||||
|
raise ValueError(
|
||||||
|
f"`callback_steps` must be an integer > 0 if provided, but is {callback_steps}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_latents(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
num_channels,
|
||||||
|
num_frames,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
shape = (
|
||||||
|
batch_size,
|
||||||
|
num_channels,
|
||||||
|
(num_frames - 1) // 4 + 1,
|
||||||
|
height // 8,
|
||||||
|
width // 8,
|
||||||
|
)
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
latents = latents.to(device)
|
||||||
|
|
||||||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
|
latents = latents * self.scheduler.init_noise_sigma
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
first_frame: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
|
||||||
|
last_frame: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
|
||||||
|
height: Optional[int] = 720,
|
||||||
|
width: Optional[int] = 1280,
|
||||||
|
num_frames: Optional[int] = 81,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
guidance_scale: float = 5.0,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
output_type: Optional[str] = "np",
|
||||||
|
callback: Optional[callable] = None,
|
||||||
|
callback_steps: Optional[int] = 1,
|
||||||
|
cross_attention_kwargs: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
# 1. Check inputs
|
||||||
|
self.check_inputs(prompt, first_frame, last_frame, height, width, callback_steps)
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = 1
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
|
||||||
|
# 3. Encode input prompt
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
prompt = [prompt]
|
||||||
|
if negative_prompt is None:
|
||||||
|
negative_prompt = [""] * len(prompt)
|
||||||
|
elif isinstance(negative_prompt, str):
|
||||||
|
negative_prompt = [negative_prompt]
|
||||||
|
|
||||||
|
self.text_encoder.model.to(device)
|
||||||
|
context = self.text_encoder(prompt, device)
|
||||||
|
context_null = self.text_encoder(negative_prompt, device)
|
||||||
|
|
||||||
|
# 4. Preprocess images
|
||||||
|
if not isinstance(first_frame, list):
|
||||||
|
first_frame_list = [first_frame]
|
||||||
|
last_frame_list = [last_frame]
|
||||||
|
else:
|
||||||
|
first_frame_list = first_frame
|
||||||
|
last_frame_list = last_frame
|
||||||
|
|
||||||
|
processed_first = []
|
||||||
|
processed_last = []
|
||||||
|
|
||||||
|
for f, l in zip(first_frame_list, last_frame_list):
|
||||||
|
f_tensor = TF.to_tensor(f).sub_(0.5).div_(0.5).to(device)
|
||||||
|
l_tensor = TF.to_tensor(l).sub_(0.5).div_(0.5).to(device)
|
||||||
|
f_tensor = F.interpolate(f_tensor.unsqueeze(0), size=(height, width), mode='bicubic', align_corners=False).squeeze(0)
|
||||||
|
l_tensor = F.interpolate(l_tensor.unsqueeze(0), size=(height, width), mode='bicubic', align_corners=False).squeeze(0)
|
||||||
|
processed_first.append(f_tensor)
|
||||||
|
processed_last.append(l_tensor)
|
||||||
|
|
||||||
|
# 5. Encode images with CLIP
|
||||||
|
clip_inputs = []
|
||||||
|
for pf, pl in zip(processed_first, processed_last):
|
||||||
|
clip_inputs.append(pf.unsqueeze(1)) # [3, 1, H, W]
|
||||||
|
clip_inputs.append(pl.unsqueeze(1))
|
||||||
|
|
||||||
|
self.image_encoder.model.to(device)
|
||||||
|
clip_context = self.image_encoder.visual(clip_inputs)
|
||||||
|
|
||||||
|
# 6. Encode with VAE
|
||||||
|
y_list = []
|
||||||
|
for pf, pl in zip(processed_first, processed_last):
|
||||||
|
pf_input = pf.unsqueeze(1) # [3, 1, H, W]
|
||||||
|
pl_input = pl.unsqueeze(1)
|
||||||
|
zeros = torch.zeros(3, num_frames - 2, height, width, device=device)
|
||||||
|
vae_input = torch.cat([pf_input, zeros, pl_input], dim=1) # [3, F, H, W]
|
||||||
|
y_list.append(vae_input)
|
||||||
|
|
||||||
|
self.vae.model.to(device)
|
||||||
|
y = self.vae.encode(y_list) # Returns list of [C, T, H, W] latents
|
||||||
|
|
||||||
|
# 7. Create Mask and Concat
|
||||||
|
lat_h = height // 8
|
||||||
|
lat_w = width // 8
|
||||||
|
|
||||||
|
msk = torch.ones(1, num_frames, lat_h, lat_w, device=device)
|
||||||
|
msk[:, 1:-1] = 0
|
||||||
|
msk = torch.concat([
|
||||||
|
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
||||||
|
], dim=1)
|
||||||
|
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||||
|
msk = msk.transpose(1, 2) # [1, 4, T_lat, H_lat, W_lat]
|
||||||
|
|
||||||
|
y_masked = []
|
||||||
|
for latent in y:
|
||||||
|
y_masked.append(torch.cat([msk[0], latent], dim=0))
|
||||||
|
|
||||||
|
# 8. Prepare Latents (Noise)
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
|
noise_shape = (16, (num_frames - 1) // 4 + 1, lat_h, lat_w)
|
||||||
|
# Use generator for reproducibility if provided
|
||||||
|
latents = randn_tensor(shape=noise_shape, generator=generator, device=device, dtype=torch.float32)
|
||||||
|
latents_list = [latents] * batch_size # List of latents
|
||||||
|
|
||||||
|
# 9. Denoising Loop
|
||||||
|
seq_len = ((num_frames - 1) // 4 + 1) * lat_h * lat_w // 4
|
||||||
|
|
||||||
|
self.transformer.to(device)
|
||||||
|
|
||||||
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
|
t_tensor = torch.stack([t] * batch_size).to(device)
|
||||||
|
|
||||||
|
# Predict noise for conditional
|
||||||
|
noise_pred_cond = self.transformer(
|
||||||
|
latents_list, t=t_tensor, context=context, seq_len=seq_len, clip_fea=clip_context, y=y_masked
|
||||||
|
)
|
||||||
|
|
||||||
|
# Predict noise for unconditional
|
||||||
|
if guidance_scale > 1.0:
|
||||||
|
noise_pred_uncond = self.transformer(
|
||||||
|
latents_list, t=t_tensor, context=context_null, seq_len=seq_len, clip_fea=clip_context, y=y_masked
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine (CFG)
|
||||||
|
noise_pred_list = []
|
||||||
|
for cond, uncond in zip(noise_pred_cond, noise_pred_uncond):
|
||||||
|
noise_pred_list.append(uncond + guidance_scale * (cond - uncond))
|
||||||
|
else:
|
||||||
|
noise_pred_list = noise_pred_cond
|
||||||
|
|
||||||
|
# Step
|
||||||
|
new_latents_list = []
|
||||||
|
for latent, noise_pred in zip(latents_list, noise_pred_list):
|
||||||
|
# Scheduler step usually expects [1, C, T, H, W] or similar.
|
||||||
|
# noise_pred is [C, T, H, W]
|
||||||
|
step_output = self.scheduler.step(noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False)[0]
|
||||||
|
new_latents_list.append(step_output.squeeze(0))
|
||||||
|
latents_list = new_latents_list
|
||||||
|
|
||||||
|
# 10. Decode
|
||||||
|
# VAE decode expects list
|
||||||
|
videos = self.vae.decode(latents_list)
|
||||||
|
|
||||||
|
output_videos = []
|
||||||
|
for vid in videos:
|
||||||
|
# video tensor [3, F, H, W] value range [-1, 1]
|
||||||
|
# Denormalize to [0, 1]
|
||||||
|
vid = (vid * 0.5 + 0.5).clamp(0, 1)
|
||||||
|
vid = vid.permute(1, 2, 3, 0).cpu().numpy() # [F, H, W, C]
|
||||||
|
output_videos.append(vid)
|
||||||
|
|
||||||
|
if output_type == "np":
|
||||||
|
return ImagePipelineOutput(images=output_videos)
|
||||||
|
|
||||||
|
return ImagePipelineOutput(images=output_videos)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
from wan.configs import WAN_CONFIGS
|
||||||
|
from functools import partial
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Re-import to ensure we are using the module definitions available in scope if needed
|
||||||
|
from wan.modules.model import WanModel
|
||||||
|
from wan.modules.t5 import T5EncoderModel
|
||||||
|
from wan.modules.vae import WanVAE
|
||||||
|
from wan.modules.clip import CLIPModel
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--checkpoint_dir", type=str, required=True)
|
||||||
|
parser.add_argument("--first_frame", type=str, required=True)
|
||||||
|
parser.add_argument("--last_frame", type=str, required=True)
|
||||||
|
parser.add_argument("--prompt", type=str, required=True)
|
||||||
|
parser.add_argument("--output", type=str, default="output.mp4")
|
||||||
|
parser.add_argument("--device_id", type=int, default=0)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config = WAN_CONFIGS['flf2v-14B']
|
||||||
|
device = torch.device(f"cuda:{args.device_id}")
|
||||||
|
|
||||||
|
print(f"Loading models from {args.checkpoint_dir}...")
|
||||||
|
|
||||||
|
# 1. Text Encoder
|
||||||
|
text_encoder = T5EncoderModel(
|
||||||
|
text_len=config.text_len,
|
||||||
|
dtype=config.t5_dtype,
|
||||||
|
device=torch.device('cpu'),
|
||||||
|
checkpoint_path=os.path.join(args.checkpoint_dir, config.t5_checkpoint),
|
||||||
|
tokenizer_path=os.path.join(args.checkpoint_dir, config.t5_tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. VAE
|
||||||
|
vae = WanVAE(
|
||||||
|
vae_pth=os.path.join(args.checkpoint_dir, config.vae_checkpoint),
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. CLIP
|
||||||
|
image_encoder = CLIPModel(
|
||||||
|
dtype=config.clip_dtype,
|
||||||
|
device=device,
|
||||||
|
checkpoint_path=os.path.join(args.checkpoint_dir, config.clip_checkpoint),
|
||||||
|
tokenizer_path=os.path.join(args.checkpoint_dir, config.clip_tokenizer)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Transformer
|
||||||
|
transformer = WanModel.from_pretrained(args.checkpoint_dir, model_type='flf2v')
|
||||||
|
transformer.eval().requires_grad_(False)
|
||||||
|
|
||||||
|
# 5. Scheduler
|
||||||
|
scheduler = UniPCMultistepScheduler(
|
||||||
|
prediction_type='flow_prediction',
|
||||||
|
use_flow_sigmas=True,
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
flow_shift=16.0
|
||||||
|
)
|
||||||
|
|
||||||
|
pipe = WanFLF2VPipeline(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
image_encoder=image_encoder,
|
||||||
|
transformer=transformer,
|
||||||
|
scheduler=scheduler
|
||||||
|
)
|
||||||
|
pipe.to(device)
|
||||||
|
|
||||||
|
print(f"Loading images...")
|
||||||
|
first_img = PIL.Image.open(args.first_frame).convert("RGB")
|
||||||
|
last_img = PIL.Image.open(args.last_frame).convert("RGB")
|
||||||
|
|
||||||
|
print("Generating video...")
|
||||||
|
output = pipe(
|
||||||
|
prompt=args.prompt,
|
||||||
|
first_frame=first_img,
|
||||||
|
last_frame=last_img,
|
||||||
|
height=720,
|
||||||
|
width=1280,
|
||||||
|
num_frames=81,
|
||||||
|
guidance_scale=5.0
|
||||||
|
)
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
video = output.images[0] # [F, H, W, C]
|
||||||
|
video = (video * 255).astype(np.uint8)
|
||||||
|
imageio.mimsave(args.output, video, fps=16)
|
||||||
|
print(f"Video saved to {args.output}")
|
||||||
Loading…
Reference in New Issue
Block a user