diff --git a/ltx_video/ltxv.py b/ltx_video/ltxv.py index e9b7d8f..7be215b 100644 --- a/ltx_video/ltxv.py +++ b/ltx_video/ltxv.py @@ -148,7 +148,6 @@ class LTXV: def __init__( self, model_filepath: str, - loras_filepath: str, text_encoder_filepath: str, dtype = torch.bfloat16, VAE_dtype = torch.bfloat16, @@ -156,8 +155,8 @@ class LTXV: ): self.mixed_precision_transformer = mixed_precision_transformer - self.distilled = loras_filepath != None - + self.distilled = any("lora" in name for name in model_filepath) + model_filepath = [name for name in model_filepath if not "lora" in name ] # with safe_open(ckpt_path, framework="pt") as f: # metadata = f.metadata() # config_str = metadata.get("config") diff --git a/wgp.py b/wgp.py index ac8d4ca..1a03a92 100644 --- a/wgp.py +++ b/wgp.py @@ -2066,14 +2066,13 @@ def load_wan_model(model_filename, quantizeTransformer = False, dtype = torch.bf pipe["text_encoder_2"] = wan_model.clip.model return wan_model, pipe -def load_ltxv_model(model_filename, loras_filenames, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): +def load_ltxv_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False): filename = model_filename[-1] print(f"Loading '{filename}' model...") from ltx_video.ltxv import LTXV ltxv_model = LTXV( model_filepath = model_filename, - loras_filepath = loras_filenames, text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization), dtype = dtype, # quantizeTransformer = quantizeTransformer, @@ -2126,11 +2125,8 @@ def load_models(model_filename): perc_reserved_mem_max = args.perc_reserved_mem_max new_transformer_loras_filenames = None dependent_models = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy) - if "_lora" in model_filename: - new_transformer_loras_filenames = [model_filename] - model_filelist = dependent_models - else: - model_filelist = dependent_models + [model_filename] + new_transformer_loras_filenames = [model_filename] if "_lora" in model_filename else None + model_filelist = dependent_models + [model_filename] for filename in model_filelist: download_models(filename) transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy) @@ -2142,7 +2138,7 @@ def load_models(model_filename): if model_family == "wan" : wan_model, pipe = load_wan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) elif model_family == "ltxv": - wan_model, pipe = load_ltxv_model(model_filelist, new_transformer_loras_filenames, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) + wan_model, pipe = load_ltxv_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) elif model_family == "hunyuan": wan_model, pipe = load_hunyuan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer) else: @@ -2159,8 +2155,6 @@ def load_models(model_filename): global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer if server_config.get("enhancer_enabled", 0) == 1: from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM ) - - prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True) prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True) prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2/Llama3_2_quanto_bf16_int8.safetensors")