This commit is contained in:
ZVAXEROWS 2026-01-04 00:09:23 +02:00 committed by GitHub
commit 98f01b83a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

344
wan/wan_flf2v_diffusers.py Normal file
View File

@ -0,0 +1,344 @@
import inspect
from typing import List, Optional, Union, Tuple
import torch
import numpy as np
import PIL.Image
from diffusers import DiffusionPipeline, UniPCMultistepScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
# Import Wan modules
# Using absolute imports ensures this works when installed as a package
from wan.modules.model import WanModel
from wan.modules.t5 import T5EncoderModel
from wan.modules.vae import WanVAE
from wan.modules.clip import CLIPModel
import torchvision.transforms.functional as TF
import torch.nn.functional as F
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanFLF2VPipeline(DiffusionPipeline):
r"""
Pipeline for First-Last-Frame-to-Video generation using Wan2.1.
"""
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
def __init__(
self,
vae: WanVAE,
text_encoder: T5EncoderModel,
image_encoder: CLIPModel,
transformer: WanModel,
scheduler: UniPCMultistepScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
image_encoder=image_encoder,
transformer=transformer,
scheduler=scheduler,
)
self.vae_stride = [4, 8, 8] # hardcoded based on config
self.patch_size = [1, 2, 2] # hardcoded based on config
def check_inputs(
self,
prompt,
first_frame,
last_frame,
height,
width,
callback_steps,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` must be divisible by 16 but are {height} and {width}.")
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` must be an integer > 0 if provided, but is {callback_steps}."
)
def prepare_latents(
self,
batch_size,
num_channels,
num_frames,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels,
(num_frames - 1) // 4 + 1,
height // 8,
width // 8,
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
first_frame: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
last_frame: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
height: Optional[int] = 720,
width: Optional[int] = 1280,
num_frames: Optional[int] = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "np",
callback: Optional[callable] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[dict] = None,
):
# 1. Check inputs
self.check_inputs(prompt, first_frame, last_frame, height, width, callback_steps)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = 1
device = self._execution_device
# 3. Encode input prompt
if isinstance(prompt, str):
prompt = [prompt]
if negative_prompt is None:
negative_prompt = [""] * len(prompt)
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
self.text_encoder.model.to(device)
context = self.text_encoder(prompt, device)
context_null = self.text_encoder(negative_prompt, device)
# 4. Preprocess images
if not isinstance(first_frame, list):
first_frame_list = [first_frame]
last_frame_list = [last_frame]
else:
first_frame_list = first_frame
last_frame_list = last_frame
processed_first = []
processed_last = []
for f, l in zip(first_frame_list, last_frame_list):
f_tensor = TF.to_tensor(f).sub_(0.5).div_(0.5).to(device)
l_tensor = TF.to_tensor(l).sub_(0.5).div_(0.5).to(device)
f_tensor = F.interpolate(f_tensor.unsqueeze(0), size=(height, width), mode='bicubic', align_corners=False).squeeze(0)
l_tensor = F.interpolate(l_tensor.unsqueeze(0), size=(height, width), mode='bicubic', align_corners=False).squeeze(0)
processed_first.append(f_tensor)
processed_last.append(l_tensor)
# 5. Encode images with CLIP
clip_inputs = []
for pf, pl in zip(processed_first, processed_last):
clip_inputs.append(pf.unsqueeze(1)) # [3, 1, H, W]
clip_inputs.append(pl.unsqueeze(1))
self.image_encoder.model.to(device)
clip_context = self.image_encoder.visual(clip_inputs)
# 6. Encode with VAE
y_list = []
for pf, pl in zip(processed_first, processed_last):
pf_input = pf.unsqueeze(1) # [3, 1, H, W]
pl_input = pl.unsqueeze(1)
zeros = torch.zeros(3, num_frames - 2, height, width, device=device)
vae_input = torch.cat([pf_input, zeros, pl_input], dim=1) # [3, F, H, W]
y_list.append(vae_input)
self.vae.model.to(device)
y = self.vae.encode(y_list) # Returns list of [C, T, H, W] latents
# 7. Create Mask and Concat
lat_h = height // 8
lat_w = width // 8
msk = torch.ones(1, num_frames, lat_h, lat_w, device=device)
msk[:, 1:-1] = 0
msk = torch.concat([
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2) # [1, 4, T_lat, H_lat, W_lat]
y_masked = []
for latent in y:
y_masked.append(torch.cat([msk[0], latent], dim=0))
# 8. Prepare Latents (Noise)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
noise_shape = (16, (num_frames - 1) // 4 + 1, lat_h, lat_w)
# Use generator for reproducibility if provided
latents = randn_tensor(shape=noise_shape, generator=generator, device=device, dtype=torch.float32)
latents_list = [latents] * batch_size # List of latents
# 9. Denoising Loop
seq_len = ((num_frames - 1) // 4 + 1) * lat_h * lat_w // 4
self.transformer.to(device)
for i, t in enumerate(self.progress_bar(timesteps)):
t_tensor = torch.stack([t] * batch_size).to(device)
# Predict noise for conditional
noise_pred_cond = self.transformer(
latents_list, t=t_tensor, context=context, seq_len=seq_len, clip_fea=clip_context, y=y_masked
)
# Predict noise for unconditional
if guidance_scale > 1.0:
noise_pred_uncond = self.transformer(
latents_list, t=t_tensor, context=context_null, seq_len=seq_len, clip_fea=clip_context, y=y_masked
)
# Combine (CFG)
noise_pred_list = []
for cond, uncond in zip(noise_pred_cond, noise_pred_uncond):
noise_pred_list.append(uncond + guidance_scale * (cond - uncond))
else:
noise_pred_list = noise_pred_cond
# Step
new_latents_list = []
for latent, noise_pred in zip(latents_list, noise_pred_list):
# Scheduler step usually expects [1, C, T, H, W] or similar.
# noise_pred is [C, T, H, W]
step_output = self.scheduler.step(noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False)[0]
new_latents_list.append(step_output.squeeze(0))
latents_list = new_latents_list
# 10. Decode
# VAE decode expects list
videos = self.vae.decode(latents_list)
output_videos = []
for vid in videos:
# video tensor [3, F, H, W] value range [-1, 1]
# Denormalize to [0, 1]
vid = (vid * 0.5 + 0.5).clamp(0, 1)
vid = vid.permute(1, 2, 3, 0).cpu().numpy() # [F, H, W, C]
output_videos.append(vid)
if output_type == "np":
return ImagePipelineOutput(images=output_videos)
return ImagePipelineOutput(images=output_videos)
if __name__ == "__main__":
import argparse
from wan.configs import WAN_CONFIGS
from functools import partial
import os
# Re-import to ensure we are using the module definitions available in scope if needed
from wan.modules.model import WanModel
from wan.modules.t5 import T5EncoderModel
from wan.modules.vae import WanVAE
from wan.modules.clip import CLIPModel
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_dir", type=str, required=True)
parser.add_argument("--first_frame", type=str, required=True)
parser.add_argument("--last_frame", type=str, required=True)
parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--output", type=str, default="output.mp4")
parser.add_argument("--device_id", type=int, default=0)
args = parser.parse_args()
config = WAN_CONFIGS['flf2v-14B']
device = torch.device(f"cuda:{args.device_id}")
print(f"Loading models from {args.checkpoint_dir}...")
# 1. Text Encoder
text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
device=torch.device('cpu'),
checkpoint_path=os.path.join(args.checkpoint_dir, config.t5_checkpoint),
tokenizer_path=os.path.join(args.checkpoint_dir, config.t5_tokenizer),
)
# 2. VAE
vae = WanVAE(
vae_pth=os.path.join(args.checkpoint_dir, config.vae_checkpoint),
device=device
)
# 3. CLIP
image_encoder = CLIPModel(
dtype=config.clip_dtype,
device=device,
checkpoint_path=os.path.join(args.checkpoint_dir, config.clip_checkpoint),
tokenizer_path=os.path.join(args.checkpoint_dir, config.clip_tokenizer)
)
# 4. Transformer
transformer = WanModel.from_pretrained(args.checkpoint_dir, model_type='flf2v')
transformer.eval().requires_grad_(False)
# 5. Scheduler
scheduler = UniPCMultistepScheduler(
prediction_type='flow_prediction',
use_flow_sigmas=True,
num_train_timesteps=1000,
flow_shift=16.0
)
pipe = WanFLF2VPipeline(
vae=vae,
text_encoder=text_encoder,
image_encoder=image_encoder,
transformer=transformer,
scheduler=scheduler
)
pipe.to(device)
print(f"Loading images...")
first_img = PIL.Image.open(args.first_frame).convert("RGB")
last_img = PIL.Image.open(args.last_frame).convert("RGB")
print("Generating video...")
output = pipe(
prompt=args.prompt,
first_frame=first_img,
last_frame=last_img,
height=720,
width=1280,
num_frames=81,
guidance_scale=5.0
)
import imageio
video = output.images[0] # [F, H, W, C]
video = (video * 255).astype(np.uint8)
imageio.mimsave(args.output, video, fps=16)
print(f"Video saved to {args.output}")