mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-11-04 14:16:57 +00:00
102 lines
3.8 KiB
Python
102 lines
3.8 KiB
Python
import torch
|
|
|
|
|
|
def get_ltxv_text_encoder_filename(text_encoder_quantization):
|
|
text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors"
|
|
if text_encoder_quantization =="int8":
|
|
text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8")
|
|
return text_encoder_filename
|
|
|
|
class family_handler():
|
|
@staticmethod
|
|
def query_model_def(base_model_type, model_def):
|
|
LTXV_config = model_def.get("LTXV_config", "")
|
|
distilled= "distilled" in LTXV_config
|
|
extra_model_def = {}
|
|
if distilled:
|
|
extra_model_def.update({
|
|
"lock_inference_steps": True,
|
|
"no_negative_prompt" : True,
|
|
})
|
|
|
|
|
|
extra_model_def["fps"] = 30
|
|
extra_model_def["frames_minimum"] = 17
|
|
extra_model_def["frames_steps"] = 8
|
|
extra_model_def["sliding_window"] = True
|
|
extra_model_def["image_prompt_types_allowed"] = "TSEV"
|
|
|
|
extra_model_def["guide_preprocessing"] = {
|
|
"selection": ["", "PV", "DV", "EV", "V"],
|
|
"labels" : { "V": "Use LTXV raw format"}
|
|
}
|
|
|
|
extra_model_def["mask_preprocessing"] = {
|
|
"selection": ["", "A", "NA", "XA", "XNA"],
|
|
}
|
|
|
|
extra_model_def["extra_control_frames"] = 1
|
|
extra_model_def["dont_cat_preguide"]= True
|
|
return extra_model_def
|
|
|
|
@staticmethod
|
|
def query_supported_types():
|
|
return ["ltxv_13B"]
|
|
|
|
@staticmethod
|
|
def query_family_maps():
|
|
return {}, {}
|
|
|
|
@staticmethod
|
|
def get_rgb_factors(base_model_type ):
|
|
from shared.RGB_factors import get_rgb_factors
|
|
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("ltxv")
|
|
return latent_rgb_factors, latent_rgb_factors_bias
|
|
|
|
@staticmethod
|
|
def query_model_family():
|
|
return "ltxv"
|
|
|
|
@staticmethod
|
|
def query_family_infos():
|
|
return {"ltxv":(10, "LTX Video")}
|
|
|
|
@staticmethod
|
|
def get_vae_block_size(base_model_type):
|
|
return 32
|
|
|
|
@staticmethod
|
|
def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization):
|
|
text_encoder_filename = get_ltxv_text_encoder_filename(text_encoder_quantization)
|
|
return {
|
|
"repoId" : "DeepBeepMeep/LTX_Video",
|
|
"sourceFolderList" : ["T5_xxl_1.1", "" ],
|
|
"fileList" : [ ["added_tokens.json", "special_tokens_map.json", "spiece.model", "tokenizer_config.json"] + computeList(text_encoder_filename), ["ltxv_0.9.7_VAE.safetensors", "ltxv_0.9.7_spatial_upscaler.safetensors", "ltxv_scheduler.json"] + computeList(model_filename) ]
|
|
}
|
|
|
|
|
|
@staticmethod
|
|
def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
|
|
from .ltxv import LTXV
|
|
|
|
ltxv_model = LTXV(
|
|
model_filepath = model_filename,
|
|
text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization),
|
|
model_type = model_type,
|
|
base_model_type = base_model_type,
|
|
model_def = model_def,
|
|
dtype = dtype,
|
|
# quantizeTransformer = quantizeTransformer,
|
|
VAE_dtype = VAE_dtype,
|
|
mixed_precision_transformer = mixed_precision_transformer
|
|
)
|
|
|
|
pipeline = ltxv_model.pipeline
|
|
pipe = {"transformer" : pipeline.video_pipeline.transformer, "vae" : pipeline.vae, "text_encoder" : pipeline.video_pipeline.text_encoder, "latent_upsampler" : pipeline.latent_upsampler}
|
|
|
|
return ltxv_model, pipe
|
|
|
|
@staticmethod
|
|
def update_default_settings(base_model_type, model_def, ui_defaults):
|
|
pass
|
|
|