mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-16 12:13:27 +00:00
fixes
This commit is contained in:
parent
618ecdb999
commit
2a2853a893
@ -148,7 +148,6 @@ class LTXV:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_filepath: str,
|
model_filepath: str,
|
||||||
loras_filepath: str,
|
|
||||||
text_encoder_filepath: str,
|
text_encoder_filepath: str,
|
||||||
dtype = torch.bfloat16,
|
dtype = torch.bfloat16,
|
||||||
VAE_dtype = torch.bfloat16,
|
VAE_dtype = torch.bfloat16,
|
||||||
@ -156,8 +155,8 @@ class LTXV:
|
|||||||
):
|
):
|
||||||
|
|
||||||
self.mixed_precision_transformer = mixed_precision_transformer
|
self.mixed_precision_transformer = mixed_precision_transformer
|
||||||
self.distilled = loras_filepath != None
|
self.distilled = any("lora" in name for name in model_filepath)
|
||||||
|
model_filepath = [name for name in model_filepath if not "lora" in name ]
|
||||||
# with safe_open(ckpt_path, framework="pt") as f:
|
# with safe_open(ckpt_path, framework="pt") as f:
|
||||||
# metadata = f.metadata()
|
# metadata = f.metadata()
|
||||||
# config_str = metadata.get("config")
|
# config_str = metadata.get("config")
|
||||||
|
|||||||
14
wgp.py
14
wgp.py
@ -2066,14 +2066,13 @@ def load_wan_model(model_filename, quantizeTransformer = False, dtype = torch.bf
|
|||||||
pipe["text_encoder_2"] = wan_model.clip.model
|
pipe["text_encoder_2"] = wan_model.clip.model
|
||||||
return wan_model, pipe
|
return wan_model, pipe
|
||||||
|
|
||||||
def load_ltxv_model(model_filename, loras_filenames, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
def load_ltxv_model(model_filename, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
||||||
filename = model_filename[-1]
|
filename = model_filename[-1]
|
||||||
print(f"Loading '{filename}' model...")
|
print(f"Loading '{filename}' model...")
|
||||||
from ltx_video.ltxv import LTXV
|
from ltx_video.ltxv import LTXV
|
||||||
|
|
||||||
ltxv_model = LTXV(
|
ltxv_model = LTXV(
|
||||||
model_filepath = model_filename,
|
model_filepath = model_filename,
|
||||||
loras_filepath = loras_filenames,
|
|
||||||
text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization),
|
text_encoder_filepath = get_ltxv_text_encoder_filename(text_encoder_quantization),
|
||||||
dtype = dtype,
|
dtype = dtype,
|
||||||
# quantizeTransformer = quantizeTransformer,
|
# quantizeTransformer = quantizeTransformer,
|
||||||
@ -2126,11 +2125,8 @@ def load_models(model_filename):
|
|||||||
perc_reserved_mem_max = args.perc_reserved_mem_max
|
perc_reserved_mem_max = args.perc_reserved_mem_max
|
||||||
new_transformer_loras_filenames = None
|
new_transformer_loras_filenames = None
|
||||||
dependent_models = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
|
dependent_models = get_dependent_models(model_filename, quantization= transformer_quantization, dtype_policy = transformer_dtype_policy)
|
||||||
if "_lora" in model_filename:
|
new_transformer_loras_filenames = [model_filename] if "_lora" in model_filename else None
|
||||||
new_transformer_loras_filenames = [model_filename]
|
model_filelist = dependent_models + [model_filename]
|
||||||
model_filelist = dependent_models
|
|
||||||
else:
|
|
||||||
model_filelist = dependent_models + [model_filename]
|
|
||||||
for filename in model_filelist:
|
for filename in model_filelist:
|
||||||
download_models(filename)
|
download_models(filename)
|
||||||
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
|
transformer_dtype = get_transformer_dtype(model_family, transformer_dtype_policy)
|
||||||
@ -2142,7 +2138,7 @@ def load_models(model_filename):
|
|||||||
if model_family == "wan" :
|
if model_family == "wan" :
|
||||||
wan_model, pipe = load_wan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
wan_model, pipe = load_wan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||||
elif model_family == "ltxv":
|
elif model_family == "ltxv":
|
||||||
wan_model, pipe = load_ltxv_model(model_filelist, new_transformer_loras_filenames, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
wan_model, pipe = load_ltxv_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||||
elif model_family == "hunyuan":
|
elif model_family == "hunyuan":
|
||||||
wan_model, pipe = load_hunyuan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
wan_model, pipe = load_hunyuan_model(model_filelist, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
||||||
else:
|
else:
|
||||||
@ -2159,8 +2155,6 @@ def load_models(model_filename):
|
|||||||
global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer
|
global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer
|
||||||
if server_config.get("enhancer_enabled", 0) == 1:
|
if server_config.get("enhancer_enabled", 0) == 1:
|
||||||
from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM )
|
from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM )
|
||||||
|
|
||||||
|
|
||||||
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True)
|
prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True)
|
||||||
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True)
|
prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True)
|
||||||
prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2/Llama3_2_quanto_bf16_int8.safetensors")
|
prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2/Llama3_2_quanto_bf16_int8.safetensors")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user