mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 11:43:21 +00:00
fixes
This commit is contained in:
parent
618ecdb999
commit
2a2853a893
@ -148,7 +148,6 @@ class LTXV:
|
||||
def __init__(
|
||||
self,
|
||||
model_filepath: str,
|
||||
loras_filepath: str,
|
||||
text_encoder_filepath: str,
|
||||
dtype = torch.bfloat16,
|
||||
VAE_dtype = torch.bfloat16,
|
||||
@ -156,8 +155,8 @@ class LTXV:
|
||||
):
|
||||
|
||||
self.mixed_precision_transformer = mixed_precision_transformer
|
||||
self.distilled = loras_filepath != None
|
||||
|
||||
self.distilled = any("lora" in name for name in model_filepath)
|
||||
model_filepath = [name for name in model_filepath if not "lora" in name ]
|
||||
# with safe_open(ckpt_path, framework="pt") as f:
|
||||
# metadata = f.metadata()
|
||||
# config_str = metadata.get("config")
|
||||
|
||||
14
wgp.py
14
wgp.py
@ -2066,14 +2066,13 @@ def load_wan_model(model_filename, quantizeTransformer = False, dtype = torch.bf
|
||||
pipe["text_encoder_2"] = wan_model.clip.model
|
||||
return wan_model, pipe
|
||||
|
||||
def load_ltxv_model(model_filename, loras_filenames, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
||||
def load_ltxv_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
||||
filename = model_filename[-1]
|
||||
print(f"Loading '{filename}' model...")
|
||||
from ltx_video.ltxv import LTXV
|
||||
|
||||
ltxv_model = LTXV(
|
||||
model_filepath = model_filename,
|
||||
loras_filepath = loras_filenames,
|
||||
text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization),
|
||||
dtype = dtype,
|
||||
# quantizeTransformer = quantizeTransformer,
|
||||
@ -2126,11 +2125,8 @@ def load_models(model_filename):
|
||||
perc_reserved_mem_max = args.perc_reserved_mem_max
|
||||
new_transformer_loras_filenames = None
|
||||
dependent_models = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
|
||||
if "_lora" in model_filename:
|
||||
new_transformer_loras_filenames = [model_filename]
|
||||
model_filelist = dependent_models
|
||||
else:
|
||||
model_filelist = dependent_models + [model_filename]
|
||||
new_transformer_loras_filenames = [model_filename] if "_lora" in model_filename else None
|
||||
model_filelist = dependent_models + [model_filename]
|
||||
for filename in model_filelist:
|
||||
download_models(filename)
|
||||
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
|
||||
@ -2142,7 +2138,7 @@ def load_models(model_filename):
|
||||
if model_family == "wan" :
|
||||
wan_model, pipe = load_wan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||
elif model_family == "ltxv":
|
||||
wan_model, pipe = load_ltxv_model(model_filelist, new_transformer_loras_filenames, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||
wan_model, pipe = load_ltxv_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||
elif model_family == "hunyuan":
|
||||
wan_model, pipe = load_hunyuan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||
else:
|
||||
@ -2159,8 +2155,6 @@ def load_models(model_filename):
|
||||
global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer
|
||||
if server_config.get("enhancer_enabled", 0) == 1:
|
||||
from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM )
|
||||
|
||||
|
||||
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True)
|
||||
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True)
|
||||
prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2/Llama3_2_quanto_bf16_int8.safetensors")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user