mirror of
				https://github.com/Wan-Video/Wan2.1.git
				synced 2025-11-04 06:15:17 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			535 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			535 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import argparse
 | 
						|
from .constants import *
 | 
						|
import re
 | 
						|
from .modules.models import HUNYUAN_VIDEO_CONFIG
 | 
						|
 | 
						|
 | 
						|
def parse_args(namespace=None):
 | 
						|
    parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
 | 
						|
 | 
						|
    parser = add_network_args(parser)
 | 
						|
    parser = add_extra_models_args(parser)
 | 
						|
    parser = add_denoise_schedule_args(parser)
 | 
						|
    parser = add_inference_args(parser)
 | 
						|
    parser = add_parallel_args(parser)
 | 
						|
 | 
						|
    args = parser.parse_args(namespace=namespace)
 | 
						|
    args = sanity_check_args(args)
 | 
						|
 | 
						|
    return args
 | 
						|
 | 
						|
 | 
						|
def add_network_args(parser: argparse.ArgumentParser):
 | 
						|
    group = parser.add_argument_group(title="HunyuanVideo network args")
 | 
						|
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--quantize-transformer",
 | 
						|
        action="store_true",
 | 
						|
        help="On the fly 'transformer' quantization"
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--lora-dir-i2v",
 | 
						|
        type=str,
 | 
						|
        default="loras_i2v",
 | 
						|
        help="Path to a directory that contains Loras for i2v"
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--lora-dir",
 | 
						|
        type=str,
 | 
						|
        default="",
 | 
						|
        help="Path to a directory that contains Loras"
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--lora-preset",
 | 
						|
        type=str,
 | 
						|
        default="",
 | 
						|
        help="Lora preset to preload"
 | 
						|
    )
 | 
						|
 | 
						|
    # group.add_argument(
 | 
						|
    #     "--lora-preset-i2v",
 | 
						|
    #     type=str,
 | 
						|
    #     default="",
 | 
						|
    #     help="Lora preset to preload for i2v"
 | 
						|
    # )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--profile",
 | 
						|
        type=str,
 | 
						|
        default=-1,
 | 
						|
        help="Profile No"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--verbose",
 | 
						|
        type=str,
 | 
						|
        default=1,
 | 
						|
        help="Verbose level"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--server-port",
 | 
						|
        type=str,
 | 
						|
        default=0,
 | 
						|
        help="Server port"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--server-name",
 | 
						|
        type=str,
 | 
						|
        default="",
 | 
						|
        help="Server name"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--open-browser",
 | 
						|
        action="store_true",
 | 
						|
        help="open browser"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--t2v",
 | 
						|
        action="store_true",
 | 
						|
        help="text to video mode"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--i2v",
 | 
						|
        action="store_true",
 | 
						|
        help="image to video mode"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--compile",
 | 
						|
        action="store_true",
 | 
						|
        help="Enable pytorch compilation"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--fast",
 | 
						|
        action="store_true",
 | 
						|
        help="use Fast HunyuanVideo model"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--fastest",
 | 
						|
        action="store_true",
 | 
						|
        help="activate the best config"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
    "--attention",
 | 
						|
    type=str,
 | 
						|
    default="",
 | 
						|
    help="attention mode"
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
    "--vae-config",
 | 
						|
    type=str,
 | 
						|
    default="",
 | 
						|
    help="vae config mode"
 | 
						|
    )    
 | 
						|
    
 | 
						|
    parser.add_argument(
 | 
						|
        "--share",
 | 
						|
        action="store_true",
 | 
						|
        help="Create a shared URL to access webserver remotely"
 | 
						|
    )
 | 
						|
 | 
						|
    parser.add_argument(
 | 
						|
    "--lock-config",
 | 
						|
    action="store_true",
 | 
						|
    help="Prevent modifying the configuration from the web interface"
 | 
						|
    )
 | 
						|
 | 
						|
    parser.add_argument(
 | 
						|
        "--preload",
 | 
						|
        type=str,
 | 
						|
        default="0",
 | 
						|
        help="Megabytes of the diffusion model to preload in VRAM"
 | 
						|
    )
 | 
						|
 | 
						|
    parser.add_argument(
 | 
						|
        "--multiple-images",
 | 
						|
        action="store_true",
 | 
						|
        help="Allow inputting multiple images with image to video"
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
    # Main model
 | 
						|
    group.add_argument(
 | 
						|
        "--model",
 | 
						|
        type=str,
 | 
						|
        choices=list(HUNYUAN_VIDEO_CONFIG.keys()),
 | 
						|
        default="HYVideo-T/2-cfgdistill",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--latent-channels",
 | 
						|
        type=str,
 | 
						|
        default=16,
 | 
						|
        help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
 | 
						|
        "it still needs to match the latent channels of the VAE model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--precision",
 | 
						|
        type=str,
 | 
						|
        default="bf16",
 | 
						|
        choices=PRECISIONS,
 | 
						|
        help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.",
 | 
						|
    )
 | 
						|
 | 
						|
    # RoPE
 | 
						|
    group.add_argument(
 | 
						|
        "--rope-theta", type=int, default=256, help="Theta used in RoPE."
 | 
						|
    )
 | 
						|
    return parser
 | 
						|
 | 
						|
 | 
						|
def add_extra_models_args(parser: argparse.ArgumentParser):
 | 
						|
    group = parser.add_argument_group(
 | 
						|
        title="Extra models args, including vae, text encoders and tokenizers)"
 | 
						|
    )
 | 
						|
 | 
						|
    # - VAE
 | 
						|
    group.add_argument(
 | 
						|
        "--vae",
 | 
						|
        type=str,
 | 
						|
        default="884-16c-hy",
 | 
						|
        choices=list(VAE_PATH),
 | 
						|
        help="Name of the VAE model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--vae-precision",
 | 
						|
        type=str,
 | 
						|
        default="fp16",
 | 
						|
        choices=PRECISIONS,
 | 
						|
        help="Precision mode for the VAE model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--vae-tiling",
 | 
						|
        action="store_true",
 | 
						|
        help="Enable tiling for the VAE model to save GPU memory.",
 | 
						|
    )
 | 
						|
    group.set_defaults(vae_tiling=True)
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--text-encoder",
 | 
						|
        type=str,
 | 
						|
        default="llm",
 | 
						|
        choices=list(TEXT_ENCODER_PATH),
 | 
						|
        help="Name of the text encoder model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--text-encoder-precision",
 | 
						|
        type=str,
 | 
						|
        default="fp16",
 | 
						|
        choices=PRECISIONS,
 | 
						|
        help="Precision mode for the text encoder model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--text-states-dim",
 | 
						|
        type=int,
 | 
						|
        default=4096,
 | 
						|
        help="Dimension of the text encoder hidden states.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--text-len", type=int, default=256, help="Maximum length of the text input."
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--tokenizer",
 | 
						|
        type=str,
 | 
						|
        default="llm",
 | 
						|
        choices=list(TOKENIZER_PATH),
 | 
						|
        help="Name of the tokenizer model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--prompt-template",
 | 
						|
        type=str,
 | 
						|
        default="dit-llm-encode",
 | 
						|
        choices=PROMPT_TEMPLATE,
 | 
						|
        help="Image prompt template for the decoder-only text encoder model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--prompt-template-video",
 | 
						|
        type=str,
 | 
						|
        default="dit-llm-encode-video",
 | 
						|
        choices=PROMPT_TEMPLATE,
 | 
						|
        help="Video prompt template for the decoder-only text encoder model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--hidden-state-skip-layer",
 | 
						|
        type=int,
 | 
						|
        default=2,
 | 
						|
        help="Skip layer for hidden states.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--apply-final-norm",
 | 
						|
        action="store_true",
 | 
						|
        help="Apply final normalization to the used text encoder hidden states.",
 | 
						|
    )
 | 
						|
 | 
						|
    # - CLIP
 | 
						|
    group.add_argument(
 | 
						|
        "--text-encoder-2",
 | 
						|
        type=str,
 | 
						|
        default="clipL",
 | 
						|
        choices=list(TEXT_ENCODER_PATH),
 | 
						|
        help="Name of the second text encoder model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--text-encoder-precision-2",
 | 
						|
        type=str,
 | 
						|
        default="fp16",
 | 
						|
        choices=PRECISIONS,
 | 
						|
        help="Precision mode for the second text encoder model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--text-states-dim-2",
 | 
						|
        type=int,
 | 
						|
        default=768,
 | 
						|
        help="Dimension of the second text encoder hidden states.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--tokenizer-2",
 | 
						|
        type=str,
 | 
						|
        default="clipL",
 | 
						|
        choices=list(TOKENIZER_PATH),
 | 
						|
        help="Name of the second tokenizer model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--text-len-2",
 | 
						|
        type=int,
 | 
						|
        default=77,
 | 
						|
        help="Maximum length of the second text input.",
 | 
						|
    )
 | 
						|
 | 
						|
    return parser
 | 
						|
 | 
						|
 | 
						|
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
 | 
						|
    group = parser.add_argument_group(title="Denoise schedule args")
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--denoise-type",
 | 
						|
        type=str,
 | 
						|
        default="flow",
 | 
						|
        help="Denoise type for noised inputs.",
 | 
						|
    )
 | 
						|
 | 
						|
    # Flow Matching
 | 
						|
    group.add_argument(
 | 
						|
        "--flow-shift",
 | 
						|
        type=float,
 | 
						|
        default=7.0,
 | 
						|
        help="Shift factor for flow matching schedulers.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--flow-reverse",
 | 
						|
        action="store_true",
 | 
						|
        help="If reverse, learning/sampling from t=1 -> t=0.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--flow-solver",
 | 
						|
        type=str,
 | 
						|
        default="euler",
 | 
						|
        help="Solver for flow matching.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--use-linear-quadratic-schedule",
 | 
						|
        action="store_true",
 | 
						|
        help="Use linear quadratic schedule for flow matching."
 | 
						|
        "Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--linear-schedule-end",
 | 
						|
        type=int,
 | 
						|
        default=25,
 | 
						|
        help="End step for linear quadratic schedule for flow matching.",
 | 
						|
    )
 | 
						|
 | 
						|
    return parser
 | 
						|
 | 
						|
 | 
						|
def add_inference_args(parser: argparse.ArgumentParser):
 | 
						|
    group = parser.add_argument_group(title="Inference args")
 | 
						|
 | 
						|
    # ======================== Model loads ========================
 | 
						|
    group.add_argument(
 | 
						|
        "--model-base",
 | 
						|
        type=str,
 | 
						|
        default="ckpts",
 | 
						|
        help="Root path of all the models, including t2v models and extra models.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--dit-weight",
 | 
						|
        type=str,
 | 
						|
        default="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
 | 
						|
        help="Path to the HunyuanVideo model. If None, search the model in the args.model_root."
 | 
						|
        "1. If it is a file, load the model directly."
 | 
						|
        "2. If it is a directory, search the model in the directory. Support two types of models: "
 | 
						|
        "1) named `pytorch_model_*.pt`"
 | 
						|
        "2) named `*_model_states.pt`, where * can be `mp_rank_00`.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--model-resolution",
 | 
						|
        type=str,
 | 
						|
        default="540p",
 | 
						|
        choices=["540p", "720p"],
 | 
						|
        help="Root path of all the models, including t2v models and extra models.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--load-key",
 | 
						|
        type=str,
 | 
						|
        default="module",
 | 
						|
        help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--use-cpu-offload",
 | 
						|
        action="store_true",
 | 
						|
        help="Use CPU offload for the model load.",
 | 
						|
    )
 | 
						|
 | 
						|
    # ======================== Inference general setting ========================
 | 
						|
    group.add_argument(
 | 
						|
        "--batch-size",
 | 
						|
        type=int,
 | 
						|
        default=1,
 | 
						|
        help="Batch size for inference and evaluation.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--infer-steps",
 | 
						|
        type=int,
 | 
						|
        default=50,
 | 
						|
        help="Number of denoising steps for inference.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--disable-autocast",
 | 
						|
        action="store_true",
 | 
						|
        help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--save-path",
 | 
						|
        type=str,
 | 
						|
        default="./results",
 | 
						|
        help="Path to save the generated samples.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--save-path-suffix",
 | 
						|
        type=str,
 | 
						|
        default="",
 | 
						|
        help="Suffix for the directory of saved samples.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--name-suffix",
 | 
						|
        type=str,
 | 
						|
        default="",
 | 
						|
        help="Suffix for the names of saved samples.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--num-videos",
 | 
						|
        type=int,
 | 
						|
        default=1,
 | 
						|
        help="Number of videos to generate for each prompt.",
 | 
						|
    )
 | 
						|
    # ---sample size---
 | 
						|
    group.add_argument(
 | 
						|
        "--video-size",
 | 
						|
        type=int,
 | 
						|
        nargs="+",
 | 
						|
        default=(720, 1280),
 | 
						|
        help="Video size for training. If a single value is provided, it will be used for both height "
 | 
						|
        "and width. If two values are provided, they will be used for height and width "
 | 
						|
        "respectively.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--video-length",
 | 
						|
        type=int,
 | 
						|
        default=129,
 | 
						|
        help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1",
 | 
						|
    )
 | 
						|
    # --- prompt ---
 | 
						|
    group.add_argument(
 | 
						|
        "--prompt",
 | 
						|
        type=str,
 | 
						|
        default=None,
 | 
						|
        help="Prompt for sampling during evaluation.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--seed-type",
 | 
						|
        type=str,
 | 
						|
        default="auto",
 | 
						|
        choices=["file", "random", "fixed", "auto"],
 | 
						|
        help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a "
 | 
						|
        "random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the "
 | 
						|
        "seed column if available, otherwise use the fixed `seed` value. `prompt` will use the "
 | 
						|
        "fixed `seed` value.",
 | 
						|
    )
 | 
						|
    group.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
 | 
						|
 | 
						|
    # Classifier-Free Guidance
 | 
						|
    group.add_argument(
 | 
						|
        "--neg-prompt", type=str, default=None, help="Negative prompt for sampling."
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale."
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--embedded-cfg-scale",
 | 
						|
        type=float,
 | 
						|
        default=6.0,
 | 
						|
        help="Embeded classifier free guidance scale.",
 | 
						|
    )
 | 
						|
 | 
						|
    group.add_argument(
 | 
						|
        "--reproduce",
 | 
						|
        action="store_true",
 | 
						|
        help="Enable reproducibility by setting random seeds and deterministic algorithms.",
 | 
						|
    )
 | 
						|
 | 
						|
    return parser
 | 
						|
 | 
						|
 | 
						|
def add_parallel_args(parser: argparse.ArgumentParser):
 | 
						|
    group = parser.add_argument_group(title="Parallel args")
 | 
						|
 | 
						|
    # ======================== Model loads ========================
 | 
						|
    group.add_argument(
 | 
						|
        "--ulysses-degree",
 | 
						|
        type=int,
 | 
						|
        default=1,
 | 
						|
        help="Ulysses degree.",
 | 
						|
    )
 | 
						|
    group.add_argument(
 | 
						|
        "--ring-degree",
 | 
						|
        type=int,
 | 
						|
        default=1,
 | 
						|
        help="Ulysses degree.",
 | 
						|
    )
 | 
						|
 | 
						|
    return parser
 | 
						|
 | 
						|
 | 
						|
def sanity_check_args(args):
 | 
						|
    # VAE channels
 | 
						|
    vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
 | 
						|
    if not re.match(vae_pattern, args.vae):
 | 
						|
        raise ValueError(
 | 
						|
            f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
 | 
						|
        )
 | 
						|
    vae_channels = int(args.vae.split("-")[1][:-1])
 | 
						|
    if args.latent_channels is None:
 | 
						|
        args.latent_channels = vae_channels
 | 
						|
    if vae_channels != args.latent_channels:
 | 
						|
        raise ValueError(
 | 
						|
            f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
 | 
						|
        )
 | 
						|
    return args
 |