mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 22:26:36 +00:00
195 lines
7.7 KiB
Python
195 lines
7.7 KiB
Python
import torch
|
|
|
|
def test_class_i2v(base_model_type):
|
|
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", ] #"hunyuan_i2v",
|
|
|
|
class family_handler():
|
|
@staticmethod
|
|
def get_wan_text_encoder_filename(text_encoder_quantization):
|
|
text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors"
|
|
if text_encoder_quantization =="int8":
|
|
text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_int8")
|
|
return text_encoder_filename
|
|
|
|
|
|
|
|
@staticmethod
|
|
def query_modules_files():
|
|
return {
|
|
"vace_14B" : ["ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mfp16_int8.safetensors"],
|
|
"vace_1.3B" : ["ckpts/wan2.1_Vace_1_3B_module.safetensors"],
|
|
"fantasy": ["ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"],
|
|
"multitalk": ["ckpts/wan2.1_multitalk_14B_mbf16.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_multitalk_14B_quanto_mfp16_int8.safetensors"]
|
|
}
|
|
|
|
@staticmethod
|
|
def query_model_def(base_model_type, model_def):
|
|
extra_model_def = {}
|
|
if "URLs2" in model_def:
|
|
extra_model_def["no_steps_skipping"] = True
|
|
extra_model_def["i2v_class"] = test_class_i2v(base_model_type)
|
|
|
|
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"]
|
|
extra_model_def["vace_class"] = vace_class
|
|
|
|
if base_model_type in ["multitalk", "vace_multitalk_14B"]:
|
|
fps = 25
|
|
elif base_model_type in ["fantasy"]:
|
|
fps = 23
|
|
elif base_model_type in ["ti2v_2_2"]:
|
|
fps = 24
|
|
else:
|
|
fps = 16
|
|
extra_model_def["fps"] =fps
|
|
|
|
if vace_class:
|
|
frames_minimum, frames_steps = 17, 4
|
|
else:
|
|
frames_minimum, frames_steps = 5, 4
|
|
extra_model_def.update({
|
|
"frames_minimum" : frames_minimum,
|
|
"frames_steps" : frames_steps,
|
|
"sliding_window" : base_model_type in ["multitalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2",
|
|
"guidance_max_phases" : 2,
|
|
"skip_layer_guidance" : True,
|
|
"cfg_zero" : True,
|
|
"cfg_star" : True,
|
|
"adaptive_projected_guidance" : True,
|
|
"skip_steps_cache" : True,
|
|
})
|
|
|
|
return extra_model_def
|
|
|
|
@staticmethod
|
|
def query_supported_types():
|
|
return ["multitalk", "fantasy", "vace_14B", "vace_multitalk_14B",
|
|
"t2v_1.3B", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
|
|
"recam_1.3B",
|
|
"i2v", "i2v_2_2", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
|
|
|
|
|
|
@staticmethod
|
|
def query_family_maps():
|
|
|
|
models_eqv_map = {
|
|
"flf2v_720p" : "i2v",
|
|
"t2v_1.3B" : "t2v",
|
|
}
|
|
|
|
models_comp_map = {
|
|
"vace_14B" : [ "vace_multitalk_14B"],
|
|
"t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"],
|
|
"i2v" : [ "fantasy", "multitalk", "flf2v_720p" ],
|
|
"fantasy": ["multitalk"],
|
|
}
|
|
return models_eqv_map, models_comp_map
|
|
|
|
@staticmethod
|
|
def query_model_family():
|
|
return "wan"
|
|
|
|
@staticmethod
|
|
def query_family_infos():
|
|
return {"wan":(0, "Wan2.1"), "wan2_2":(1, "Wan2.2") }
|
|
|
|
@staticmethod
|
|
def get_vae_block_size(base_model_type):
|
|
return 32 if base_model_type == "ti2v_2_2" else 16
|
|
|
|
@staticmethod
|
|
def get_rgb_factors(model_type):
|
|
from shared.RGB_factors import get_rgb_factors
|
|
if model_type == "ti2v_2_2": return None, None
|
|
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan")
|
|
return latent_rgb_factors, latent_rgb_factors_bias
|
|
|
|
@staticmethod
|
|
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
|
|
text_encoder_filename = family_handler.get_wan_text_encoder_filename(text_encoder_quantization)
|
|
|
|
download_def = [{
|
|
"repoId" : "DeepBeepMeep/Wan2.1",
|
|
"sourceFolderList" : ["xlm-roberta-large", "umt5-xxl", "" ],
|
|
"fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ]
|
|
}]
|
|
|
|
if base_model_type == "ti2v_2_2":
|
|
download_def += [ {
|
|
"repoId" : "DeepBeepMeep/Wan2.2",
|
|
"sourceFolderList" : [""],
|
|
"fileList" : [ [ "Wan2.2_VAE.safetensors" ] ]
|
|
}]
|
|
|
|
return download_def
|
|
|
|
|
|
@staticmethod
|
|
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False):
|
|
from .configs import WAN_CONFIGS
|
|
|
|
if test_class_i2v(base_model_type):
|
|
cfg = WAN_CONFIGS['i2v-14B']
|
|
else:
|
|
cfg = WAN_CONFIGS['t2v-14B']
|
|
# cfg = WAN_CONFIGS['t2v-1.3B']
|
|
from . import WanAny2V
|
|
wan_model = WanAny2V(
|
|
config=cfg,
|
|
checkpoint_dir="ckpts",
|
|
model_filename=model_filename,
|
|
model_type = model_type,
|
|
model_def = model_def,
|
|
base_model_type=base_model_type,
|
|
text_encoder_filename= family_handler.get_wan_text_encoder_filename(text_encoder_quantization),
|
|
quantizeTransformer = quantizeTransformer,
|
|
dtype = dtype,
|
|
VAE_dtype = VAE_dtype,
|
|
mixed_precision_transformer = mixed_precision_transformer,
|
|
save_quantized = save_quantized
|
|
)
|
|
|
|
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
|
|
if hasattr(wan_model,"model2") and wan_model.model2 is not None:
|
|
pipe["transformer2"] = wan_model.model2
|
|
if hasattr(wan_model, "clip"):
|
|
pipe["text_encoder_2"] = wan_model.clip.model
|
|
return wan_model, pipe
|
|
|
|
@staticmethod
|
|
def update_default_settings(base_model_type, model_def, ui_defaults):
|
|
if base_model_type in ["fantasy"]:
|
|
ui_defaults.update({
|
|
"audio_guidance_scale": 5.0,
|
|
"sliding_window_size": 1,
|
|
})
|
|
|
|
elif base_model_type in ["multitalk"]:
|
|
ui_defaults.update({
|
|
"guidance_scale": 5.0,
|
|
"flow_shift": 7, # 11 for 720p
|
|
"audio_guidance_scale": 4,
|
|
"sliding_window_discard_last_frames" : 4,
|
|
"sample_solver" : "euler",
|
|
"adaptive_switch" : 1,
|
|
})
|
|
|
|
elif base_model_type in ["phantom_1.3B", "phantom_14B"]:
|
|
ui_defaults.update({
|
|
"guidance_scale": 7.5,
|
|
"flow_shift": 5,
|
|
"remove_background_images_ref": 1,
|
|
"video_prompt_type": "I",
|
|
# "resolution": "1280x720"
|
|
})
|
|
|
|
elif base_model_type in ["vace_14B", "vace_multitalk_14B"]:
|
|
ui_defaults.update({
|
|
"sliding_window_discard_last_frames": 0,
|
|
})
|
|
|
|
elif base_model_type in ["ti2v_2_2"]:
|
|
ui_defaults.update({
|
|
"image_prompt_type": "T",
|
|
})
|
|
|
|
|