mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 11:43:21 +00:00
weekend fun
This commit is contained in:
parent
ecd219322f
commit
58b99f85ba
24
README.md
24
README.md
@ -20,20 +20,28 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||
|
||||
## 🔥 Latest Updates :
|
||||
### August 29 2025: WanGP v8.2 - Here Goes Your Weekend
|
||||
|
||||
- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)
|
||||
|
||||
- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn
|
||||
|
||||
- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis
|
||||
|
||||
### August 24 2025: WanGP v8.1 - the RAM Liberator
|
||||
|
||||
- Reserved RAM entirely freed when switching models, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP
|
||||
- InfiniteTalk support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\
|
||||
- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP
|
||||
- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\
|
||||
If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators.
|
||||
- Flux Chroma 1 HD support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps
|
||||
- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps
|
||||
|
||||
### August 21 2025: WanGP v8.01 - the killer of seven
|
||||
|
||||
- Qwen Image Edit : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good.
|
||||
- On demand Prompt Enhancer (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt.
|
||||
- Choice of a non censored Prompt Enhancer. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work
|
||||
- Memory Profile customizable per model : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading
|
||||
- Expert Guidance Mode: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1.
|
||||
- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good.
|
||||
- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt.
|
||||
- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work
|
||||
- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading
|
||||
- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1.
|
||||
|
||||
*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation*
|
||||
### August 12 2025: WanGP v7.7777 - Lucky Day(s)
|
||||
|
||||
15
configs/standin.json
Normal file
15
configs/standin.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"_class_name": "WanModel",
|
||||
"_diffusers_version": "0.30.0",
|
||||
"dim": 5120,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_dim": 16,
|
||||
"model_type": "t2v",
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_dim": 16,
|
||||
"text_len": 512,
|
||||
"standin": true
|
||||
}
|
||||
17
configs/vace_standin_14B.json
Normal file
17
configs/vace_standin_14B.json
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"_class_name": "VaceWanModel",
|
||||
"_diffusers_version": "0.30.0",
|
||||
"dim": 5120,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_dim": 16,
|
||||
"model_type": "t2v",
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_dim": 16,
|
||||
"text_len": 512,
|
||||
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
||||
"vace_in_dim": 96,
|
||||
"standin": true
|
||||
}
|
||||
@ -16,6 +16,7 @@
|
||||
],
|
||||
"group": "wan2_2"
|
||||
},
|
||||
"guidance_phases": 2,
|
||||
"switch_threshold" : 900,
|
||||
"guidance_scale" : 3.5,
|
||||
"guidance2_scale" : 3.5,
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_infinitetalk_single_14B_quanto_mfp16_int8.safetensors"
|
||||
]
|
||||
],
|
||||
"description": "The Infinitetalk model is an improved version of Multitalk that supports very long videos. This is the single speaker version.",
|
||||
"description": "The Infinitetalk model is an improved version of Multitalk that supports very long videos. This is the single speaker version. Sliding Window size must be 81 frames to get smooth transitions between shots.",
|
||||
"one_speaker_only": true,
|
||||
"URLs": "i2v"
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_infinitetalk_multi_14B_quanto_mbf16_int8.safetensors"
|
||||
]
|
||||
],
|
||||
"description": "The Infinitetalk model is an improved version of Multitalk that supports very long videos. This is the multi speakers version.",
|
||||
"description": "The Infinitetalk model is an improved version of Multitalk that supports very long videos. This is the multi speakers version.Sliding Window size must be 81 frames to get smooth transitions between shots",
|
||||
"multi_speakers_only": true,
|
||||
"URLs": "i2v"
|
||||
}
|
||||
|
||||
10
defaults/standin.json
Normal file
10
defaults/standin.json
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"model":
|
||||
{
|
||||
"name": "Wan2.1 Standin 14B",
|
||||
"modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main//Stand-In_wan2.1_T2V_14B_ver1.0_bf16.safetensors"]],
|
||||
"architecture" : "standin",
|
||||
"description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image which is a close up of person face to transfer this person in the Video.",
|
||||
"URLs": "t2v"
|
||||
}
|
||||
}
|
||||
@ -16,6 +16,7 @@
|
||||
],
|
||||
"group": "wan2_2"
|
||||
},
|
||||
"guidance_phases": 2,
|
||||
"switch_threshold" : 875,
|
||||
"guidance_scale" : 4,
|
||||
"guidance2_scale" : 3,
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
"loras_multipliers": [1, 0.2, 0.5, 0.5],
|
||||
"group": "wan2_2"
|
||||
},
|
||||
"guidance_phases": 2,
|
||||
"num_inference_steps": 10,
|
||||
"guidance_scale": 1,
|
||||
"guidance2_scale": 1,
|
||||
|
||||
9
defaults/vace_standin_14B.json
Normal file
9
defaults/vace_standin_14B.json
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Vace Standin 14B",
|
||||
"architecture": "vace_standin_14B",
|
||||
"modules": [ "vace_14B", "standin"],
|
||||
"description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.",
|
||||
"URLs": "t2v"
|
||||
}
|
||||
}
|
||||
@ -12,6 +12,7 @@ class family_handler():
|
||||
flux_model = model_def.get("flux-model", "flux-dev")
|
||||
flux_schnell = flux_model == "flux-schnell"
|
||||
flux_chroma = flux_model == "flux-chroma"
|
||||
flux_uso = flux_model == "flux-dev-uso"
|
||||
model_def_output = {
|
||||
"image_outputs" : True,
|
||||
"no_negative_prompt" : not flux_chroma,
|
||||
@ -20,7 +21,17 @@ class family_handler():
|
||||
model_def_output["guidance_max_phases"] = 1
|
||||
elif not flux_schnell:
|
||||
model_def_output["embedded_guidance"] = True
|
||||
|
||||
if flux_uso :
|
||||
model_def_output["any_image_refs_relative_size"] = True
|
||||
model_def_output["no_background_removal"] = True
|
||||
|
||||
model_def_output["image_ref_choices"] = {
|
||||
"choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "I"),
|
||||
("Up to two Images are Style Images", "IJ")],
|
||||
"default": "I",
|
||||
"letters_filter": "IJ",
|
||||
"label": "Reference Images / Style Images"
|
||||
}
|
||||
|
||||
return model_def_output
|
||||
|
||||
@ -53,8 +64,8 @@ class family_handler():
|
||||
return [
|
||||
{
|
||||
"repoId" : "DeepBeepMeep/Flux",
|
||||
"sourceFolderList" : [""],
|
||||
"fileList" : [ ["flux_vae.safetensors"] ]
|
||||
"sourceFolderList" : ["siglip-so400m-patch14-384", "",],
|
||||
"fileList" : [ ["config.json", "preprocessor_config.json", "model.safetensors"], ["flux_vae.safetensors"] ]
|
||||
},
|
||||
{
|
||||
"repoId" : "DeepBeepMeep/LTX_Video",
|
||||
@ -90,15 +101,21 @@ class family_handler():
|
||||
|
||||
pipe = { "transformer": flux_model.model, "vae" : flux_model.vae, "text_encoder" : flux_model.clip, "text_encoder_2" : flux_model.t5}
|
||||
|
||||
if flux_model.vision_encoder is not None:
|
||||
pipe["siglip_model"] = flux_model.vision_encoder
|
||||
if flux_model.feature_embedder is not None:
|
||||
pipe["feature_embedder"] = flux_model.feature_embedder
|
||||
return flux_model, pipe
|
||||
|
||||
@staticmethod
|
||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||
flux_model = model_def.get("flux-model", "flux-dev")
|
||||
flux_uso = flux_model == "flux-dev-uso"
|
||||
ui_defaults.update({
|
||||
"embedded_guidance": 2.5,
|
||||
})
|
||||
if model_def.get("reference_image", False):
|
||||
ui_defaults.update({
|
||||
"video_prompt_type": "KI",
|
||||
"video_prompt_type": "I" if flux_uso else "KI",
|
||||
})
|
||||
|
||||
|
||||
@ -6,8 +6,10 @@ from glob import iglob
|
||||
from mmgp import offload as offload
|
||||
import torch
|
||||
from shared.utils.utils import calculate_new_dimensions
|
||||
from .sampling import denoise, get_schedule, prepare_kontext, unpack
|
||||
from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack
|
||||
from .modules.layers import get_linear_split_map
|
||||
from transformers import SiglipVisionModel, SiglipImageProcessor
|
||||
|
||||
from .util import (
|
||||
aspect_ratio_to_height_width,
|
||||
load_ae,
|
||||
@ -65,6 +67,26 @@ class model_factory:
|
||||
|
||||
self.vae = load_ae(self.name, device=torch_device)
|
||||
|
||||
siglip_processor = siglip_model = feature_embedder = None
|
||||
if self.name == 'flux-dev-uso':
|
||||
siglip_path = "ckpts/siglip-so400m-patch14-384"
|
||||
siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path)
|
||||
siglip_model = SiglipVisionModel.from_pretrained(siglip_path)
|
||||
siglip_model.eval().to("cpu")
|
||||
if len(model_filename) > 1:
|
||||
from .modules.layers import SigLIPMultiFeatProjModel
|
||||
feature_embedder = SigLIPMultiFeatProjModel(
|
||||
siglip_token_nums=729,
|
||||
style_token_nums=64,
|
||||
siglip_token_dims=1152,
|
||||
hidden_size=3072, #self.hidden_size,
|
||||
context_layer_norm=True,
|
||||
)
|
||||
offload.load_model_data(feature_embedder, model_filename[1])
|
||||
self.vision_encoder = siglip_model
|
||||
self.vision_encoder_processor = siglip_processor
|
||||
self.feature_embedder = feature_embedder
|
||||
|
||||
# offload.change_dtype(self.model, dtype, True)
|
||||
# offload.save_model(self.model, "flux-dev.safetensors")
|
||||
|
||||
@ -97,22 +119,30 @@ class model_factory:
|
||||
loras_slists = None,
|
||||
batch_size = 1,
|
||||
video_prompt_type = "",
|
||||
joint_pass = False,
|
||||
joint_pass = False,
|
||||
image_refs_relative_size = 100,
|
||||
**bbargs
|
||||
):
|
||||
if self._interrupt:
|
||||
return None
|
||||
if self.guidance_max_phases < 1:
|
||||
guide_scale = 1
|
||||
if n_prompt is None or len(n_prompt) == 0:
|
||||
n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
if self.guidance_max_phases < 1: guide_scale = 1
|
||||
if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
device="cuda"
|
||||
if "I" in video_prompt_type and input_ref_images != None and len(input_ref_images) > 0:
|
||||
if "K" in video_prompt_type and False :
|
||||
# image latents tiling method
|
||||
w, h = input_ref_images[0].size
|
||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||
else:
|
||||
flux_dev_uso = self.name in ['flux-dev-uso']
|
||||
image_stiching = not self.name in ['flux-dev-uso']
|
||||
|
||||
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
|
||||
ref_style_imgs = []
|
||||
|
||||
if "I" in video_prompt_type and len(input_ref_images) > 0:
|
||||
if flux_dev_uso :
|
||||
if "J" in video_prompt_type:
|
||||
ref_style_imgs = input_ref_images
|
||||
input_ref_images = []
|
||||
elif len(input_ref_images) > 1 :
|
||||
ref_style_imgs = input_ref_images[-1:]
|
||||
input_ref_images = input_ref_images[:-1]
|
||||
if image_stiching:
|
||||
# image stiching method
|
||||
stiched = input_ref_images[0]
|
||||
if "K" in video_prompt_type :
|
||||
@ -122,27 +152,61 @@ class model_factory:
|
||||
for new_img in input_ref_images[1:]:
|
||||
stiched = stitch_images(stiched, new_img)
|
||||
input_ref_images = [stiched]
|
||||
else:
|
||||
first_ref = 0
|
||||
if "K" in video_prompt_type:
|
||||
# image latents tiling method
|
||||
w, h = input_ref_images[0].size
|
||||
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
|
||||
input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS)
|
||||
first_ref = 1
|
||||
|
||||
for i in range(first_ref,len(input_ref_images)):
|
||||
w, h = input_ref_images[i].size
|
||||
image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
|
||||
input_ref_images[0] = input_ref_images[0].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
|
||||
else:
|
||||
input_ref_images = None
|
||||
|
||||
inp, height, width = prepare_kontext(
|
||||
t5=self.t5,
|
||||
clip=self.clip,
|
||||
prompt=input_prompt,
|
||||
neg_prompt= n_prompt,
|
||||
ae=self.vae,
|
||||
img_cond_list=input_ref_images,
|
||||
target_width=width,
|
||||
target_height=height,
|
||||
bs=batch_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
real_guidance_scale=guide_scale,
|
||||
)
|
||||
if flux_dev_uso :
|
||||
inp, height, width = prepare_multi_ip(
|
||||
ae=self.vae,
|
||||
img_cond_list=input_ref_images,
|
||||
target_width=width,
|
||||
target_height=height,
|
||||
bs=batch_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
inp, height, width = prepare_kontext(
|
||||
ae=self.vae,
|
||||
img_cond_list=input_ref_images,
|
||||
target_width=width,
|
||||
target_height=height,
|
||||
bs=batch_size,
|
||||
seed=seed,
|
||||
device=device,
|
||||
)
|
||||
|
||||
inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt))
|
||||
if guide_scale != 1:
|
||||
inp.update(prepare_prompt(self.t5, self.clip, batch_size, n_prompt, neg = True, device=device))
|
||||
|
||||
timesteps = get_schedule(sampling_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
|
||||
|
||||
ref_style_imgs = [self.vision_encoder_processor(img, return_tensors="pt").to(self.device) for img in ref_style_imgs]
|
||||
if self.feature_embedder is not None and ref_style_imgs is not None and len(ref_style_imgs) > 0 and self.vision_encoder is not None:
|
||||
# processing style feat into textural hidden space
|
||||
siglip_embedding = [self.vision_encoder(**emb, output_hidden_states=True) for emb in ref_style_imgs]
|
||||
siglip_embedding = torch.cat([self.feature_embedder(emb) for emb in siglip_embedding], dim=1)
|
||||
siglip_embedding_ids = torch.zeros( siglip_embedding.shape[0], siglip_embedding.shape[1], 3 ).to(device)
|
||||
inp["siglip_embedding"] = siglip_embedding
|
||||
inp["siglip_embedding_ids"] = siglip_embedding_ids
|
||||
|
||||
def unpack_latent(x):
|
||||
return unpack(x.float(), height, width)
|
||||
|
||||
# denoise initial noise
|
||||
x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass)
|
||||
if x==None: return None
|
||||
|
||||
@ -12,6 +12,7 @@ from .modules.layers import (
|
||||
timestep_embedding,
|
||||
DistilledGuidance,
|
||||
ChromaModulationOut,
|
||||
SigLIPMultiFeatProjModel,
|
||||
)
|
||||
from .modules.lora import LinearLora, replace_linear_with_lora
|
||||
|
||||
@ -32,7 +33,7 @@ class FluxParams:
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
chroma: bool = False
|
||||
|
||||
eso: bool = False
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
@ -189,6 +190,8 @@ class Flux(nn.Module):
|
||||
v = swap_scale_shift(v)
|
||||
k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1")
|
||||
new_sd[k] = v
|
||||
else:
|
||||
new_sd = sd
|
||||
return new_sd
|
||||
|
||||
def forward(
|
||||
@ -198,11 +201,13 @@ class Flux(nn.Module):
|
||||
txt_list,
|
||||
txt_ids_list,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
y_list,
|
||||
img_len = 0,
|
||||
guidance: Tensor | None = None,
|
||||
callback= None,
|
||||
pipeline =None,
|
||||
|
||||
siglip_embedding = None,
|
||||
siglip_embedding_ids = None,
|
||||
) -> Tensor:
|
||||
|
||||
sz = len(txt_list)
|
||||
@ -226,40 +231,44 @@ class Flux(nn.Module):
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec += self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec += self.vector_in(y)
|
||||
vec_list = [ vec + self.vector_in(y) for y in y_list]
|
||||
|
||||
img = None
|
||||
txt_list = [self.txt_in(txt) for txt in txt_list ]
|
||||
if siglip_embedding is not None:
|
||||
txt_list = [torch.cat((siglip_embedding, txt) , dim=1) for txt in txt_list]
|
||||
txt_ids_list = [torch.cat((siglip_embedding_ids, txt_id) , dim=1) for txt_id in txt_ids_list]
|
||||
|
||||
pe_list = [self.pe_embedder(torch.cat((txt_ids, img_ids), dim=1)) for txt_ids in txt_ids_list]
|
||||
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if self.chroma: vec = ( self.get_modulations(mod_vectors, "double_img", idx=i), self.get_modulations(mod_vectors, "double_txt", idx=i))
|
||||
if self.chroma: vec_list = [( self.get_modulations(mod_vectors, "double_img", idx=i), self.get_modulations(mod_vectors, "double_txt", idx=i))] * sz
|
||||
if callback != None:
|
||||
callback(-1, None, False, True)
|
||||
if pipeline._interrupt:
|
||||
return [None] * sz
|
||||
for img, txt, pe in zip(img_list, txt_list, pe_list):
|
||||
for img, txt, pe, vec in zip(img_list, txt_list, pe_list, vec_list):
|
||||
img[...], txt[...] = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
img = txt = pe = None
|
||||
img = txt = pe = vec= None
|
||||
|
||||
img_list = [torch.cat((txt, img), 1) for txt, img in zip(txt_list, img_list)]
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if self.chroma: vec = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if self.chroma: vec_list= [self.get_modulations(mod_vectors, "single", idx=i)] * sz
|
||||
if callback != None:
|
||||
callback(-1, None, False, True)
|
||||
if pipeline._interrupt:
|
||||
return [None] * sz
|
||||
for img, pe in zip(img_list, pe_list):
|
||||
for img, pe, vec in zip(img_list, pe_list, vec_list):
|
||||
img[...]= block(x=img, vec=vec, pe=pe)
|
||||
img = pe = None
|
||||
img_list = [ img[:, txt.shape[1] :, ...] for img, txt in zip(img_list, txt_list)]
|
||||
img = pe = vec = None
|
||||
img_list = [ img[:, txt.shape[1] : txt.shape[1] + img_len, ...] for img, txt in zip(img_list, txt_list)]
|
||||
|
||||
if self.chroma: vec = self.get_modulations(mod_vectors, "final")
|
||||
if self.chroma: vec_list = [self.get_modulations(mod_vectors, "final")] * sz
|
||||
out_list = []
|
||||
for i, img in enumerate(img_list):
|
||||
for i, (img, vec) in enumerate(zip(img_list, vec_list)):
|
||||
out_list.append( self.final_layer(img, vec)) # (N, T, patch_size ** 2 * out_channels)
|
||||
img_list[i] = img = None
|
||||
img_list[i] = img = vec = None
|
||||
return out_list
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,8 @@ def get_linear_split_map():
|
||||
hidden_size = 3072
|
||||
split_linear_modules_map = {
|
||||
"qkv" : {"mapped_modules" : ["q", "k", "v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
|
||||
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
|
||||
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]},
|
||||
"linear1_qkv" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
|
||||
}
|
||||
return split_linear_modules_map
|
||||
|
||||
@ -375,4 +376,134 @@ class DistilledGuidance(nn.Module):
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x
|
||||
return x
|
||||
|
||||
|
||||
class SigLIPMultiFeatProjModel(torch.nn.Module):
|
||||
"""
|
||||
SigLIP Multi-Feature Projection Model for processing style features from different layers
|
||||
and projecting them into a unified hidden space.
|
||||
|
||||
Args:
|
||||
siglip_token_nums (int): Number of SigLIP tokens, default 257
|
||||
style_token_nums (int): Number of style tokens, default 256
|
||||
siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
|
||||
hidden_size (int): Hidden layer size, default 3072
|
||||
context_layer_norm (bool): Whether to use context layer normalization, default False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
siglip_token_nums: int = 257,
|
||||
style_token_nums: int = 256,
|
||||
siglip_token_dims: int = 1536,
|
||||
hidden_size: int = 3072,
|
||||
context_layer_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# High-level feature processing (layer -2)
|
||||
self.high_embedding_linear = nn.Sequential(
|
||||
nn.Linear(siglip_token_nums, style_token_nums),
|
||||
nn.SiLU()
|
||||
)
|
||||
self.high_layer_norm = (
|
||||
nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
|
||||
)
|
||||
self.high_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
|
||||
|
||||
# Mid-level feature processing (layer -11)
|
||||
self.mid_embedding_linear = nn.Sequential(
|
||||
nn.Linear(siglip_token_nums, style_token_nums),
|
||||
nn.SiLU()
|
||||
)
|
||||
self.mid_layer_norm = (
|
||||
nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
|
||||
)
|
||||
self.mid_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
|
||||
|
||||
# Low-level feature processing (layer -20)
|
||||
self.low_embedding_linear = nn.Sequential(
|
||||
nn.Linear(siglip_token_nums, style_token_nums),
|
||||
nn.SiLU()
|
||||
)
|
||||
self.low_layer_norm = (
|
||||
nn.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
|
||||
)
|
||||
self.low_projection = nn.Linear(siglip_token_dims, hidden_size, bias=True)
|
||||
|
||||
def forward(self, siglip_outputs):
|
||||
"""
|
||||
Forward pass function
|
||||
|
||||
Args:
|
||||
siglip_outputs: Output from SigLIP model, containing hidden_states
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
|
||||
"""
|
||||
dtype = next(self.high_embedding_linear.parameters()).dtype
|
||||
|
||||
# Process high-level features (layer -2)
|
||||
high_embedding = self._process_layer_features(
|
||||
siglip_outputs.hidden_states[-2],
|
||||
self.high_embedding_linear,
|
||||
self.high_layer_norm,
|
||||
self.high_projection,
|
||||
dtype
|
||||
)
|
||||
|
||||
# Process mid-level features (layer -11)
|
||||
mid_embedding = self._process_layer_features(
|
||||
siglip_outputs.hidden_states[-11],
|
||||
self.mid_embedding_linear,
|
||||
self.mid_layer_norm,
|
||||
self.mid_projection,
|
||||
dtype
|
||||
)
|
||||
|
||||
# Process low-level features (layer -20)
|
||||
low_embedding = self._process_layer_features(
|
||||
siglip_outputs.hidden_states[-20],
|
||||
self.low_embedding_linear,
|
||||
self.low_layer_norm,
|
||||
self.low_projection,
|
||||
dtype
|
||||
)
|
||||
|
||||
# Concatenate features from all layers
|
||||
return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)
|
||||
|
||||
def _process_layer_features(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
embedding_linear: nn.Module,
|
||||
layer_norm: nn.Module,
|
||||
projection: nn.Module,
|
||||
dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper function to process features from a single layer
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states [bs, seq_len, dim]
|
||||
embedding_linear: Embedding linear layer
|
||||
layer_norm: Layer normalization
|
||||
projection: Projection layer
|
||||
dtype: Target data type
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
|
||||
"""
|
||||
# Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
|
||||
embedding = embedding_linear(
|
||||
hidden_states.to(dtype).transpose(1, 2)
|
||||
).transpose(1, 2)
|
||||
|
||||
# Apply layer normalization
|
||||
embedding = layer_norm(embedding)
|
||||
|
||||
# Project to target hidden space
|
||||
embedding = projection(embedding)
|
||||
|
||||
return embedding
|
||||
|
||||
@ -13,6 +13,8 @@ from .modules.conditioner import HFEmbedder
|
||||
from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder
|
||||
from .util import PREFERED_KONTEXT_RESOLUTIONS
|
||||
from einops import rearrange, repeat
|
||||
from typing import Literal
|
||||
import torchvision.transforms.functional as TVF
|
||||
|
||||
|
||||
def get_noise(
|
||||
@ -35,20 +37,10 @@ def get_noise(
|
||||
)
|
||||
|
||||
|
||||
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
def prepare_prompt(t5: HFEmbedder, clip: HFEmbedder, bs: int, prompt: str | list[str], neg: bool = False, device: str = "cuda") -> dict[str, Tensor]:
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
@ -61,102 +53,31 @@ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[st
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
"neg_txt" if neg else "txt": txt.to(device),
|
||||
"neg_txt_ids" if neg else "txt_ids": txt_ids.to(device),
|
||||
"neg_vec" if neg else "vec": vec.to(device),
|
||||
}
|
||||
|
||||
|
||||
def prepare_control(
|
||||
t5: HFEmbedder,
|
||||
clip: HFEmbedder,
|
||||
img: Tensor,
|
||||
prompt: str | list[str],
|
||||
ae: AutoEncoder,
|
||||
encoder: DepthImageEncoder | CannyImageEncoder,
|
||||
img_cond_path: str,
|
||||
) -> dict[str, Tensor]:
|
||||
# load and encode the conditioning image
|
||||
bs, _, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
def prepare_img( img: Tensor) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
|
||||
img_cond = Image.open(img_cond_path).convert("RGB")
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
width = w * 8
|
||||
height = h * 8
|
||||
img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS)
|
||||
img_cond = np.array(img_cond)
|
||||
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
|
||||
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
with torch.no_grad():
|
||||
img_cond = encoder(img_cond)
|
||||
img_cond = ae.encode(img_cond)
|
||||
|
||||
img_cond = img_cond.to(torch.bfloat16)
|
||||
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img_cond.shape[0] == 1 and bs > 1:
|
||||
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return_dict = prepare(t5, clip, img, prompt)
|
||||
return_dict["img_cond"] = img_cond
|
||||
return return_dict
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def prepare_fill(
|
||||
t5: HFEmbedder,
|
||||
clip: HFEmbedder,
|
||||
img: Tensor,
|
||||
prompt: str | list[str],
|
||||
ae: AutoEncoder,
|
||||
img_cond_path: str,
|
||||
mask_path: str,
|
||||
) -> dict[str, Tensor]:
|
||||
# load and encode the conditioning image and the mask
|
||||
bs, _, _, _ = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img_cond = Image.open(img_cond_path).convert("RGB")
|
||||
img_cond = np.array(img_cond)
|
||||
img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
|
||||
img_cond = rearrange(img_cond, "h w c -> 1 c h w")
|
||||
|
||||
mask = Image.open(mask_path).convert("L")
|
||||
mask = np.array(mask)
|
||||
mask = torch.from_numpy(mask).float() / 255.0
|
||||
mask = rearrange(mask, "h w -> 1 1 h w")
|
||||
|
||||
with torch.no_grad():
|
||||
img_cond = img_cond.to(img.device)
|
||||
mask = mask.to(img.device)
|
||||
img_cond = img_cond * (1 - mask)
|
||||
img_cond = ae.encode(img_cond)
|
||||
mask = mask[:, 0, :, :]
|
||||
mask = mask.to(torch.bfloat16)
|
||||
mask = rearrange(
|
||||
mask,
|
||||
"b (h ph) (w pw) -> b (ph pw) h w",
|
||||
ph=8,
|
||||
pw=8,
|
||||
)
|
||||
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if mask.shape[0] == 1 and bs > 1:
|
||||
mask = repeat(mask, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_cond = img_cond.to(torch.bfloat16)
|
||||
img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img_cond.shape[0] == 1 and bs > 1:
|
||||
img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_cond = torch.cat((img_cond, mask), dim=-1)
|
||||
|
||||
return_dict = prepare(t5, clip, img, prompt)
|
||||
return_dict["img_cond"] = img_cond.to(img.device)
|
||||
return return_dict
|
||||
|
||||
|
||||
def prepare_redux(
|
||||
@ -210,9 +131,6 @@ def prepare_redux(
|
||||
|
||||
|
||||
def prepare_kontext(
|
||||
t5: HFEmbedder,
|
||||
clip: HFEmbedder,
|
||||
prompt: str | list[str],
|
||||
ae: AutoEncoder,
|
||||
img_cond_list: list,
|
||||
seed: int,
|
||||
@ -220,17 +138,15 @@ def prepare_kontext(
|
||||
target_width: int | None = None,
|
||||
target_height: int | None = None,
|
||||
bs: int = 1,
|
||||
neg_prompt: str | list[str] = None,
|
||||
real_guidance_scale = False,
|
||||
|
||||
) -> tuple[dict[str, Tensor], int, int]:
|
||||
# load and encode the conditioning image
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img_cond_seq = None
|
||||
img_cond_seq_ids = None
|
||||
if img_cond_list == None: img_cond_list = []
|
||||
height_offset = 0
|
||||
width_offset = 0
|
||||
for cond_no, img_cond in enumerate(img_cond_list):
|
||||
width, height = img_cond.size
|
||||
aspect_ratio = width / height
|
||||
@ -257,10 +173,12 @@ def prepare_kontext(
|
||||
# image ids are the same as base image with the first dimension set to 1
|
||||
# instead of 0
|
||||
img_cond_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
img_cond_ids[..., 0] = cond_no + 1
|
||||
img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
img_cond_ids[..., 0] = 1
|
||||
img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None] + height_offset
|
||||
img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :] + width_offset
|
||||
img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs)
|
||||
height_offset += height // 2
|
||||
width_offset += width // 2
|
||||
|
||||
if target_width is None:
|
||||
target_width = 8 * width
|
||||
@ -272,6 +190,10 @@ def prepare_kontext(
|
||||
else:
|
||||
img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, img_cond_latents], dim=1), torch.cat([img_cond_seq_ids, img_cond_ids], dim=1)
|
||||
|
||||
return_dict = {
|
||||
"img_cond_seq": img_cond_seq,
|
||||
"img_cond_seq_ids": img_cond_seq_ids,
|
||||
}
|
||||
img = get_noise(
|
||||
bs,
|
||||
target_height,
|
||||
@ -280,17 +202,7 @@ def prepare_kontext(
|
||||
dtype=torch.bfloat16,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
return_dict = prepare(t5, clip, img, prompt)
|
||||
if real_guidance_scale != 1:
|
||||
temp_dict = prepare(t5, clip, img, neg_prompt)
|
||||
return_dict["neg_txt"]= temp_dict["txt"]
|
||||
return_dict["neg_txt_ids"]= temp_dict["txt_ids"]
|
||||
else:
|
||||
return_dict["neg_txt"]= None
|
||||
return_dict["neg_txt_ids"]= None
|
||||
return_dict["img_cond_seq"] = img_cond_seq
|
||||
return_dict["img_cond_seq_ids"] = img_cond_seq_ids
|
||||
return_dict.update(prepare_img(img))
|
||||
|
||||
return return_dict, target_height, target_width
|
||||
|
||||
@ -341,10 +253,13 @@ def denoise(
|
||||
# extra img tokens (channel-wise)
|
||||
neg_txt: Tensor = None,
|
||||
neg_txt_ids: Tensor= None,
|
||||
neg_vec: Tensor = None,
|
||||
img_cond: Tensor | None = None,
|
||||
# extra img tokens (sequence-wise)
|
||||
img_cond_seq: Tensor | None = None,
|
||||
img_cond_seq_ids: Tensor | None = None,
|
||||
siglip_embedding = None,
|
||||
siglip_embedding_ids = None,
|
||||
callback=None,
|
||||
pipeline=None,
|
||||
loras_slists=None,
|
||||
@ -352,7 +267,8 @@ def denoise(
|
||||
joint_pass= False,
|
||||
):
|
||||
|
||||
kwargs = {'pipeline': pipeline, 'callback': callback}
|
||||
kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids}
|
||||
|
||||
if callback != None:
|
||||
callback(-1, None, True)
|
||||
|
||||
@ -383,7 +299,7 @@ def denoise(
|
||||
img_ids=img_input_ids,
|
||||
txt_list=[txt],
|
||||
txt_ids_list=[txt_ids],
|
||||
y=vec,
|
||||
y_list=[vec],
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
**kwargs
|
||||
@ -395,7 +311,7 @@ def denoise(
|
||||
img_ids=img_input_ids,
|
||||
txt_list=[neg_txt],
|
||||
txt_ids_list=[neg_txt_ids],
|
||||
y=vec,
|
||||
y_list=[neg_vec],
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
**kwargs
|
||||
@ -407,19 +323,14 @@ def denoise(
|
||||
img_ids=img_input_ids,
|
||||
txt_list=[txt, neg_txt],
|
||||
txt_ids_list=[txt_ids, neg_txt_ids],
|
||||
y=vec,
|
||||
y_list=[vec, neg_vec],
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
**kwargs
|
||||
)
|
||||
if pred == None: return None
|
||||
|
||||
|
||||
if img_input_ids is not None:
|
||||
pred = pred[:, : img.shape[1]]
|
||||
if real_guidance_scale > 1:
|
||||
if img_input_ids is not None:
|
||||
neg_pred = neg_pred[:, : img.shape[1]]
|
||||
pred = neg_pred + real_guidance_scale * (pred - neg_pred)
|
||||
|
||||
img += (t_prev - t_curr) * pred
|
||||
@ -430,6 +341,82 @@ def denoise(
|
||||
|
||||
return img
|
||||
|
||||
def prepare_multi_ip(
|
||||
ae: AutoEncoder,
|
||||
img_cond_list: list,
|
||||
seed: int,
|
||||
device: torch.device,
|
||||
target_width: int | None = None,
|
||||
target_height: int | None = None,
|
||||
bs: int = 1,
|
||||
pe: Literal["d", "h", "w", "o"] = "d",
|
||||
) -> dict[str, Tensor]:
|
||||
ref_imgs = img_cond_list
|
||||
assert pe in ["d", "h", "w", "o"]
|
||||
|
||||
ref_imgs = [
|
||||
ae.encode(
|
||||
(TVF.to_tensor(ref_img) * 2.0 - 1.0)
|
||||
.unsqueeze(0)
|
||||
.to(device, torch.float32)
|
||||
).to(torch.bfloat16)
|
||||
for ref_img in img_cond_list
|
||||
]
|
||||
|
||||
img = get_noise( bs, target_height, target_width, device=device, dtype=torch.bfloat16, seed=seed)
|
||||
bs, c, h, w = img.shape
|
||||
# tgt img
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
img_cond_seq = img_cond_seq_ids = None
|
||||
pe_shift_w, pe_shift_h = w // 2, h // 2
|
||||
for cond_no, ref_img in enumerate(ref_imgs):
|
||||
_, _, ref_h1, ref_w1 = ref_img.shape
|
||||
ref_img = rearrange(
|
||||
ref_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2
|
||||
)
|
||||
if ref_img.shape[0] == 1 and bs > 1:
|
||||
ref_img = repeat(ref_img, "1 ... -> bs ...", bs=bs)
|
||||
ref_img_ids1 = torch.zeros(ref_h1 // 2, ref_w1 // 2, 3)
|
||||
# img id分别在宽高偏移各自最大值
|
||||
h_offset = pe_shift_h if pe in {"d", "h"} else 0
|
||||
w_offset = pe_shift_w if pe in {"d", "w"} else 0
|
||||
ref_img_ids1[..., 1] = (
|
||||
ref_img_ids1[..., 1] + torch.arange(ref_h1 // 2)[:, None] + h_offset
|
||||
)
|
||||
ref_img_ids1[..., 2] = (
|
||||
ref_img_ids1[..., 2] + torch.arange(ref_w1 // 2)[None, :] + w_offset
|
||||
)
|
||||
ref_img_ids1 = repeat(ref_img_ids1, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if target_width is None:
|
||||
target_width = 8 * ref_w1
|
||||
if target_height is None:
|
||||
target_height = 8 * ref_h1
|
||||
ref_img_ids1 = ref_img_ids1.to(device)
|
||||
if cond_no == 0:
|
||||
img_cond_seq, img_cond_seq_ids = ref_img, ref_img_ids1
|
||||
else:
|
||||
img_cond_seq, img_cond_seq_ids = torch.cat([img_cond_seq, ref_img], dim=1), torch.cat([img_cond_seq_ids, ref_img_ids1], dim=1)
|
||||
|
||||
|
||||
# 更新pe shift
|
||||
pe_shift_h += ref_h1 // 2
|
||||
pe_shift_w += ref_w1 // 2
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"img_cond_seq": img_cond_seq,
|
||||
"img_cond_seq_ids": img_cond_seq_ids,
|
||||
}, target_height, target_width
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
|
||||
@ -608,6 +608,38 @@ configs = {
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
"flux-dev-uso": ModelSpec(
|
||||
repo_id="",
|
||||
repo_flow="",
|
||||
repo_ae="ckpts/flux_vae.safetensors",
|
||||
params=FluxParams(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
vec_in_dim=768,
|
||||
context_in_dim=4096,
|
||||
hidden_size=3072,
|
||||
mlp_ratio=4.0,
|
||||
num_heads=24,
|
||||
depth=19,
|
||||
depth_single_blocks=38,
|
||||
axes_dim=[16, 56, 56],
|
||||
theta=10_000,
|
||||
qkv_bias=True,
|
||||
guidance_embed=True,
|
||||
eso= True,
|
||||
),
|
||||
ae_params=AutoEncoderParams(
|
||||
resolution=256,
|
||||
in_channels=3,
|
||||
ch=128,
|
||||
out_ch=3,
|
||||
ch_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
z_channels=16,
|
||||
scale_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -249,7 +249,7 @@ class DetFace():
|
||||
for scale in [8,16,32]:
|
||||
ny = h1//scale
|
||||
nx = w1//scale
|
||||
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
|
||||
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij")
|
||||
grid = torch.stack((xv, yv), 2).view((1,1,ny, nx, 2)).float()
|
||||
grids.append(grid.to(self.test_device))
|
||||
self.grids = grids
|
||||
|
||||
@ -19,7 +19,7 @@ from PIL import Image
|
||||
import torchvision.transforms.functional as TF
|
||||
import torch.nn.functional as F
|
||||
from .distributed.fsdp import shard_model
|
||||
from .modules.model import WanModel
|
||||
from .modules.model import WanModel, clear_caches
|
||||
from .modules.t5 import T5EncoderModel
|
||||
from .modules.vae import WanVAE
|
||||
from .modules.vae2_2 import Wan2_2_VAE
|
||||
@ -28,7 +28,7 @@ from .modules.clip import CLIPModel
|
||||
from shared.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
||||
get_sampling_sigmas, retrieve_timesteps)
|
||||
from shared.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from .modules.posemb_layers import get_rotary_pos_embed
|
||||
from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed
|
||||
from shared.utils.vace_preprocessor import VaceVideoProcessor
|
||||
from shared.utils.basic_flowmatch import FlowMatchScheduler
|
||||
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor
|
||||
@ -442,6 +442,8 @@ class WanAny2V:
|
||||
window_no = 0,
|
||||
set_header_text = None,
|
||||
pre_video_frame = None,
|
||||
video_prompt_type= "",
|
||||
original_input_ref_images = [],
|
||||
**bbargs
|
||||
):
|
||||
|
||||
@ -507,11 +509,12 @@ class WanAny2V:
|
||||
# if NAG_scale > 1: context = torch.cat([context, context_NAG], dim=0)
|
||||
if self._interrupt: return None
|
||||
|
||||
vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B"]
|
||||
vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B", "vace_standin_14B"]
|
||||
phantom = model_type in ["phantom_1.3B", "phantom_14B"]
|
||||
fantasy = model_type in ["fantasy"]
|
||||
multitalk = model_type in ["multitalk", "infinitetalk", "vace_multitalk_14B", "i2v_2_2_multitalk"]
|
||||
infinitetalk = model_type in ["infinitetalk"]
|
||||
standin = model_type in ["standin", "vace_standin_14B"]
|
||||
recam = model_type in ["recam_1.3B"]
|
||||
ti2v = model_type in ["ti2v_2_2"]
|
||||
start_step_no = 0
|
||||
@ -526,17 +529,25 @@ class WanAny2V:
|
||||
any_end_frame = False
|
||||
if image_start is None:
|
||||
if infinitetalk:
|
||||
if pre_video_frame is None:
|
||||
new_shot = True
|
||||
if input_frames is not None:
|
||||
image_ref = input_frames[:, -1]
|
||||
if input_video is None: input_video = input_frames[:, -1:]
|
||||
new_shot = "Q" in video_prompt_type
|
||||
else:
|
||||
if input_ref_images is None:
|
||||
input_ref_images, new_shot = [pre_video_frame], False
|
||||
if pre_video_frame is None:
|
||||
new_shot = True
|
||||
else:
|
||||
input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], True
|
||||
if input_ref_images is None: raise Exception("Missing Reference Image")
|
||||
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
|
||||
if new_shot and window_no <= len(input_ref_images):
|
||||
if input_ref_images is None:
|
||||
input_ref_images, new_shot = [pre_video_frame], False
|
||||
else:
|
||||
input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type
|
||||
if input_ref_images is None: raise Exception("Missing Reference Image")
|
||||
new_shot = new_shot and window_no <= len(input_ref_images)
|
||||
image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
|
||||
if new_shot:
|
||||
input_video = image_ref.unsqueeze(1)
|
||||
else:
|
||||
color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
|
||||
_ , preframes_count, height, width = input_video.shape
|
||||
input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
|
||||
if infinitetalk:
|
||||
@ -762,6 +773,25 @@ class WanAny2V:
|
||||
|
||||
kwargs["freqs"] = freqs
|
||||
|
||||
#Standin
|
||||
if standin:
|
||||
from preprocessing.face_preprocessor import FaceProcessor
|
||||
standin_ref_pos = 1 if "K" in video_prompt_type else 0
|
||||
if len(original_input_ref_images) < standin_ref_pos + 1: raise Exception("Missing Standin ref image")
|
||||
standin_ref_pos = -1
|
||||
image_ref = original_input_ref_images[standin_ref_pos]
|
||||
image_ref.save("si.png")
|
||||
# face_processor = FaceProcessor(antelopv2_path="ckpts/antelopev2")
|
||||
face_processor = FaceProcessor()
|
||||
standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"])
|
||||
face_processor = None
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) ))
|
||||
standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0)
|
||||
kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, })
|
||||
|
||||
|
||||
# Steps Skipping
|
||||
skip_steps_cache = self.model.cache
|
||||
if skip_steps_cache != None:
|
||||
@ -804,6 +834,11 @@ class WanAny2V:
|
||||
if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
|
||||
callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra)
|
||||
|
||||
def clear():
|
||||
clear_caches()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return None
|
||||
|
||||
if sample_scheduler != None:
|
||||
scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g}
|
||||
@ -862,6 +897,7 @@ class WanAny2V:
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
any_guidance = guide_scale != 1
|
||||
if phantom:
|
||||
gen_args = {
|
||||
"x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
|
||||
@ -875,33 +911,42 @@ class WanAny2V:
|
||||
"audio_scale": [audio_scale, None, None ]
|
||||
}
|
||||
elif multitalk and audio_proj != None:
|
||||
gen_args = {
|
||||
"x" : [latent_model_input, latent_model_input, latent_model_input],
|
||||
"context" : [context, context_null, context_null],
|
||||
"multitalk_audio": [audio_proj, audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]],
|
||||
"multitalk_masks": [token_ref_target_masks, token_ref_target_masks, None]
|
||||
}
|
||||
if guide_scale == 1:
|
||||
gen_args = {
|
||||
"x" : [latent_model_input, latent_model_input],
|
||||
"context" : [context, context],
|
||||
"multitalk_audio": [audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]],
|
||||
"multitalk_masks": [token_ref_target_masks, None]
|
||||
}
|
||||
any_guidance = audio_cfg_scale != 1
|
||||
else:
|
||||
gen_args = {
|
||||
"x" : [latent_model_input, latent_model_input, latent_model_input],
|
||||
"context" : [context, context_null, context_null],
|
||||
"multitalk_audio": [audio_proj, audio_proj, [torch.zeros_like(audio_proj[0][-1:]), torch.zeros_like(audio_proj[1][-1:])]],
|
||||
"multitalk_masks": [token_ref_target_masks, token_ref_target_masks, None]
|
||||
}
|
||||
else:
|
||||
gen_args = {
|
||||
"x" : [latent_model_input, latent_model_input],
|
||||
"context": [context, context_null]
|
||||
}
|
||||
|
||||
if joint_pass and guide_scale > 1:
|
||||
if joint_pass and any_guidance:
|
||||
ret_values = trans( **gen_args , **kwargs)
|
||||
if self._interrupt:
|
||||
return None
|
||||
return clear()
|
||||
else:
|
||||
size = 1 if guide_scale == 1 else len(gen_args["x"])
|
||||
size = len(gen_args["x"]) if any_guidance else 1
|
||||
ret_values = [None] * size
|
||||
for x_id in range(size):
|
||||
sub_gen_args = {k : [v[x_id]] for k, v in gen_args.items() }
|
||||
ret_values[x_id] = trans( **sub_gen_args, x_id= x_id , **kwargs)[0]
|
||||
if self._interrupt:
|
||||
return None
|
||||
return clear()
|
||||
sub_gen_args = None
|
||||
if guide_scale == 1:
|
||||
noise_pred = ret_values[0]
|
||||
if not any_guidance:
|
||||
noise_pred = ret_values[0]
|
||||
elif phantom:
|
||||
guide_scale_img= 5.0
|
||||
guide_scale_text= guide_scale #7.5
|
||||
@ -913,19 +958,32 @@ class WanAny2V:
|
||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
|
||||
noise_pred_noaudio = None
|
||||
elif multitalk and audio_proj != None:
|
||||
noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values
|
||||
if apg_switch != 0:
|
||||
noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_text,
|
||||
noise_pred_cond,
|
||||
momentum_buffer=text_momentumbuffer,
|
||||
norm_threshold=apg_norm_threshold) \
|
||||
+ (audio_cfg_scale - 1) * adaptive_projected_guidance(noise_pred_drop_text - noise_pred_uncond,
|
||||
noise_pred_cond,
|
||||
momentum_buffer=audio_momentumbuffer,
|
||||
norm_threshold=apg_norm_threshold)
|
||||
if guide_scale == 1:
|
||||
noise_pred_cond, noise_pred_drop_audio = ret_values
|
||||
noise_pred = noise_pred_cond + (audio_cfg_scale - 1)* adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_audio,
|
||||
noise_pred_cond,
|
||||
momentum_buffer=audio_momentumbuffer,
|
||||
norm_threshold=apg_norm_threshold)
|
||||
|
||||
else:
|
||||
noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values
|
||||
noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_text,
|
||||
noise_pred_cond,
|
||||
momentum_buffer=text_momentumbuffer,
|
||||
norm_threshold=apg_norm_threshold) \
|
||||
+ (audio_cfg_scale - 1) * adaptive_projected_guidance(noise_pred_drop_text - noise_pred_uncond,
|
||||
noise_pred_cond,
|
||||
momentum_buffer=audio_momentumbuffer,
|
||||
norm_threshold=apg_norm_threshold)
|
||||
else:
|
||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_drop_text) + audio_cfg_scale * (noise_pred_drop_text - noise_pred_uncond)
|
||||
noise_pred_uncond = noise_pred_cond = noise_pred_drop_text = None
|
||||
if guide_scale == 1:
|
||||
noise_pred_cond, noise_pred_drop_audio = ret_values
|
||||
noise_pred = noise_pred_drop_audio + audio_cfg_scale* (noise_pred_cond - noise_pred_drop_audio)
|
||||
else:
|
||||
noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values
|
||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_drop_text) + audio_cfg_scale * (noise_pred_drop_text - noise_pred_uncond)
|
||||
noise_pred_uncond = noise_pred_cond = noise_pred_drop_text = noise_pred_drop_audio = None
|
||||
else:
|
||||
noise_pred_cond, noise_pred_uncond = ret_values
|
||||
if apg_switch != 0:
|
||||
@ -970,6 +1028,7 @@ class WanAny2V:
|
||||
callback(i, latents_preview[0], False, denoising_extra =denoising_extra )
|
||||
latents_preview = None
|
||||
|
||||
clear()
|
||||
if timestep_injection:
|
||||
latents[:, :, :source_latents.shape[2]] = source_latents
|
||||
|
||||
|
||||
@ -19,6 +19,22 @@ from ..multitalk.multitalk_utils import get_attn_map_with_target
|
||||
__all__ = ['WanModel']
|
||||
|
||||
|
||||
def get_cache(cache_name):
|
||||
all_cache = offload.shared_state.get("_cache", None)
|
||||
if all_cache is None:
|
||||
all_cache = {}
|
||||
offload.shared_state["_cache"]= all_cache
|
||||
cache = offload.shared_state.get(cache_name, None)
|
||||
if cache is None:
|
||||
cache = {}
|
||||
offload.shared_state[cache_name] = cache
|
||||
return cache
|
||||
|
||||
def clear_caches():
|
||||
all_cache = offload.shared_state.get("_cache", None)
|
||||
if all_cache is not None:
|
||||
all_cache.clear()
|
||||
|
||||
def sinusoidal_embedding_1d(dim, position):
|
||||
# preprocess
|
||||
assert dim % 2 == 0
|
||||
@ -93,6 +109,32 @@ def relative_l1_distance(last_tensor, current_tensor):
|
||||
relative_l1_distance = l1_distance / norm
|
||||
return relative_l1_distance.to(torch.float32)
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int = 128,
|
||||
dtype: Optional[torch.dtype] = torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.down = nn.Linear(in_features, rank, bias=False, dtype=dtype)
|
||||
self.up = nn.Linear(rank, out_features, bias=False, dtype=dtype)
|
||||
self.rank = rank
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
down_hidden_states = self.down(hidden_states.to(dtype))
|
||||
up_hidden_states = self.up(down_hidden_states)
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
|
||||
class WanRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
@ -244,7 +286,7 @@ class WanSelfAttention(nn.Module):
|
||||
else:
|
||||
return x, None
|
||||
|
||||
def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0):
|
||||
def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0, standin_phase =-1):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
@ -257,19 +299,28 @@ class WanSelfAttention(nn.Module):
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
# query, key, value function
|
||||
q = self.q(x)
|
||||
q, k, v = self.q(x), self.k(x), self.v(x)
|
||||
if standin_phase == 1:
|
||||
q += self.q_loras(x)
|
||||
k += self.k_loras(x)
|
||||
v += self.v_loras(x)
|
||||
self.norm_q(q)
|
||||
q = q.view(b, s, n, d)
|
||||
k = self.k(x)
|
||||
self.norm_k(k)
|
||||
k = k.view(b, s, n, d)
|
||||
v = self.v(x).view(b, s, n, d)
|
||||
q,k,v = q.view(b, s, n, d), k.view(b, s, n, d), v.view(b, s, n, d)
|
||||
del x
|
||||
qklist = [q,k]
|
||||
del q,k
|
||||
|
||||
q,k = apply_rotary_emb(qklist, freqs, head_first=False)
|
||||
|
||||
if standin_phase >= 1:
|
||||
standin_cache = get_cache("standin")
|
||||
if standin_phase == 1:
|
||||
standin_cache[self.block_no] = (k,v)
|
||||
elif standin_phase == 2:
|
||||
k_ip, v_ip = standin_cache[self.block_no]
|
||||
k, v = torch.concat([k, k_ip], dim=1), torch.concat([v, v_ip], dim=1)
|
||||
del k_ip, v_ip
|
||||
if ref_target_masks != None:
|
||||
x_ref_attn_map = get_attn_map_with_target(q, k , grid_sizes, ref_target_masks=ref_target_masks, ref_images_count = ref_images_count)
|
||||
else:
|
||||
@ -289,6 +340,7 @@ class WanSelfAttention(nn.Module):
|
||||
x = pay_attention(
|
||||
qkv_list,
|
||||
window_size=self.window_size)
|
||||
|
||||
else:
|
||||
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
x = (
|
||||
@ -461,6 +513,7 @@ class WanAttentionBlock(nn.Module):
|
||||
multitalk_audio=None,
|
||||
multitalk_masks=None,
|
||||
ref_images_count=0,
|
||||
standin_phase=-1,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@ -504,7 +557,7 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
xlist = [x_mod.to(attention_dtype)]
|
||||
del x_mod
|
||||
y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count)
|
||||
y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count, standin_phase= standin_phase, )
|
||||
y = y.to(dtype)
|
||||
|
||||
if cam_emb != None: y = self.projector(y)
|
||||
@ -513,11 +566,13 @@ class WanAttentionBlock(nn.Module):
|
||||
x.addcmul_(y, e[2])
|
||||
x, y = restore_latent_shape(x), restore_latent_shape(y)
|
||||
del y
|
||||
y = self.norm3(x)
|
||||
y = y.to(attention_dtype)
|
||||
ylist= [y]
|
||||
del y
|
||||
x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
|
||||
|
||||
if context is not None:
|
||||
y = self.norm3(x)
|
||||
y = y.to(attention_dtype)
|
||||
ylist= [y]
|
||||
del y
|
||||
x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
|
||||
|
||||
if multitalk_audio != None:
|
||||
# cross attn of multitalk audio
|
||||
@ -853,6 +908,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
vae_scale=4, # vae timedownsample scale
|
||||
norm_input_visual=True,
|
||||
norm_output_audio=True,
|
||||
standin= False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
@ -976,6 +1032,11 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
for block in self.blocks:
|
||||
block.cross_attn.processor = WanCrossAttentionProcessor(fantasytalking_dim, dim)
|
||||
|
||||
if standin:
|
||||
for block in self.blocks:
|
||||
block.self_attn.q_loras = LoRALinearLayer(dim, dim, rank=128)
|
||||
block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128)
|
||||
block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
|
||||
|
||||
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
|
||||
layer_list = [self.head, self.head.head, self.patch_embedding]
|
||||
@ -1155,8 +1216,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
audio_scale=None,
|
||||
multitalk_audio = None,
|
||||
multitalk_masks = None,
|
||||
ref_images_count = 0,
|
||||
|
||||
ref_images_count = 0,
|
||||
standin_freqs = None,
|
||||
standin_ref = None,
|
||||
):
|
||||
# patch_dtype = self.patch_embedding.weight.dtype
|
||||
modulation_dtype = self.time_projection[1].weight.dtype
|
||||
@ -1220,6 +1282,18 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
offload.shared_state["embed_sizes"] = grid_sizes
|
||||
offload.shared_state["step_no"] = current_step
|
||||
offload.shared_state["max_steps"] = max_steps
|
||||
if current_step == 0 and x_id == 0: clear_caches()
|
||||
# arguments
|
||||
|
||||
kwargs = dict(
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=freqs,
|
||||
cam_emb = cam_emb,
|
||||
block_mask = block_mask,
|
||||
audio_proj=audio_proj,
|
||||
audio_context_lens=audio_context_lens,
|
||||
ref_images_count=ref_images_count,
|
||||
)
|
||||
|
||||
_flag_df = t.dim() == 2
|
||||
|
||||
@ -1228,6 +1302,16 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
) # b, dim
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
||||
|
||||
standin_x = None
|
||||
if standin_ref is not None:
|
||||
standin_cache_enabled = False
|
||||
kwargs["standin_phase"] = 2
|
||||
if (current_step == 0 or not standin_cache_enabled) and x_id == 0:
|
||||
standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2)
|
||||
standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)) )
|
||||
standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype)
|
||||
standin_e = standin_ref = None
|
||||
|
||||
if self.inject_sample_info and fps!=None:
|
||||
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
||||
|
||||
@ -1254,8 +1338,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
if multitalk_audio != None:
|
||||
multitalk_audio_list = []
|
||||
for audio in multitalk_audio:
|
||||
audio = self.audio_proj(*audio)
|
||||
audio = torch.concat(audio.split(1), dim=2).to(context[0])
|
||||
if audio is not None:
|
||||
audio = self.audio_proj(*audio)
|
||||
audio = torch.concat(audio.split(1), dim=2).to(context[0])
|
||||
multitalk_audio_list.append(audio)
|
||||
audio = None
|
||||
else:
|
||||
@ -1271,17 +1356,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
else:
|
||||
audio_scale_list = [None] * len(x_list)
|
||||
|
||||
# arguments
|
||||
|
||||
kwargs = dict(
|
||||
grid_sizes=grid_sizes,
|
||||
freqs=freqs,
|
||||
cam_emb = cam_emb,
|
||||
block_mask = block_mask,
|
||||
audio_proj=audio_proj,
|
||||
audio_context_lens=audio_context_lens,
|
||||
ref_images_count=ref_images_count,
|
||||
)
|
||||
|
||||
if vace_context == None:
|
||||
hints_list = [None ] *len(x_list)
|
||||
@ -1378,8 +1452,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
||||
if pipeline._interrupt:
|
||||
return [None] * len(x_list)
|
||||
|
||||
# if (x_id != 0 or joint_pass) and slg_layers is not None and block_idx in slg_layers:
|
||||
# if not joint_pass or not x_should_calc[0]:
|
||||
if standin_x is not None:
|
||||
if not standin_cache_enabled and x_id ==0 : get_cache("standin").clear()
|
||||
standin_x = block(standin_x, context = None, grid_sizes = None, e= standin_e0, freqs = standin_freqs, standin_phase = 1)
|
||||
|
||||
if slg_layers is not None and block_idx in slg_layers:
|
||||
if x_id != 0 or not x_should_calc[0]:
|
||||
|
||||
@ -137,7 +137,10 @@ def get_meshgrid_nd(start, *args, dim=2):
|
||||
axis_grid = []
|
||||
for i in range(dim):
|
||||
a, b, n = start[i], stop[i], num[i]
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
if a == b:
|
||||
g = torch.tensor([a], dtype=torch.float32 )
|
||||
else:
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
axis_grid.append(g)
|
||||
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
||||
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
||||
@ -297,16 +300,15 @@ def apply_rotary_emb( qklist,
|
||||
|
||||
return xq_out, xk_out
|
||||
def get_nd_rotary_pos_embed(
|
||||
rope_dim_list,
|
||||
start,
|
||||
*args,
|
||||
theta=10000.0,
|
||||
use_real=False,
|
||||
use_real=True,
|
||||
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
||||
interpolation_factor: Union[float, List[float]] = 1.0,
|
||||
k = 6,
|
||||
L_test = 66,
|
||||
enable_riflex = True
|
||||
enable_riflex = False
|
||||
):
|
||||
"""
|
||||
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
||||
@ -326,6 +328,13 @@ def get_nd_rotary_pos_embed(
|
||||
Returns:
|
||||
pos_embed (torch.Tensor): [HW, D/2]
|
||||
"""
|
||||
head_dim = 128
|
||||
rope_dim_list = [44, 42, 42]
|
||||
if rope_dim_list is None:
|
||||
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||
assert (
|
||||
sum(rope_dim_list) == head_dim
|
||||
), "sum(rope_dim_list) should equal to head_dim of attention layer"
|
||||
|
||||
grid = get_meshgrid_nd(
|
||||
start, *args, dim=len(rope_dim_list)
|
||||
@ -454,15 +463,7 @@ def get_rotary_pos_embed(latents_size, enable_RIFLEx = False):
|
||||
|
||||
if len(rope_sizes) != target_ndim:
|
||||
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
||||
head_dim = 128
|
||||
rope_dim_list = [44, 42, 42]
|
||||
if rope_dim_list is None:
|
||||
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||
assert (
|
||||
sum(rope_dim_list) == head_dim
|
||||
), "sum(rope_dim_list) should equal to head_dim of attention layer"
|
||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||
rope_dim_list,
|
||||
rope_sizes,
|
||||
theta=10000,
|
||||
use_real=True,
|
||||
|
||||
@ -75,12 +75,14 @@ def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=1
|
||||
if not (left_path==None or right_path==None):
|
||||
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
|
||||
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
|
||||
elif left_path==None:
|
||||
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
|
||||
human_speech_array1 = np.zeros(human_speech_array2.shape[0])
|
||||
elif right_path==None:
|
||||
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
|
||||
human_speech_array2 = np.zeros(human_speech_array1.shape[0])
|
||||
else:
|
||||
audio_type='para'
|
||||
if left_path==None:
|
||||
human_speech_array2 = audio_prepare_single(right_path, duration = duration)
|
||||
human_speech_array1 = np.zeros(human_speech_array2.shape[0])
|
||||
elif right_path==None:
|
||||
human_speech_array1 = audio_prepare_single(left_path, duration = duration)
|
||||
human_speech_array2 = np.zeros(human_speech_array1.shape[0])
|
||||
|
||||
if audio_type=='para':
|
||||
new_human_speech1 = human_speech_array1
|
||||
@ -94,21 +96,27 @@ def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=1
|
||||
new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])])
|
||||
new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]])
|
||||
|
||||
#dont include the padding on the summed audio which is used to build the output audio track
|
||||
sum_human_speechs = new_human_speech1 + new_human_speech2
|
||||
if pad > 0:
|
||||
new_human_speech1 = np.concatenate([np.zeros(pad), new_human_speech1])
|
||||
new_human_speech2 = np.concatenate([np.zeros(pad), new_human_speech2])
|
||||
|
||||
duration_changed = False
|
||||
if min_audio_duration > 0:
|
||||
min_samples = math.ceil( min_audio_duration * sample_rate)
|
||||
if len(new_human_speech1) < min_samples:
|
||||
new_human_speech1 = np.concatenate([new_human_speech1, np.zeros(min_samples -len(new_human_speech1)) ])
|
||||
duration_changed = True
|
||||
if len(new_human_speech2) < min_samples:
|
||||
new_human_speech2 = np.concatenate([new_human_speech2, np.zeros(min_samples -len(new_human_speech2)) ])
|
||||
duration_changed = True
|
||||
|
||||
#dont include the padding on the summed audio which is used to build the output audio track
|
||||
sum_human_speechs = new_human_speech1 + new_human_speech2
|
||||
|
||||
if pad > 0:
|
||||
duration_changed = True
|
||||
new_human_speech1 = np.concatenate([np.zeros(pad), new_human_speech1])
|
||||
new_human_speech2 = np.concatenate([np.zeros(pad), new_human_speech2])
|
||||
|
||||
return new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed
|
||||
|
||||
return new_human_speech1, new_human_speech2, sum_human_speechs
|
||||
|
||||
def process_tts_single(text, save_dir, voice1):
|
||||
s1_sentences = []
|
||||
@ -187,14 +195,14 @@ def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combinat
|
||||
wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base")
|
||||
# wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec")
|
||||
pad = int(padded_frames_for_embeddings/ fps * sr)
|
||||
new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration )
|
||||
new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration )
|
||||
audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
|
||||
audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
|
||||
full_audio_embs = []
|
||||
if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
|
||||
# if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
|
||||
if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
|
||||
if audio_guide2 == None: sum_human_speechs = None
|
||||
if audio_guide2 == None and not duration_changed: sum_human_speechs = None
|
||||
return full_audio_embs, sum_human_speechs
|
||||
|
||||
|
||||
|
||||
@ -10,6 +10,9 @@ def test_class_1_3B(base_model_type):
|
||||
def test_multitalk(base_model_type):
|
||||
return base_model_type in ["multitalk", "vace_multitalk_14B", "i2v_2_2_multitalk", "infinitetalk"]
|
||||
|
||||
def test_standin(base_model_type):
|
||||
return base_model_type in ["standin", "vace_standin_14B"]
|
||||
|
||||
class family_handler():
|
||||
|
||||
@staticmethod
|
||||
@ -72,7 +75,8 @@ class family_handler():
|
||||
i2v = test_class_i2v(base_model_type)
|
||||
extra_model_def["i2v_class"] = i2v
|
||||
extra_model_def["multitalk_class"] = test_multitalk(base_model_type)
|
||||
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B"]
|
||||
extra_model_def["standin_class"] = test_standin(base_model_type)
|
||||
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"]
|
||||
extra_model_def["vace_class"] = vace_class
|
||||
|
||||
if test_multitalk(base_model_type):
|
||||
@ -101,6 +105,7 @@ class family_handler():
|
||||
"adaptive_projected_guidance" : True,
|
||||
"tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
|
||||
"mag_cache" : True,
|
||||
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
|
||||
"sample_solvers":[
|
||||
("unipc", "unipc"),
|
||||
("euler", "euler"),
|
||||
@ -115,8 +120,8 @@ class family_handler():
|
||||
|
||||
@staticmethod
|
||||
def query_supported_types():
|
||||
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B",
|
||||
"t2v_1.3B", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
|
||||
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
|
||||
"t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
|
||||
"recam_1.3B",
|
||||
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
|
||||
|
||||
@ -130,8 +135,8 @@ class family_handler():
|
||||
}
|
||||
|
||||
models_comp_map = {
|
||||
"vace_14B" : [ "vace_multitalk_14B"],
|
||||
"t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B"],
|
||||
"vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B"],
|
||||
"t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin"],
|
||||
"i2v" : [ "fantasy", "multitalk", "flf2v_720p" ],
|
||||
"i2v_2_2" : ["i2v_2_2_multitalk"],
|
||||
"fantasy": ["multitalk"],
|
||||
@ -221,7 +226,15 @@ class family_handler():
|
||||
mult = model_def.get("loras_multipliers","")
|
||||
if len(mult)> 1 and len(mult[0].split(";"))==3: ui_defaults["guidance_phases"] = 3
|
||||
|
||||
|
||||
if settings_version < 2.27:
|
||||
if base_model_type in "infinitetalk":
|
||||
guidance_scale = ui_defaults.get("guidance_scale", None)
|
||||
if guidance_scale == 1:
|
||||
ui_defaults["audio_guidance_scale"]= 1
|
||||
video_prompt_type = ui_defaults.get("video_prompt_type", "")
|
||||
if "I" in video_prompt_type:
|
||||
video_prompt_type = video_prompt_type.replace("KI", "QKI")
|
||||
ui_defaults["video_prompt_type"] = video_prompt_type
|
||||
@staticmethod
|
||||
def update_default_settings(base_model_type, model_def, ui_defaults):
|
||||
ui_defaults.update({
|
||||
@ -248,11 +261,19 @@ class family_handler():
|
||||
"flow_shift": 7, # 11 for 720p
|
||||
"sliding_window_overlap" : 9,
|
||||
"sample_solver" : "euler",
|
||||
"video_prompt_type": "KI",
|
||||
"video_prompt_type": "QKI",
|
||||
"remove_background_images_ref" : 0,
|
||||
"adaptive_switch" : 1,
|
||||
})
|
||||
|
||||
elif base_model_type in ["standin"]:
|
||||
ui_defaults.update({
|
||||
"guidance_scale": 5.0,
|
||||
"flow_shift": 7, # 11 for 720p
|
||||
"sliding_window_overlap" : 9,
|
||||
"video_prompt_type": "I",
|
||||
"remove_background_images_ref" : 1,
|
||||
})
|
||||
elif base_model_type in ["phantom_1.3B", "phantom_14B"]:
|
||||
ui_defaults.update({
|
||||
"guidance_scale": 7.5,
|
||||
|
||||
259
preprocessing/face_preprocessor.py
Normal file
259
preprocessing/face_preprocessor.py
Normal file
@ -0,0 +1,259 @@
|
||||
import os
|
||||
import cv2
|
||||
import requests
|
||||
import torch
|
||||
import numpy as np
|
||||
import PIL.Image as Image
|
||||
import PIL.ImageOps
|
||||
# from insightface.app import FaceAnalysis
|
||||
# from facexlib.parsing import init_parsing_model
|
||||
from torchvision.transforms.functional import normalize
|
||||
from typing import Union, Optional
|
||||
from models.hyvideo.data_kits.face_align import AlignImage
|
||||
|
||||
|
||||
def _img2tensor(img: np.ndarray, bgr2rgb: bool = True) -> torch.Tensor:
|
||||
if bgr2rgb:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = img.astype(np.float32) / 255.0
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
return torch.from_numpy(img)
|
||||
|
||||
|
||||
def _pad_to_square(img: np.ndarray, pad_color: int = 255) -> np.ndarray:
|
||||
h, w, _ = img.shape
|
||||
if h == w:
|
||||
return img
|
||||
|
||||
if h > w:
|
||||
pad_size = (h - w) // 2
|
||||
padded_img = cv2.copyMakeBorder(
|
||||
img,
|
||||
0,
|
||||
0,
|
||||
pad_size,
|
||||
h - w - pad_size,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=[pad_color] * 3,
|
||||
)
|
||||
else:
|
||||
pad_size = (w - h) // 2
|
||||
padded_img = cv2.copyMakeBorder(
|
||||
img,
|
||||
pad_size,
|
||||
w - h - pad_size,
|
||||
0,
|
||||
0,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=[pad_color] * 3,
|
||||
)
|
||||
|
||||
return padded_img
|
||||
|
||||
|
||||
class FaceProcessor:
|
||||
def __init__(self):
|
||||
self.align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt")
|
||||
self.align_instance.facedet.model.to("cpu")
|
||||
|
||||
|
||||
def process(
|
||||
self,
|
||||
image: Union[str, PIL.Image.Image],
|
||||
resize_to: int = 512,
|
||||
border_thresh: int = 10,
|
||||
face_crop_scale: float = 1.5,
|
||||
remove_bg= False,
|
||||
# area=1.25
|
||||
) -> PIL.Image.Image:
|
||||
|
||||
image_pil = PIL.ImageOps.exif_transpose(image).convert("RGB")
|
||||
w, h = image_pil.size
|
||||
self.align_instance.facedet.model.to("cuda")
|
||||
_, _, bboxes_list = self.align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True)
|
||||
self.align_instance.facedet.model.to("cpu")
|
||||
|
||||
try:
|
||||
bboxSrc = bboxes_list[0]
|
||||
except:
|
||||
bboxSrc = [0, 0, w, h]
|
||||
x1, y1, ww, hh = bboxSrc
|
||||
x2, y2 = x1 + ww, y1 + hh
|
||||
# ww, hh = (x2-x1) * area, (y2-y1) * area
|
||||
# center = [(x2+x1)//2, (y2+y1)//2]
|
||||
# x1 = max(center[0] - ww//2, 0)
|
||||
# y1 = max(center[1] - hh//2, 0)
|
||||
# x2 = min(center[0] + ww//2, w)
|
||||
# y2 = min(center[1] + hh//2, h)
|
||||
|
||||
frame = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
||||
h, w, _ = frame.shape
|
||||
image_to_process = None
|
||||
|
||||
is_close_to_border = (
|
||||
x1 <= border_thresh
|
||||
and y1 <= border_thresh
|
||||
and x2 >= w - border_thresh
|
||||
and y2 >= h - border_thresh
|
||||
)
|
||||
|
||||
if is_close_to_border:
|
||||
# print(
|
||||
# "[Info] Face is close to border, padding original image to square."
|
||||
# )
|
||||
image_to_process = _pad_to_square(frame, pad_color=255)
|
||||
else:
|
||||
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
||||
side = int(max(x2 - x1, y2 - y1) * face_crop_scale)
|
||||
half = side // 2
|
||||
|
||||
left = int(max(cx - half, 0))
|
||||
top = int(max(cy - half, 0))
|
||||
right = int(min(cx + half, w))
|
||||
bottom = int(min(cy + half, h))
|
||||
|
||||
cropped_face = frame[top:bottom, left:right]
|
||||
image_to_process = _pad_to_square(cropped_face, pad_color=255)
|
||||
|
||||
image_resized = cv2.resize(
|
||||
image_to_process, (resize_to, resize_to), interpolation=cv2.INTER_LANCZOS4 # .INTER_AREA
|
||||
)
|
||||
|
||||
face_tensor = _img2tensor(image_resized).to("cpu")
|
||||
|
||||
from shared.utils.utils import remove_background, convert_tensor_to_image
|
||||
if remove_bg:
|
||||
face_tensor = remove_background(face_tensor)
|
||||
img_out = Image.fromarray(face_tensor.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy())
|
||||
return img_out
|
||||
|
||||
|
||||
# class FaceProcessor2:
|
||||
# def __init__(self, antelopv2_path=".", device: Optional[torch.device] = None):
|
||||
# if device is None:
|
||||
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# else:
|
||||
# self.device = device
|
||||
|
||||
# providers = (
|
||||
# ["CUDAExecutionProvider"]
|
||||
# if self.device.type == "cuda"
|
||||
# else ["CPUExecutionProvider"]
|
||||
# )
|
||||
# self.app = FaceAnalysis(
|
||||
# name="antelopev2", root=antelopv2_path, providers=providers
|
||||
# )
|
||||
# self.app.prepare(ctx_id=0, det_size=(640, 640))
|
||||
|
||||
# self.parsing_model = init_parsing_model(
|
||||
# model_name="bisenet", device=self.device
|
||||
# )
|
||||
# self.parsing_model.eval()
|
||||
|
||||
# print("FaceProcessor initialized successfully.")
|
||||
|
||||
# def process(
|
||||
# self,
|
||||
# image: Union[str, PIL.Image.Image],
|
||||
# resize_to: int = 512,
|
||||
# border_thresh: int = 10,
|
||||
# face_crop_scale: float = 1.5,
|
||||
# extra_input: bool = False,
|
||||
# ) -> PIL.Image.Image:
|
||||
# if isinstance(image, str):
|
||||
# if image.startswith("http://") or image.startswith("https://"):
|
||||
# image = PIL.Image.open(requests.get(image, stream=True, timeout=10).raw)
|
||||
# elif os.path.isfile(image):
|
||||
# image = PIL.Image.open(image)
|
||||
# else:
|
||||
# raise ValueError(
|
||||
# f"Input string is not a valid URL or file path: {image}"
|
||||
# )
|
||||
# elif not isinstance(image, PIL.Image.Image):
|
||||
# raise TypeError(
|
||||
# "Input must be a file path, a URL, or a PIL.Image.Image object."
|
||||
# )
|
||||
|
||||
# image = PIL.ImageOps.exif_transpose(image).convert("RGB")
|
||||
|
||||
# frame = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# faces = self.app.get(frame)
|
||||
# h, w, _ = frame.shape
|
||||
# image_to_process = None
|
||||
|
||||
# if not faces:
|
||||
# print(
|
||||
# "[Warning] No face detected. Using the whole image, padded to square."
|
||||
# )
|
||||
# image_to_process = _pad_to_square(frame, pad_color=255)
|
||||
# else:
|
||||
# largest_face = max(
|
||||
# faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1])
|
||||
# )
|
||||
# x1, y1, x2, y2 = map(int, largest_face.bbox)
|
||||
|
||||
# is_close_to_border = (
|
||||
# x1 <= border_thresh
|
||||
# and y1 <= border_thresh
|
||||
# and x2 >= w - border_thresh
|
||||
# and y2 >= h - border_thresh
|
||||
# )
|
||||
|
||||
# if is_close_to_border:
|
||||
# print(
|
||||
# "[Info] Face is close to border, padding original image to square."
|
||||
# )
|
||||
# image_to_process = _pad_to_square(frame, pad_color=255)
|
||||
# else:
|
||||
# cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
||||
# side = int(max(x2 - x1, y2 - y1) * face_crop_scale)
|
||||
# half = side // 2
|
||||
|
||||
# left = max(cx - half, 0)
|
||||
# top = max(cy - half, 0)
|
||||
# right = min(cx + half, w)
|
||||
# bottom = min(cy + half, h)
|
||||
|
||||
# cropped_face = frame[top:bottom, left:right]
|
||||
# image_to_process = _pad_to_square(cropped_face, pad_color=255)
|
||||
|
||||
# image_resized = cv2.resize(
|
||||
# image_to_process, (resize_to, resize_to), interpolation=cv2.INTER_AREA
|
||||
# )
|
||||
|
||||
# face_tensor = (
|
||||
# _img2tensor(image_resized, bgr2rgb=True).unsqueeze(0).to(self.device)
|
||||
# )
|
||||
# with torch.no_grad():
|
||||
# normalized_face = normalize(face_tensor, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
||||
# parsing_out = self.parsing_model(normalized_face)[0]
|
||||
# parsing_mask = parsing_out.argmax(dim=1, keepdim=True)
|
||||
|
||||
# background_mask_np = (parsing_mask.squeeze().cpu().numpy() == 0).astype(
|
||||
# np.uint8
|
||||
# )
|
||||
# white_background = np.ones_like(image_resized, dtype=np.uint8) * 255
|
||||
# mask_3channel = cv2.cvtColor(background_mask_np * 255, cv2.COLOR_GRAY2BGR)
|
||||
# result_img_bgr = np.where(mask_3channel == 255, white_background, image_resized)
|
||||
# result_img_rgb = cv2.cvtColor(result_img_bgr, cv2.COLOR_BGR2RGB)
|
||||
# img_white_bg = PIL.Image.fromarray(result_img_rgb)
|
||||
# if extra_input:
|
||||
# # 2. Create image with transparent background (new logic)
|
||||
# # Create an alpha channel: 255 for foreground (not background), 0 for background
|
||||
# alpha_channel = (parsing_mask.squeeze().cpu().numpy() != 0).astype(
|
||||
# np.uint8
|
||||
# ) * 255
|
||||
|
||||
# # Convert the resized BGR image to RGB
|
||||
# image_resized_rgb = cv2.cvtColor(image_resized, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# # Stack RGB channels with the new alpha channel
|
||||
# rgba_image = np.dstack((image_resized_rgb, alpha_channel))
|
||||
|
||||
# # Create PIL image from the RGBA numpy array
|
||||
# img_transparent_bg = PIL.Image.fromarray(rgba_image, "RGBA")
|
||||
|
||||
# return img_white_bg, img_transparent_bg
|
||||
# else:
|
||||
# return img_white_bg
|
||||
@ -92,13 +92,44 @@ def get_video_info(video_path):
|
||||
|
||||
return fps, width, height, frame_count
|
||||
|
||||
def get_video_frame(file_name, frame_no):
|
||||
decord.bridge.set_bridge('torch')
|
||||
reader = decord.VideoReader(file_name)
|
||||
def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor:
|
||||
"""Extract nth frame from video as PyTorch tensor normalized to [-1, 1]."""
|
||||
cap = cv2.VideoCapture(file_name)
|
||||
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Cannot open video: {file_name}")
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
# Handle out of bounds
|
||||
if frame_no >= total_frames or frame_no < 0:
|
||||
if return_last_if_missing:
|
||||
frame_no = total_frames - 1
|
||||
else:
|
||||
cap.release()
|
||||
raise IndexError(f"Frame {frame_no} out of bounds (0-{total_frames-1})")
|
||||
|
||||
# Get frame
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
|
||||
if not ret:
|
||||
raise ValueError(f"Failed to read frame {frame_no}")
|
||||
|
||||
# Convert BGR->RGB, reshape to (C,H,W), normalize to [-1,1]
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
if return_PIL:
|
||||
return Image.fromarray(frame)
|
||||
else:
|
||||
return (torch.from_numpy(frame).permute(2, 0, 1).float() / 127.5) - 1.0
|
||||
# def get_video_frame(file_name, frame_no):
|
||||
# decord.bridge.set_bridge('torch')
|
||||
# reader = decord.VideoReader(file_name)
|
||||
|
||||
frame = reader.get_batch([frame_no]).squeeze(0)
|
||||
img = Image.fromarray(frame.numpy().astype(np.uint8))
|
||||
return img
|
||||
# frame = reader.get_batch([frame_no]).squeeze(0)
|
||||
# img = Image.fromarray(frame.numpy().astype(np.uint8))
|
||||
# return img
|
||||
|
||||
def convert_image_to_video(image):
|
||||
if image is None:
|
||||
@ -174,18 +205,19 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims
|
||||
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
|
||||
return height, width, margin_top, margin_left
|
||||
|
||||
def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16):
|
||||
def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
|
||||
if fit_into_canvas == None:
|
||||
return height, width
|
||||
# return image_height, image_width
|
||||
return canvas_height, canvas_width
|
||||
if fit_into_canvas:
|
||||
scale1 = min(canvas_height / height, canvas_width / width)
|
||||
scale2 = min(canvas_width / height, canvas_height / width)
|
||||
scale1 = min(canvas_height / image_height, canvas_width / image_width)
|
||||
scale2 = min(canvas_width / image_height, canvas_height / image_width)
|
||||
scale = max(scale1, scale2)
|
||||
else:
|
||||
scale = (canvas_height * canvas_width / (height * width))**(1/2)
|
||||
scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2)
|
||||
|
||||
new_height = round( height * scale / block_size) * block_size
|
||||
new_width = round( width * scale / block_size) * block_size
|
||||
new_height = round( image_height * scale / block_size) * block_size
|
||||
new_width = round( image_width * scale / block_size) * block_size
|
||||
return new_height, new_width
|
||||
|
||||
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ):
|
||||
|
||||
191
wgp.py
191
wgp.py
@ -17,7 +17,7 @@ import numpy as np
|
||||
import importlib
|
||||
from shared.utils import notification_sound
|
||||
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
|
||||
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor
|
||||
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, get_video_frame
|
||||
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
|
||||
from shared.utils.audio_video import save_image_metadata, read_image_metadata
|
||||
from shared.match_archi import match_nvidia_architecture
|
||||
@ -55,8 +55,8 @@ AUTOSAVE_FILENAME = "queue.zip"
|
||||
PROMPT_VARS_MAX = 10
|
||||
|
||||
target_mmgp_version = "3.5.10"
|
||||
WanGP_version = "8.11"
|
||||
settings_version = 2.26
|
||||
WanGP_version = "8.2"
|
||||
settings_version = 2.27
|
||||
max_source_video_frames = 3000
|
||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
|
||||
|
||||
@ -75,9 +75,14 @@ gen_lock = threading.Lock()
|
||||
offloadobj = enhancer_offloadobj = wan_model = None
|
||||
reload_needed = True
|
||||
|
||||
def clear_gen_cache():
|
||||
if "_cache" in offload.shared_state:
|
||||
del offload.shared_state["_cache"]
|
||||
|
||||
def release_model():
|
||||
global wan_model, offloadobj, reload_needed
|
||||
wan_model = None
|
||||
clear_gen_cache()
|
||||
offload.shared_state
|
||||
if offloadobj is not None:
|
||||
offloadobj.release()
|
||||
offloadobj = None
|
||||
@ -2066,7 +2071,7 @@ def fix_settings(model_type, ui_defaults):
|
||||
if base_model_type in ["flux"] and settings_version < 2.23:
|
||||
video_prompt_type = video_prompt_type.replace("K", "").replace("I", "KI")
|
||||
|
||||
remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 1)
|
||||
remove_background_images_ref = ui_defaults.get("remove_background_images_ref", None)
|
||||
if settings_version < 2.22:
|
||||
if "I" in video_prompt_type:
|
||||
if remove_background_images_ref == 2:
|
||||
@ -2077,7 +2082,8 @@ def fix_settings(model_type, ui_defaults):
|
||||
remove_background_images_ref = 0
|
||||
if settings_version < 2.26:
|
||||
if not "K" in video_prompt_type: video_prompt_type = video_prompt_type.replace("I", "KI")
|
||||
ui_defaults["remove_background_images_ref"] = remove_background_images_ref
|
||||
if remove_background_images_ref is not None:
|
||||
ui_defaults["remove_background_images_ref"] = remove_background_images_ref
|
||||
|
||||
ui_defaults["video_prompt_type"] = video_prompt_type
|
||||
|
||||
@ -2438,12 +2444,12 @@ def download_models(model_filename = None, model_type= None, module_type = None,
|
||||
|
||||
shared_def = {
|
||||
"repoId" : "DeepBeepMeep/Wan2.1",
|
||||
"sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "" ],
|
||||
"sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "det_align", "" ],
|
||||
"fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"],
|
||||
["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"],
|
||||
["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
|
||||
["config.json", "pytorch_model.bin", "preprocessor_config.json"],
|
||||
["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], [ "flownet.pkl" ] ]
|
||||
["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], ["detface.pt"], [ "flownet.pkl" ] ]
|
||||
}
|
||||
process_files_def(**shared_def)
|
||||
|
||||
@ -3314,6 +3320,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
|
||||
video_guidance_scale = configs.get("guidance_scale", None)
|
||||
video_guidance2_scale = configs.get("guidance2_scale", None)
|
||||
video_guidance3_scale = configs.get("guidance3_scale", None)
|
||||
video_audio_guidance_scale = configs.get("audio_guidance_scale", None)
|
||||
video_switch_threshold = configs.get("switch_threshold", 0)
|
||||
video_switch_threshold2 = configs.get("switch_threshold2", 0)
|
||||
video_model_switch_phase = configs.get("model_switch_phase", 1)
|
||||
@ -3364,8 +3371,8 @@ def select_video(state, input_file_list, event_data: gr.EventData):
|
||||
if model_def.get("sample_solvers", None) is not None and len(video_sample_solver) > 0 :
|
||||
values += [video_sample_solver]
|
||||
labels += ["Sampler Solver"]
|
||||
values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps]
|
||||
labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Shift Scale", "Num Inference steps"]
|
||||
values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_audio_guidance_scale, video_flow_shift, video_num_inference_steps]
|
||||
labels += [ "Resolution", video_length_label, "Seed", video_guidance_label, "Audio Guidance Scale", "Shift Scale", "Num Inference steps"]
|
||||
video_negative_prompt = configs.get("negative_prompt", "")
|
||||
if len(video_negative_prompt) > 0:
|
||||
values += [video_negative_prompt]
|
||||
@ -4272,6 +4279,7 @@ def generate_video(
|
||||
sliding_window_color_correction_strength,
|
||||
sliding_window_overlap_noise,
|
||||
sliding_window_discard_last_frames,
|
||||
image_refs_relative_size,
|
||||
remove_background_images_ref,
|
||||
temporal_upsampling,
|
||||
spatial_upsampling,
|
||||
@ -4438,10 +4446,8 @@ def generate_video(
|
||||
i2v = test_class_i2v(model_type)
|
||||
diffusion_forcing = "diffusion_forcing" in model_filename
|
||||
t2v = base_model_type in ["t2v"]
|
||||
recam = base_model_type in ["recam_1.3B"]
|
||||
ltxv = "ltxv" in model_filename
|
||||
vace = test_vace_module(base_model_type)
|
||||
phantom = "phantom" in model_filename
|
||||
hunyuan_t2v = "hunyuan_video_720" in model_filename
|
||||
hunyuan_i2v = "hunyuan_video_i2v" in model_filename
|
||||
hunyuan_custom = "hunyuan_video_custom" in model_filename
|
||||
@ -4450,8 +4456,8 @@ def generate_video(
|
||||
hunyuan_avatar = "hunyuan_video_avatar" in model_filename
|
||||
fantasy = base_model_type in ["fantasy"]
|
||||
multitalk = model_def.get("multitalk_class", False)
|
||||
flux = model_family in ["flux"]
|
||||
qwen = model_family in ["qwen"]
|
||||
standin = model_def.get("standin_class", False)
|
||||
infinitetalk = base_model_type in ["infinitetalk"]
|
||||
|
||||
if "B" in audio_prompt_type or "X" in audio_prompt_type:
|
||||
from models.wan.multitalk.multitalk import parse_speakers_locations
|
||||
@ -4489,7 +4495,9 @@ def generate_video(
|
||||
|
||||
_, latent_size = get_model_min_frames_and_step(model_type)
|
||||
if diffusion_forcing: latent_size = 4
|
||||
original_image_refs = image_refs
|
||||
original_image_refs = image_refs
|
||||
# image_refs = None
|
||||
# nb_frames_positions= 0
|
||||
frames_to_inject = []
|
||||
any_background_ref = False
|
||||
outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
|
||||
@ -4519,7 +4527,7 @@ def generate_video(
|
||||
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
|
||||
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
|
||||
from shared.utils.utils import resize_and_remove_background
|
||||
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or vace) ) # no fit for vace ref images as it is done later
|
||||
image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or vace or standin) ) # no fit for vace ref images as it is done later
|
||||
update_task_thumbnails(task, locals())
|
||||
send_cmd("output")
|
||||
joint_pass = boost ==1 #and profile != 1 and profile != 3
|
||||
@ -4580,7 +4588,7 @@ def generate_video(
|
||||
# pad audio_proj_full if aligned to beginning of window to simulate source window overlap
|
||||
min_audio_duration = current_video_length/fps if reset_control_aligment else video_source_duration + current_video_length/fps
|
||||
audio_proj_full, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = audio_guide, audio_guide2= audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration)
|
||||
if output_new_audio_filepath is not None: output_new_audio_data = None
|
||||
if output_new_audio_data is not None: output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined
|
||||
if not args.save_speakers and "X" in audio_prompt_type:
|
||||
os.remove(audio_guide)
|
||||
os.remove(audio_guide2)
|
||||
@ -4698,7 +4706,6 @@ def generate_video(
|
||||
image_end_tensor = convert_image_to_tensor(image_end_tensor)
|
||||
else:
|
||||
if "L" in image_prompt_type:
|
||||
from shared.utils.utils import get_video_frame
|
||||
refresh_preview["video_source"] = get_video_frame(video_source, 0)
|
||||
prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = block_size )
|
||||
prefix_video = prefix_video.permute(3, 0, 1, 2)
|
||||
@ -4707,9 +4714,10 @@ def generate_video(
|
||||
pre_video_frame = convert_tensor_to_image(prefix_video[:, -1])
|
||||
source_video_overlap_frames_count = pre_video_guide.shape[1]
|
||||
source_video_frames_count = prefix_video.shape[1]
|
||||
if sample_fit_canvas != None: image_size = pre_video_guide.shape[-2:]
|
||||
if sample_fit_canvas != None:
|
||||
image_size = pre_video_guide.shape[-2:]
|
||||
sample_fit_canvas = None
|
||||
guide_start_frame = prefix_video.shape[1]
|
||||
sample_fit_canvas = None
|
||||
|
||||
window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count)
|
||||
guide_end_frame = guide_start_frame + current_video_length - (source_video_overlap_frames_count if window_no == 1 else reuse_frames)
|
||||
@ -4729,12 +4737,20 @@ def generate_video(
|
||||
if len(error) > 0:
|
||||
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
|
||||
keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ]
|
||||
|
||||
if infinitetalk and video_guide is not None:
|
||||
src_image = get_video_frame(video_guide, aligned_guide_start_frame-1, return_last_if_missing = True, return_PIL = True)
|
||||
new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size)
|
||||
src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||
refresh_preview["video_guide"] = src_image
|
||||
src_video = convert_image_to_tensor(src_image).unsqueeze(1)
|
||||
if sample_fit_canvas != None:
|
||||
image_size = src_video.shape[-2:]
|
||||
sample_fit_canvas = None
|
||||
if ltxv and video_guide is not None:
|
||||
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
|
||||
status_info = "Extracting " + processes_names[preprocess_type]
|
||||
send_cmd("progress", [0, get_latest_status(state, status_info)])
|
||||
# start one frame ealier to faciliate latents merging later
|
||||
# start one frame ealier to facilitate latents merging later
|
||||
src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
|
||||
if src_video != None:
|
||||
src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
|
||||
@ -4912,6 +4928,7 @@ def generate_video(
|
||||
overlapped_latents = overlapped_latents,
|
||||
return_latent_slice= return_latent_slice,
|
||||
overlap_noise = sliding_window_overlap_noise,
|
||||
overlap_size = sliding_window_overlap,
|
||||
color_correction_strength = sliding_window_color_correction_strength,
|
||||
conditioning_latents_size = conditioning_latents_size,
|
||||
keep_frames_parsed = keep_frames_parsed,
|
||||
@ -4928,11 +4945,14 @@ def generate_video(
|
||||
offloadobj = offloadobj,
|
||||
set_header_text= set_header_text,
|
||||
pre_video_frame = pre_video_frame,
|
||||
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
|
||||
image_refs_relative_size = image_refs_relative_size,
|
||||
)
|
||||
except Exception as e:
|
||||
if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0:
|
||||
cleanup_temp_audio_files(control_audio_tracks + source_audio_tracks)
|
||||
remove_temp_filenames(temp_filenames_list)
|
||||
clear_gen_cache()
|
||||
offloadobj.unload_all()
|
||||
trans.cache = None
|
||||
offload.unload_loras_from_model(trans)
|
||||
@ -4977,6 +4997,7 @@ def generate_video(
|
||||
overlapped_latents = samples.get("latent_slice", None)
|
||||
samples= samples["x"]
|
||||
samples = samples.to("cpu")
|
||||
clear_gen_cache()
|
||||
offloadobj.unload_all()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@ -5891,9 +5912,15 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
|
||||
pop += ["model_mode"]
|
||||
|
||||
if not vace and not phantom and not hunyuan_video_custom:
|
||||
unsaved_params = ["keep_frames_video_guide", "remove_background_images_ref", "mask_expand"] #"video_prompt_type",
|
||||
unsaved_params = ["keep_frames_video_guide", "mask_expand"] #"video_prompt_type",
|
||||
if base_model_type in ["t2v"]: unsaved_params = unsaved_params[1:]
|
||||
pop += unsaved_params
|
||||
|
||||
if not "I" in video_prompt_type:
|
||||
pop += ["remove_background_images_ref"]
|
||||
if not model_def.get("any_image_refs_relative_size", False):
|
||||
pop += ["image_refs_relative_size"]
|
||||
|
||||
if not vace:
|
||||
pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"]
|
||||
|
||||
@ -6328,7 +6355,8 @@ def save_inputs(
|
||||
sliding_window_overlap,
|
||||
sliding_window_color_correction_strength,
|
||||
sliding_window_overlap_noise,
|
||||
sliding_window_discard_last_frames,
|
||||
sliding_window_discard_last_frames,
|
||||
image_refs_relative_size,
|
||||
remove_background_images_ref,
|
||||
temporal_upsampling,
|
||||
spatial_upsampling,
|
||||
@ -6543,11 +6571,20 @@ def refresh_image_prompt_type(state, image_prompt_type):
|
||||
return gr.update(visible = "S" in image_prompt_type ), gr.update(visible = "E" in image_prompt_type ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source)
|
||||
|
||||
def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs):
|
||||
video_prompt_type = del_in_sequence(video_prompt_type, "KFI")
|
||||
model_type = state["model_type"]
|
||||
model_def = get_model_def(model_type)
|
||||
image_ref_choices = model_def.get("image_ref_choices", None)
|
||||
if image_ref_choices is not None:
|
||||
video_prompt_type = del_in_sequence(video_prompt_type, image_ref_choices["letters_filter"])
|
||||
else:
|
||||
video_prompt_type = del_in_sequence(video_prompt_type, "KFI")
|
||||
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs)
|
||||
visible = "I" in video_prompt_type
|
||||
vace= test_vace_module(state["model_type"])
|
||||
return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
|
||||
|
||||
rm_bg_visible= visible and not model_def.get("no_background_removal", False)
|
||||
img_rel_size_visible = visible and model_def.get("any_image_refs_relative_size", False)
|
||||
return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
|
||||
|
||||
def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask, image_mode):
|
||||
video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA")
|
||||
@ -6573,8 +6610,16 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt
|
||||
model_def = get_model_def(model_type)
|
||||
image_outputs = image_mode == 1
|
||||
vace= test_vace_module(model_type)
|
||||
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = visible and not image_outputs), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
|
||||
keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False)
|
||||
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
|
||||
|
||||
def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt):
|
||||
video_prompt_type = del_in_sequence(video_prompt_type, "UVQKI")
|
||||
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt)
|
||||
control_video_visible = "V" in video_prompt_type
|
||||
ref_images_visible = "I" in video_prompt_type
|
||||
return video_prompt_type, gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible )
|
||||
|
||||
# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
|
||||
# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
|
||||
# return refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide)
|
||||
@ -6799,6 +6844,13 @@ def detect_auto_save_form(state, evt:gr.SelectData):
|
||||
else:
|
||||
return gr.update()
|
||||
|
||||
def compute_video_length_label(fps, current_video_length):
|
||||
return f"Number of frames ({fps} frames = 1s), current duration: {(current_video_length / fps):.1f}s",
|
||||
|
||||
def refresh_video_length_label(state, current_video_length):
|
||||
fps = get_model_fps(get_base_model_type(state["model_type"]))
|
||||
return gr.update(label= compute_video_length_label(fps, current_video_length))
|
||||
|
||||
def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None):
|
||||
global inputs_names #, advanced
|
||||
|
||||
@ -6918,6 +6970,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
phantom = base_model_type in ["phantom_1.3B", "phantom_14B"]
|
||||
fantasy = base_model_type in ["fantasy"]
|
||||
multitalk = model_def.get("multitalk_class", False)
|
||||
standin = model_def.get("standin_class", False)
|
||||
infinitetalk = base_model_type in ["infinitetalk"]
|
||||
hunyuan_t2v = "hunyuan_video_720" in model_filename
|
||||
hunyuan_i2v = "hunyuan_video_i2v" in model_filename
|
||||
@ -6939,7 +6992,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
any_start_image = False
|
||||
any_end_image = False
|
||||
any_reference_image = False
|
||||
v2i_switch_supported = (vace or t2v) and not image_outputs
|
||||
v2i_switch_supported = (vace or t2v or standin) and not image_outputs
|
||||
ti2v_2_2 = base_model_type in ["ti2v_2_2"]
|
||||
|
||||
image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 )
|
||||
@ -7060,7 +7113,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
model_mode = gr.Dropdown(value=None, visible=False)
|
||||
keep_frames_video_source = gr.Text(visible=False)
|
||||
|
||||
with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or ltxv or infinitetalk or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column:
|
||||
with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column:
|
||||
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
|
||||
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
|
||||
any_control_video = True
|
||||
@ -7119,12 +7172,28 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
],
|
||||
value=filter_letters(video_prompt_type_value, "PDSLCMUV"),
|
||||
label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True,
|
||||
)
|
||||
)
|
||||
elif infinitetalk:
|
||||
video_prompt_type_video_guide = gr.Dropdown(value="", choices = [("","")], visible=False)
|
||||
else:
|
||||
any_control_video = False
|
||||
any_control_image = False
|
||||
video_prompt_type_video_guide = gr.Dropdown(visible= False)
|
||||
|
||||
if infinitetalk:
|
||||
video_prompt_type_video_guide_alt = gr.Dropdown(
|
||||
choices=[
|
||||
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "UV"),
|
||||
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QUV"),
|
||||
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
|
||||
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
|
||||
],
|
||||
value=filter_letters(video_prompt_type_value, "UVQKI"),
|
||||
label="Video to Video", scale = 3, visible= True, show_label= False,
|
||||
)
|
||||
else:
|
||||
video_prompt_type_video_guide_alt = gr.Dropdown(value="", choices = [("","")], visible=False)
|
||||
|
||||
# video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
|
||||
if t2v:
|
||||
video_prompt_type_video_mask = gr.Dropdown(value = "", choices = [""], visible = False)
|
||||
@ -7170,7 +7239,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom and not ltxv,
|
||||
label="Area Processed", scale = 2, show_label= True,
|
||||
)
|
||||
if t2v:
|
||||
image_ref_choices = model_def.get("image_ref_choices", None)
|
||||
if image_ref_choices is not None:
|
||||
video_prompt_type_image_refs = gr.Dropdown(
|
||||
choices= image_ref_choices["choices"],
|
||||
value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]),
|
||||
visible = True,
|
||||
label=image_ref_choices["label"], show_label= True, scale = 2
|
||||
)
|
||||
elif t2v:
|
||||
video_prompt_type_image_refs = gr.Dropdown(value="", label="Ref Image", choices=[""], visible =False)
|
||||
elif vace:
|
||||
video_prompt_type_image_refs = gr.Dropdown(
|
||||
@ -7184,7 +7261,17 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
visible = True,
|
||||
label="Reference Images", show_label= True, scale = 2
|
||||
)
|
||||
|
||||
elif standin: # and not vace
|
||||
video_prompt_type_image_refs = gr.Dropdown(
|
||||
choices=[
|
||||
("No Reference Image", ""),
|
||||
("Reference Image is a Person Face", "I"),
|
||||
],
|
||||
value=filter_letters(video_prompt_type_value, "I"),
|
||||
visible = True,
|
||||
show_label=False,
|
||||
label="Reference Image", scale = 2
|
||||
)
|
||||
|
||||
elif (flux or qwen) and model_reference_image:
|
||||
video_prompt_type_image_refs = gr.Dropdown(
|
||||
@ -7209,7 +7296,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
video_guide = gr.Video(label= "Control Video", visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None))
|
||||
|
||||
denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False)
|
||||
keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= (not image_outputs) and "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
|
||||
keep_frames_video_guide_visible = not image_outputs and "V" in video_prompt_type_value and not model_def.get("keep_frames_video_guide_not_supported", False)
|
||||
keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= keep_frames_video_guide_visible , scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
|
||||
|
||||
with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col:
|
||||
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
|
||||
@ -7236,13 +7324,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
)
|
||||
|
||||
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" )
|
||||
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
|
||||
|
||||
no_background_removal = model_def.get("no_background_removal", False)
|
||||
remove_background_images_ref = gr.Dropdown(
|
||||
choices=[
|
||||
("Keep Backgrounds behind all Reference Images", 0),
|
||||
("Remove Backgrounds only behind People / Objects except main Subject / Landscape" if (flux or qwen) else "Remove Backgrounds only behind People / Objects" , 1),
|
||||
("Remove Backgrounds only behind People / Objects except main Subject / Landscape" if (flux or qwen) else ("Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" if vace else "Remove Backgrounds behind People / Objects") , 1),
|
||||
],
|
||||
value=ui_defaults.get("remove_background_images_ref",1),
|
||||
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not model_def.get("no_background_removal", False)
|
||||
value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1),
|
||||
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal
|
||||
)
|
||||
|
||||
any_audio_voices_support = any_audio_track(base_model_type)
|
||||
@ -7348,16 +7439,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True)
|
||||
else:
|
||||
min_frames, frames_step = get_model_min_frames_and_step(base_model_type)
|
||||
|
||||
current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97)
|
||||
|
||||
video_length = gr.Slider(min_frames, get_max_frames(737 if test_any_sliding_window(base_model_type) else 337), value=ui_defaults.get(
|
||||
"video_length", 81 if get_model_family(base_model_type)=="wan" else 97),
|
||||
step=frames_step, label=f"Number of frames ({fps} = 1s)", visible = True, interactive= True)
|
||||
video_length = gr.Slider(min_frames, get_max_frames(737 if test_any_sliding_window(base_model_type) else 337), value=current_video_length,
|
||||
step=frames_step, label=compute_video_length_label(fps, current_video_length) , visible = True, interactive= True)
|
||||
|
||||
with gr.Row(visible = not lock_inference_steps) as inference_steps_row:
|
||||
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps", visible = True)
|
||||
|
||||
|
||||
|
||||
show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced_ui)
|
||||
with gr.Tabs(visible=advanced_ui) as advanced_row:
|
||||
guidance_max_phases = model_def.get("guidance_max_phases", 0)
|
||||
@ -7396,6 +7487,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
guidance2_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance2_scale",5), step=0.5, label="Guidance2 (CFG)", visible= guidance_max_phases >=2 and guidance_phases_value >= 2)
|
||||
guidance3_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("guidance3_scale",5), step=0.5, label="Guidance3 (CFG)", visible= guidance_max_phases >=3 and guidance_phases_value >= 3)
|
||||
|
||||
any_embedded_guidance = model_def.get("embedded_guidance", False)
|
||||
with gr.Row(visible =any_embedded_guidance or any_audio_guidance) as embedded_guidance_row:
|
||||
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 4), step=0.5, label="Audio Guidance", visible= any_audio_guidance )
|
||||
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 6.0), step=0.5, label="Embedded Guidance Scale", visible=any_embedded_guidance )
|
||||
|
||||
sample_solver_choices = model_def.get("sample_solvers", None)
|
||||
with gr.Row(visible = sample_solver_choices is not None or not image_outputs) as sample_solver_row:
|
||||
@ -7407,12 +7502,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
)
|
||||
flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs)
|
||||
|
||||
|
||||
any_embedded_guidance = model_def.get("embedded_guidance", False)
|
||||
with gr.Row(visible =any_embedded_guidance or any_audio_guidance) as embedded_guidance_row:
|
||||
audio_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("audio_guidance_scale", 4), step=0.5, label="Audio Guidance", visible= any_audio_guidance )
|
||||
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=ui_defaults.get("embedded_guidance", 6.0), step=0.5, label="Embedded Guidance Scale", visible=any_embedded_guidance )
|
||||
|
||||
with gr.Row(visible = vace) as control_net_weights_row:
|
||||
control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace)
|
||||
control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace)
|
||||
@ -7584,7 +7673,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
value=ui_defaults.get("apg_switch",0),
|
||||
visible=True,
|
||||
scale = 1,
|
||||
label="Adaptive Projected Guidance (requires Guidance > 1) "
|
||||
label="Adaptive Projected Guidance (requires Guidance > 1 or Audio Guidance > 1) " if multitalk else "Adaptive Projected Guidance (requires Guidance > 1)",
|
||||
)
|
||||
|
||||
with gr.Column(visible = any_cfg_star) as cfg_free_guidance_col:
|
||||
@ -7602,7 +7691,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
with gr.Row():
|
||||
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = any_cfg_zero)
|
||||
|
||||
with gr.Column(visible = (vace or t2v) and image_outputs) as min_frames_if_references_col:
|
||||
with gr.Column(visible = (vace or t2v or standin) and image_outputs) as min_frames_if_references_col:
|
||||
gr.Markdown("<B>Generating a single Frame alone may not be sufficient to preserve Reference Image Identity / Control Image Information or simply to get a good Image Quality. A workaround is to generate a short Video and keep the First Frame.")
|
||||
min_frames_if_references = gr.Dropdown(
|
||||
choices=[
|
||||
@ -7839,7 +7928,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
|
||||
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab,
|
||||
sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col,
|
||||
video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row,
|
||||
video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row,
|
||||
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
|
||||
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
|
||||
video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row,
|
||||
@ -7859,13 +7948,15 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
resolution.change(fn=record_last_resolution, inputs=[state, resolution])
|
||||
|
||||
|
||||
video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" )
|
||||
guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ])
|
||||
audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type])
|
||||
audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row])
|
||||
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] )
|
||||
# video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
|
||||
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col])
|
||||
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col])
|
||||
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
|
||||
video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs ])
|
||||
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
|
||||
video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type])
|
||||
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user