just another version

This commit is contained in:
deepbeepmeep 2025-07-29 03:52:07 +02:00
parent 6850d60caa
commit 8d12cf08b6
9 changed files with 323 additions and 107 deletions

View File

@ -20,6 +20,20 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates
### July 29 2025: WanGP v7.3 : Wan 2.2 Preview
Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters.
So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment.
However as I felt bad to deliver only half of the wares, I gave you instead ..... Wan 2.2 Vace Cocktail !
Very good surprise indeed, the loras and Vace mostly work with Wan 2.2 !!! I have made a light version of the cocktail that uses only half of the parameters of Wan 2.2, this version has exactly the same RAM requirements. Wan 2.1. Videos baked with half of the model are not so good but maybe they are better than Wan 2.1 . So you tell me if we should keep the light version.
Probably Multitalk should work too, but I have a life to attend to so I will let you test.
Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan...
### July 27 2025: WanGP v7.3 : Interlude
While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family.

24
defaults/t2v_2_2.json Normal file
View File

@ -0,0 +1,24 @@
{
"model":
{
"name": "Wan2.2 Text2video 14B",
"architecture" : "t2v",
"description": "Wan 2.2 Text 2 Video model",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mfp16_int8.safetensors"
],
"URLs2": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mfp16_int8.safetensors"
],
"group": "wan2_2"
},
"switch_threshold" : 875,
"guidance_scale" : 4,
"guidance2_scale" : 3,
"flow_shift" : 12
}

View File

@ -0,0 +1,25 @@
{
"model": {
"name": "Wan2.2 Vace Cocktail 14B",
"architecture": "vace_14B",
"modules": [
"vace_14B"
],
"description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Copy the model def in the finetune folder to change the Cocktail composition.",
"URLs": "t2v_2_2",
"URLs2": "t2v_2_2",
"loras": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors"
],
"loras_multipliers": [1, 0.2, 0.5, 0.5],
"group": "wan2_2"
},
"num_inference_steps": 10,
"guidance_scale": 1,
"guidance_scale": 2,
"flow_shift": 2,
"switch_threshold" : 875
}

View File

@ -0,0 +1,19 @@
{
"model": {
"name": "Wan2.2 Vace Cocktail Light 14B",
"architecture": "vace_14B",
"modules": [
"vace_14B"
],
"description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Only the high noise part of the v2.2 model is used to reduce RAM usage.",
"URLs": "t2v_2_2",
"URLs2": "t2v_2_2",
"loras": "vace_14B_cocktail_2_2",
"loras_multipliers": "vace_14B_cocktail_2_2",
"group": "wan2_2"
},
"num_inference_steps": 10,
"guidance_scale": 1,
"guidance_scale": 2,
"flow_shift": 2
}

View File

@ -17,6 +17,20 @@ from flux.util import (
save_image,
)
from PIL import Image
def stitch_images(img1, img2):
# Resize img2 to match img1's height
width1, height1 = img1.size
width2, height2 = img2.size
new_width2 = int(width2 * height1 / height2)
img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS)
stitched = Image.new('RGB', (width1 + new_width2, height1))
stitched.paste(img1, (0, 0))
stitched.paste(img2_resized, (width1, 0))
return stitched
class model_factory:
def __init__(
self,
@ -72,6 +86,7 @@ class model_factory:
callback = None,
loras_slists = None,
batch_size = 1,
video_prompt_type = "",
**bbargs
):
@ -79,19 +94,30 @@ class model_factory:
return None
device="cuda"
if input_ref_images != None and len(input_ref_images) > 0:
image_ref = input_ref_images[0]
w, h = image_ref.size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
if "I" in video_prompt_type and input_ref_images != None and len(input_ref_images) > 0:
if "K" in video_prompt_type and False :
# image latents tiling method
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
else:
# image stiching method
stiched = input_ref_images[0]
if "K" in video_prompt_type :
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
else:
image_ref = None
input_ref_images = None
inp, height, width = prepare_kontext(
t5=self.t5,
clip=self.clip,
prompt=input_prompt,
ae=self.vae,
img_cond=image_ref,
img_cond_list=input_ref_images,
target_width=width,
target_height=height,
bs=batch_size,
@ -99,7 +125,6 @@ class model_factory:
device=device,
)
inp.pop("img_cond_orig")
timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
def unpack_latent(x):
return unpack(x.float(), height, width)

View File

@ -214,7 +214,7 @@ def prepare_kontext(
clip: HFEmbedder,
prompt: str | list[str],
ae: AutoEncoder,
img_cond: str,
img_cond_list: list,
seed: int,
device: torch.device,
target_width: int | None = None,
@ -225,7 +225,10 @@ def prepare_kontext(
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
if img_cond != None:
img_cond_seq = None
img_cond_seq_ids = None
for cond_no, img_cond in enumerate(img_cond_list):
width, height = img_cond.size
aspect_ratio = width / height
@ -239,20 +242,19 @@ def prepare_kontext(
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
img_cond_orig = img_cond.clone()
with torch.no_grad():
img_cond = ae.encode(img_cond.to(device))
img_cond_latents = ae.encode(img_cond.to(device))
img_cond = img_cond.to(torch.bfloat16)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
img_cond_latents = img_cond_latents.to(torch.bfloat16)
img_cond_latents = rearrange(img_cond_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img_cond.shape[0] == 1 and bs > 1:
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
img_cond_latents = repeat(img_cond_latents, "1 ... -> bs ...", bs=bs)
img_cond = None
# image ids are the same as base image with the first dimension set to 1
# instead of 0
img_cond_ids = torch.zeros(height // 2, width // 2, 3)
img_cond_ids[..., 0] = 1
img_cond_ids[..., 0] = cond_no + 1
img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None]
img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :]
img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs)
@ -262,10 +264,10 @@ def prepare_kontext(
if target_height is None:
target_height = 8 * height
img_cond_ids = img_cond_ids.to(device)
else:
img_cond = None
img_cond_ids = None
img_cond_orig = None
if cond_no == 0:
img_cond_seq, img_cond_seq_ids = img_cond_latents, img_cond_ids
else:
img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, img_cond_latents], dim=1), torch.cat([img_cond_seq_ids, img_cond_ids], dim=1)
img = get_noise(
bs,
@ -277,9 +279,8 @@ def prepare_kontext(
)
return_dict = prepare(t5, clip, img, prompt)
return_dict["img_cond_seq"] = img_cond
return_dict["img_cond_seq_ids"] = img_cond_ids
return_dict["img_cond_orig"] = img_cond_orig
return_dict["img_cond_seq"] = img_cond_seq
return_dict["img_cond_seq_ids"] = img_cond_seq_ids
return return_dict, target_height, target_width

View File

@ -78,6 +78,8 @@ class WanAny2V:
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
self.model_def = model_def
self.model2 = None
self.transformer_switch = model_def.get("URLs2", None) is not None
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
@ -101,24 +103,31 @@ class WanAny2V:
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
device=self.device)
xmodel_filename = "c:/temp/wan2.1_text2video_1.3B_bf16.safetensors"
# config_filename= "configs/t2v_1.3B.json"
# import json
# with open(config_filename, 'r', encoding='utf-8') as f:
# config = json.load(f)
# sd = safetensors2.torch_load_file(xmodel_filename)
# model_filename = "c:/temp/vace1_3B.safetensors"
# model_filename = "c:/temp/wan2.2t2v/high/diffusion_pytorch_model-00001-of-00006.safetensors"
base_config_file = f"configs/{base_model_type}.json"
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
# forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
# model_filename[1] = xmodel_filename
model_filename2 = None
if self.transformer_switch:
model_filename2 = model_filename[1:]
model_filename = model_filename[:1] + model_filename[2:]
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
if model_filename2 is not None:
self.model2 = offload.fast_load_transformers_model(model_filename2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
# self.model = offload.load_model_data(self.model, xmodel_filename )
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd)
# offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
# offload.save_model(self.model, "wan2.2_text2video_14B_high_mbf16.safetensors", config_file_path=base_config_file)
# offload.save_model(self.model, "wan2.2_text2video_14B_high_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
self.model.eval().requires_grad_(False)
if save_quantized:
from wgp import save_quantized_model
@ -136,7 +145,8 @@ class WanAny2V:
seq_len=32760,
keep_last=True)
self.adapt_vace_model()
self.adapt_vace_model(self.model)
if self.model2 is not None: self.adapt_vace_model(self.model2)
self.num_timesteps = 1000
self.use_timestep_transform = True
@ -353,6 +363,8 @@ class WanAny2V:
sample_solver='unipc',
sampling_steps=50,
guide_scale=5.0,
guide2_scale = 5.0,
switch_threshold = 0,
n_prompt="",
seed=-1,
callback = None,
@ -384,6 +396,7 @@ class WanAny2V:
color_correction_strength = 1,
prefix_frames_count = 0,
image_mode = 0,
**bbargs
):
@ -699,9 +712,18 @@ class WanAny2V:
apg_norm_threshold = 55
text_momentumbuffer = MomentumBuffer(apg_momentum)
audio_momentumbuffer = MomentumBuffer(apg_momentum)
guidance_switch_done = False
# denoising
trans = self.model
for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i)
if not guidance_switch_done and t <= switch_threshold:
guide_scale = guide2_scale
if self.model2 is not None: trans = self.model2
guidance_switch_done = True
offload.set_step_no_for_lora(trans, i)
timestep = torch.stack([t])
kwargs.update({"t": timestep, "current_step": i})
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
@ -760,7 +782,7 @@ class WanAny2V:
}
if joint_pass and guide_scale > 1:
ret_values = self.model( **gen_args , **kwargs)
ret_values = trans( **gen_args , **kwargs)
if self._interrupt:
return None
else:
@ -768,7 +790,7 @@ class WanAny2V:
ret_values = [None] * size
for x_id in range(size):
sub_gen_args = {k : [v[x_id]] for k, v in gen_args.items() }
ret_values[x_id] = self.model( **sub_gen_args, x_id= x_id , **kwargs)[0]
ret_values[x_id] = trans( **sub_gen_args, x_id= x_id , **kwargs)[0]
if self._interrupt:
return None
sub_gen_args = None
@ -870,8 +892,7 @@ class WanAny2V:
return { "x" : videos, "latent_slice" : latent_slice }
return videos
def adapt_vace_model(self):
model = self.model
def adapt_vace_model(self, model):
modules_dict= { k: m for k, m in model.named_modules()}
for model_layer, vace_layer in model.vace_layers_mapping.items():
module = modules_dict[f"vace_blocks.{vace_layer}"]
@ -880,4 +901,7 @@ class WanAny2V:
delattr(model, "vace_blocks")
def query_model_def(model_type, model_def):
return None
if "URLs2" in model_def:
return { "no_steps_skipping":True}
else:
return None

View File

@ -211,7 +211,7 @@ class DTT2V:
guide_scale: float = 5.0,
seed: float = 0.0,
overlap_noise: int = 0,
ar_step: int = 5,
model_mode: int = 5,
causal_block_size: int = 5,
causal_attention: bool = True,
fps: int = 24,
@ -231,7 +231,7 @@ class DTT2V:
if frame_num > 1:
frame_num = max(17, frame_num) # must match causal_block_size for value of 5
frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
ar_step = model_mode
if ar_step == 0:
causal_block_size = 1
causal_attention = False

228
wgp.py
View File

@ -52,7 +52,7 @@ PROMPT_VARS_MAX = 10
target_mmgp_version = "3.5.1"
WanGP_version = "7.3"
settings_version = 2.22
settings_version = 2.23
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -178,6 +178,7 @@ def process_prompt_and_add_tasks(state, model_choice):
return get_queue_table(queue)
model_def = get_model_def(model_type)
image_outputs = model_def.get("image_outputs", False)
no_steps_skipping = model_def.get("no_steps_skipping", False)
model_type = get_base_model_type(model_type)
inputs["model_filename"] = model_filename
@ -278,8 +279,12 @@ def process_prompt_and_add_tasks(state, model_choice):
skip_steps_cache_type= inputs["skip_steps_cache_type"]
MMAudio_setting = inputs["MMAudio_setting"]
image_mode = inputs["image_mode"]
switch_threshold = inputs["switch_threshold"]
if no_steps_skipping: skip_steps_cache_type = ""
if switch_threshold is not None and switch_threshold != 0 and len(skip_steps_cache_type) > 0:
gr.Info("Steps skipping is not yet supported if Switch Threshold is not null")
return
if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20:
gr.Info("The minimum number of steps should be 20")
return
@ -495,7 +500,8 @@ def process_prompt_and_add_tasks(state, model_choice):
"denoising_strength": denoising_strength,
"image_prompt_type": image_prompt_type,
"video_prompt_type": video_prompt_type,
"audio_prompt_type": audio_prompt_type,
"audio_prompt_type": audio_prompt_type,
"skip_steps_cache_type": skip_steps_cache_type
}
if inputs["multi_prompts_gen_type"] == 0:
@ -1691,7 +1697,7 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion
print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.")
os.remove( os.path.join("ckpts" , path))
families_infos = {"wan":(0, "Wan2.1"), "ltxv":(1, "LTX Video"), "hunyuan":(2, "Hunyuan Video"), "flux":(3, "Flux 1"), "unknown": (100, "Unknown") }
families_infos = {"wan":(0, "Wan2.1"), "wan2_2":(1, "Wan2.2"), "ltxv":(10, "LTX Video"), "hunyuan":(20, "Hunyuan Video"), "flux":(30, "Flux 1"), "unknown": (100, "Unknown") }
models_def = {}
@ -1764,15 +1770,22 @@ def get_model_type(model_filename):
return None
# raise Exception("Unknown model:" + model_filename)
def get_model_family(model_type):
model_type = get_base_model_type(model_type)
if model_type == None:
def get_model_family(model_type, for_ui = False):
base_model_type = get_base_model_type(model_type)
if base_model_type is None:
return "unknown"
if "hunyuan" in model_type :
if for_ui :
model_def = get_model_def(model_type)
model_family = model_def.get("group", None)
if model_family is not None and model_family in families_infos:
return model_family
if "hunyuan" in base_model_type :
return "hunyuan"
elif "ltxv" in model_type:
elif "ltxv" in base_model_type:
return "ltxv"
elif "flux" in model_type:
elif "flux" in base_model_type:
return "flux"
else:
return "wan"
@ -1855,17 +1868,19 @@ def get_model_recursive_prop(model_type, prop = "URLs", return_list = True, sta
raise Exception(f"Unknown model type '{model_type}'")
def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_module = False, stack=[]):
def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_module = False, submodel_no = 1, stack=[]):
if is_module:
choices = modules_files.get(model_type, None)
if choices == None: raise Exception(f"Invalid Module Id '{model_type}'")
else:
key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}"
model_def = models_def.get(model_type, None)
if model_def == None: return ""
URLs = model_def["URLs"]
URLs = model_def[key_name]
if isinstance(URLs, str):
if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}")
return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, stack = stack + [URLs])
if len(stack) > 10: raise Exception(f"Circular Reference in Model {key_name} dependencies: {stack}")
return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, submodel_no = submodel_no, stack = stack + [URLs])
else:
choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ]
if len(quantization) == 0:
@ -1953,6 +1968,8 @@ def fix_settings(model_type, ui_defaults):
if model_type in ["hunyuan"]:
video_prompt_type = video_prompt_type.replace("I", "")
if model_type in ["flux"] and video_settings_version < 2.23:
video_prompt_type = video_prompt_type.replace("K", "").replace("I", "KI")
remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 1)
if video_settings_version < 2.22:
@ -2037,7 +2054,7 @@ def get_default_settings(model_type):
})
if model_def.get("reference_image", False):
ui_defaults.update({
"video_prompt_type": "I",
"video_prompt_type": "KI",
})
elif base_model_type in ["sky_df_1.3B", "sky_df_14B"]:
ui_defaults.update({
@ -2124,9 +2141,9 @@ def get_model_query_handler(model_type):
model_family= get_model_family(base_model_type)
if model_family == "wan":
if base_model_type in ("sky_df_1.3B", "sky_df_14B"):
from wan.any2video import query_model_def
else:
from wan.diffusion_forcing import query_model_def
else:
from wan.any2video import query_model_def
elif model_family == "hunyuan":
from hyvideo.hunyuan import query_model_def
elif model_family == "ltxv":
@ -2340,7 +2357,7 @@ def download_mmaudio():
}
process_files_def(**enhancer_def)
def download_models(model_filename, model_type):
def download_models(model_filename, model_type, submodel_no = 1):
def computeList(filename):
if filename == None:
return []
@ -2397,23 +2414,27 @@ def download_models(model_filename, model_type):
model_family = get_model_family(model_type)
model_def = get_model_def(model_type)
if model_def != None and not model_type in modules_files:
key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}"
if not model_type in modules_files:
if not os.path.isfile(model_filename ):
URLs = get_model_recursive_prop(model_type, "URLs", return_list= False)
if not isinstance(URLs, str): # dont download anything right now if a base type is referenced as the download will occur just after
use_url = model_filename
for url in URLs:
if os.path.basename(model_filename) in url:
use_url = url
break
if not url.startswith("http"):
raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.")
try:
download_file(use_url, model_filename)
except Exception as e:
if os.path.isfile(model_filename): os.remove(model_filename)
raise Exception(f"URL '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'")
model_filename = None
URLs = get_model_recursive_prop(model_type, key_name, return_list= False)
if isinstance(URLs, str):
raise Exception("Missing model " + URLs)
use_url = model_filename
for url in URLs:
if os.path.basename(model_filename) in url:
use_url = url
break
if not url.startswith("http"):
raise Exception(f"Model '{model_filename}' in field '{key_name}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.")
try:
download_file(use_url, model_filename)
except Exception as e:
if os.path.isfile(model_filename): os.remove(model_filename)
raise Exception(f"{key_name} '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'")
model_filename = None
preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True)
for url in preload_URLs:
@ -2609,6 +2630,9 @@ def load_wan_model(model_filename, model_type, base_model_type, model_def, quant
)
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
if wan_model.model2 is not None:
pipe["transformer2"] = wan_model.model2
if hasattr(wan_model, "clip"):
pipe["text_encoder_2"] = wan_model.clip.model
return wan_model, pipe
@ -2689,9 +2713,16 @@ def load_hunyuan_model(model_filename, model_type = None, base_model_type = Non
return hunyuan_model, pipe
def get_transformer_model(model):
def get_transformer_model(model, submodel_no = 1):
if submodel_no > 1:
model_key = f"model{submodel_no}"
if not hasattr(model, model_key): return None
if hasattr(model, "model"):
return model.model
if submodel_no > 1:
return getattr(model, f"model{submodel_no}")
else:
return model.model
elif hasattr(model, "transformer"):
return model.transformer
else:
@ -2705,6 +2736,10 @@ def load_models(model_type):
preload =int(args.preload)
save_quantized = args.save_quantized and model_def != None
model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy)
if "URLs2" in model_def:
model_filename2 = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy, submodel_no=2) # !!!!
else:
model_filename2 = None
modules = get_model_recursive_prop(model_type, "modules", return_list= True)
if save_quantized and "quanto" in model_filename:
save_quantized = False
@ -2726,7 +2761,9 @@ def load_models(model_type):
preload = server_config.get("preload_in_VRAM", 0)
model_file_list = [model_filename]
model_type_list = [model_type]
new_transformer_filename = model_file_list[-1]
if model_filename2 != None:
model_file_list += [model_filename2]
model_type_list += [model_type]
for module_type in modules:
model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype, is_module= True))
model_type_list.append(module_type)
@ -2750,14 +2787,22 @@ def load_models(model_type):
elif model_family == "hunyuan":
wan_model, pipe = load_hunyuan_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized)
else:
raise Exception(f"Model '{new_transformer_filename}' not supported.")
wan_model._model_file_name = new_transformer_filename
raise Exception(f"Model '{model_filename}' not supported.")
kwargs = { "extraModelsToQuantize": None }
loras_transformer = ["transformer"]
if profile in (2, 4, 5):
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100 if preload == 0 else preload, "*" : max(1000 if profile==5 else 3000 , preload) }
budgets = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100 if preload == 0 else preload, "*" : max(1000 if profile==5 else 3000 , preload) }
if "transformer2" in pipe:
budgets["transformer2"] = 100 if preload == 0 else preload
loras_transformer += ["transformer2"]
kwargs["budgets"] = budgets
elif profile == 3:
kwargs["budgets"] = { "*" : "70%" }
if "transformer2" in pipe and profile in [2,4]:
kwargs["pinnedMemory"] = ["transformer", "transformer2"]
global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer
if server_config.get("enhancer_enabled", 0) == 1:
from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM )
@ -2777,21 +2822,21 @@ def load_models(model_type):
prompt_enhancer_llm_tokenizer = None
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
if len(args.gpu) > 0:
torch.set_default_device(args.gpu)
transformer_type = model_type
return wan_model, offloadobj, pipe["transformer"]
return wan_model, offloadobj
if not "P" in preload_model_policy:
wan_model, offloadobj, transformer = None, None, None
reload_needed = True
else:
wan_model, offloadobj, transformer = load_models(transformer_type)
wan_model, offloadobj = load_models(transformer_type)
if check_loras:
transformer = get_transformer_model(wan_model)
setup_loras(transformer_type, transformer, get_lora_dir(transformer_type), "", None)
exit()
del transformer
gen_in_progress = False
@ -2929,6 +2974,7 @@ def apply_changes( state,
model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy)
state["model_filename"] = model_filename
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats"] for change in changes ):
model_family = gr.Dropdown()
model_choice = gr.Dropdown()
else:
reload_needed = True
@ -3303,7 +3349,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
values += ["unipc" if len(video_sample_solver) ==0 else video_sample_solver]
labels += ["Sampler Solver"]
values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps]
labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Flow Shift", "Num Inference steps"]
labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Shift Scale", "Num Inference steps"]
video_negative_prompt = configs.get("negative_prompt", "")
if len(video_negative_prompt) > 0:
values += [video_negative_prompt]
@ -3973,6 +4019,8 @@ def generate_video(
force_fps,
num_inference_steps,
guidance_scale,
guidance2_scale,
switch_threshold,
audio_guidance_scale,
flow_shift,
sample_solver,
@ -4094,7 +4142,7 @@ def generate_video(
offloadobj = None
gc.collect()
send_cmd("status", f"Loading model {get_model_name(model_type)}...")
wan_model, offloadobj, trans = load_models(model_type)
wan_model, offloadobj = load_models(model_type)
send_cmd("status", "Model loaded")
reload_needed= False
@ -4120,6 +4168,7 @@ def generate_video(
VAE_tile_size = wan_model.vae.get_VAE_tile_size(vae_config, device_mem_capacity, server_config.get("vae_precision", "16") == "32")
trans = get_transformer_model(wan_model)
trans2 = get_transformer_model(wan_model, 2)
audio_sampling_rate = 16000
base_model_type = get_base_model_type(model_type)
@ -4179,11 +4228,15 @@ def generate_video(
loras_selected = transformer_loras_filenames + loras_selected
loras_list_mult_choices_nums = transformer_loras_multipliers + loras_list_mult_choices_nums
loras_slists = transformer_loras_multipliers + loras_slists
offload.load_loras_into_model(trans, loras_selected, loras_list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, base_model_type), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map)
errors = trans._loras_errors
if len(errors) > 0:
error_files = [msg for _ , msg in errors]
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
trans_list = [trans]
if trans2 is not None: trans_list += [trans2]
for trans_item in trans_list:
offload.load_loras_into_model(trans_item, loras_selected, loras_list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, base_model_type), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map)
errors = trans._loras_errors
if len(errors) > 0:
error_files = [msg for _ , msg in errors]
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
trans_item = trans_list = None
seed = None if seed == -1 else seed
# negative_prompt = "" # not applicable in the inference
original_filename = model_filename
@ -4243,7 +4296,7 @@ def generate_video(
any_background_ref = False
outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace):
if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace or flux):
frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions !=None and len(frames_positions)> 0 else []
frames_positions_list = frames_positions_list[:len(image_refs)]
nb_frames_positions = len(frames_positions_list)
@ -4265,11 +4318,13 @@ def generate_video(
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
from wan.utils.utils import resize_and_remove_background
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (vace or hunyuan_avatar) ) # no fit for vace ref images as it is done later
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (vace or hunyuan_avatar or flux) ) # no fit for vace ref images as it is done later
update_task_thumbnails(task, locals())
send_cmd("output")
joint_pass = boost ==1 #and profile != 1 and profile != 3
trans.enable_cache = None if len(skip_steps_cache_type) == 0 else skip_steps_cache_type
if trans2 is not None:
trans2.enable_cache = None
if trans.enable_cache != None:
trans.cache_multiplier = skip_steps_multiplier
@ -4660,6 +4715,8 @@ def generate_video(
sample_solver=sample_solver,
sampling_steps=num_inference_steps,
guide_scale=guidance_scale,
guide2_scale = guidance2_scale,
switch_threshold = switch_threshold,
embedded_guidance_scale=embedded_guidance_scale,
n_prompt=negative_prompt,
seed=seed,
@ -4680,7 +4737,7 @@ def generate_video(
audio_scale= audio_scale,
audio_context_lens= audio_context_lens,
context_scale = context_scale,
ar_step = model_mode, #5
model_mode = model_mode,
causal_block_size = 5,
causal_attention = True,
fps = fps,
@ -4698,6 +4755,7 @@ def generate_video(
NAG_alpha = NAG_alpha,
speakers_bboxes =speakers_bboxes,
image_mode = image_mode,
video_prompt_type= video_prompt_type,
offloadobj = offloadobj,
)
except Exception as e:
@ -4915,6 +4973,9 @@ def generate_video(
seed = set_seed(-1)
clear_status(state)
offload.unload_loras_from_model(trans)
if not trans2 is None:
offload.unload_loras_from_model(trans2)
if len(control_audio_tracks) > 0:
cleanup_temp_audio_files(control_audio_tracks)
@ -5793,6 +5854,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
ltxv = base_model_type in ["ltxv_13B"]
recammaster = base_model_type in ["recam_1.3B"]
phantom = base_model_type in ["phantom_1.3B", "phantom_14B"]
flux = base_model_type in ["flux"]
hunyuan_video_custom = base_model_type in ["hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit"]
model_family = get_model_family(base_model_type)
if target == "settings":
@ -5818,7 +5880,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
if not server_config.get("enhancer_enabled", 0) == 1:
pop += ["prompt_enhancer"]
if not recammaster and not diffusion_forcing:
if not recammaster and not diffusion_forcing and not flux:
pop += ["model_mode"]
if not vace and not phantom and not hunyuan_video_custom:
@ -5837,14 +5899,14 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
if not base_model_type in ["fantasy", "multitalk", "vace_multitalk_14B"]:
pop += ["audio_guidance_scale", "speakers_locations"]
if not model_family in ["hunyuan", "flux"]:
if not model_family in ["hunyuan", "flux"] or model_def.get("no_guidance", False):
pop += ["embedded_guidance_scale"]
if not model_family in ["hunyuan", "wan"]:
pop += ["skip_steps_cache_type", "skip_steps_multiplier", "skip_steps_start_step_perc"]
if model_def.get("no_guidance", False) or ltxv:
pop += ["guidance_scale", "audio_guidance_scale", "embedded_guidance_scale"]
if model_def.get("no_guidance", False) or ltxv or model_family in ["hunyuan", "flux"] :
pop += ["guidance_scale", "guidance2_scale", "switch_threshold", "audio_guidance_scale"]
if model_def.get("image_outputs", False) or ltxv:
pop += ["flow_shift"]
@ -6171,6 +6233,8 @@ def save_inputs(
force_fps,
num_inference_steps,
guidance_scale,
guidance2_scale,
switch_threshold,
audio_guidance_scale,
flow_shift,
sample_solver,
@ -6334,7 +6398,7 @@ def change_model(state, model_choice):
model_filename = get_model_filename(model_choice, transformer_quantization, transformer_dtype_policy)
state["model_filename"] = model_filename
last_model_per_family = state["last_model_per_family"]
last_model_per_family[get_model_family(model_choice)] = model_choice
last_model_per_family[get_model_family(model_choice, for_ui= True)] = model_choice
server_config["last_model_per_family"] = last_model_per_family
server_config["last_model_type"] = model_choice
@ -6366,7 +6430,7 @@ def preload_model_when_switching(state):
gc.collect()
model_filename = get_model_name(model_type)
yield f"Loading model {model_filename}..."
wan_model, offloadobj, _ = load_models(model_type)
wan_model, offloadobj = load_models(model_type)
yield f"Model loaded"
reload_needed= False
return
@ -6760,6 +6824,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
diffusion_forcing = "diffusion_forcing" in model_filename
ltxv = "ltxv" in model_filename
lock_inference_steps = model_def.get("lock_inference_steps", False)
model_reference_image = model_def.get("reference_image", False)
no_steps_skipping = model_def.get("no_steps_skipping", False)
recammaster = base_model_type in ["recam_1.3B"]
vace = test_vace_module(base_model_type)
phantom = base_model_type in ["phantom_1.3B", "phantom_14B"]
@ -6889,7 +6955,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
keep_frames_video_source = gr.Text(visible=False)
any_video_source = False
with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or flux and model_def.get("reference_image", False) ) as video_prompt_column:
with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or flux and model_reference_image) as video_prompt_column:
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
any_control_video = True
@ -6947,7 +7013,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
],
value=filter_letters(video_prompt_type_value, "PDSLCMUV"),
label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True,
)
)
else:
any_control_video = False
any_control_image = False
@ -7012,6 +7078,20 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
visible = True,
label="Reference Images", scale = 2
)
elif flux and model_reference_image:
video_prompt_type_image_refs = gr.Dropdown(
choices=[
("None", ""),
("Inject only People / Objects", "I"),
("Inject Main Subject / Landscape and then People / Objects", "KI"),
],
value=filter_letters(video_prompt_type_value, "KFI"),
visible = True,
show_label=False,
label="Reference Images Combination Method", scale = 2
)
else:
video_prompt_type_image_refs = gr.Dropdown(
choices=[ ("Start / Ref Image", "I")],
@ -7052,11 +7132,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" )
remove_background_images_ref = gr.Dropdown(
choices=[
("Keep Backgrounds behind People / Objects", 0),
("Remove Backgrounds behind People / Objects", 1),
("Keep Backgrounds behind all Reference Images", 0),
("Remove Backgrounds only behind People / Objects", 1),
],
value=ui_defaults.get("remove_background_images_ref",1),
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar and not flux
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar
)
any_audio_voices_support = any_audio_track(base_model_type)
@ -7174,6 +7254,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 5 if fantasy else 4), step=0.5, label="Audio Guidance", visible=(fantasy or multitalk) and not no_guidance)
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 2.5 if flux else 6.0), step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance)
flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs)
with gr.Row(visible = not ltxv and not (no_guidance and image_outputs)) as guidance_row2:
guidance2_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance2_scale",5), step=0.5, label="Guidance2 (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance)
switch_threshold = gr.Slider(0, 1000, value=ui_defaults.get("switch_threshold",0), step=1, label="Guidance / Model Switch Threshold", visible=not (hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance)
with gr.Row(visible = get_model_family(model_type) == "wan" and not diffusion_forcing ) as sample_solver_row:
sample_solver = gr.Dropdown( value=ui_defaults.get("sample_solver",""),
choices=[
@ -7214,7 +7298,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
label="Activated Loras"
)
loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by Space chars or CR, lines that start with # are ignored", value=launch_multis_str)
with gr.Tab("Steps Skipping", visible = not (ltxv or image_outputs)) as speed_tab:
with gr.Tab("Steps Skipping", visible = not (ltxv or image_outputs) and not no_steps_skipping) as speed_tab:
with gr.Column():
gr.Markdown("<B>Tea Cache and Mag Cache accelerate the Video Generation by skipping intelligently some steps, the more steps are skipped the lower the quality of the video.</B>")
gr.Markdown("<B>Steps Skipping consumes also VRAM. It is recommended not to skip at least the first 10% steps.</B>")
@ -7577,7 +7661,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn,
NAG_col, speakers_locations_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, min_frames_if_references_col] # presets_column,
NAG_col, speakers_locations_row, guidance_row, guidance_row2, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, min_frames_if_references_col] # presets_column,
if update_form:
locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
@ -8265,7 +8349,7 @@ def compact_name(family_name, model_name):
return model_name
def get_sorted_dropdown(dropdown_types, current_model_family):
models_families = [get_model_family(type) for type in dropdown_types]
models_families = [get_model_family(type, for_ui= True) for type in dropdown_types]
families = {}
for family in models_families:
if family not in families: families[family] = 1
@ -8285,7 +8369,7 @@ def generate_dropdown_model_list(current_model_type):
dropdown_types= transformer_types if len(transformer_types) > 0 else displayed_model_types
if current_model_type not in dropdown_types:
dropdown_types.append(current_model_type)
current_model_family = get_model_family(current_model_type)
current_model_family = get_model_family(current_model_type, for_ui= True)
sorted_familes, dropdown_choices = get_sorted_dropdown(dropdown_types, current_model_family)
dropdown_families = gr.Dropdown(
@ -8308,12 +8392,12 @@ def generate_dropdown_model_list(current_model_type):
def change_model_family(state, current_model_family):
dropdown_types= transformer_types if len(transformer_types) > 0 else displayed_model_types
current_family_name = families_infos[current_model_family][1]
models_families = [get_model_family(type) for type in dropdown_types]
models_families = [get_model_family(type, for_ui= True) for type in dropdown_types]
dropdown_choices = [ (compact_name(current_family_name, get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family ]
dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0])
last_model_per_family = state.get("last_model_per_family", {})
model_type = last_model_per_family.get(current_model_family, "")
if len(model_type) == "" or model_type not in displayed_model_types: model_type = dropdown_choices[0][1]
if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices] : model_type = dropdown_choices[0][1]
return gr.Dropdown(choices= dropdown_choices, value = model_type )
def set_new_tab(tab_state, new_tab_no):