This commit is contained in:
DeepBeepMeep 2025-05-17 18:28:40 +02:00
parent 618ecdb999
commit 2a2853a893
2 changed files with 6 additions and 13 deletions

View File

@ -148,7 +148,6 @@ class LTXV:
def __init__(
self,
model_filepath: str,
loras_filepath: str,
text_encoder_filepath: str,
dtype = torch.bfloat16,
VAE_dtype = torch.bfloat16,
@ -156,8 +155,8 @@ class LTXV:
):
self.mixed_precision_transformer = mixed_precision_transformer
self.distilled = loras_filepath != None
self.distilled = any("lora" in name for name in model_filepath)
model_filepath = [name for name in model_filepath if not "lora" in name ]
# with safe_open(ckpt_path, framework="pt") as f:
# metadata = f.metadata()
# config_str = metadata.get("config")

14
wgp.py
View File

@ -2066,14 +2066,13 @@ def load_wan_model(model_filename, quantizeTransformer = False, dtype = torch.bf
pipe["text_encoder_2"] = wan_model.clip.model
return wan_model, pipe
def load_ltxv_model(model_filename, loras_filenames, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
def load_ltxv_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
filename = model_filename[-1]
print(f"Loading '{filename}' model...")
from ltx_video.ltxv import LTXV
ltxv_model = LTXV(
model_filepath = model_filename,
loras_filepath = loras_filenames,
text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization),
dtype = dtype,
# quantizeTransformer = quantizeTransformer,
@ -2126,11 +2125,8 @@ def load_models(model_filename):
perc_reserved_mem_max = args.perc_reserved_mem_max
new_transformer_loras_filenames = None
dependent_models = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
if "_lora" in model_filename:
new_transformer_loras_filenames = [model_filename]
model_filelist = dependent_models
else:
model_filelist = dependent_models + [model_filename]
new_transformer_loras_filenames = [model_filename] if "_lora" in model_filename else None
model_filelist = dependent_models + [model_filename]
for filename in model_filelist:
download_models(filename)
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
@ -2142,7 +2138,7 @@ def load_models(model_filename):
if model_family == "wan" :
wan_model, pipe = load_wan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
elif model_family == "ltxv":
wan_model, pipe = load_ltxv_model(model_filelist, new_transformer_loras_filenames, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
wan_model, pipe = load_ltxv_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
elif model_family == "hunyuan":
wan_model, pipe = load_hunyuan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
else:
@ -2159,8 +2155,6 @@ def load_models(model_filename):
global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer
if server_config.get("enhancer_enabled", 0) == 1:
from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM )
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True)
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True)
prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2/Llama3_2_quanto_bf16_int8.safetensors")