mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 11:43:21 +00:00
just another version
This commit is contained in:
parent
6850d60caa
commit
8d12cf08b6
14
README.md
14
README.md
@ -20,6 +20,20 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||
|
||||
## 🔥 Latest Updates
|
||||
### July 29 2025: WanGP v7.3 : Wan 2.2 Preview
|
||||
|
||||
Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters.
|
||||
|
||||
So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment.
|
||||
|
||||
However as I felt bad to deliver only half of the wares, I gave you instead ..... Wan 2.2 Vace Cocktail !
|
||||
|
||||
Very good surprise indeed, the loras and Vace mostly work with Wan 2.2 !!! I have made a light version of the cocktail that uses only half of the parameters of Wan 2.2, this version has exactly the same RAM requirements. Wan 2.1. Videos baked with half of the model are not so good but maybe they are better than Wan 2.1 . So you tell me if we should keep the light version.
|
||||
|
||||
Probably Multitalk should work too, but I have a life to attend to so I will let you test.
|
||||
|
||||
Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan...
|
||||
|
||||
### July 27 2025: WanGP v7.3 : Interlude
|
||||
While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family.
|
||||
|
||||
|
||||
24
defaults/t2v_2_2.json
Normal file
24
defaults/t2v_2_2.json
Normal file
@ -0,0 +1,24 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Wan2.2 Text2video 14B",
|
||||
"architecture" : "t2v",
|
||||
"description": "Wan 2.2 Text 2 Video model",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_mbf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mbf16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_quanto_mfp16_int8.safetensors"
|
||||
],
|
||||
"URLs2": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_mbf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mbf16_int8.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_low_quanto_mfp16_int8.safetensors"
|
||||
],
|
||||
"group": "wan2_2"
|
||||
},
|
||||
"switch_threshold" : 875,
|
||||
"guidance_scale" : 4,
|
||||
"guidance2_scale" : 3,
|
||||
"flow_shift" : 12
|
||||
|
||||
}
|
||||
25
defaults/vace_14B_cocktail_2_2.json
Normal file
25
defaults/vace_14B_cocktail_2_2.json
Normal file
@ -0,0 +1,25 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Wan2.2 Vace Cocktail 14B",
|
||||
"architecture": "vace_14B",
|
||||
"modules": [
|
||||
"vace_14B"
|
||||
],
|
||||
"description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Copy the model def in the finetune folder to change the Cocktail composition.",
|
||||
"URLs": "t2v_2_2",
|
||||
"URLs2": "t2v_2_2",
|
||||
"loras": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors"
|
||||
],
|
||||
"loras_multipliers": [1, 0.2, 0.5, 0.5],
|
||||
"group": "wan2_2"
|
||||
},
|
||||
"num_inference_steps": 10,
|
||||
"guidance_scale": 1,
|
||||
"guidance_scale": 2,
|
||||
"flow_shift": 2,
|
||||
"switch_threshold" : 875
|
||||
}
|
||||
19
defaults/vace_14B_cocktail_2_2_light.json
Normal file
19
defaults/vace_14B_cocktail_2_2_light.json
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Wan2.2 Vace Cocktail Light 14B",
|
||||
"architecture": "vace_14B",
|
||||
"modules": [
|
||||
"vace_14B"
|
||||
],
|
||||
"description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Only the high noise part of the v2.2 model is used to reduce RAM usage.",
|
||||
"URLs": "t2v_2_2",
|
||||
"URLs2": "t2v_2_2",
|
||||
"loras": "vace_14B_cocktail_2_2",
|
||||
"loras_multipliers": "vace_14B_cocktail_2_2",
|
||||
"group": "wan2_2"
|
||||
},
|
||||
"num_inference_steps": 10,
|
||||
"guidance_scale": 1,
|
||||
"guidance_scale": 2,
|
||||
"flow_shift": 2
|
||||
}
|
||||
@ -17,6 +17,20 @@ from flux.util import (
|
||||
save_image,
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
def stitch_images(img1, img2):
|
||||
# Resize img2 to match img1's height
|
||||
width1, height1 = img1.size
|
||||
width2, height2 = img2.size
|
||||
new_width2 = int(width2 * height1 / height2)
|
||||
img2_resized = img2.resize((new_width2, height1), Image.Resampling.LANCZOS)
|
||||
|
||||
stitched = Image.new('RGB', (width1 + new_width2, height1))
|
||||
stitched.paste(img1, (0, 0))
|
||||
stitched.paste(img2_resized, (width1, 0))
|
||||
return stitched
|
||||
|
||||
class model_factory:
|
||||
def __init__(
|
||||
self,
|
||||
@ -72,6 +86,7 @@ class model_factory:
|
||||
callback = None,
|
||||
loras_slists = None,
|
||||
batch_size = 1,
|
||||
video_prompt_type = "",
|
||||
**bbargs
|
||||
):
|
||||
|
||||
@ -79,19 +94,30 @@ class model_factory:
|
||||
return None
|
||||
|
||||
device="cuda"
|
||||
if input_ref_images != None and len(input_ref_images) > 0:
|
||||
image_ref = input_ref_images[0]
|
||||
w, h = image_ref.size
|
||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||
if "I" in video_prompt_type and input_ref_images != None and len(input_ref_images) > 0:
|
||||
if "K" in video_prompt_type and False :
|
||||
# image latents tiling method
|
||||
w, h = input_ref_images[0].size
|
||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||
else:
|
||||
# image stiching method
|
||||
stiched = input_ref_images[0]
|
||||
if "K" in video_prompt_type :
|
||||
w, h = input_ref_images[0].size
|
||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||
|
||||
for new_img in input_ref_images[1:]:
|
||||
stiched = stitch_images(stiched, new_img)
|
||||
input_ref_images = [stiched]
|
||||
else:
|
||||
image_ref = None
|
||||
input_ref_images = None
|
||||
|
||||
inp, height, width = prepare_kontext(
|
||||
t5=self.t5,
|
||||
clip=self.clip,
|
||||
prompt=input_prompt,
|
||||
ae=self.vae,
|
||||
img_cond=image_ref,
|
||||
img_cond_list=input_ref_images,
|
||||
target_width=width,
|
||||
target_height=height,
|
||||
bs=batch_size,
|
||||
@ -99,7 +125,6 @@ class model_factory:
|
||||
device=device,
|
||||
)
|
||||
|
||||
inp.pop("img_cond_orig")
|
||||
timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
|
||||
def unpack_latent(x):
|
||||
return unpack(x.float(), height, width)
|
||||
|
||||
@ -214,7 +214,7 @@ def prepare_kontext(
|
||||
clip: HFEmbedder,
|
||||
prompt: str | list[str],
|
||||
ae: AutoEncoder,
|
||||
img_cond: str,
|
||||
img_cond_list: list,
|
||||
seed: int,
|
||||
device: torch.device,
|
||||
target_width: int | None = None,
|
||||
@ -225,7 +225,10 @@ def prepare_kontext(
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
if img_cond != None:
|
||||
img_cond_seq = None
|
||||
img_cond_seq_ids = None
|
||||
|
||||
for cond_no, img_cond in enumerate(img_cond_list):
|
||||
width, height = img_cond.size
|
||||
aspect_ratio = width / height
|
||||
|
||||
@ -239,20 +242,19 @@ def prepare_kontext(
|
||||
img_cond = np.array(img_cond)
|
||||
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
|
||||
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
|
||||
img_cond_orig = img_cond.clone()
|
||||
|
||||
with torch.no_grad():
|
||||
img_cond = ae.encode(img_cond.to(device))
|
||||
img_cond_latents = ae.encode(img_cond.to(device))
|
||||
|
||||
img_cond = img_cond.to(torch.bfloat16)
|
||||
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
img_cond_latents = img_cond_latents.to(torch.bfloat16)
|
||||
img_cond_latents = rearrange(img_cond_latents, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img_cond.shape[0] == 1 and bs > 1:
|
||||
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
|
||||
img_cond_latents = repeat(img_cond_latents, "1 ... -> bs ...", bs=bs)
|
||||
img_cond = None
|
||||
|
||||
# image ids are the same as base image with the first dimension set to 1
|
||||
# instead of 0
|
||||
img_cond_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
img_cond_ids[..., 0] = 1
|
||||
img_cond_ids[..., 0] = cond_no + 1
|
||||
img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs)
|
||||
@ -262,10 +264,10 @@ def prepare_kontext(
|
||||
if target_height is None:
|
||||
target_height = 8 * height
|
||||
img_cond_ids = img_cond_ids.to(device)
|
||||
else:
|
||||
img_cond = None
|
||||
img_cond_ids = None
|
||||
img_cond_orig = None
|
||||
if cond_no == 0:
|
||||
img_cond_seq, img_cond_seq_ids = img_cond_latents, img_cond_ids
|
||||
else:
|
||||
img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, img_cond_latents], dim=1), torch.cat([img_cond_seq_ids, img_cond_ids], dim=1)
|
||||
|
||||
img = get_noise(
|
||||
bs,
|
||||
@ -277,9 +279,8 @@ def prepare_kontext(
|
||||
)
|
||||
|
||||
return_dict = prepare(t5, clip, img, prompt)
|
||||
return_dict["img_cond_seq"] = img_cond
|
||||
return_dict["img_cond_seq_ids"] = img_cond_ids
|
||||
return_dict["img_cond_orig"] = img_cond_orig
|
||||
return_dict["img_cond_seq"] = img_cond_seq
|
||||
return_dict["img_cond_seq_ids"] = img_cond_seq_ids
|
||||
return return_dict, target_height, target_width
|
||||
|
||||
|
||||
|
||||
@ -78,6 +78,8 @@ class WanAny2V:
|
||||
self.num_train_timesteps = config.num_train_timesteps
|
||||
self.param_dtype = config.param_dtype
|
||||
self.model_def = model_def
|
||||
self.model2 = None
|
||||
self.transformer_switch = model_def.get("URLs2", None) is not None
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
dtype=config.t5_dtype,
|
||||
@ -101,24 +103,31 @@ class WanAny2V:
|
||||
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
||||
device=self.device)
|
||||
|
||||
xmodel_filename = "c:/temp/wan2.1_text2video_1.3B_bf16.safetensors"
|
||||
# config_filename= "configs/t2v_1.3B.json"
|
||||
# import json
|
||||
# with open(config_filename, 'r', encoding='utf-8') as f:
|
||||
# config = json.load(f)
|
||||
# sd = safetensors2.torch_load_file(xmodel_filename)
|
||||
# model_filename = "c:/temp/vace1_3B.safetensors"
|
||||
# model_filename = "c:/temp/wan2.2t2v/high/diffusion_pytorch_model-00001-of-00006.safetensors"
|
||||
base_config_file = f"configs/{base_model_type}.json"
|
||||
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
|
||||
# forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
|
||||
# model_filename[1] = xmodel_filename
|
||||
model_filename2 = None
|
||||
if self.transformer_switch:
|
||||
model_filename2 = model_filename[1:]
|
||||
model_filename = model_filename[:1] + model_filename[2:]
|
||||
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
|
||||
if model_filename2 is not None:
|
||||
self.model2 = offload.fast_load_transformers_model(model_filename2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
|
||||
|
||||
# self.model = offload.load_model_data(self.model, xmodel_filename )
|
||||
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
||||
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
|
||||
offload.change_dtype(self.model, dtype, True)
|
||||
# offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd)
|
||||
# offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
|
||||
# offload.save_model(self.model, "wan2.2_text2video_14B_high_mbf16.safetensors", config_file_path=base_config_file)
|
||||
# offload.save_model(self.model, "wan2.2_text2video_14B_high_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
|
||||
self.model.eval().requires_grad_(False)
|
||||
if save_quantized:
|
||||
from wgp import save_quantized_model
|
||||
@ -136,7 +145,8 @@ class WanAny2V:
|
||||
seq_len=32760,
|
||||
keep_last=True)
|
||||
|
||||
self.adapt_vace_model()
|
||||
self.adapt_vace_model(self.model)
|
||||
if self.model2 is not None: self.adapt_vace_model(self.model2)
|
||||
|
||||
self.num_timesteps = 1000
|
||||
self.use_timestep_transform = True
|
||||
@ -353,6 +363,8 @@ class WanAny2V:
|
||||
sample_solver='unipc',
|
||||
sampling_steps=50,
|
||||
guide_scale=5.0,
|
||||
guide2_scale = 5.0,
|
||||
switch_threshold = 0,
|
||||
n_prompt="",
|
||||
seed=-1,
|
||||
callback = None,
|
||||
@ -384,6 +396,7 @@ class WanAny2V:
|
||||
color_correction_strength = 1,
|
||||
prefix_frames_count = 0,
|
||||
image_mode = 0,
|
||||
|
||||
**bbargs
|
||||
):
|
||||
|
||||
@ -699,9 +712,18 @@ class WanAny2V:
|
||||
apg_norm_threshold = 55
|
||||
text_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||
audio_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||
|
||||
guidance_switch_done = False
|
||||
|
||||
# denoising
|
||||
trans = self.model
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
offload.set_step_no_for_lora(self.model, i)
|
||||
if not guidance_switch_done and t <= switch_threshold:
|
||||
guide_scale = guide2_scale
|
||||
if self.model2 is not None: trans = self.model2
|
||||
guidance_switch_done = True
|
||||
|
||||
offload.set_step_no_for_lora(trans, i)
|
||||
timestep = torch.stack([t])
|
||||
kwargs.update({"t": timestep, "current_step": i})
|
||||
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
|
||||
@ -760,7 +782,7 @@ class WanAny2V:
|
||||
}
|
||||
|
||||
if joint_pass and guide_scale > 1:
|
||||
ret_values = self.model( **gen_args , **kwargs)
|
||||
ret_values = trans( **gen_args , **kwargs)
|
||||
if self._interrupt:
|
||||
return None
|
||||
else:
|
||||
@ -768,7 +790,7 @@ class WanAny2V:
|
||||
ret_values = [None] * size
|
||||
for x_id in range(size):
|
||||
sub_gen_args = {k : [v[x_id]] for k, v in gen_args.items() }
|
||||
ret_values[x_id] = self.model( **sub_gen_args, x_id= x_id , **kwargs)[0]
|
||||
ret_values[x_id] = trans( **sub_gen_args, x_id= x_id , **kwargs)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
sub_gen_args = None
|
||||
@ -870,8 +892,7 @@ class WanAny2V:
|
||||
return { "x" : videos, "latent_slice" : latent_slice }
|
||||
return videos
|
||||
|
||||
def adapt_vace_model(self):
|
||||
model = self.model
|
||||
def adapt_vace_model(self, model):
|
||||
modules_dict= { k: m for k, m in model.named_modules()}
|
||||
for model_layer, vace_layer in model.vace_layers_mapping.items():
|
||||
module = modules_dict[f"vace_blocks.{vace_layer}"]
|
||||
@ -880,4 +901,7 @@ class WanAny2V:
|
||||
delattr(model, "vace_blocks")
|
||||
|
||||
def query_model_def(model_type, model_def):
|
||||
return None
|
||||
if "URLs2" in model_def:
|
||||
return { "no_steps_skipping":True}
|
||||
else:
|
||||
return None
|
||||
@ -211,7 +211,7 @@ class DTT2V:
|
||||
guide_scale: float = 5.0,
|
||||
seed: float = 0.0,
|
||||
overlap_noise: int = 0,
|
||||
ar_step: int = 5,
|
||||
model_mode: int = 5,
|
||||
causal_block_size: int = 5,
|
||||
causal_attention: bool = True,
|
||||
fps: int = 24,
|
||||
@ -231,7 +231,7 @@ class DTT2V:
|
||||
if frame_num > 1:
|
||||
frame_num = max(17, frame_num) # must match causal_block_size for value of 5
|
||||
frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
|
||||
|
||||
ar_step = model_mode
|
||||
if ar_step == 0:
|
||||
causal_block_size = 1
|
||||
causal_attention = False
|
||||
|
||||
228
wgp.py
228
wgp.py
@ -52,7 +52,7 @@ PROMPT_VARS_MAX = 10
|
||||
|
||||
target_mmgp_version = "3.5.1"
|
||||
WanGP_version = "7.3"
|
||||
settings_version = 2.22
|
||||
settings_version = 2.23
|
||||
max_source_video_frames = 3000
|
||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
|
||||
|
||||
@ -178,6 +178,7 @@ def process_prompt_and_add_tasks(state, model_choice):
|
||||
return get_queue_table(queue)
|
||||
model_def = get_model_def(model_type)
|
||||
image_outputs = model_def.get("image_outputs", False)
|
||||
no_steps_skipping = model_def.get("no_steps_skipping", False)
|
||||
model_type = get_base_model_type(model_type)
|
||||
inputs["model_filename"] = model_filename
|
||||
|
||||
@ -278,8 +279,12 @@ def process_prompt_and_add_tasks(state, model_choice):
|
||||
skip_steps_cache_type= inputs["skip_steps_cache_type"]
|
||||
MMAudio_setting = inputs["MMAudio_setting"]
|
||||
image_mode = inputs["image_mode"]
|
||||
switch_threshold = inputs["switch_threshold"]
|
||||
|
||||
|
||||
if no_steps_skipping: skip_steps_cache_type = ""
|
||||
if switch_threshold is not None and switch_threshold != 0 and len(skip_steps_cache_type) > 0:
|
||||
gr.Info("Steps skipping is not yet supported if Switch Threshold is not null")
|
||||
return
|
||||
if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20:
|
||||
gr.Info("The minimum number of steps should be 20")
|
||||
return
|
||||
@ -495,7 +500,8 @@ def process_prompt_and_add_tasks(state, model_choice):
|
||||
"denoising_strength": denoising_strength,
|
||||
"image_prompt_type": image_prompt_type,
|
||||
"video_prompt_type": video_prompt_type,
|
||||
"audio_prompt_type": audio_prompt_type,
|
||||
"audio_prompt_type": audio_prompt_type,
|
||||
"skip_steps_cache_type": skip_steps_cache_type
|
||||
}
|
||||
|
||||
if inputs["multi_prompts_gen_type"] == 0:
|
||||
@ -1691,7 +1697,7 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion
|
||||
print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.")
|
||||
os.remove( os.path.join("ckpts" , path))
|
||||
|
||||
families_infos = {"wan":(0, "Wan2.1"), "ltxv":(1, "LTX Video"), "hunyuan":(2, "Hunyuan Video"), "flux":(3, "Flux 1"), "unknown": (100, "Unknown") }
|
||||
families_infos = {"wan":(0, "Wan2.1"), "wan2_2":(1, "Wan2.2"), "ltxv":(10, "LTX Video"), "hunyuan":(20, "Hunyuan Video"), "flux":(30, "Flux 1"), "unknown": (100, "Unknown") }
|
||||
|
||||
models_def = {}
|
||||
|
||||
@ -1764,15 +1770,22 @@ def get_model_type(model_filename):
|
||||
return None
|
||||
# raise Exception("Unknown model:" + model_filename)
|
||||
|
||||
def get_model_family(model_type):
|
||||
model_type = get_base_model_type(model_type)
|
||||
if model_type == None:
|
||||
def get_model_family(model_type, for_ui = False):
|
||||
base_model_type = get_base_model_type(model_type)
|
||||
if base_model_type is None:
|
||||
return "unknown"
|
||||
if "hunyuan" in model_type :
|
||||
|
||||
if for_ui :
|
||||
model_def = get_model_def(model_type)
|
||||
model_family = model_def.get("group", None)
|
||||
if model_family is not None and model_family in families_infos:
|
||||
return model_family
|
||||
|
||||
if "hunyuan" in base_model_type :
|
||||
return "hunyuan"
|
||||
elif "ltxv" in model_type:
|
||||
elif "ltxv" in base_model_type:
|
||||
return "ltxv"
|
||||
elif "flux" in model_type:
|
||||
elif "flux" in base_model_type:
|
||||
return "flux"
|
||||
else:
|
||||
return "wan"
|
||||
@ -1855,17 +1868,19 @@ def get_model_recursive_prop(model_type, prop = "URLs", return_list = True, sta
|
||||
raise Exception(f"Unknown model type '{model_type}'")
|
||||
|
||||
|
||||
def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_module = False, stack=[]):
|
||||
def get_model_filename(model_type, quantization ="int8", dtype_policy = "", is_module = False, submodel_no = 1, stack=[]):
|
||||
if is_module:
|
||||
choices = modules_files.get(model_type, None)
|
||||
if choices == None: raise Exception(f"Invalid Module Id '{model_type}'")
|
||||
else:
|
||||
key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}"
|
||||
|
||||
model_def = models_def.get(model_type, None)
|
||||
if model_def == None: return ""
|
||||
URLs = model_def["URLs"]
|
||||
URLs = model_def[key_name]
|
||||
if isinstance(URLs, str):
|
||||
if len(stack) > 10: raise Exception(f"Circular Reference in Model URLs dependencies: {stack}")
|
||||
return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, stack = stack + [URLs])
|
||||
if len(stack) > 10: raise Exception(f"Circular Reference in Model {key_name} dependencies: {stack}")
|
||||
return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, submodel_no = submodel_no, stack = stack + [URLs])
|
||||
else:
|
||||
choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ]
|
||||
if len(quantization) == 0:
|
||||
@ -1953,6 +1968,8 @@ def fix_settings(model_type, ui_defaults):
|
||||
if model_type in ["hunyuan"]:
|
||||
video_prompt_type = video_prompt_type.replace("I", "")
|
||||
|
||||
if model_type in ["flux"] and video_settings_version < 2.23:
|
||||
video_prompt_type = video_prompt_type.replace("K", "").replace("I", "KI")
|
||||
|
||||
remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 1)
|
||||
if video_settings_version < 2.22:
|
||||
@ -2037,7 +2054,7 @@ def get_default_settings(model_type):
|
||||
})
|
||||
if model_def.get("reference_image", False):
|
||||
ui_defaults.update({
|
||||
"video_prompt_type": "I",
|
||||
"video_prompt_type": "KI",
|
||||
})
|
||||
elif base_model_type in ["sky_df_1.3B", "sky_df_14B"]:
|
||||
ui_defaults.update({
|
||||
@ -2124,9 +2141,9 @@ def get_model_query_handler(model_type):
|
||||
model_family= get_model_family(base_model_type)
|
||||
if model_family == "wan":
|
||||
if base_model_type in ("sky_df_1.3B", "sky_df_14B"):
|
||||
from wan.any2video import query_model_def
|
||||
else:
|
||||
from wan.diffusion_forcing import query_model_def
|
||||
else:
|
||||
from wan.any2video import query_model_def
|
||||
elif model_family == "hunyuan":
|
||||
from hyvideo.hunyuan import query_model_def
|
||||
elif model_family == "ltxv":
|
||||
@ -2340,7 +2357,7 @@ def download_mmaudio():
|
||||
}
|
||||
process_files_def(**enhancer_def)
|
||||
|
||||
def download_models(model_filename, model_type):
|
||||
def download_models(model_filename, model_type, submodel_no = 1):
|
||||
def computeList(filename):
|
||||
if filename == None:
|
||||
return []
|
||||
@ -2397,23 +2414,27 @@ def download_models(model_filename, model_type):
|
||||
|
||||
model_family = get_model_family(model_type)
|
||||
model_def = get_model_def(model_type)
|
||||
if model_def != None and not model_type in modules_files:
|
||||
|
||||
key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}"
|
||||
if not model_type in modules_files:
|
||||
if not os.path.isfile(model_filename ):
|
||||
URLs = get_model_recursive_prop(model_type, "URLs", return_list= False)
|
||||
if not isinstance(URLs, str): # dont download anything right now if a base type is referenced as the download will occur just after
|
||||
use_url = model_filename
|
||||
for url in URLs:
|
||||
if os.path.basename(model_filename) in url:
|
||||
use_url = url
|
||||
break
|
||||
if not url.startswith("http"):
|
||||
raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.")
|
||||
try:
|
||||
download_file(use_url, model_filename)
|
||||
except Exception as e:
|
||||
if os.path.isfile(model_filename): os.remove(model_filename)
|
||||
raise Exception(f"URL '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'")
|
||||
model_filename = None
|
||||
URLs = get_model_recursive_prop(model_type, key_name, return_list= False)
|
||||
if isinstance(URLs, str):
|
||||
raise Exception("Missing model " + URLs)
|
||||
use_url = model_filename
|
||||
for url in URLs:
|
||||
if os.path.basename(model_filename) in url:
|
||||
use_url = url
|
||||
break
|
||||
if not url.startswith("http"):
|
||||
raise Exception(f"Model '{model_filename}' in field '{key_name}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.")
|
||||
try:
|
||||
download_file(use_url, model_filename)
|
||||
except Exception as e:
|
||||
if os.path.isfile(model_filename): os.remove(model_filename)
|
||||
raise Exception(f"{key_name} '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'")
|
||||
|
||||
model_filename = None
|
||||
|
||||
preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True)
|
||||
for url in preload_URLs:
|
||||
@ -2609,6 +2630,9 @@ def load_wan_model(model_filename, model_type, base_model_type, model_def, quant
|
||||
)
|
||||
|
||||
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
|
||||
if wan_model.model2 is not None:
|
||||
pipe["transformer2"] = wan_model.model2
|
||||
|
||||
if hasattr(wan_model, "clip"):
|
||||
pipe["text_encoder_2"] = wan_model.clip.model
|
||||
return wan_model, pipe
|
||||
@ -2689,9 +2713,16 @@ def load_hunyuan_model(model_filename, model_type = None, base_model_type = Non
|
||||
|
||||
return hunyuan_model, pipe
|
||||
|
||||
def get_transformer_model(model):
|
||||
def get_transformer_model(model, submodel_no = 1):
|
||||
if submodel_no > 1:
|
||||
model_key = f"model{submodel_no}"
|
||||
if not hasattr(model, model_key): return None
|
||||
|
||||
if hasattr(model, "model"):
|
||||
return model.model
|
||||
if submodel_no > 1:
|
||||
return getattr(model, f"model{submodel_no}")
|
||||
else:
|
||||
return model.model
|
||||
elif hasattr(model, "transformer"):
|
||||
return model.transformer
|
||||
else:
|
||||
@ -2705,6 +2736,10 @@ def load_models(model_type):
|
||||
preload =int(args.preload)
|
||||
save_quantized = args.save_quantized and model_def != None
|
||||
model_filename = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy)
|
||||
if "URLs2" in model_def:
|
||||
model_filename2 = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy, submodel_no=2) # !!!!
|
||||
else:
|
||||
model_filename2 = None
|
||||
modules = get_model_recursive_prop(model_type, "modules", return_list= True)
|
||||
if save_quantized and "quanto" in model_filename:
|
||||
save_quantized = False
|
||||
@ -2726,7 +2761,9 @@ def load_models(model_type):
|
||||
preload = server_config.get("preload_in_VRAM", 0)
|
||||
model_file_list = [model_filename]
|
||||
model_type_list = [model_type]
|
||||
new_transformer_filename = model_file_list[-1]
|
||||
if model_filename2 != None:
|
||||
model_file_list += [model_filename2]
|
||||
model_type_list += [model_type]
|
||||
for module_type in modules:
|
||||
model_file_list.append(get_model_filename(module_type, transformer_quantization, transformer_dtype, is_module= True))
|
||||
model_type_list.append(module_type)
|
||||
@ -2750,14 +2787,22 @@ def load_models(model_type):
|
||||
elif model_family == "hunyuan":
|
||||
wan_model, pipe = load_hunyuan_model(model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized)
|
||||
else:
|
||||
raise Exception(f"Model '{new_transformer_filename}' not supported.")
|
||||
wan_model._model_file_name = new_transformer_filename
|
||||
raise Exception(f"Model '{model_filename}' not supported.")
|
||||
kwargs = { "extraModelsToQuantize": None }
|
||||
loras_transformer = ["transformer"]
|
||||
if profile in (2, 4, 5):
|
||||
kwargs["budgets"] = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100 if preload == 0 else preload, "*" : max(1000 if profile==5 else 3000 , preload) }
|
||||
budgets = { "transformer" : 100 if preload == 0 else preload, "text_encoder" : 100 if preload == 0 else preload, "*" : max(1000 if profile==5 else 3000 , preload) }
|
||||
if "transformer2" in pipe:
|
||||
budgets["transformer2"] = 100 if preload == 0 else preload
|
||||
loras_transformer += ["transformer2"]
|
||||
kwargs["budgets"] = budgets
|
||||
elif profile == 3:
|
||||
kwargs["budgets"] = { "*" : "70%" }
|
||||
|
||||
|
||||
if "transformer2" in pipe and profile in [2,4]:
|
||||
kwargs["pinnedMemory"] = ["transformer", "transformer2"]
|
||||
|
||||
|
||||
global prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer
|
||||
if server_config.get("enhancer_enabled", 0) == 1:
|
||||
from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM )
|
||||
@ -2777,21 +2822,21 @@ def load_models(model_type):
|
||||
prompt_enhancer_llm_tokenizer = None
|
||||
|
||||
|
||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = "transformer", coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
|
||||
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
|
||||
if len(args.gpu) > 0:
|
||||
torch.set_default_device(args.gpu)
|
||||
transformer_type = model_type
|
||||
return wan_model, offloadobj, pipe["transformer"]
|
||||
return wan_model, offloadobj
|
||||
|
||||
if not "P" in preload_model_policy:
|
||||
wan_model, offloadobj, transformer = None, None, None
|
||||
reload_needed = True
|
||||
else:
|
||||
wan_model, offloadobj, transformer = load_models(transformer_type)
|
||||
wan_model, offloadobj = load_models(transformer_type)
|
||||
if check_loras:
|
||||
transformer = get_transformer_model(wan_model)
|
||||
setup_loras(transformer_type, transformer, get_lora_dir(transformer_type), "", None)
|
||||
exit()
|
||||
del transformer
|
||||
|
||||
gen_in_progress = False
|
||||
|
||||
@ -2929,6 +2974,7 @@ def apply_changes( state,
|
||||
model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy)
|
||||
state["model_filename"] = model_filename
|
||||
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats"] for change in changes ):
|
||||
model_family = gr.Dropdown()
|
||||
model_choice = gr.Dropdown()
|
||||
else:
|
||||
reload_needed = True
|
||||
@ -3303,7 +3349,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
|
||||
values += ["unipc" if len(video_sample_solver) ==0 else video_sample_solver]
|
||||
labels += ["Sampler Solver"]
|
||||
values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps]
|
||||
labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Flow Shift", "Num Inference steps"]
|
||||
labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Shift Scale", "Num Inference steps"]
|
||||
video_negative_prompt = configs.get("negative_prompt", "")
|
||||
if len(video_negative_prompt) > 0:
|
||||
values += [video_negative_prompt]
|
||||
@ -3973,6 +4019,8 @@ def generate_video(
|
||||
force_fps,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
guidance2_scale,
|
||||
switch_threshold,
|
||||
audio_guidance_scale,
|
||||
flow_shift,
|
||||
sample_solver,
|
||||
@ -4094,7 +4142,7 @@ def generate_video(
|
||||
offloadobj = None
|
||||
gc.collect()
|
||||
send_cmd("status", f"Loading model {get_model_name(model_type)}...")
|
||||
wan_model, offloadobj, trans = load_models(model_type)
|
||||
wan_model, offloadobj = load_models(model_type)
|
||||
send_cmd("status", "Model loaded")
|
||||
reload_needed= False
|
||||
|
||||
@ -4120,6 +4168,7 @@ def generate_video(
|
||||
VAE_tile_size = wan_model.vae.get_VAE_tile_size(vae_config, device_mem_capacity, server_config.get("vae_precision", "16") == "32")
|
||||
|
||||
trans = get_transformer_model(wan_model)
|
||||
trans2 = get_transformer_model(wan_model, 2)
|
||||
audio_sampling_rate = 16000
|
||||
base_model_type = get_base_model_type(model_type)
|
||||
|
||||
@ -4179,11 +4228,15 @@ def generate_video(
|
||||
loras_selected = transformer_loras_filenames + loras_selected
|
||||
loras_list_mult_choices_nums = transformer_loras_multipliers + loras_list_mult_choices_nums
|
||||
loras_slists = transformer_loras_multipliers + loras_slists
|
||||
offload.load_loras_into_model(trans, loras_selected, loras_list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, base_model_type), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map)
|
||||
errors = trans._loras_errors
|
||||
if len(errors) > 0:
|
||||
error_files = [msg for _ , msg in errors]
|
||||
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
|
||||
trans_list = [trans]
|
||||
if trans2 is not None: trans_list += [trans2]
|
||||
for trans_item in trans_list:
|
||||
offload.load_loras_into_model(trans_item, loras_selected, loras_list_mult_choices_nums, activate_all_loras=True, preprocess_sd=get_loras_preprocessor(trans, base_model_type), pinnedLora=pinnedLora, split_linear_modules_map = split_linear_modules_map)
|
||||
errors = trans._loras_errors
|
||||
if len(errors) > 0:
|
||||
error_files = [msg for _ , msg in errors]
|
||||
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
|
||||
trans_item = trans_list = None
|
||||
seed = None if seed == -1 else seed
|
||||
# negative_prompt = "" # not applicable in the inference
|
||||
original_filename = model_filename
|
||||
@ -4243,7 +4296,7 @@ def generate_video(
|
||||
any_background_ref = False
|
||||
outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
|
||||
|
||||
if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace):
|
||||
if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace or flux):
|
||||
frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions !=None and len(frames_positions)> 0 else []
|
||||
frames_positions_list = frames_positions_list[:len(image_refs)]
|
||||
nb_frames_positions = len(frames_positions_list)
|
||||
@ -4265,11 +4318,13 @@ def generate_video(
|
||||
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
|
||||
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
|
||||
from wan.utils.utils import resize_and_remove_background
|
||||
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (vace or hunyuan_avatar) ) # no fit for vace ref images as it is done later
|
||||
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (vace or hunyuan_avatar or flux) ) # no fit for vace ref images as it is done later
|
||||
update_task_thumbnails(task, locals())
|
||||
send_cmd("output")
|
||||
joint_pass = boost ==1 #and profile != 1 and profile != 3
|
||||
trans.enable_cache = None if len(skip_steps_cache_type) == 0 else skip_steps_cache_type
|
||||
if trans2 is not None:
|
||||
trans2.enable_cache = None
|
||||
|
||||
if trans.enable_cache != None:
|
||||
trans.cache_multiplier = skip_steps_multiplier
|
||||
@ -4660,6 +4715,8 @@ def generate_video(
|
||||
sample_solver=sample_solver,
|
||||
sampling_steps=num_inference_steps,
|
||||
guide_scale=guidance_scale,
|
||||
guide2_scale = guidance2_scale,
|
||||
switch_threshold = switch_threshold,
|
||||
embedded_guidance_scale=embedded_guidance_scale,
|
||||
n_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
@ -4680,7 +4737,7 @@ def generate_video(
|
||||
audio_scale= audio_scale,
|
||||
audio_context_lens= audio_context_lens,
|
||||
context_scale = context_scale,
|
||||
ar_step = model_mode, #5
|
||||
model_mode = model_mode,
|
||||
causal_block_size = 5,
|
||||
causal_attention = True,
|
||||
fps = fps,
|
||||
@ -4698,6 +4755,7 @@ def generate_video(
|
||||
NAG_alpha = NAG_alpha,
|
||||
speakers_bboxes =speakers_bboxes,
|
||||
image_mode = image_mode,
|
||||
video_prompt_type= video_prompt_type,
|
||||
offloadobj = offloadobj,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -4915,6 +4973,9 @@ def generate_video(
|
||||
seed = set_seed(-1)
|
||||
clear_status(state)
|
||||
offload.unload_loras_from_model(trans)
|
||||
if not trans2 is None:
|
||||
offload.unload_loras_from_model(trans2)
|
||||
|
||||
if len(control_audio_tracks) > 0:
|
||||
cleanup_temp_audio_files(control_audio_tracks)
|
||||
|
||||
@ -5793,6 +5854,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
|
||||
ltxv = base_model_type in ["ltxv_13B"]
|
||||
recammaster = base_model_type in ["recam_1.3B"]
|
||||
phantom = base_model_type in ["phantom_1.3B", "phantom_14B"]
|
||||
flux = base_model_type in ["flux"]
|
||||
hunyuan_video_custom = base_model_type in ["hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit"]
|
||||
model_family = get_model_family(base_model_type)
|
||||
if target == "settings":
|
||||
@ -5818,7 +5880,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
|
||||
if not server_config.get("enhancer_enabled", 0) == 1:
|
||||
pop += ["prompt_enhancer"]
|
||||
|
||||
if not recammaster and not diffusion_forcing:
|
||||
if not recammaster and not diffusion_forcing and not flux:
|
||||
pop += ["model_mode"]
|
||||
|
||||
if not vace and not phantom and not hunyuan_video_custom:
|
||||
@ -5837,14 +5899,14 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
|
||||
if not base_model_type in ["fantasy", "multitalk", "vace_multitalk_14B"]:
|
||||
pop += ["audio_guidance_scale", "speakers_locations"]
|
||||
|
||||
if not model_family in ["hunyuan", "flux"]:
|
||||
if not model_family in ["hunyuan", "flux"] or model_def.get("no_guidance", False):
|
||||
pop += ["embedded_guidance_scale"]
|
||||
|
||||
if not model_family in ["hunyuan", "wan"]:
|
||||
pop += ["skip_steps_cache_type", "skip_steps_multiplier", "skip_steps_start_step_perc"]
|
||||
|
||||
if model_def.get("no_guidance", False) or ltxv:
|
||||
pop += ["guidance_scale", "audio_guidance_scale", "embedded_guidance_scale"]
|
||||
if model_def.get("no_guidance", False) or ltxv or model_family in ["hunyuan", "flux"] :
|
||||
pop += ["guidance_scale", "guidance2_scale", "switch_threshold", "audio_guidance_scale"]
|
||||
|
||||
if model_def.get("image_outputs", False) or ltxv:
|
||||
pop += ["flow_shift"]
|
||||
@ -6171,6 +6233,8 @@ def save_inputs(
|
||||
force_fps,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
guidance2_scale,
|
||||
switch_threshold,
|
||||
audio_guidance_scale,
|
||||
flow_shift,
|
||||
sample_solver,
|
||||
@ -6334,7 +6398,7 @@ def change_model(state, model_choice):
|
||||
model_filename = get_model_filename(model_choice, transformer_quantization, transformer_dtype_policy)
|
||||
state["model_filename"] = model_filename
|
||||
last_model_per_family = state["last_model_per_family"]
|
||||
last_model_per_family[get_model_family(model_choice)] = model_choice
|
||||
last_model_per_family[get_model_family(model_choice, for_ui= True)] = model_choice
|
||||
server_config["last_model_per_family"] = last_model_per_family
|
||||
server_config["last_model_type"] = model_choice
|
||||
|
||||
@ -6366,7 +6430,7 @@ def preload_model_when_switching(state):
|
||||
gc.collect()
|
||||
model_filename = get_model_name(model_type)
|
||||
yield f"Loading model {model_filename}..."
|
||||
wan_model, offloadobj, _ = load_models(model_type)
|
||||
wan_model, offloadobj = load_models(model_type)
|
||||
yield f"Model loaded"
|
||||
reload_needed= False
|
||||
return
|
||||
@ -6760,6 +6824,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
diffusion_forcing = "diffusion_forcing" in model_filename
|
||||
ltxv = "ltxv" in model_filename
|
||||
lock_inference_steps = model_def.get("lock_inference_steps", False)
|
||||
model_reference_image = model_def.get("reference_image", False)
|
||||
no_steps_skipping = model_def.get("no_steps_skipping", False)
|
||||
recammaster = base_model_type in ["recam_1.3B"]
|
||||
vace = test_vace_module(base_model_type)
|
||||
phantom = base_model_type in ["phantom_1.3B", "phantom_14B"]
|
||||
@ -6889,7 +6955,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
keep_frames_video_source = gr.Text(visible=False)
|
||||
any_video_source = False
|
||||
|
||||
with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or flux and model_def.get("reference_image", False) ) as video_prompt_column:
|
||||
with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or flux and model_reference_image) as video_prompt_column:
|
||||
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
|
||||
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
|
||||
any_control_video = True
|
||||
@ -6947,7 +7013,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
],
|
||||
value=filter_letters(video_prompt_type_value, "PDSLCMUV"),
|
||||
label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
any_control_video = False
|
||||
any_control_image = False
|
||||
@ -7012,6 +7078,20 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
visible = True,
|
||||
label="Reference Images", scale = 2
|
||||
)
|
||||
|
||||
|
||||
elif flux and model_reference_image:
|
||||
video_prompt_type_image_refs = gr.Dropdown(
|
||||
choices=[
|
||||
("None", ""),
|
||||
("Inject only People / Objects", "I"),
|
||||
("Inject Main Subject / Landscape and then People / Objects", "KI"),
|
||||
],
|
||||
value=filter_letters(video_prompt_type_value, "KFI"),
|
||||
visible = True,
|
||||
show_label=False,
|
||||
label="Reference Images Combination Method", scale = 2
|
||||
)
|
||||
else:
|
||||
video_prompt_type_image_refs = gr.Dropdown(
|
||||
choices=[ ("Start / Ref Image", "I")],
|
||||
@ -7052,11 +7132,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" )
|
||||
remove_background_images_ref = gr.Dropdown(
|
||||
choices=[
|
||||
("Keep Backgrounds behind People / Objects", 0),
|
||||
("Remove Backgrounds behind People / Objects", 1),
|
||||
("Keep Backgrounds behind all Reference Images", 0),
|
||||
("Remove Backgrounds only behind People / Objects", 1),
|
||||
],
|
||||
value=ui_defaults.get("remove_background_images_ref",1),
|
||||
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar and not flux
|
||||
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar
|
||||
)
|
||||
|
||||
any_audio_voices_support = any_audio_track(base_model_type)
|
||||
@ -7174,6 +7254,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 5 if fantasy else 4), step=0.5, label="Audio Guidance", visible=(fantasy or multitalk) and not no_guidance)
|
||||
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 2.5 if flux else 6.0), step=0.5, label="Embedded Guidance Scale", visible=(hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance)
|
||||
flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs)
|
||||
with gr.Row(visible = not ltxv and not (no_guidance and image_outputs)) as guidance_row2:
|
||||
guidance2_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance2_scale",5), step=0.5, label="Guidance2 (CFG)", visible=not (hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance)
|
||||
switch_threshold = gr.Slider(0, 1000, value=ui_defaults.get("switch_threshold",0), step=1, label="Guidance / Model Switch Threshold", visible=not (hunyuan_t2v or hunyuan_i2v or flux) and not no_guidance)
|
||||
|
||||
with gr.Row(visible = get_model_family(model_type) == "wan" and not diffusion_forcing ) as sample_solver_row:
|
||||
sample_solver = gr.Dropdown( value=ui_defaults.get("sample_solver",""),
|
||||
choices=[
|
||||
@ -7214,7 +7298,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
label="Activated Loras"
|
||||
)
|
||||
loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by Space chars or CR, lines that start with # are ignored", value=launch_multis_str)
|
||||
with gr.Tab("Steps Skipping", visible = not (ltxv or image_outputs)) as speed_tab:
|
||||
with gr.Tab("Steps Skipping", visible = not (ltxv or image_outputs) and not no_steps_skipping) as speed_tab:
|
||||
with gr.Column():
|
||||
gr.Markdown("<B>Tea Cache and Mag Cache accelerate the Video Generation by skipping intelligently some steps, the more steps are skipped the lower the quality of the video.</B>")
|
||||
gr.Markdown("<B>Steps Skipping consumes also VRAM. It is recommended not to skip at least the first 10% steps.</B>")
|
||||
@ -7577,7 +7661,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
|
||||
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
|
||||
video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn,
|
||||
NAG_col, speakers_locations_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, min_frames_if_references_col] # presets_column,
|
||||
NAG_col, speakers_locations_row, guidance_row, guidance_row2, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, min_frames_if_references_col] # presets_column,
|
||||
if update_form:
|
||||
locals_dict = locals()
|
||||
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
|
||||
@ -8265,7 +8349,7 @@ def compact_name(family_name, model_name):
|
||||
return model_name
|
||||
|
||||
def get_sorted_dropdown(dropdown_types, current_model_family):
|
||||
models_families = [get_model_family(type) for type in dropdown_types]
|
||||
models_families = [get_model_family(type, for_ui= True) for type in dropdown_types]
|
||||
families = {}
|
||||
for family in models_families:
|
||||
if family not in families: families[family] = 1
|
||||
@ -8285,7 +8369,7 @@ def generate_dropdown_model_list(current_model_type):
|
||||
dropdown_types= transformer_types if len(transformer_types) > 0 else displayed_model_types
|
||||
if current_model_type not in dropdown_types:
|
||||
dropdown_types.append(current_model_type)
|
||||
current_model_family = get_model_family(current_model_type)
|
||||
current_model_family = get_model_family(current_model_type, for_ui= True)
|
||||
sorted_familes, dropdown_choices = get_sorted_dropdown(dropdown_types, current_model_family)
|
||||
|
||||
dropdown_families = gr.Dropdown(
|
||||
@ -8308,12 +8392,12 @@ def generate_dropdown_model_list(current_model_type):
|
||||
def change_model_family(state, current_model_family):
|
||||
dropdown_types= transformer_types if len(transformer_types) > 0 else displayed_model_types
|
||||
current_family_name = families_infos[current_model_family][1]
|
||||
models_families = [get_model_family(type) for type in dropdown_types]
|
||||
models_families = [get_model_family(type, for_ui= True) for type in dropdown_types]
|
||||
dropdown_choices = [ (compact_name(current_family_name, get_model_name(model_type)), model_type) for model_type, family in zip(dropdown_types, models_families) if family == current_model_family ]
|
||||
dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0])
|
||||
last_model_per_family = state.get("last_model_per_family", {})
|
||||
model_type = last_model_per_family.get(current_model_family, "")
|
||||
if len(model_type) == "" or model_type not in displayed_model_types: model_type = dropdown_choices[0][1]
|
||||
if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices] : model_type = dropdown_choices[0][1]
|
||||
return gr.Dropdown(choices= dropdown_choices, value = model_type )
|
||||
|
||||
def set_new_tab(tab_state, new_tab_no):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user