diff --git a/wan/wan_flf2v_diffusers.py b/wan/wan_flf2v_diffusers.py new file mode 100644 index 0000000..a855ea6 --- /dev/null +++ b/wan/wan_flf2v_diffusers.py @@ -0,0 +1,344 @@ +import inspect +from typing import List, Optional, Union, Tuple + +import torch +import numpy as np +import PIL.Image +from diffusers import DiffusionPipeline, UniPCMultistepScheduler +from diffusers.utils import logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import ImagePipelineOutput + +# Import Wan modules +# Using absolute imports ensures this works when installed as a package +from wan.modules.model import WanModel +from wan.modules.t5 import T5EncoderModel +from wan.modules.vae import WanVAE +from wan.modules.clip import CLIPModel +import torchvision.transforms.functional as TF +import torch.nn.functional as F + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class WanFLF2VPipeline(DiffusionPipeline): + r""" + Pipeline for First-Last-Frame-to-Video generation using Wan2.1. + """ + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + + def __init__( + self, + vae: WanVAE, + text_encoder: T5EncoderModel, + image_encoder: CLIPModel, + transformer: WanModel, + scheduler: UniPCMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_stride = [4, 8, 8] # hardcoded based on config + self.patch_size = [1, 2, 2] # hardcoded based on config + + def check_inputs( + self, + prompt, + first_frame, + last_frame, + height, + width, + callback_steps, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` must be divisible by 16 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` must be an integer > 0 if provided, but is {callback_steps}." + ) + + def prepare_latents( + self, + batch_size, + num_channels, + num_frames, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels, + (num_frames - 1) // 4 + 1, + height // 8, + width // 8, + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + first_frame: Union[PIL.Image.Image, List[PIL.Image.Image]] = None, + last_frame: Union[PIL.Image.Image, List[PIL.Image.Image]] = None, + height: Optional[int] = 720, + width: Optional[int] = 1280, + num_frames: Optional[int] = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "np", + callback: Optional[callable] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[dict] = None, + ): + # 1. Check inputs + self.check_inputs(prompt, first_frame, last_frame, height, width, callback_steps) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = 1 + + device = self._execution_device + + # 3. Encode input prompt + if isinstance(prompt, str): + prompt = [prompt] + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + self.text_encoder.model.to(device) + context = self.text_encoder(prompt, device) + context_null = self.text_encoder(negative_prompt, device) + + # 4. Preprocess images + if not isinstance(first_frame, list): + first_frame_list = [first_frame] + last_frame_list = [last_frame] + else: + first_frame_list = first_frame + last_frame_list = last_frame + + processed_first = [] + processed_last = [] + + for f, l in zip(first_frame_list, last_frame_list): + f_tensor = TF.to_tensor(f).sub_(0.5).div_(0.5).to(device) + l_tensor = TF.to_tensor(l).sub_(0.5).div_(0.5).to(device) + f_tensor = F.interpolate(f_tensor.unsqueeze(0), size=(height, width), mode='bicubic', align_corners=False).squeeze(0) + l_tensor = F.interpolate(l_tensor.unsqueeze(0), size=(height, width), mode='bicubic', align_corners=False).squeeze(0) + processed_first.append(f_tensor) + processed_last.append(l_tensor) + + # 5. Encode images with CLIP + clip_inputs = [] + for pf, pl in zip(processed_first, processed_last): + clip_inputs.append(pf.unsqueeze(1)) # [3, 1, H, W] + clip_inputs.append(pl.unsqueeze(1)) + + self.image_encoder.model.to(device) + clip_context = self.image_encoder.visual(clip_inputs) + + # 6. Encode with VAE + y_list = [] + for pf, pl in zip(processed_first, processed_last): + pf_input = pf.unsqueeze(1) # [3, 1, H, W] + pl_input = pl.unsqueeze(1) + zeros = torch.zeros(3, num_frames - 2, height, width, device=device) + vae_input = torch.cat([pf_input, zeros, pl_input], dim=1) # [3, F, H, W] + y_list.append(vae_input) + + self.vae.model.to(device) + y = self.vae.encode(y_list) # Returns list of [C, T, H, W] latents + + # 7. Create Mask and Concat + lat_h = height // 8 + lat_w = width // 8 + + msk = torch.ones(1, num_frames, lat_h, lat_w, device=device) + msk[:, 1:-1] = 0 + msk = torch.concat([ + torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] + ], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2) # [1, 4, T_lat, H_lat, W_lat] + + y_masked = [] + for latent in y: + y_masked.append(torch.cat([msk[0], latent], dim=0)) + + # 8. Prepare Latents (Noise) + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + noise_shape = (16, (num_frames - 1) // 4 + 1, lat_h, lat_w) + # Use generator for reproducibility if provided + latents = randn_tensor(shape=noise_shape, generator=generator, device=device, dtype=torch.float32) + latents_list = [latents] * batch_size # List of latents + + # 9. Denoising Loop + seq_len = ((num_frames - 1) // 4 + 1) * lat_h * lat_w // 4 + + self.transformer.to(device) + + for i, t in enumerate(self.progress_bar(timesteps)): + t_tensor = torch.stack([t] * batch_size).to(device) + + # Predict noise for conditional + noise_pred_cond = self.transformer( + latents_list, t=t_tensor, context=context, seq_len=seq_len, clip_fea=clip_context, y=y_masked + ) + + # Predict noise for unconditional + if guidance_scale > 1.0: + noise_pred_uncond = self.transformer( + latents_list, t=t_tensor, context=context_null, seq_len=seq_len, clip_fea=clip_context, y=y_masked + ) + + # Combine (CFG) + noise_pred_list = [] + for cond, uncond in zip(noise_pred_cond, noise_pred_uncond): + noise_pred_list.append(uncond + guidance_scale * (cond - uncond)) + else: + noise_pred_list = noise_pred_cond + + # Step + new_latents_list = [] + for latent, noise_pred in zip(latents_list, noise_pred_list): + # Scheduler step usually expects [1, C, T, H, W] or similar. + # noise_pred is [C, T, H, W] + step_output = self.scheduler.step(noise_pred.unsqueeze(0), t, latent.unsqueeze(0), return_dict=False)[0] + new_latents_list.append(step_output.squeeze(0)) + latents_list = new_latents_list + + # 10. Decode + # VAE decode expects list + videos = self.vae.decode(latents_list) + + output_videos = [] + for vid in videos: + # video tensor [3, F, H, W] value range [-1, 1] + # Denormalize to [0, 1] + vid = (vid * 0.5 + 0.5).clamp(0, 1) + vid = vid.permute(1, 2, 3, 0).cpu().numpy() # [F, H, W, C] + output_videos.append(vid) + + if output_type == "np": + return ImagePipelineOutput(images=output_videos) + + return ImagePipelineOutput(images=output_videos) + +if __name__ == "__main__": + import argparse + from wan.configs import WAN_CONFIGS + from functools import partial + import os + + # Re-import to ensure we are using the module definitions available in scope if needed + from wan.modules.model import WanModel + from wan.modules.t5 import T5EncoderModel + from wan.modules.vae import WanVAE + from wan.modules.clip import CLIPModel + + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_dir", type=str, required=True) + parser.add_argument("--first_frame", type=str, required=True) + parser.add_argument("--last_frame", type=str, required=True) + parser.add_argument("--prompt", type=str, required=True) + parser.add_argument("--output", type=str, default="output.mp4") + parser.add_argument("--device_id", type=int, default=0) + args = parser.parse_args() + + config = WAN_CONFIGS['flf2v-14B'] + device = torch.device(f"cuda:{args.device_id}") + + print(f"Loading models from {args.checkpoint_dir}...") + + # 1. Text Encoder + text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(args.checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(args.checkpoint_dir, config.t5_tokenizer), + ) + + # 2. VAE + vae = WanVAE( + vae_pth=os.path.join(args.checkpoint_dir, config.vae_checkpoint), + device=device + ) + + # 3. CLIP + image_encoder = CLIPModel( + dtype=config.clip_dtype, + device=device, + checkpoint_path=os.path.join(args.checkpoint_dir, config.clip_checkpoint), + tokenizer_path=os.path.join(args.checkpoint_dir, config.clip_tokenizer) + ) + + # 4. Transformer + transformer = WanModel.from_pretrained(args.checkpoint_dir, model_type='flf2v') + transformer.eval().requires_grad_(False) + + # 5. Scheduler + scheduler = UniPCMultistepScheduler( + prediction_type='flow_prediction', + use_flow_sigmas=True, + num_train_timesteps=1000, + flow_shift=16.0 + ) + + pipe = WanFLF2VPipeline( + vae=vae, + text_encoder=text_encoder, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler + ) + pipe.to(device) + + print(f"Loading images...") + first_img = PIL.Image.open(args.first_frame).convert("RGB") + last_img = PIL.Image.open(args.last_frame).convert("RGB") + + print("Generating video...") + output = pipe( + prompt=args.prompt, + first_frame=first_img, + last_frame=last_img, + height=720, + width=1280, + num_frames=81, + guidance_scale=5.0 + ) + + import imageio + video = output.images[0] # [F, H, W, C] + video = (video * 255).astype(np.uint8) + imageio.mimsave(args.output, video, fps=16) + print(f"Video saved to {args.output}")