Wan2.1/models/wan/wan_handler.py
2025-08-06 20:38:13 +02:00

186 lines
7.6 KiB
Python

import torch
def test_class_i2v(base_model_type):
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", ] #"hunyuan_i2v",
class family_handler():
@staticmethod
def get_wan_text_encoder_filename(text_encoder_quantization):
text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors"
if text_encoder_quantization =="int8":
text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_int8")
return text_encoder_filename
@staticmethod
def query_modules_files():
return {
"vace_14B" : ["ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mfp16_int8.safetensors"],
"vace_1.3B" : ["ckpts/wan2.1_Vace_1_3B_module.safetensors"],
"fantasy": ["ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"],
"multitalk": ["ckpts/wan2.1_multitalk_14B_mbf16.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mfp16_int8.safetensors"]
}
@staticmethod
def query_model_def(base_model_type, model_def):
extra_model_def = {}
if "URLs2" in model_def:
extra_model_def["no_steps_skipping"] = True
extra_model_def["i2v_class"] = test_class_i2v(base_model_type)
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"]
extra_model_def["vace_class"] = vace_class
if base_model_type in ["multitalk", "vace_multitalk_14B"]:
fps = 25
elif base_model_type in ["fantasy"]:
fps = 23
elif base_model_type in ["ti2v_2_2"]:
fps = 24
else:
fps = 16
extra_model_def["fps"] =fps
if vace_class:
frames_minimum, frames_steps = 17, 4
else:
frames_minimum, frames_steps = 5, 4
extra_model_def.update({
"frames_minimum" : frames_minimum,
"frames_steps" : frames_steps,
"sliding_window" : base_model_type in ["multitalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2",
"guidance_max_phases" : 2,
"skip_layer_guidance" : True,
"cfg_zero" : True,
"cfg_star" : True,
"adaptive_projected_guidance" : True,
"skip_steps_cache" : True,
})
return extra_model_def
@staticmethod
def query_supported_types():
return ["multitalk", "fantasy", "vace_14B", "vace_multitalk_14B",
"t2v_1.3B", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
"recam_1.3B",
"i2v", "i2v_2_2", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
@staticmethod
def query_family_maps():
models_eqv_map = {
"flf2v_720p" : "i2v",
"t2v_1.3B" : "t2v",
}
models_comp_map = {
"vace_14B" : [ "vace_multitalk_14B"],
"t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"],
"i2v" : [ "fantasy", "multitalk", "flf2v_720p" ],
"fantasy": ["multitalk"],
}
return models_eqv_map, models_comp_map
@staticmethod
def query_model_family():
return "wan"
@staticmethod
def query_family_infos():
return {"wan":(0, "Wan2.1"), "wan2_2":(1, "Wan2.2") }
@staticmethod
def get_vae_block_size(base_model_type):
return 32 if base_model_type == "ti2v_2_2" else 16
@staticmethod
def get_rgb_factors(model_type):
from shared.RGB_factors import get_rgb_factors
if model_type == "ti2v_2_2": return None, None
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan")
return latent_rgb_factors, latent_rgb_factors_bias
@staticmethod
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
text_encoder_filename = family_handler.get_wan_text_encoder_filename(text_encoder_quantization)
vae_filepath = "Wan2.2_VAE.safetensors" if base_model_type == "ti2v_2_2" else "Wan2.1_VAE.safetensors"
return {
"repoId" : "DeepBeepMeep/Wan2.1",
"sourceFolderList" : ["xlm-roberta-large", "umt5-xxl", "" ],
"fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , [vae_filepath, "fantasy_proj_model.safetensors" ] + computeList(model_filename) ]
}
@staticmethod
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False):
from .configs import WAN_CONFIGS
if test_class_i2v(base_model_type):
cfg = WAN_CONFIGS['i2v-14B']
else:
cfg = WAN_CONFIGS['t2v-14B']
# cfg = WAN_CONFIGS['t2v-1.3B']
from . import WanAny2V
wan_model = WanAny2V(
config=cfg,
checkpoint_dir="ckpts",
model_filename=model_filename,
model_type = model_type,
model_def = model_def,
base_model_type=base_model_type,
text_encoder_filename= family_handler.get_wan_text_encoder_filename(text_encoder_quantization),
quantizeTransformer = quantizeTransformer,
dtype = dtype,
VAE_dtype = VAE_dtype,
mixed_precision_transformer = mixed_precision_transformer,
save_quantized = save_quantized
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
if hasattr(wan_model,"model2") and wan_model.model2 is not None:
pipe["transformer2"] = wan_model.model2
if hasattr(wan_model, "clip"):
pipe["text_encoder_2"] = wan_model.clip.model
return wan_model, pipe
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
if base_model_type in ["fantasy"]:
ui_defaults.update({
"audio_guidance_scale": 5.0,
"sliding_window_size": 1,
})
elif base_model_type in ["multitalk"]:
ui_defaults.update({
"guidance_scale": 5.0,
"flow_shift": 7, # 11 for 720p
"audio_guidance_scale": 4,
"sliding_window_discard_last_frames" : 4,
"sample_solver" : "euler",
"adaptive_switch" : 1,
})
elif base_model_type in ["phantom_1.3B", "phantom_14B"]:
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 5,
"remove_background_images_ref": 1,
"video_prompt_type": "I",
# "resolution": "1280x720"
})
elif base_model_type in ["vace_14B", "vace_multitalk_14B"]:
ui_defaults.update({
"sliding_window_discard_last_frames": 0,
})
elif base_model_type in ["ti2v_2_2"]:
ui_defaults.update({
"image_prompt_type": "T",
})