weekend fun

This commit is contained in:
deepbeepmeep 2025-08-30 01:09:03 +02:00
parent ecd219322f
commit 58b99f85ba
25 changed files with 1192 additions and 344 deletions

View File

@ -20,20 +20,28 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
### August 29 2025: WanGP v8.2 - Here Goes Your Weekend
- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)
- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn
- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis
### August 24 2025: WanGP v8.1 - the RAM Liberator
- Reserved RAM entirely freed when switching models, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP
- InfiniteTalk support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\
- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP
- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\
If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators.
- Flux Chroma 1 HD support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps
- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps
### August 21 2025: WanGP v8.01 - the killer of seven
- Qwen Image Edit : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good.
- On demand Prompt Enhancer (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt.
- Choice of a non censored Prompt Enhancer. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work
- Memory Profile customizable per model : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading
- Expert Guidance Mode: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1.
- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good.
- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt.
- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work
- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading
- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1.
*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation*
### August 12 2025: WanGP v7.7777 - Lucky Day(s)

15
configs/standin.json Normal file
View File

@ -0,0 +1,15 @@
{
"_class_name": "WanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 16,
"model_type": "t2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512,
"standin": true
}

View File

@ -0,0 +1,17 @@
{
"_class_name": "VaceWanModel",
"_diffusers_version": "0.30.0",
"dim": 5120,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_dim": 16,
"model_type": "t2v",
"num_heads": 40,
"num_layers": 40,
"out_dim": 16,
"text_len": 512,
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
"vace_in_dim": 96,
"standin": true
}

View File

@ -16,6 +16,7 @@
],
"group": "wan2_2"
},
"guidance_phases": 2,
"switch_threshold" : 900,
"guidance_scale" : 3.5,
"guidance2_scale" : 3.5,

View File

@ -9,7 +9,7 @@
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_infinitetalk_single_14B_quanto_mfp16_int8.safetensors"
]
],
"description": "The Infinitetalk model is an improved version of Multitalk that supports very long videos. This is the single speaker version.",
"description": "The Infinitetalk model is an improved version of Multitalk that supports very long videos. This is the single speaker version. Sliding Window size must be 81 frames to get smooth transitions between shots.",
"one_speaker_only": true,
"URLs": "i2v"
}

View File

@ -9,7 +9,7 @@
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_infinitetalk_multi_14B_quanto_mbf16_int8.safetensors"
]
],
"description": "The Infinitetalk model is an improved version of Multitalk that supports very long videos. This is the multi speakers version.",
"description": "The Infinitetalk model is an improved version of Multitalk that supports very long videos. This is the multi speakers version.Sliding Window size must be 81 frames to get smooth transitions between shots",
"multi_speakers_only": true,
"URLs": "i2v"
}

10
defaults/standin.json Normal file
View File

@ -0,0 +1,10 @@
{
"model":
{
"name": "Wan2.1 Standin 14B",
"modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main//Stand-In_wan2.1_T2V_14B_ver1.0_bf16.safetensors"]],
"architecture" : "standin",
"description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image which is a close up of person face to transfer this person in the Video.",
"URLs": "t2v"
}
}

View File

@ -16,6 +16,7 @@
],
"group": "wan2_2"
},
"guidance_phases": 2,
"switch_threshold" : 875,
"guidance_scale" : 4,
"guidance2_scale" : 3,

View File

@ -17,6 +17,7 @@
"loras_multipliers": [1, 0.2, 0.5, 0.5],
"group": "wan2_2"
},
"guidance_phases": 2,
"num_inference_steps": 10,
"guidance_scale": 1,
"guidance2_scale": 1,

View File

@ -0,0 +1,9 @@
{
"model": {
"name": "Vace Standin 14B",
"architecture": "vace_standin_14B",
"modules": [ "vace_14B", "standin"],
"description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.",
"URLs": "t2v"
}
}

View File

@ -12,6 +12,7 @@ class family_handler():
flux_model = model_def.get("flux-model", "flux-dev")
flux_schnell = flux_model == "flux-schnell"
flux_chroma = flux_model == "flux-chroma"
flux_uso = flux_model == "flux-dev-uso"
model_def_output = {
"image_outputs" : True,
"no_negative_prompt" : not flux_chroma,
@ -20,7 +21,17 @@ class family_handler():
model_def_output["guidance_max_phases"] = 1
elif not flux_schnell:
model_def_output["embedded_guidance"] = True
if flux_uso :
model_def_output["any_image_refs_relative_size"] = True
model_def_output["no_background_removal"] = True
model_def_output["image_ref_choices"] = {
"choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "I"),
("Up to two Images are Style Images", "IJ")],
"default": "I",
"letters_filter": "IJ",
"label": "Reference Images / Style Images"
}
return model_def_output
@ -53,8 +64,8 @@ class family_handler():
return [
{
"repoId" : "DeepBeepMeep/Flux",
"sourceFolderList" : [""],
"fileList" : [ ["flux_vae.safetensors"] ]
"sourceFolderList" : ["siglip-so400m-patch14-384", "",],
"fileList" : [ ["config.json", "preprocessor_config.json", "model.safetensors"], ["flux_vae.safetensors"] ]
},
{
"repoId" : "DeepBeepMeep/LTX_Video",
@ -90,15 +101,21 @@ class family_handler():
pipe = { "transformer": flux_model.model, "vae" : flux_model.vae, "text_encoder" : flux_model.clip, "text_encoder_2" : flux_model.t5}
if flux_model.vision_encoder is not None:
pipe["siglip_model"] = flux_model.vision_encoder
if flux_model.feature_embedder is not None:
pipe["feature_embedder"] = flux_model.feature_embedder
return flux_model, pipe
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
flux_model = model_def.get("flux-model", "flux-dev")
flux_uso = flux_model == "flux-dev-uso"
ui_defaults.update({
"embedded_guidance": 2.5,
})
if model_def.get("reference_image", False):
ui_defaults.update({
"video_prompt_type": "KI",
"video_prompt_type": "I" if flux_uso else "KI",
})

View File

@ -6,8 +6,10 @@ from glob import iglob
from mmgp import offload as offload
import torch
from shared.utils.utils import calculate_new_dimensions
from .sampling import denoise, get_schedule, prepare_kontext, unpack
from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack
from .modules.layers import get_linear_split_map
from transformers import SiglipVisionModel, SiglipImageProcessor
from .util import (
aspect_ratio_to_height_width,
load_ae,
@ -65,6 +67,26 @@ class model_factory:
self.vae = load_ae(self.name, device=torch_device)
siglip_processor = siglip_model = feature_embedder = None
if self.name == 'flux-dev-uso':
siglip_path = "ckpts/siglip-so400m-patch14-384"
siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path)
siglip_model = SiglipVisionModel.from_pretrained(siglip_path)
siglip_model.eval().to("cpu")
if len(model_filename) > 1:
from .modules.layers import SigLIPMultiFeatProjModel
feature_embedder = SigLIPMultiFeatProjModel(
siglip_token_nums=729,
style_token_nums=64,
siglip_token_dims=1152,
hidden_size=3072, #self.hidden_size,
context_layer_norm=True,
)
offload.load_model_data(feature_embedder, model_filename[1])
self.vision_encoder = siglip_model
self.vision_encoder_processor = siglip_processor
self.feature_embedder = feature_embedder
# offload.change_dtype(self.model, dtype, True)
# offload.save_model(self.model, "flux-dev.safetensors")
@ -97,22 +119,30 @@ class model_factory:
loras_slists = None,
batch_size = 1,
video_prompt_type = "",
joint_pass = False,
joint_pass = False,
image_refs_relative_size = 100,
**bbargs
):
if self._interrupt:
return None
if self.guidance_max_phases < 1:
guide_scale = 1
if n_prompt is None or len(n_prompt) == 0:
n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
if self.guidance_max_phases < 1: guide_scale = 1
if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
device="cuda"
if "I" in video_prompt_type and input_ref_images != None and len(input_ref_images) > 0:
if "K" in video_prompt_type and False :
# image latents tiling method
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
else:
flux_dev_uso = self.name in ['flux-dev-uso']
image_stiching = not self.name in ['flux-dev-uso']
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
ref_style_imgs = []
if "I" in video_prompt_type and len(input_ref_images) > 0:
if flux_dev_uso :
if "J" in video_prompt_type:
ref_style_imgs = input_ref_images
input_ref_images = []
elif len(input_ref_images) > 1 :
ref_style_imgs = input_ref_images[-1:]
input_ref_images = input_ref_images[:-1]
if image_stiching:
# image stiching method
stiched = input_ref_images[0]
if "K" in video_prompt_type :
@ -122,27 +152,61 @@ class model_factory:
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
else:
first_ref = 0
if "K" in video_prompt_type:
# image latents tiling method
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS)
first_ref = 1
for i in range(first_ref,len(input_ref_images)):
w, h = input_ref_images[i].size
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
input_ref_images[0] = input_ref_images[0].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
else:
input_ref_images = None
inp, height, width = prepare_kontext(
t5=self.t5,
clip=self.clip,
prompt=input_prompt,
neg_prompt= n_prompt,
ae=self.vae,
img_cond_list=input_ref_images,
target_width=width,
target_height=height,
bs=batch_size,
seed=seed,
device=device,
real_guidance_scale=guide_scale,
)
if flux_dev_uso :
inp, height, width = prepare_multi_ip(
ae=self.vae,
img_cond_list=input_ref_images,
target_width=width,
target_height=height,
bs=batch_size,
seed=seed,
device=device,
)
else:
inp, height, width = prepare_kontext(
ae=self.vae,
img_cond_list=input_ref_images,
target_width=width,
target_height=height,
bs=batch_size,
seed=seed,
device=device,
)
inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt))
if guide_scale != 1:
inp.update(prepare_prompt(self.t5, self.clip, batch_size, n_prompt, neg = True, device=device))
timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
ref_style_imgs = [self.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs]
if self.feature_embedder is not None and ref_style_imgs is not None and len(ref_style_imgs) > 0 and self.vision_encoder is not None:
# processing style feat into textural hidden space
siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in ref_style_imgs]
siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1)
siglip_embedding_ids = torch.zeros( siglip_embedding.shape[0], siglip_embedding.shape[1], 3 ).to(device)
inp["siglip_embedding"] = siglip_embedding
inp["siglip_embedding_ids"] = siglip_embedding_ids
def unpack_latent(x):
return unpack(x.float(), height, width)
# denoise initial noise
x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass)
if x==None: return None

View File

@ -12,6 +12,7 @@ from .modules.layers import (
timestep_embedding,
DistilledGuidance,
ChromaModulationOut,
SigLIPMultiFeatProjModel,
)
from .modules.lora import LinearLora, replace_linear_with_lora
@ -32,7 +33,7 @@ class FluxParams:
qkv_bias: bool
guidance_embed: bool
chroma: bool = False
eso: bool = False
class Flux(nn.Module):
"""
@ -189,6 +190,8 @@ class Flux(nn.Module):
v = swap_scale_shift(v)
k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1")
new_sd[k] = v
else:
new_sd = sd
return new_sd
def forward(
@ -198,11 +201,13 @@ class Flux(nn.Module):
txt_list,
txt_ids_list,
timesteps: Tensor,
y: Tensor,
y_list,
img_len = 0,
guidance: Tensor | None = None,
callback= None,
pipeline =None,
siglip_embedding = None,
siglip_embedding_ids = None,
) -> Tensor:
sz = len(txt_list)
@ -226,40 +231,44 @@ class Flux(nn.Module):
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec += self.guidance_in(timestep_embedding(guidance, 256))
vec += self.vector_in(y)
vec_list = [ vec + self.vector_in(y) for y in y_list]
img = None
txt_list = [self.txt_in(txt) for txt in txt_list ]
if siglip_embedding is not None:
txt_list = [torch.cat((siglip_embedding, txt) , dim=1) for txt in txt_list]
txt_ids_list = [torch.cat((siglip_embedding_ids, txt_id) , dim=1) for txt_id in txt_ids_list]
pe_list = [self.pe_embedder(torch.cat((txt_ids, img_ids), dim=1)) for txt_ids in txt_ids_list]
for i, block in enumerate(self.double_blocks):
if self.chroma: vec = ( self.get_modulations(mod_vectors, "double_img", idx=i), self.get_modulations(mod_vectors, "double_txt", idx=i))
if self.chroma: vec_list = [( self.get_modulations(mod_vectors, "double_img", idx=i), self.get_modulations(mod_vectors, "double_txt", idx=i))] * sz
if callback != None:
callback(-1, None, False, True)
if pipeline._interrupt:
return [None] * sz
for img, txt, pe in zip(img_list, txt_list, pe_list):
for img, txt, pe, vec in zip(img_list, txt_list, pe_list, vec_list):
img[...], txt[...] = block(img=img, txt=txt, vec=vec, pe=pe)
img = txt = pe = None
img = txt = pe = vec= None
img_list = [torch.cat((txt, img), 1) for txt, img in zip(txt_list, img_list)]
for i, block in enumerate(self.single_blocks):
if self.chroma: vec = self.get_modulations(mod_vectors, "single", idx=i)
if self.chroma: vec_list= [self.get_modulations(mod_vectors, "single", idx=i)] * sz
if callback != None:
callback(-1, None, False, True)
if pipeline._interrupt:
return [None] * sz
for img, pe in zip(img_list, pe_list):
for img, pe, vec in zip(img_list, pe_list, vec_list):
img[...]= block(x=img, vec=vec, pe=pe)
img = pe = None
img_list = [ img[:, txt.shape[1] :, ...] for img, txt in zip(img_list, txt_list)]
img = pe = vec = None
img_list = [ img[:, txt.shape[1] : txt.shape[1] + img_len, ...] for img, txt in zip(img_list, txt_list)]
if self.chroma: vec = self.get_modulations(mod_vectors, "final")
if self.chroma: vec_list = [self.get_modulations(mod_vectors, "final")] * sz
out_list = []
for i, img in enumerate(img_list):
for i, (img, vec) in enumerate(zip(img_list, vec_list)):
out_list.append( self.final_layer(img, vec)) # (N, T, patch_size ** 2 * out_channels)
img_list[i] = img = None
img_list[i] = img = vec = None
return out_list

View File

@ -11,7 +11,8 @@ def get_linear_split_map():
hidden_size = 3072
split_linear_modules_map = {
"qkv" : {"mapped_modules" : ["q", "k", "v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]},
"linear1_qkv" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
}
return split_linear_modules_map
@ -375,4 +376,134 @@ class DistilledGuidance(nn.Module):
x = self.out_proj(x)
return x
return x
class SigLIPMultiFeatProjModel(torch.nn.Module):
"""
SigLIP Multi-Feature Projection Model for processing style features from different layers
and projecting them into a unified hidden space.
Args:
siglip_token_nums (int): Number of SigLIP tokens, default 257
style_token_nums (int): Number of style tokens, default 256
siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
hidden_size (int): Hidden layer size, default 3072
context_layer_norm (bool): Whether to use context layer normalization, default False
"""
def __init__(
self,
siglip_token_nums: int = 257,
style_token_nums: int = 256,
siglip_token_dims: int = 1536,
hidden_size: int = 3072,
context_layer_norm: bool = False,
):
super().__init__()
# High-level feature processing (layer -2)
self.high_embedding_linear = nn.Sequential(
nn.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.high_layer_norm = (
nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.high_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
# Mid-level feature processing (layer -11)
self.mid_embedding_linear = nn.Sequential(
nn.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.mid_layer_norm = (
nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.mid_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
# Low-level feature processing (layer -20)
self.low_embedding_linear = nn.Sequential(
nn.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.low_layer_norm = (
nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.low_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
def forward(self, siglip_outputs):
"""
Forward pass function
Args:
siglip_outputs: Output from SigLIP model, containing hidden_states
Returns:
torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
"""
dtype = next(self.high_embedding_linear.parameters()).dtype
# Process high-level features (layer -2)
high_embedding = self._process_layer_features(
siglip_outputs.hidden_states[-2],
self.high_embedding_linear,
self.high_layer_norm,
self.high_projection,
dtype
)
# Process mid-level features (layer -11)
mid_embedding = self._process_layer_features(
siglip_outputs.hidden_states[-11],
self.mid_embedding_linear,
self.mid_layer_norm,
self.mid_projection,
dtype
)
# Process low-level features (layer -20)
low_embedding = self._process_layer_features(
siglip_outputs.hidden_states[-20],
self.low_embedding_linear,
self.low_layer_norm,
self.low_projection,
dtype
)
# Concatenate features from all layers
return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)
def _process_layer_features(
self,
hidden_states: torch.Tensor,
embedding_linear: nn.Module,
layer_norm: nn.Module,
projection: nn.Module,
dtype: torch.dtype
) -> torch.Tensor:
"""
Helper function to process features from a single layer
Args:
hidden_states: Input hidden states [bs, seq_len, dim]
embedding_linear: Embedding linear layer
layer_norm: Layer normalization
projection: Projection layer
dtype: Target data type
Returns:
torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
"""
# Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
embedding = embedding_linear(
hidden_states.to(dtype).transpose(1, 2)
).transpose(1, 2)
# Apply layer normalization
embedding = layer_norm(embedding)
# Project to target hidden space
embedding = projection(embedding)
return embedding

View File

@ -13,6 +13,8 @@ from .modules.conditioner import HFEmbedder
from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder
from .util import PREFERED_KONTEXT_RESOLUTIONS
from einops import rearrange, repeat
from typing import Literal
import torchvision.transforms.functional as TVF
def get_noise(
@ -35,20 +37,10 @@ def get_noise(
)
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
bs, c, h, w = img.shape
def prepare_prompt(t5: HFEmbedder, clip: HFEmbedder, bs: int, prompt: str | list[str], neg: bool = False, device: str = "cuda") -> dict[str, Tensor]:
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
@ -61,102 +53,31 @@ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[st
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
return {
"img": img,
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"vec": vec.to(img.device),
"neg_txt" if neg else "txt": txt.to(device),
"neg_txt_ids" if neg else "txt_ids": txt_ids.to(device),
"neg_vec" if neg else "vec": vec.to(device),
}
def prepare_control(
t5: HFEmbedder,
clip: HFEmbedder,
img: Tensor,
prompt: str | list[str],
ae: AutoEncoder,
encoder: DepthImageEncoder | CannyImageEncoder,
img_cond_path: str,
) -> dict[str, Tensor]:
# load and encode the conditioning image
bs, _, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
def prepare_img( img: Tensor) -> dict[str, Tensor]:
bs, c, h, w = img.shape
img_cond = Image.open(img_cond_path).convert("RGB")
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
width = w * 8
height = h * 8
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
with torch.no_grad():
img_cond = encoder(img_cond)
img_cond = ae.encode(img_cond)
img_cond = img_cond.to(torch.bfloat16)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img_cond.shape[0] == 1 and bs > 1:
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
return_dict = prepare(t5, clip, img, prompt)
return_dict["img_cond"] = img_cond
return return_dict
return {
"img": img,
"img_ids": img_ids.to(img.device),
}
def prepare_fill(
t5: HFEmbedder,
clip: HFEmbedder,
img: Tensor,
prompt: str | list[str],
ae: AutoEncoder,
img_cond_path: str,
mask_path: str,
) -> dict[str, Tensor]:
# load and encode the conditioning image and the mask
bs, _, _, _ = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img_cond = Image.open(img_cond_path).convert("RGB")
img_cond = np.array(img_cond)
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
mask = Image.open(mask_path).convert("L")
mask = np.array(mask)
mask = torch.from_numpy(mask).float() / 255.0
mask = rearrange(mask, "h w -> 1 1 h w")
with torch.no_grad():
img_cond = img_cond.to(img.device)
mask = mask.to(img.device)
img_cond = img_cond * (1 - mask)
img_cond = ae.encode(img_cond)
mask = mask[:, 0, :, :]
mask = mask.to(torch.bfloat16)
mask = rearrange(
mask,
"b (h ph) (w pw) -> b (ph pw) h w",
ph=8,
pw=8,
)
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if mask.shape[0] == 1 and bs > 1:
mask = repeat(mask, "1 ... -> bs ...", bs=bs)
img_cond = img_cond.to(torch.bfloat16)
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img_cond.shape[0] == 1 and bs > 1:
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
img_cond = torch.cat((img_cond, mask), dim=-1)
return_dict = prepare(t5, clip, img, prompt)
return_dict["img_cond"] = img_cond.to(img.device)
return return_dict
def prepare_redux(
@ -210,9 +131,6 @@ def prepare_redux(
def prepare_kontext(
t5: HFEmbedder,
clip: HFEmbedder,
prompt: str | list[str],
ae: AutoEncoder,
img_cond_list: list,
seed: int,
@ -220,17 +138,15 @@ def prepare_kontext(
target_width: int | None = None,
target_height: int | None = None,
bs: int = 1,
neg_prompt: str | list[str] = None,
real_guidance_scale = False,
) -> tuple[dict[str, Tensor], int, int]:
# load and encode the conditioning image
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img_cond_seq = None
img_cond_seq_ids = None
if img_cond_list == None: img_cond_list = []
height_offset = 0
width_offset = 0
for cond_no, img_cond in enumerate(img_cond_list):
width, height = img_cond.size
aspect_ratio = width / height
@ -257,10 +173,12 @@ def prepare_kontext(
# image ids are the same as base image with the first dimension set to 1
# instead of 0
img_cond_ids = torch.zeros(height // 2, width // 2, 3)
img_cond_ids[..., 0] = cond_no + 1
img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None]
img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :]
img_cond_ids[..., 0] = 1
img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None] + height_offset
img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :] + width_offset
img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs)
height_offset += height // 2
width_offset += width // 2
if target_width is None:
target_width = 8 * width
@ -272,6 +190,10 @@ def prepare_kontext(
else:
img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, img_cond_latents], dim=1), torch.cat([img_cond_seq_ids, img_cond_ids], dim=1)
return_dict = {
"img_cond_seq": img_cond_seq,
"img_cond_seq_ids": img_cond_seq_ids,
}
img = get_noise(
bs,
target_height,
@ -280,17 +202,7 @@ def prepare_kontext(
dtype=torch.bfloat16,
seed=seed,
)
return_dict = prepare(t5, clip, img, prompt)
if real_guidance_scale != 1:
temp_dict = prepare(t5, clip, img, neg_prompt)
return_dict["neg_txt"]= temp_dict["txt"]
return_dict["neg_txt_ids"]= temp_dict["txt_ids"]
else:
return_dict["neg_txt"]= None
return_dict["neg_txt_ids"]= None
return_dict["img_cond_seq"] = img_cond_seq
return_dict["img_cond_seq_ids"] = img_cond_seq_ids
return_dict.update(prepare_img(img))
return return_dict, target_height, target_width
@ -341,10 +253,13 @@ def denoise(
# extra img tokens (channel-wise)
neg_txt: Tensor = None,
neg_txt_ids: Tensor= None,
neg_vec: Tensor = None,
img_cond: Tensor | None = None,
# extra img tokens (sequence-wise)
img_cond_seq: Tensor | None = None,
img_cond_seq_ids: Tensor | None = None,
siglip_embedding = None,
siglip_embedding_ids = None,
callback=None,
pipeline=None,
loras_slists=None,
@ -352,7 +267,8 @@ def denoise(
joint_pass= False,
):
kwargs = {'pipeline': pipeline, 'callback': callback}
kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids}
if callback != None:
callback(-1, None, True)
@ -383,7 +299,7 @@ def denoise(
img_ids=img_input_ids,
txt_list=[txt],
txt_ids_list=[txt_ids],
y=vec,
y_list=[vec],
timesteps=t_vec,
guidance=guidance_vec,
**kwargs
@ -395,7 +311,7 @@ def denoise(
img_ids=img_input_ids,
txt_list=[neg_txt],
txt_ids_list=[neg_txt_ids],
y=vec,
y_list=[neg_vec],
timesteps=t_vec,
guidance=guidance_vec,
**kwargs
@ -407,19 +323,14 @@ def denoise(
img_ids=img_input_ids,
txt_list=[txt, neg_txt],
txt_ids_list=[txt_ids, neg_txt_ids],
y=vec,
y_list=[vec, neg_vec],
timesteps=t_vec,
guidance=guidance_vec,
**kwargs
)
if pred == None: return None
if img_input_ids is not None:
pred = pred[:, : img.shape[1]]
if real_guidance_scale > 1:
if img_input_ids is not None:
neg_pred = neg_pred[:, : img.shape[1]]
pred = neg_pred + real_guidance_scale * (pred - neg_pred)
img += (t_prev - t_curr) * pred
@ -430,6 +341,82 @@ def denoise(
return img
def prepare_multi_ip(
ae: AutoEncoder,
img_cond_list: list,
seed: int,
device: torch.device,
target_width: int | None = None,
target_height: int | None = None,
bs: int = 1,
pe: Literal["d", "h", "w", "o"] = "d",
) -> dict[str, Tensor]:
ref_imgs = img_cond_list
assert pe in ["d", "h", "w", "o"]
ref_imgs = [
ae.encode(
(TVF.to_tensor(ref_img) * 2.0 - 1.0)
.unsqueeze(0)
.to(device, torch.float32)
).to(torch.bfloat16)
for ref_img in img_cond_list
]
img = get_noise( bs, target_height, target_width, device=device, dtype=torch.bfloat16, seed=seed)
bs, c, h, w = img.shape
# tgt img
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
img_cond_seq = img_cond_seq_ids = None
pe_shift_w, pe_shift_h = w // 2, h // 2
for cond_no, ref_img in enumerate(ref_imgs):
_, _, ref_h1, ref_w1 = ref_img.shape
ref_img = rearrange(
ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2
)
if ref_img.shape[0] == 1 and bs > 1:
ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
# img id分别在宽高偏移各自最大值
h_offset = pe_shift_h if pe in {"d", "h"} else 0
w_offset = pe_shift_w if pe in {"d", "w"} else 0
ref_img_ids1[..., 1] = (
ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
)
ref_img_ids1[..., 2] = (
ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
)
ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
if target_width is None:
target_width = 8 * ref_w1
if target_height is None:
target_height = 8 * ref_h1
ref_img_ids1 = ref_img_ids1.to(device)
if cond_no == 0:
img_cond_seq, img_cond_seq_ids = ref_img, ref_img_ids1
else:
img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, ref_img], dim=1), torch.cat([img_cond_seq_ids, ref_img_ids1], dim=1)
# 更新pe shift
pe_shift_h += ref_h1 // 2
pe_shift_w += ref_w1 // 2
return {
"img": img,
"img_ids": img_ids.to(img.device),
"img_cond_seq": img_cond_seq,
"img_cond_seq_ids": img_cond_seq_ids,
}, target_height, target_width
def unpack(x: Tensor, height: int, width: int) -> Tensor:
return rearrange(

View File

@ -608,6 +608,38 @@ configs = {
shift_factor=0.1159,
),
),
"flux-dev-uso": ModelSpec(
repo_id="",
repo_flow="",
repo_ae="ckpts/flux_vae.safetensors",
params=FluxParams(
in_channels=64,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
eso= True,
),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}

View File

@ -249,7 +249,7 @@ class DetFace():
for scale in [8,16,32]:
ny = h1//scale
nx = w1//scale
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij")
grid = torch.stack((xv, yv), 2).view((1,1,ny, nx, 2)).float()
grids.append(grid.to(self.test_device))
self.grids = grids

View File

@ -19,7 +19,7 @@ from PIL import Image
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from .distributed.fsdp import shard_model
from .modules.model import WanModel
from .modules.model import WanModel, clear_caches
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .modules.vae2_2 import Wan2_2_VAE
@ -28,7 +28,7 @@ from .modules.clip import CLIPModel
from shared.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
get_sampling_sigmas, retrieve_timesteps)
from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from .modules.posemb_layers import get_rotary_pos_embed
from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed
from shared.utils.vace_preprocessor import VaceVideoProcessor
from shared.utils.basic_flowmatch import FlowMatchScheduler
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor
@ -442,6 +442,8 @@ class WanAny2V:
window_no = 0,
set_header_text = None,
pre_video_frame = None,
video_prompt_type= "",
original_input_ref_images = [],
**bbargs
):
@ -507,11 +509,12 @@ class WanAny2V:
# if NAG_scale > 1: context = torch.cat([context, context_NAG], dim=0)
if self._interrupt: return None
vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B"]
vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B", "vace_standin_14B"]
phantom = model_type in ["phantom_1.3B", "phantom_14B"]
fantasy = model_type in ["fantasy"]
multitalk = model_type in ["multitalk", "infinitetalk", "vace_multitalk_14B", "i2v_2_2_multitalk"]
infinitetalk = model_type in ["infinitetalk"]
standin = model_type in ["standin", "vace_standin_14B"]
recam = model_type in ["recam_1.3B"]
ti2v = model_type in ["ti2v_2_2"]
start_step_no = 0
@ -526,17 +529,25 @@ class WanAny2V:
any_end_frame = False
if image_start is None:
if infinitetalk:
if pre_video_frame is None:
new_shot = True
if input_frames is not None:
image_ref = input_frames[:, -1]
if input_video is None: input_video = input_frames[:, -1:]
new_shot = "Q" in video_prompt_type
else:
if input_ref_images is None:
input_ref_images, new_shot = [pre_video_frame], False
if pre_video_frame is None:
new_shot = True
else:
input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], True
if input_ref_images is None: raise Exception("Missing Reference Image")
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
if new_shot and window_no <= len(input_ref_images):
if input_ref_images is None:
input_ref_images, new_shot = [pre_video_frame], False
else:
input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type
if input_ref_images is None: raise Exception("Missing Reference Image")
new_shot = new_shot and window_no <= len(input_ref_images)
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
if new_shot:
input_video = image_ref.unsqueeze(1)
else:
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
_ , preframes_count, height, width = input_video.shape
input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
if infinitetalk:
@ -762,6 +773,25 @@ class WanAny2V:
kwargs["freqs"] = freqs
#Standin
if standin:
from preprocessing.face_preprocessor import FaceProcessor
standin_ref_pos = 1 if "K" in video_prompt_type else 0
if len(original_input_ref_images) < standin_ref_pos + 1: raise Exception("Missing Standin ref image")
standin_ref_pos = -1
image_ref = original_input_ref_images[standin_ref_pos]
image_ref.save("si.png")
# face_processor = FaceProcessor(antelopv2_path="ckpts/antelopev2")
face_processor = FaceProcessor()
standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"])
face_processor = None
gc.collect()
torch.cuda.empty_cache()
standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) ))
standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0)
kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, })
# Steps Skipping
skip_steps_cache = self.model.cache
if skip_steps_cache != None:
@ -804,6 +834,11 @@ class WanAny2V:
if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra)
def clear():
clear_caches()
gc.collect()
torch.cuda.empty_cache()
return None
if sample_scheduler != None:
scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g}
@ -862,6 +897,7 @@ class WanAny2V:
else:
latent_model_input = latents
any_guidance = guide_scale != 1
if phantom:
gen_args = {
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
@ -875,33 +911,42 @@ class WanAny2V:
"audio_scale": [audio_scale, None, None ]
}
elif multitalk and audio_proj != None:
gen_args = {
"x" : [latent_model_input, latent_model_input, latent_model_input],
"context" : [context, context_null, context_null],
"multitalk_audio": [audio_proj, audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]],
"multitalk_masks": [token_ref_target_masks, token_ref_target_masks, None]
}
if guide_scale == 1:
gen_args = {
"x" : [latent_model_input, latent_model_input],
"context" : [context, context],
"multitalk_audio": [audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]],
"multitalk_masks": [token_ref_target_masks, None]
}
any_guidance = audio_cfg_scale != 1
else:
gen_args = {
"x" : [latent_model_input, latent_model_input, latent_model_input],
"context" : [context, context_null, context_null],
"multitalk_audio": [audio_proj, audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]],
"multitalk_masks": [token_ref_target_masks, token_ref_target_masks, None]
}
else:
gen_args = {
"x" : [latent_model_input, latent_model_input],
"context": [context, context_null]
}
if joint_pass and guide_scale > 1:
if joint_pass and any_guidance:
ret_values = trans( **gen_args , **kwargs)
if self._interrupt:
return None
return clear()
else:
size = 1 if guide_scale == 1 else len(gen_args["x"])
size = len(gen_args["x"]) if any_guidance else 1
ret_values = [None] * size
for x_id in range(size):
sub_gen_args = {k : [v[x_id]] for k, v in gen_args.items() }
ret_values[x_id] = trans( **sub_gen_args, x_id= x_id , **kwargs)[0]
if self._interrupt:
return None
return clear()
sub_gen_args = None
if guide_scale == 1:
noise_pred = ret_values[0]
if not any_guidance:
noise_pred = ret_values[0]
elif phantom:
guide_scale_img= 5.0
guide_scale_text= guide_scale #7.5
@ -913,19 +958,32 @@ class WanAny2V:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
noise_pred_noaudio = None
elif multitalk and audio_proj != None:
noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values
if apg_switch != 0:
noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_text,
noise_pred_cond,
momentum_buffer=text_momentumbuffer,
norm_threshold=apg_norm_threshold) \
+ (audio_cfg_scale - 1) * adaptive_projected_guidance(noise_pred_drop_text - noise_pred_uncond,
noise_pred_cond,
momentum_buffer=audio_momentumbuffer,
norm_threshold=apg_norm_threshold)
if guide_scale == 1:
noise_pred_cond, noise_pred_drop_audio = ret_values
noise_pred = noise_pred_cond + (audio_cfg_scale - 1)* adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_audio,
noise_pred_cond,
momentum_buffer=audio_momentumbuffer,
norm_threshold=apg_norm_threshold)
else:
noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values
noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_text,
noise_pred_cond,
momentum_buffer=text_momentumbuffer,
norm_threshold=apg_norm_threshold) \
+ (audio_cfg_scale - 1) * adaptive_projected_guidance(noise_pred_drop_text - noise_pred_uncond,
noise_pred_cond,
momentum_buffer=audio_momentumbuffer,
norm_threshold=apg_norm_threshold)
else:
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_drop_text) + audio_cfg_scale * (noise_pred_drop_text - noise_pred_uncond)
noise_pred_uncond = noise_pred_cond = noise_pred_drop_text = None
if guide_scale == 1:
noise_pred_cond, noise_pred_drop_audio = ret_values
noise_pred = noise_pred_drop_audio + audio_cfg_scale* (noise_pred_cond - noise_pred_drop_audio)
else:
noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_drop_text) + audio_cfg_scale * (noise_pred_drop_text - noise_pred_uncond)
noise_pred_uncond = noise_pred_cond = noise_pred_drop_text = noise_pred_drop_audio = None
else:
noise_pred_cond, noise_pred_uncond = ret_values
if apg_switch != 0:
@ -970,6 +1028,7 @@ class WanAny2V:
callback(i, latents_preview[0], False, denoising_extra =denoising_extra )
latents_preview = None
clear()
if timestep_injection:
latents[:, :, :source_latents.shape[2]] = source_latents

View File

@ -19,6 +19,22 @@ from ..multitalk.multitalk_utils import get_attn_map_with_target
__all__ = ['WanModel']
def get_cache(cache_name):
all_cache = offload.shared_state.get("_cache", None)
if all_cache is None:
all_cache = {}
offload.shared_state["_cache"]= all_cache
cache = offload.shared_state.get(cache_name, None)
if cache is None:
cache = {}
offload.shared_state[cache_name] = cache
return cache
def clear_caches():
all_cache = offload.shared_state.get("_cache", None)
if all_cache is not None:
all_cache.clear()
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
@ -93,6 +109,32 @@ def relative_l1_distance(last_tensor, current_tensor):
relative_l1_distance = l1_distance / norm
return relative_l1_distance.to(torch.float32)
class LoRALinearLayer(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 128,
dtype: Optional[torch.dtype] = torch.float32,
):
super().__init__()
self.down = nn.Linear(in_features, rank, bias=False, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, dtype=dtype)
self.rank = rank
self.out_features = out_features
self.in_features = in_features
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
return up_hidden_states.to(orig_dtype)
class WanRMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
@ -244,7 +286,7 @@ class WanSelfAttention(nn.Module):
else:
return x, None
def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0):
def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0, standin_phase =-1):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@ -257,19 +299,28 @@ class WanSelfAttention(nn.Module):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
q = self.q(x)
q, k, v = self.q(x), self.k(x), self.v(x)
if standin_phase == 1:
q += self.q_loras(x)
k += self.k_loras(x)
v += self.v_loras(x)
self.norm_q(q)
q = q.view(b, s, n, d)
k = self.k(x)
self.norm_k(k)
k = k.view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
q,k,v = q.view(b, s, n, d), k.view(b, s, n, d), v.view(b, s, n, d)
del x
qklist = [q,k]
del q,k
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
if standin_phase >= 1:
standin_cache = get_cache("standin")
if standin_phase == 1:
standin_cache[self.block_no] = (k,v)
elif standin_phase == 2:
k_ip, v_ip = standin_cache[self.block_no]
k, v = torch.concat([k, k_ip], dim=1), torch.concat([v, v_ip], dim=1)
del k_ip, v_ip
if ref_target_masks != None:
x_ref_attn_map = get_attn_map_with_target(q, k , grid_sizes, ref_target_masks=ref_target_masks, ref_images_count = ref_images_count)
else:
@ -289,6 +340,7 @@ class WanSelfAttention(nn.Module):
x = pay_attention(
qkv_list,
window_size=self.window_size)
else:
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
x = (
@ -461,6 +513,7 @@ class WanAttentionBlock(nn.Module):
multitalk_audio=None,
multitalk_masks=None,
ref_images_count=0,
standin_phase=-1,
):
r"""
Args:
@ -504,7 +557,7 @@ class WanAttentionBlock(nn.Module):
xlist = [x_mod.to(attention_dtype)]
del x_mod
y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count)
y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count, standin_phase= standin_phase, )
y = y.to(dtype)
if cam_emb != None: y = self.projector(y)
@ -513,11 +566,13 @@ class WanAttentionBlock(nn.Module):
x.addcmul_(y, e[2])
x, y = restore_latent_shape(x), restore_latent_shape(y)
del y
y = self.norm3(x)
y = y.to(attention_dtype)
ylist= [y]
del y
x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
if context is not None:
y = self.norm3(x)
y = y.to(attention_dtype)
ylist= [y]
del y
x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
if multitalk_audio != None:
# cross attn of multitalk audio
@ -853,6 +908,7 @@ class WanModel(ModelMixin, ConfigMixin):
vae_scale=4, # vae timedownsample scale
norm_input_visual=True,
norm_output_audio=True,
standin= False,
):
super().__init__()
@ -976,6 +1032,11 @@ class WanModel(ModelMixin, ConfigMixin):
for block in self.blocks:
block.cross_attn.processor = WanCrossAttentionProcessor(fantasytalking_dim, dim)
if standin:
for block in self.blocks:
block.self_attn.q_loras = LoRALinearLayer(dim, dim, rank=128)
block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128)
block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
layer_list = [self.head, self.head.head, self.patch_embedding]
@ -1155,8 +1216,9 @@ class WanModel(ModelMixin, ConfigMixin):
audio_scale=None,
multitalk_audio = None,
multitalk_masks = None,
ref_images_count = 0,
ref_images_count = 0,
standin_freqs = None,
standin_ref = None,
):
# patch_dtype = self.patch_embedding.weight.dtype
modulation_dtype = self.time_projection[1].weight.dtype
@ -1220,6 +1282,18 @@ class WanModel(ModelMixin, ConfigMixin):
offload.shared_state["embed_sizes"] = grid_sizes
offload.shared_state["step_no"] = current_step
offload.shared_state["max_steps"] = max_steps
if current_step == 0 and x_id == 0: clear_caches()
# arguments
kwargs = dict(
grid_sizes=grid_sizes,
freqs=freqs,
cam_emb = cam_emb,
block_mask = block_mask,
audio_proj=audio_proj,
audio_context_lens=audio_context_lens,
ref_images_count=ref_images_count,
)
_flag_df = t.dim() == 2
@ -1228,6 +1302,16 @@ class WanModel(ModelMixin, ConfigMixin):
) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
standin_x = None
if standin_ref is not None:
standin_cache_enabled = False
kwargs["standin_phase"] = 2
if (current_step == 0 or not standin_cache_enabled) and x_id == 0:
standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2)
standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)) )
standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype)
standin_e = standin_ref = None
if self.inject_sample_info and fps!=None:
fps = torch.tensor(fps, dtype=torch.long, device=device)
@ -1254,8 +1338,9 @@ class WanModel(ModelMixin, ConfigMixin):
if multitalk_audio != None:
multitalk_audio_list = []
for audio in multitalk_audio:
audio = self.audio_proj(*audio)
audio = torch.concat(audio.split(1), dim=2).to(context[0])
if audio is not None:
audio = self.audio_proj(*audio)
audio = torch.concat(audio.split(1), dim=2).to(context[0])
multitalk_audio_list.append(audio)
audio = None
else:
@ -1271,17 +1356,6 @@ class WanModel(ModelMixin, ConfigMixin):
else:
audio_scale_list = [None] * len(x_list)
# arguments
kwargs = dict(
grid_sizes=grid_sizes,
freqs=freqs,
cam_emb = cam_emb,
block_mask = block_mask,
audio_proj=audio_proj,
audio_context_lens=audio_context_lens,
ref_images_count=ref_images_count,
)
if vace_context == None:
hints_list = [None ] *len(x_list)
@ -1378,8 +1452,9 @@ class WanModel(ModelMixin, ConfigMixin):
if pipeline._interrupt:
return [None] * len(x_list)
# if (x_id != 0 or joint_pass) and slg_layers is not None and block_idx in slg_layers:
# if not joint_pass or not x_should_calc[0]:
if standin_x is not None:
if not standin_cache_enabled and x_id ==0 : get_cache("standin").clear()
standin_x = block(standin_x, context = None, grid_sizes = None, e= standin_e0, freqs = standin_freqs, standin_phase = 1)
if slg_layers is not None and block_idx in slg_layers:
if x_id != 0 or not x_should_calc[0]:

View File

@ -137,7 +137,10 @@ def get_meshgrid_nd(start, *args, dim=2):
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
if a == b:
g = torch.tensor([a], dtype=torch.float32 )
else:
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
@ -297,16 +300,15 @@ def apply_rotary_emb( qklist,
return xq_out, xk_out
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
use_real=True,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
k = 6,
L_test = 66,
enable_riflex = True
enable_riflex = False
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
@ -326,6 +328,13 @@ def get_nd_rotary_pos_embed(
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
head_dim = 128
rope_dim_list = [44, 42, 42]
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert (
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
grid = get_meshgrid_nd(
start, *args, dim=len(rope_dim_list)
@ -454,15 +463,7 @@ def get_rotary_pos_embed(latents_size, enable_RIFLEx = False):
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
head_dim = 128
rope_dim_list = [44, 42, 42]
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert (
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=10000,
use_real=True,

View File

@ -75,12 +75,14 @@ def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=1
if not (left_path==None or right_path==None):
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
elif left_path==None:
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
human_speech_array1 = np.zeros(human_speech_array2.shape[0])
elif right_path==None:
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
human_speech_array2 = np.zeros(human_speech_array1.shape[0])
else:
audio_type='para'
if left_path==None:
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
human_speech_array1 = np.zeros(human_speech_array2.shape[0])
elif right_path==None:
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
human_speech_array2 = np.zeros(human_speech_array1.shape[0])
if audio_type=='para':
new_human_speech1 = human_speech_array1
@ -94,21 +96,27 @@ def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=1
new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])])
new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]])
#dont include the padding on the summed audio which is used to build the output audio track
sum_human_speechs = new_human_speech1 + new_human_speech2
if pad > 0:
new_human_speech1 = np.concatenate([np.zeros(pad), new_human_speech1])
new_human_speech2 = np.concatenate([np.zeros(pad), new_human_speech2])
duration_changed = False
if min_audio_duration > 0:
min_samples = math.ceil( min_audio_duration * sample_rate)
if len(new_human_speech1) < min_samples:
new_human_speech1 = np.concatenate([new_human_speech1, np.zeros(min_samples -len(new_human_speech1)) ])
duration_changed = True
if len(new_human_speech2) < min_samples:
new_human_speech2 = np.concatenate([new_human_speech2, np.zeros(min_samples -len(new_human_speech2)) ])
duration_changed = True
#dont include the padding on the summed audio which is used to build the output audio track
sum_human_speechs = new_human_speech1 + new_human_speech2
if pad > 0:
duration_changed = True
new_human_speech1 = np.concatenate([np.zeros(pad), new_human_speech1])
new_human_speech2 = np.concatenate([np.zeros(pad), new_human_speech2])
return new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed
return new_human_speech1, new_human_speech2, sum_human_speechs
def process_tts_single(text, save_dir, voice1):
s1_sentences = []
@ -187,14 +195,14 @@ def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combinat
wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base")
# wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec")
pad = int(padded_frames_for_embeddings/ fps * sr)
new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration )
new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration )
audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
full_audio_embs = []
if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
# if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
if audio_guide2 == None: sum_human_speechs = None
if audio_guide2 == None and not duration_changed: sum_human_speechs = None
return full_audio_embs, sum_human_speechs

View File

@ -10,6 +10,9 @@ def test_class_1_3B(base_model_type):
def test_multitalk(base_model_type):
return base_model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk", "infinitetalk"]
def test_standin(base_model_type):
return base_model_type in ["standin", "vace_standin_14B"]
class family_handler():
@staticmethod
@ -72,7 +75,8 @@ class family_handler():
i2v = test_class_i2v(base_model_type)
extra_model_def["i2v_class"] = i2v
extra_model_def["multitalk_class"] = test_multitalk(base_model_type)
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"]
extra_model_def["standin_class"] = test_standin(base_model_type)
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"]
extra_model_def["vace_class"] = vace_class
if test_multitalk(base_model_type):
@ -101,6 +105,7 @@ class family_handler():
"adaptive_projected_guidance" : True,
"tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
"mag_cache" : True,
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
"sample_solvers":[
("unipc", "unipc"),
("euler", "euler"),
@ -115,8 +120,8 @@ class family_handler():
@staticmethod
def query_supported_types():
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B",
"t2v_1.3B", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
"t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
"recam_1.3B",
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
@ -130,8 +135,8 @@ class family_handler():
}
models_comp_map = {
"vace_14B" : [ "vace_multitalk_14B"],
"t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"],
"vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B"],
"t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin"],
"i2v" : [ "fantasy", "multitalk", "flf2v_720p" ],
"i2v_2_2" : ["i2v_2_2_multitalk"],
"fantasy": ["multitalk"],
@ -221,7 +226,15 @@ class family_handler():
mult = model_def.get("loras_multipliers","")
if len(mult)> 1 and len(mult[0].split(";"))==3: ui_defaults["guidance_phases"] = 3
if settings_version < 2.27:
if base_model_type in "infinitetalk":
guidance_scale = ui_defaults.get("guidance_scale", None)
if guidance_scale == 1:
ui_defaults["audio_guidance_scale"]= 1
video_prompt_type = ui_defaults.get("video_prompt_type", "")
if "I" in video_prompt_type:
video_prompt_type = video_prompt_type.replace("KI", "QKI")
ui_defaults["video_prompt_type"] = video_prompt_type
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
@ -248,11 +261,19 @@ class family_handler():
"flow_shift": 7, # 11 for 720p
"sliding_window_overlap" : 9,
"sample_solver" : "euler",
"video_prompt_type": "KI",
"video_prompt_type": "QKI",
"remove_background_images_ref" : 0,
"adaptive_switch" : 1,
})
elif base_model_type in ["standin"]:
ui_defaults.update({
"guidance_scale": 5.0,
"flow_shift": 7, # 11 for 720p
"sliding_window_overlap" : 9,
"video_prompt_type": "I",
"remove_background_images_ref" : 1,
})
elif base_model_type in ["phantom_1.3B", "phantom_14B"]:
ui_defaults.update({
"guidance_scale": 7.5,

View File

@ -0,0 +1,259 @@
import os
import cv2
import requests
import torch
import numpy as np
import PIL.Image as Image
import PIL.ImageOps
# from insightface.app import FaceAnalysis
# from facexlib.parsing import init_parsing_model
from torchvision.transforms.functional import normalize
from typing import Union, Optional
from models.hyvideo.data_kits.face_align import AlignImage
def _img2tensor(img: np.ndarray, bgr2rgb: bool = True) -> torch.Tensor:
if bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.0
img = np.transpose(img, (2, 0, 1))
return torch.from_numpy(img)
def _pad_to_square(img: np.ndarray, pad_color: int = 255) -> np.ndarray:
h, w, _ = img.shape
if h == w:
return img
if h > w:
pad_size = (h - w) // 2
padded_img = cv2.copyMakeBorder(
img,
0,
0,
pad_size,
h - w - pad_size,
cv2.BORDER_CONSTANT,
value=[pad_color] * 3,
)
else:
pad_size = (w - h) // 2
padded_img = cv2.copyMakeBorder(
img,
pad_size,
w - h - pad_size,
0,
0,
cv2.BORDER_CONSTANT,
value=[pad_color] * 3,
)
return padded_img
class FaceProcessor:
def __init__(self):
self.align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt")
self.align_instance.facedet.model.to("cpu")
def process(
self,
image: Union[str, PIL.Image.Image],
resize_to: int = 512,
border_thresh: int = 10,
face_crop_scale: float = 1.5,
remove_bg= False,
# area=1.25
) -> PIL.Image.Image:
image_pil = PIL.ImageOps.exif_transpose(image).convert("RGB")
w, h = image_pil.size
self.align_instance.facedet.model.to("cuda")
_, _, bboxes_list = self.align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True)
self.align_instance.facedet.model.to("cpu")
try:
bboxSrc = bboxes_list[0]
except:
bboxSrc = [0, 0, w, h]
x1, y1, ww, hh = bboxSrc
x2, y2 = x1 + ww, y1 + hh
# ww, hh = (x2-x1) * area, (y2-y1) * area
# center = [(x2+x1)//2, (y2+y1)//2]
# x1 = max(center[0] - ww//2, 0)
# y1 = max(center[1] - hh//2, 0)
# x2 = min(center[0] + ww//2, w)
# y2 = min(center[1] + hh//2, h)
frame = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
h, w, _ = frame.shape
image_to_process = None
is_close_to_border = (
x1 <= border_thresh
and y1 <= border_thresh
and x2 >= w - border_thresh
and y2 >= h - border_thresh
)
if is_close_to_border:
# print(
# "[Info] Face is close to border, padding original image to square."
# )
image_to_process = _pad_to_square(frame, pad_color=255)
else:
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
side = int(max(x2 - x1, y2 - y1) * face_crop_scale)
half = side // 2
left = int(max(cx - half, 0))
top = int(max(cy - half, 0))
right = int(min(cx + half, w))
bottom = int(min(cy + half, h))
cropped_face = frame[top:bottom, left:right]
image_to_process = _pad_to_square(cropped_face, pad_color=255)
image_resized = cv2.resize(
image_to_process, (resize_to, resize_to), interpolation=cv2.INTER_LANCZOS4 # .INTER_AREA
)
face_tensor = _img2tensor(image_resized).to("cpu")
from shared.utils.utils import remove_background, convert_tensor_to_image
if remove_bg:
face_tensor = remove_background(face_tensor)
img_out = Image.fromarray(face_tensor.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy())
return img_out
# class FaceProcessor2:
# def __init__(self, antelopv2_path=".", device: Optional[torch.device] = None):
# if device is None:
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# else:
# self.device = device
# providers = (
# ["CUDAExecutionProvider"]
# if self.device.type == "cuda"
# else ["CPUExecutionProvider"]
# )
# self.app = FaceAnalysis(
# name="antelopev2", root=antelopv2_path, providers=providers
# )
# self.app.prepare(ctx_id=0, det_size=(640, 640))
# self.parsing_model = init_parsing_model(
# model_name="bisenet", device=self.device
# )
# self.parsing_model.eval()
# print("FaceProcessor initialized successfully.")
# def process(
# self,
# image: Union[str, PIL.Image.Image],
# resize_to: int = 512,
# border_thresh: int = 10,
# face_crop_scale: float = 1.5,
# extra_input: bool = False,
# ) -> PIL.Image.Image:
# if isinstance(image, str):
# if image.startswith("http://") or image.startswith("https://"):
# image = PIL.Image.open(requests.get(image, stream=True, timeout=10).raw)
# elif os.path.isfile(image):
# image = PIL.Image.open(image)
# else:
# raise ValueError(
# f"Input string is not a valid URL or file path: {image}"
# )
# elif not isinstance(image, PIL.Image.Image):
# raise TypeError(
# "Input must be a file path, a URL, or a PIL.Image.Image object."
# )
# image = PIL.ImageOps.exif_transpose(image).convert("RGB")
# frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# faces = self.app.get(frame)
# h, w, _ = frame.shape
# image_to_process = None
# if not faces:
# print(
# "[Warning] No face detected. Using the whole image, padded to square."
# )
# image_to_process = _pad_to_square(frame, pad_color=255)
# else:
# largest_face = max(
# faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1])
# )
# x1, y1, x2, y2 = map(int, largest_face.bbox)
# is_close_to_border = (
# x1 <= border_thresh
# and y1 <= border_thresh
# and x2 >= w - border_thresh
# and y2 >= h - border_thresh
# )
# if is_close_to_border:
# print(
# "[Info] Face is close to border, padding original image to square."
# )
# image_to_process = _pad_to_square(frame, pad_color=255)
# else:
# cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
# side = int(max(x2 - x1, y2 - y1) * face_crop_scale)
# half = side // 2
# left = max(cx - half, 0)
# top = max(cy - half, 0)
# right = min(cx + half, w)
# bottom = min(cy + half, h)
# cropped_face = frame[top:bottom, left:right]
# image_to_process = _pad_to_square(cropped_face, pad_color=255)
# image_resized = cv2.resize(
# image_to_process, (resize_to, resize_to), interpolation=cv2.INTER_AREA
# )
# face_tensor = (
# _img2tensor(image_resized, bgr2rgb=True).unsqueeze(0).to(self.device)
# )
# with torch.no_grad():
# normalized_face = normalize(face_tensor, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# parsing_out = self.parsing_model(normalized_face)[0]
# parsing_mask = parsing_out.argmax(dim=1, keepdim=True)
# background_mask_np = (parsing_mask.squeeze().cpu().numpy() == 0).astype(
# np.uint8
# )
# white_background = np.ones_like(image_resized, dtype=np.uint8) * 255
# mask_3channel = cv2.cvtColor(background_mask_np * 255, cv2.COLOR_GRAY2BGR)
# result_img_bgr = np.where(mask_3channel == 255, white_background, image_resized)
# result_img_rgb = cv2.cvtColor(result_img_bgr, cv2.COLOR_BGR2RGB)
# img_white_bg = PIL.Image.fromarray(result_img_rgb)
# if extra_input:
# # 2. Create image with transparent background (new logic)
# # Create an alpha channel: 255 for foreground (not background), 0 for background
# alpha_channel = (parsing_mask.squeeze().cpu().numpy() != 0).astype(
# np.uint8
# ) * 255
# # Convert the resized BGR image to RGB
# image_resized_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)
# # Stack RGB channels with the new alpha channel
# rgba_image = np.dstack((image_resized_rgb, alpha_channel))
# # Create PIL image from the RGBA numpy array
# img_transparent_bg = PIL.Image.fromarray(rgba_image, "RGBA")
# return img_white_bg, img_transparent_bg
# else:
# return img_white_bg

View File

@ -92,13 +92,44 @@ def get_video_info(video_path):
return fps, width, height, frame_count
def get_video_frame(file_name, frame_no):
decord.bridge.set_bridge('torch')
reader = decord.VideoReader(file_name)
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor:
"""Extract nth frame from video as PyTorch tensor normalized to [-1, 1]."""
cap = cv2.VideoCapture(file_name)
if not cap.isOpened():
raise ValueError(f"Cannot open video: {file_name}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Handle out of bounds
if frame_no >= total_frames or frame_no < 0:
if return_last_if_missing:
frame_no = total_frames - 1
else:
cap.release()
raise IndexError(f"Frame {frame_no} out of bounds (0-{total_frames-1})")
# Get frame
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
ret, frame = cap.read()
cap.release()
if not ret:
raise ValueError(f"Failed to read frame {frame_no}")
# Convert BGR->RGB, reshape to (C,H,W), normalize to [-1,1]
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if return_PIL:
return Image.fromarray(frame)
else:
return (torch.from_numpy(frame).permute(2, 0, 1).float() / 127.5) - 1.0
# def get_video_frame(file_name, frame_no):
# decord.bridge.set_bridge('torch')
# reader = decord.VideoReader(file_name)
frame = reader.get_batch([frame_no]).squeeze(0)
img = Image.fromarray(frame.numpy().astype(np.uint8))
return img
# frame = reader.get_batch([frame_no]).squeeze(0)
# img = Image.fromarray(frame.numpy().astype(np.uint8))
# return img
def convert_image_to_video(image):
if image is None:
@ -174,18 +205,19 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
return height, width, margin_top, margin_left
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16):
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
if fit_into_canvas == None:
return height, width
# return image_height, image_width
return canvas_height, canvas_width
if fit_into_canvas:
scale1 = min(canvas_height / height, canvas_width / width)
scale2 = min(canvas_width / height, canvas_height / width)
scale1 = min(canvas_height / image_height, canvas_width / image_width)
scale2 = min(canvas_width / image_height, canvas_height / image_width)
scale = max(scale1, scale2)
else:
scale = (canvas_height * canvas_width / (height * width))**(1/2)
scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2)
new_height = round( height * scale / block_size) * block_size
new_width = round( width * scale / block_size) * block_size
new_height = round( image_height * scale / block_size) * block_size
new_width = round( image_width * scale / block_size) * block_size
return new_height, new_width
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ):

191
wgp.py
View File

@ -17,7 +17,7 @@ import numpy as np
import importlib
from shared.utils import notification_sound
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, get_video_frame
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
from shared.utils.audio_video import save_image_metadata, read_image_metadata
from shared.match_archi import match_nvidia_architecture
@ -55,8 +55,8 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.5.10"
WanGP_version = "8.11"
settings_version = 2.26
WanGP_version = "8.2"
settings_version = 2.27
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -75,9 +75,14 @@ gen_lock = threading.Lock()
offloadobj = enhancer_offloadobj = wan_model = None
reload_needed = True
def clear_gen_cache():
if "_cache" in offload.shared_state:
del offload.shared_state["_cache"]
def release_model():
global wan_model, offloadobj, reload_needed
wan_model = None
clear_gen_cache()
offload.shared_state
if offloadobj is not None:
offloadobj.release()
offloadobj = None
@ -2066,7 +2071,7 @@ def fix_settings(model_type, ui_defaults):
if base_model_type in ["flux"] and settings_version < 2.23:
video_prompt_type = video_prompt_type.replace("K", "").replace("I", "KI")
remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 1)
remove_background_images_ref = ui_defaults.get("remove_background_images_ref", None)
if settings_version < 2.22:
if "I" in video_prompt_type:
if remove_background_images_ref == 2:
@ -2077,7 +2082,8 @@ def fix_settings(model_type, ui_defaults):
remove_background_images_ref = 0
if settings_version < 2.26:
if not "K" in video_prompt_type: video_prompt_type = video_prompt_type.replace("I", "KI")
ui_defaults["remove_background_images_ref"] = remove_background_images_ref
if remove_background_images_ref is not None:
ui_defaults["remove_background_images_ref"] = remove_background_images_ref
ui_defaults["video_prompt_type"] = video_prompt_type
@ -2438,12 +2444,12 @@ def download_models(model_filename = None, model_type= None, module_type = None,
shared_def = {
"repoId" : "DeepBeepMeep/Wan2.1",
"sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "" ],
"sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "det_align", "" ],
"fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"],
["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"],
["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
["config.json", "pytorch_model.bin", "preprocessor_config.json"],
["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], [ "flownet.pkl" ] ]
["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], ["detface.pt"], [ "flownet.pkl" ] ]
}
process_files_def(**shared_def)
@ -3314,6 +3320,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
video_guidance_scale = configs.get("guidance_scale", None)
video_guidance2_scale = configs.get("guidance2_scale", None)
video_guidance3_scale = configs.get("guidance3_scale", None)
video_audio_guidance_scale = configs.get("audio_guidance_scale", None)
video_switch_threshold = configs.get("switch_threshold", 0)
video_switch_threshold2 = configs.get("switch_threshold2", 0)
video_model_switch_phase = configs.get("model_switch_phase", 1)
@ -3364,8 +3371,8 @@ def select_video(state, input_file_list, event_data: gr.EventData):
if model_def.get("sample_solvers", None) is not None and len(video_sample_solver) > 0 :
values += [video_sample_solver]
labels += ["Sampler Solver"]
values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps]
labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Shift Scale", "Num Inference steps"]
values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_audio_guidance_scale, video_flow_shift, video_num_inference_steps]
labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Audio Guidance Scale", "Shift Scale", "Num Inference steps"]
video_negative_prompt = configs.get("negative_prompt", "")
if len(video_negative_prompt) > 0:
values += [video_negative_prompt]
@ -4272,6 +4279,7 @@ def generate_video(
sliding_window_color_correction_strength,
sliding_window_overlap_noise,
sliding_window_discard_last_frames,
image_refs_relative_size,
remove_background_images_ref,
temporal_upsampling,
spatial_upsampling,
@ -4438,10 +4446,8 @@ def generate_video(
i2v = test_class_i2v(model_type)
diffusion_forcing = "diffusion_forcing" in model_filename
t2v = base_model_type in ["t2v"]
recam = base_model_type in ["recam_1.3B"]
ltxv = "ltxv" in model_filename
vace = test_vace_module(base_model_type)
phantom = "phantom" in model_filename
hunyuan_t2v = "hunyuan_video_720" in model_filename
hunyuan_i2v = "hunyuan_video_i2v" in model_filename
hunyuan_custom = "hunyuan_video_custom" in model_filename
@ -4450,8 +4456,8 @@ def generate_video(
hunyuan_avatar = "hunyuan_video_avatar" in model_filename
fantasy = base_model_type in ["fantasy"]
multitalk = model_def.get("multitalk_class", False)
flux = model_family in ["flux"]
qwen = model_family in ["qwen"]
standin = model_def.get("standin_class", False)
infinitetalk = base_model_type in ["infinitetalk"]
if "B" in audio_prompt_type or "X" in audio_prompt_type:
from models.wan.multitalk.multitalk import parse_speakers_locations
@ -4489,7 +4495,9 @@ def generate_video(
_, latent_size = get_model_min_frames_and_step(model_type)
if diffusion_forcing: latent_size = 4
original_image_refs = image_refs
original_image_refs = image_refs
# image_refs = None
# nb_frames_positions= 0
frames_to_inject = []
any_background_ref = False
outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
@ -4519,7 +4527,7 @@ def generate_video(
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
from shared.utils.utils import resize_and_remove_background
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or vace) ) # no fit for vace ref images as it is done later
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or vace or standin) ) # no fit for vace ref images as it is done later
update_task_thumbnails(task, locals())
send_cmd("output")
joint_pass = boost ==1 #and profile != 1 and profile != 3
@ -4580,7 +4588,7 @@ def generate_video(
# pad audio_proj_full if aligned to beginning of window to simulate source window overlap
min_audio_duration = current_video_length/fps if reset_control_aligment else video_source_duration + current_video_length/fps
audio_proj_full, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = audio_guide, audio_guide2= audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration)
if output_new_audio_filepath is not None: output_new_audio_data = None
if output_new_audio_data is not None: output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined
if not args.save_speakers and "X" in audio_prompt_type:
os.remove(audio_guide)
os.remove(audio_guide2)
@ -4698,7 +4706,6 @@ def generate_video(
image_end_tensor = convert_image_to_tensor(image_end_tensor)
else:
if "L" in image_prompt_type:
from shared.utils.utils import get_video_frame
refresh_preview["video_source"] = get_video_frame(video_source, 0)
prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = block_size )
prefix_video = prefix_video.permute(3, 0, 1, 2)
@ -4707,9 +4714,10 @@ def generate_video(
pre_video_frame = convert_tensor_to_image(prefix_video[:, -1])
source_video_overlap_frames_count = pre_video_guide.shape[1]
source_video_frames_count = prefix_video.shape[1]
if sample_fit_canvas != None: image_size = pre_video_guide.shape[-2:]
if sample_fit_canvas != None:
image_size = pre_video_guide.shape[-2:]
sample_fit_canvas = None
guide_start_frame = prefix_video.shape[1]
sample_fit_canvas = None
window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count)
guide_end_frame = guide_start_frame + current_video_length - (source_video_overlap_frames_count if window_no == 1 else reuse_frames)
@ -4729,12 +4737,20 @@ def generate_video(
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ]
if infinitetalk and video_guide is not None:
src_image = get_video_frame(video_guide, aligned_guide_start_frame-1, return_last_if_missing = True, return_PIL = True)
new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size)
src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
refresh_preview["video_guide"] = src_image
src_video = convert_image_to_tensor(src_image).unsqueeze(1)
if sample_fit_canvas != None:
image_size = src_video.shape[-2:]
sample_fit_canvas = None
if ltxv and video_guide is not None:
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
status_info = "Extracting " + processes_names[preprocess_type]
send_cmd("progress", [0, get_latest_status(state, status_info)])
# start one frame ealier to faciliate latents merging later
# start one frame ealier to facilitate latents merging later
src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
if src_video != None:
src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
@ -4912,6 +4928,7 @@ def generate_video(
overlapped_latents = overlapped_latents,
return_latent_slice= return_latent_slice,
overlap_noise = sliding_window_overlap_noise,
overlap_size = sliding_window_overlap,
color_correction_strength = sliding_window_color_correction_strength,
conditioning_latents_size = conditioning_latents_size,
keep_frames_parsed = keep_frames_parsed,
@ -4928,11 +4945,14 @@ def generate_video(
offloadobj = offloadobj,
set_header_text= set_header_text,
pre_video_frame = pre_video_frame,
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
image_refs_relative_size = image_refs_relative_size,
)
except Exception as e:
if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0:
cleanup_temp_audio_files(control_audio_tracks + source_audio_tracks)
remove_temp_filenames(temp_filenames_list)
clear_gen_cache()
offloadobj.unload_all()
trans.cache = None
offload.unload_loras_from_model(trans)
@ -4977,6 +4997,7 @@ def generate_video(
overlapped_latents = samples.get("latent_slice", None)
samples= samples["x"]
samples = samples.to("cpu")
clear_gen_cache()
offloadobj.unload_all()
gc.collect()
torch.cuda.empty_cache()
@ -5891,9 +5912,15 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
pop += ["model_mode"]
if not vace and not phantom and not hunyuan_video_custom:
unsaved_params = ["keep_frames_video_guide", "remove_background_images_ref", "mask_expand"] #"video_prompt_type",
unsaved_params = ["keep_frames_video_guide", "mask_expand"] #"video_prompt_type",
if base_model_type in ["t2v"]: unsaved_params = unsaved_params[1:]
pop += unsaved_params
if not "I" in video_prompt_type:
pop += ["remove_background_images_ref"]
if not model_def.get("any_image_refs_relative_size", False):
pop += ["image_refs_relative_size"]
if not vace:
pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"]
@ -6328,7 +6355,8 @@ def save_inputs(
sliding_window_overlap,
sliding_window_color_correction_strength,
sliding_window_overlap_noise,
sliding_window_discard_last_frames,
sliding_window_discard_last_frames,
image_refs_relative_size,
remove_background_images_ref,
temporal_upsampling,
spatial_upsampling,
@ -6543,11 +6571,20 @@ def refresh_image_prompt_type(state, image_prompt_type):
return gr.update(visible = "S" in image_prompt_type ), gr.update(visible = "E" in image_prompt_type ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source)
def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs):
video_prompt_type = del_in_sequence(video_prompt_type, "KFI")
model_type = state["model_type"]
model_def = get_model_def(model_type)
image_ref_choices = model_def.get("image_ref_choices", None)
if image_ref_choices is not None:
video_prompt_type = del_in_sequence(video_prompt_type, image_ref_choices["letters_filter"])
else:
video_prompt_type = del_in_sequence(video_prompt_type, "KFI")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs)
visible = "I" in video_prompt_type
vace= test_vace_module(state["model_type"])
return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
rm_bg_visible= visible and not model_def.get("no_background_removal", False)
img_rel_size_visible = visible and model_def.get("any_image_refs_relative_size", False)
return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask, image_mode):
video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA")
@ -6573,8 +6610,16 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt
model_def = get_model_def(model_type)
image_outputs = image_mode == 1
vace= test_vace_module(model_type)
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = visible and not image_outputs), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False)
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt):
video_prompt_type = del_in_sequence(video_prompt_type, "UVQKI")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt)
control_video_visible = "V" in video_prompt_type
ref_images_visible = "I" in video_prompt_type
return video_prompt_type, gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible )
# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
# return refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide)
@ -6799,6 +6844,13 @@ def detect_auto_save_form(state, evt:gr.SelectData):
else:
return gr.update()
def compute_video_length_label(fps, current_video_length):
return f"Number of frames ({fps} frames = 1s), current duration: {(current_video_length / fps):.1f}s",
def refresh_video_length_label(state, current_video_length):
fps = get_model_fps(get_base_model_type(state["model_type"]))
return gr.update(label= compute_video_length_label(fps, current_video_length))
def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None):
global inputs_names #, advanced
@ -6918,6 +6970,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
phantom = base_model_type in ["phantom_1.3B", "phantom_14B"]
fantasy = base_model_type in ["fantasy"]
multitalk = model_def.get("multitalk_class", False)
standin = model_def.get("standin_class", False)
infinitetalk = base_model_type in ["infinitetalk"]
hunyuan_t2v = "hunyuan_video_720" in model_filename
hunyuan_i2v = "hunyuan_video_i2v" in model_filename
@ -6939,7 +6992,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
any_start_image = False
any_end_image = False
any_reference_image = False
v2i_switch_supported = (vace or t2v) and not image_outputs
v2i_switch_supported = (vace or t2v or standin) and not image_outputs
ti2v_2_2 = base_model_type in ["ti2v_2_2"]
image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 )
@ -7060,7 +7113,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
model_mode = gr.Dropdown(value=None, visible=False)
keep_frames_video_source = gr.Text(visible=False)
with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or infinitetalk or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column:
with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column:
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
any_control_video = True
@ -7119,12 +7172,28 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
],
value=filter_letters(video_prompt_type_value, "PDSLCMUV"),
label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True,
)
)
elif infinitetalk:
video_prompt_type_video_guide = gr.Dropdown(value="", choices = [("","")], visible=False)
else:
any_control_video = False
any_control_image = False
video_prompt_type_video_guide = gr.Dropdown(visible= False)
if infinitetalk:
video_prompt_type_video_guide_alt = gr.Dropdown(
choices=[
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "UV"),
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QUV"),
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
],
value=filter_letters(video_prompt_type_value, "UVQKI"),
label="Video to Video", scale = 3, visible= True, show_label= False,
)
else:
video_prompt_type_video_guide_alt = gr.Dropdown(value="", choices = [("","")], visible=False)
# video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
if t2v:
video_prompt_type_video_mask = gr.Dropdown(value = "", choices = [""], visible = False)
@ -7170,7 +7239,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom and not ltxv,
label="Area Processed", scale = 2, show_label= True,
)
if t2v:
image_ref_choices = model_def.get("image_ref_choices", None)
if image_ref_choices is not None:
video_prompt_type_image_refs = gr.Dropdown(
choices= image_ref_choices["choices"],
value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]),
visible = True,
label=image_ref_choices["label"], show_label= True, scale = 2
)
elif t2v:
video_prompt_type_image_refs = gr.Dropdown(value="", label="Ref Image", choices=[""], visible =False)
elif vace:
video_prompt_type_image_refs = gr.Dropdown(
@ -7184,7 +7261,17 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
visible = True,
label="Reference Images", show_label= True, scale = 2
)
elif standin: # and not vace
video_prompt_type_image_refs = gr.Dropdown(
choices=[
("No Reference Image", ""),
("Reference Image is a Person Face", "I"),
],
value=filter_letters(video_prompt_type_value, "I"),
visible = True,
show_label=False,
label="Reference Image", scale = 2
)
elif (flux or qwen) and model_reference_image:
video_prompt_type_image_refs = gr.Dropdown(
@ -7209,7 +7296,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide = gr.Video(label= "Control Video", visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None))
denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False)
keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= (not image_outputs) and "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
keep_frames_video_guide_visible = not image_outputs and "V" in video_prompt_type_value and not model_def.get("keep_frames_video_guide_not_supported", False)
keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= keep_frames_video_guide_visible , scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col:
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
@ -7236,13 +7324,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
)
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" )
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
no_background_removal = model_def.get("no_background_removal", False)
remove_background_images_ref = gr.Dropdown(
choices=[
("Keep Backgrounds behind all Reference Images", 0),
("Remove Backgrounds only behind People / Objects except main Subject / Landscape" if (flux or qwen) else "Remove Backgrounds only behind People / Objects" , 1),
("Remove Backgrounds only behind People / Objects except main Subject / Landscape" if (flux or qwen) else ("Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" if vace else "Remove Backgrounds behind People / Objects") , 1),
],
value=ui_defaults.get("remove_background_images_ref",1),
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not model_def.get("no_background_removal", False)
value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1),
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal
)
any_audio_voices_support = any_audio_track(base_model_type)
@ -7348,16 +7439,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True)
else:
min_frames, frames_step = get_model_min_frames_and_step(base_model_type)
current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97)
video_length = gr.Slider(min_frames, get_max_frames(737 if test_any_sliding_window(base_model_type) else 337), value=ui_defaults.get(
"video_length", 81 if get_model_family(base_model_type)=="wan" else 97),
step=frames_step, label=f"Number of frames ({fps} = 1s)", visible = True, interactive= True)
video_length = gr.Slider(min_frames, get_max_frames(737 if test_any_sliding_window(base_model_type) else 337), value=current_video_length,
step=frames_step, label=compute_video_length_label(fps, current_video_length) , visible = True, interactive= True)
with gr.Row(visible = not lock_inference_steps) as inference_steps_row:
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps", visible = True)
show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced_ui)
with gr.Tabs(visible=advanced_ui) as advanced_row:
guidance_max_phases = model_def.get("guidance_max_phases", 0)
@ -7396,6 +7487,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
guidance2_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance2_scale",5), step=0.5, label="Guidance2 (CFG)", visible= guidance_max_phases >=2 and guidance_phases_value >= 2)
guidance3_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance3_scale",5), step=0.5, label="Guidance3 (CFG)", visible= guidance_max_phases >=3 and guidance_phases_value >= 3)
any_embedded_guidance = model_def.get("embedded_guidance", False)
with gr.Row(visible =any_embedded_guidance or any_audio_guidance) as embedded_guidance_row:
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 4), step=0.5, label="Audio Guidance", visible= any_audio_guidance )
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 6.0), step=0.5, label="Embedded Guidance Scale", visible=any_embedded_guidance )
sample_solver_choices = model_def.get("sample_solvers", None)
with gr.Row(visible = sample_solver_choices is not None or not image_outputs) as sample_solver_row:
@ -7407,12 +7502,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
)
flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs)
any_embedded_guidance = model_def.get("embedded_guidance", False)
with gr.Row(visible =any_embedded_guidance or any_audio_guidance) as embedded_guidance_row:
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 4), step=0.5, label="Audio Guidance", visible= any_audio_guidance )
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 6.0), step=0.5, label="Embedded Guidance Scale", visible=any_embedded_guidance )
with gr.Row(visible = vace) as control_net_weights_row:
control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace)
control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace)
@ -7584,7 +7673,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
value=ui_defaults.get("apg_switch",0),
visible=True,
scale = 1,
label="Adaptive Projected Guidance (requires Guidance > 1) "
label="Adaptive Projected Guidance (requires Guidance > 1 or Audio Guidance > 1) " if multitalk else "Adaptive Projected Guidance (requires Guidance > 1)",
)
with gr.Column(visible = any_cfg_star) as cfg_free_guidance_col:
@ -7602,7 +7691,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
with gr.Row():
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = any_cfg_zero)
with gr.Column(visible = (vace or t2v) and image_outputs) as min_frames_if_references_col:
with gr.Column(visible = (vace or t2v or standin) and image_outputs) as min_frames_if_references_col:
gr.Markdown("<B>Generating a single Frame alone may not be sufficient to preserve Reference Image Identity / Control Image Information or simply to get a good Image Quality. A workaround is to generate a short Video and keep the First Frame.")
min_frames_if_references = gr.Dropdown(
choices=[
@ -7839,7 +7928,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab,
sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col,
video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row,
video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row,
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row,
@ -7859,13 +7948,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
resolution.change(fn=record_last_resolution, inputs=[state, resolution])
video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" )
guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ])
audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type])
audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row])
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] )
# video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs ])
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type])
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt])