mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
import torch
|
|
|
|
class family_handler():
|
|
@staticmethod
|
|
def query_model_def(base_model_type, model_def):
|
|
extra_model_def = {}
|
|
if base_model_type in ["sky_df_14B"]:
|
|
fps = 24
|
|
else:
|
|
fps = 16
|
|
extra_model_def["fps"] =fps
|
|
extra_model_def["frames_minimum"] = 17
|
|
extra_model_def["frames_steps"] = 20
|
|
extra_model_def["sliding_window"] = True
|
|
extra_model_def["skip_layer_guidance"] = True
|
|
return extra_model_def
|
|
|
|
@staticmethod
|
|
def query_supported_types():
|
|
return ["sky_df_1.3B", "sky_df_14B"]
|
|
|
|
|
|
@staticmethod
|
|
def query_family_maps():
|
|
models_eqv_map = {
|
|
"sky_df_1.3B" : "sky_df_14B",
|
|
}
|
|
|
|
models_comp_map = {
|
|
"sky_df_14B": ["sky_df_1.3B"],
|
|
}
|
|
return models_eqv_map, models_comp_map
|
|
|
|
|
|
|
|
@staticmethod
|
|
def query_model_family():
|
|
return "wan"
|
|
|
|
@staticmethod
|
|
def query_family_infos():
|
|
return {}
|
|
|
|
|
|
|
|
@staticmethod
|
|
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
|
|
from .wan_handler import family_handler
|
|
return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization)
|
|
|
|
@staticmethod
|
|
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False):
|
|
from .configs import WAN_CONFIGS
|
|
from .wan_handler import family_handler
|
|
cfg = WAN_CONFIGS['t2v-14B']
|
|
from . import DTT2V
|
|
wan_model = DTT2V(
|
|
config=cfg,
|
|
checkpoint_dir="ckpts",
|
|
model_filename=model_filename,
|
|
model_type = model_type,
|
|
model_def = model_def,
|
|
base_model_type=base_model_type,
|
|
text_encoder_filename= family_handler.get_wan_text_encoder_filename(text_encoder_quantization),
|
|
quantizeTransformer = quantizeTransformer,
|
|
dtype = dtype,
|
|
VAE_dtype = VAE_dtype,
|
|
mixed_precision_transformer = mixed_precision_transformer,
|
|
save_quantized = save_quantized
|
|
)
|
|
|
|
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
|
|
return wan_model, pipe
|
|
|
|
@staticmethod
|
|
def update_default_settings(base_model_type, model_def, ui_defaults):
|
|
ui_defaults.update({
|
|
"guidance_scale": 6.0,
|
|
"flow_shift": 8,
|
|
"sliding_window_discard_last_frames" : 0,
|
|
"resolution": "1280x720" if "720" in base_model_type else "960x544",
|
|
"sliding_window_size" : 121 if "720" in base_model_type else 97,
|
|
"RIFLEx_setting": 2,
|
|
"guidance_scale": 6,
|
|
"flow_shift": 8,
|
|
}) |