mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			218 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			218 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 | 
						|
import gc
 | 
						|
import logging
 | 
						|
import math
 | 
						|
import os
 | 
						|
import random
 | 
						|
import sys
 | 
						|
from contextlib import contextmanager
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.cuda.amp as amp
 | 
						|
from tqdm import tqdm
 | 
						|
 | 
						|
from .modules.model import WanModel
 | 
						|
from .modules.t5 import T5EncoderModel
 | 
						|
from .modules.vae import WanVAE
 | 
						|
from .utils.fm_solvers import (
 | 
						|
    FlowDPMSolverMultistepScheduler,
 | 
						|
    get_sampling_sigmas,
 | 
						|
    retrieve_timesteps,
 | 
						|
)
 | 
						|
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 | 
						|
 | 
						|
 | 
						|
class WanT2V:
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        model_hyperparam,
 | 
						|
        checkpoint_dir,
 | 
						|
        t5_cpu=False,
 | 
						|
    ):
 | 
						|
        r"""
 | 
						|
        Initializes the Wan text-to-video generation model components (single-GPU only).
 | 
						|
 | 
						|
        Args:
 | 
						|
            config (EasyDict):
 | 
						|
                Object containing model parameters initialized from config.py
 | 
						|
            checkpoint_dir (`str`):
 | 
						|
                Path to directory containing model checkpoints
 | 
						|
            t5_cpu (`bool`, *optional*, defaults to False):
 | 
						|
                Whether to place T5 model on CPU during encoding.
 | 
						|
        """
 | 
						|
        self.device = torch.device("cuda:0")
 | 
						|
        self.config = model_hyperparam
 | 
						|
        self.t5_cpu = t5_cpu
 | 
						|
 | 
						|
        self.num_train_timesteps = model_hyperparam.num_train_timesteps
 | 
						|
        self.param_dtype = model_hyperparam.param_dtype
 | 
						|
 | 
						|
        # Load T5 text encoder
 | 
						|
        self.text_encoder = T5EncoderModel(
 | 
						|
            text_len=model_hyperparam.text_len,
 | 
						|
            dtype=model_hyperparam.t5_dtype,
 | 
						|
            device=torch.device('cpu'),
 | 
						|
            checkpoint_path=os.path.join(checkpoint_dir, model_hyperparam.t5_checkpoint),
 | 
						|
            tokenizer_path=os.path.join(checkpoint_dir, model_hyperparam.t5_tokenizer),
 | 
						|
            shard_fn=None  # No FSDP
 | 
						|
        )
 | 
						|
 | 
						|
        # Load VAE
 | 
						|
        self.vae_stride = model_hyperparam.vae_stride
 | 
						|
        self.patch_size = model_hyperparam.patch_size
 | 
						|
        self.vae = WanVAE(
 | 
						|
            vae_pth=os.path.join(checkpoint_dir, model_hyperparam.vae_checkpoint),
 | 
						|
            device=self.device
 | 
						|
        )
 | 
						|
 | 
						|
        logging.info(f"Creating WanModel from {checkpoint_dir}")
 | 
						|
        self.model = WanModel.from_pretrained(checkpoint_dir)
 | 
						|
        self.model.eval().requires_grad_(False)
 | 
						|
        self.model.to(self.device)
 | 
						|
 | 
						|
        self.sample_neg_prompt = model_hyperparam.sample_neg_prompt
 | 
						|
 | 
						|
    def generate(
 | 
						|
        self,
 | 
						|
        input_prompt,
 | 
						|
        size=(1280, 720),
 | 
						|
        frame_num=81,
 | 
						|
        shift=5.0,
 | 
						|
        sample_solver='unipc',
 | 
						|
        sampling_steps=50,
 | 
						|
        guide_scale=5.0,
 | 
						|
        n_prompt="",
 | 
						|
        seed=-1,
 | 
						|
        offload_model=True
 | 
						|
    ):
 | 
						|
        r"""
 | 
						|
        Generates video frames from text prompt using diffusion process (single-GPU).
 | 
						|
 | 
						|
        Args:
 | 
						|
            input_prompt (`str`): Text prompt for content generation
 | 
						|
            size (tuple[`int`], *optional*, defaults to (1280,720)): (width, height)
 | 
						|
            frame_num (`int`, *optional*, defaults to 81): Number of frames (must be 4n+1)
 | 
						|
            shift (`float`, *optional*, defaults to 5.0): Noise schedule shift
 | 
						|
            sample_solver (`str`, *optional*, defaults to 'unipc'): Solver name
 | 
						|
            sampling_steps (`int`, *optional*, defaults to 50): Diffusion steps
 | 
						|
            guide_scale (`float`, *optional*, defaults to 5.0): CFG scale
 | 
						|
            n_prompt (`str`, *optional*, defaults to ""): Negative prompt
 | 
						|
            seed (`int`, *optional*, defaults to -1): Random seed
 | 
						|
            offload_model (`bool`, *optional*, defaults to True): Offload to CPU to save VRAM
 | 
						|
 | 
						|
        Returns:
 | 
						|
            torch.Tensor: Video tensor of shape (C, N, H, W)
 | 
						|
        """
 | 
						|
        F = frame_num
 | 
						|
        target_shape = (
 | 
						|
            self.vae.model.z_dim,
 | 
						|
            (F - 1) // self.vae_stride[0] + 1,
 | 
						|
            size[1] // self.vae_stride[1],
 | 
						|
            size[0] // self.vae_stride[2]
 | 
						|
        )
 | 
						|
 | 
						|
        seq_len = math.ceil(
 | 
						|
            (target_shape[2] * target_shape[3]) /
 | 
						|
            (self.patch_size[1] * self.patch_size[2]) *
 | 
						|
            target_shape[1]
 | 
						|
        )
 | 
						|
 | 
						|
        if n_prompt == "":
 | 
						|
            n_prompt = self.sample_neg_prompt
 | 
						|
 | 
						|
        seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
 | 
						|
        seed_g = torch.Generator(device=self.device)
 | 
						|
        seed_g.manual_seed(seed)
 | 
						|
 | 
						|
        # Encode prompts
 | 
						|
        if not self.t5_cpu:
 | 
						|
            self.text_encoder.model.to(self.device)
 | 
						|
            context = self.text_encoder([input_prompt], self.device)
 | 
						|
            context_null = self.text_encoder([n_prompt], self.device)
 | 
						|
            if offload_model:
 | 
						|
                self.text_encoder.model.cpu()
 | 
						|
        else:
 | 
						|
            context = self.text_encoder([input_prompt], torch.device('cpu'))
 | 
						|
            context_null = self.text_encoder([n_prompt], torch.device('cpu'))
 | 
						|
            context = [t.to(self.device) for t in context]
 | 
						|
            context_null = [t.to(self.device) for t in context_null]
 | 
						|
 | 
						|
        noise = [
 | 
						|
            torch.randn(
 | 
						|
                target_shape[0],
 | 
						|
                target_shape[1],
 | 
						|
                target_shape[2],
 | 
						|
                target_shape[3],
 | 
						|
                dtype=torch.float32,
 | 
						|
                device=self.device,
 | 
						|
                generator=seed_g
 | 
						|
            )
 | 
						|
        ]
 | 
						|
 | 
						|
        @contextmanager
 | 
						|
        def noop():
 | 
						|
            yield
 | 
						|
 | 
						|
        # Evaluation
 | 
						|
        with amp.autocast(dtype=self.param_dtype), torch.no_grad(), noop():
 | 
						|
 | 
						|
            if sample_solver == 'unipc':
 | 
						|
                sample_scheduler = FlowUniPCMultistepScheduler(
 | 
						|
                    num_train_timesteps=self.num_train_timesteps,
 | 
						|
                    shift=1,
 | 
						|
                    use_dynamic_shifting=False
 | 
						|
                )
 | 
						|
                sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
 | 
						|
                timesteps = sample_scheduler.timesteps
 | 
						|
            elif sample_solver == 'dpm++':
 | 
						|
                sample_scheduler = FlowDPMSolverMultistepScheduler(
 | 
						|
                    num_train_timesteps=self.num_train_timesteps,
 | 
						|
                    shift=1,
 | 
						|
                    use_dynamic_shifting=False
 | 
						|
                )
 | 
						|
                sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
 | 
						|
                timesteps, _ = retrieve_timesteps(
 | 
						|
                    sample_scheduler,
 | 
						|
                    device=self.device,
 | 
						|
                    sigmas=sampling_sigmas
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                raise NotImplementedError("Unsupported solver.")
 | 
						|
 | 
						|
            latents = noise
 | 
						|
            arg_c = {'context': context, 'seq_len': seq_len}
 | 
						|
            arg_null = {'context': context_null, 'seq_len': seq_len}
 | 
						|
 | 
						|
            for t in tqdm(timesteps):
 | 
						|
                latent_model_input = latents
 | 
						|
                timestep = torch.tensor([t], device=self.device)
 | 
						|
 | 
						|
                self.model.to(self.device)
 | 
						|
                noise_pred_cond = self.model(latent_model_input, t=timestep, **arg_c)[0]
 | 
						|
                noise_pred_uncond = self.model(latent_model_input, t=timestep, **arg_null)[0]
 | 
						|
 | 
						|
                noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
 | 
						|
 | 
						|
                temp_x0 = sample_scheduler.step(
 | 
						|
                    noise_pred.unsqueeze(0),
 | 
						|
                    t,
 | 
						|
                    latents[0].unsqueeze(0),
 | 
						|
                    return_dict=False,
 | 
						|
                    generator=seed_g
 | 
						|
                )[0]
 | 
						|
                latents = [temp_x0.squeeze(0)]
 | 
						|
 | 
						|
            x0 = latents
 | 
						|
            if offload_model:
 | 
						|
                self.model.cpu()
 | 
						|
                torch.cuda.empty_cache()
 | 
						|
 | 
						|
            videos = self.vae.decode(x0)
 | 
						|
 | 
						|
        del noise, latents, sample_scheduler
 | 
						|
        if offload_model:
 | 
						|
            gc.collect()
 | 
						|
            torch.cuda.synchronize()
 | 
						|
 | 
						|
        return videos[0] |