ode to Vace

This commit is contained in:
deepbeepmeep 2025-07-26 11:45:42 +02:00
parent d2843303a2
commit 90275dfc78
13 changed files with 1101 additions and 169 deletions

View File

@ -20,10 +20,22 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates
### July 26 2025: WanGP v7.2 : Ode to Vace
I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk.
Here are some new Vace improvements:
- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model def in the finetune folder to change the Cocktail composition.
- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little be slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*).
- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged.
- Color fixing when using Sliding Windows. A new postprocessing *Colot Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code.
Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats.
### July 21 2025: WanGP v7.12
- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App)
- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them.

View File

@ -1,16 +0,0 @@
{
"model": {
"name": "Flux 1 Dev 12B",
"architecture": "flux",
"description": "FLUX.1 Dev is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev_quanto_bf16_int8.safetensors"
],
"image_outputs": true,
"flux-model": "flux-dev"
},
"prompt": "draw a hat",
"resolution": "1280x720",
"batch_size": 1
}

View File

@ -1,13 +0,0 @@
{
"model": {
"name": "Wan2.1 text2image 14B",
"architecture": "t2v",
"description": "The original Wan Text 2 Video model configured to generate an image instead of a video.",
"image_outputs": true,
"URLs": "t2v"
},
"batch_size": 1,
"resolution": "1280x720"
}

View File

@ -0,0 +1,21 @@
{
"model": {
"name": "Vace Cocktail 14B",
"architecture": "vace_14B",
"modules": [
"vace_14B"
],
"description": "This model has been created on the fly using the Wan text 2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Copy the model def in the finetune folder to change the Cocktail composition.",
"URLs": "t2v",
"loras": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors"
],
"loras_multipliers": [1, 0.5, 0.5, 0.5]
},
"num_inference_steps": 10,
"guidance_scale": 1,
"flow_shift": 2
}

View File

@ -1,16 +0,0 @@
{
"model": {
"name": "Vace FusioniX image2image 14B",
"architecture": "vace_14B",
"modules": [
"vace_14B"
],
"image_outputs": true,
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
"URLs": "t2v_fusionix"
},
"resolution": "1280x720",
"guidance_scale": 1,
"num_inference_steps": 10,
"batch_size": 1
}

View File

@ -94,17 +94,20 @@ class Flux(nn.Module):
if first_key.startswith("lora_unet_"):
new_sd = {}
print("Converting Lora Safetensors format to Lora Diffusers format")
repl_list = ["linear1", "linear2", "modulation_lin"]
repl_list = ["linear1", "linear2", "modulation", "img_attn", "txt_attn", "img_mlp", "txt_mlp", "img_mod", "txt_mod"]
src_list = ["_" + k + "." for k in repl_list]
tgt_list = ["." + k.replace("_", ".") + "." for k in repl_list]
src_list2 = ["_" + k + "_" for k in repl_list]
tgt_list = ["." + k + "." for k in repl_list]
for k,v in sd.items():
k = k.replace("lora_unet_blocks_","diffusion_model.blocks.")
k = k.replace("lora_unet__blocks_","diffusion_model.blocks.")
k = k.replace("lora_unet_single_blocks_","diffusion_model.single_blocks.")
k = k.replace("lora_unet_double_blocks_","diffusion_model.double_blocks.")
for s,t in zip(src_list, tgt_list):
for s,s2, t in zip(src_list, src_list2, tgt_list):
k = k.replace(s,t)
k = k.replace(s2,t)
k = k.replace("lora_up","lora_B")
k = k.replace("lora_down","lora_A")

View File

@ -30,6 +30,7 @@ def get_noise(
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
dtype=dtype,
device=device,
generator=torch.Generator(device=device).manual_seed(seed),
)

View File

@ -45,5 +45,6 @@ misaki
soundfile
ffmpeg-python
pyannote.audio
pynvml
# num2words
# spacy

View File

@ -30,7 +30,7 @@ from wan.modules.posemb_layers import get_rotary_pos_embed
from .utils.vace_preprocessor import VaceVideoProcessor
from wan.utils.basic_flowmatch import FlowMatchScheduler
from wan.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask
from mmgp import safetensors2
def optimized_scale(positive_flat, negative_flat):
@ -78,7 +78,6 @@ class WanAny2V:
self.num_train_timesteps = config.num_train_timesteps
self.param_dtype = config.param_dtype
self.model_def = model_def
self.image_outputs = model_def.get("image_outputs", False)
self.text_encoder = T5EncoderModel(
text_len=config.text_len,
dtype=config.t5_dtype,
@ -382,6 +381,9 @@ class WanAny2V:
offloadobj = None,
apg_switch = False,
speakers_bboxes = None,
color_correction_strength = 1,
prefix_frames_count = 0,
image_mode = 0,
**bbargs
):
@ -418,9 +420,9 @@ class WanAny2V:
seed_g = torch.Generator(device=self.device)
seed_g.manual_seed(seed)
image_outputs = image_mode == 1
kwargs = {'pipeline': self, 'callback': callback}
color_reference_frame = None
if self._interrupt:
return None
@ -468,6 +470,7 @@ class WanAny2V:
enc = torch.concat( [input_frames, torch.zeros( (3, frame_num-preframes_count, height, width),
device=self.device, dtype= self.VAE_dtype)],
dim = 1).to(self.device)
color_reference_frame = input_frames[:, -1:].clone()
input_frames = None
else:
preframes_count = 1
@ -498,6 +501,7 @@ class WanAny2V:
img_interpolated = resize_lanczos(image_start, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
image_start = image_start.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
color_reference_frame = image_start.clone()
if image_end!= None:
img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
image_end = resize_lanczos(image_end, clip_image_size, clip_image_size)
@ -566,12 +570,13 @@ class WanAny2V:
injection_denoising_step = 0
inject_from_start = False
if input_frames != None and denoising_strength < 1 :
color_reference_frame = input_frames[:, -1:].clone()
if overlapped_latents != None:
overlapped_latents_frames_num = overlapped_latents.shape[1]
overlapped_latents_frames_num = overlapped_latents.shape[2]
overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
else:
overlapped_latents_frames_num = overlapped_frames_num = 0
if len(keep_frames_parsed) == 0 or self.image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
if len(keep_frames_parsed) == 0 or image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
injection_denoising_step = int(sampling_steps * (1. - denoising_strength) )
latent_keep_frames = []
if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0:
@ -606,6 +611,7 @@ class WanAny2V:
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
m0 = self.vace_encode_masks(input_masks, input_ref_images)
if self.background_mask != None:
color_reference_frame = input_ref_images[0][0].clone()
zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
mbg = self.vace_encode_masks(self.background_mask, None)
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
@ -621,6 +627,8 @@ class WanAny2V:
if overlapped_latents != None :
overlapped_latents_size = overlapped_latents.shape[2]
extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
if prefix_frames_count > 0:
color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone()
target_shape = list(z0[0].shape)
target_shape[0] = int(target_shape[0] / 2)
@ -691,7 +699,6 @@ class WanAny2V:
apg_norm_threshold = 55
text_momentumbuffer = MomentumBuffer(apg_momentum)
audio_momentumbuffer = MomentumBuffer(apg_momentum)
# self.image_outputs = False
# denoising
for i, t in enumerate(tqdm(timesteps)):
offload.set_step_no_for_lora(self.model, i)
@ -777,7 +784,7 @@ class WanAny2V:
noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = ret_values
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
noise_pred_noaudio = None
elif multitalk:
elif multitalk and audio_proj != None:
noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values
if apg_switch != 0:
noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_text,
@ -830,6 +837,7 @@ class WanAny2V:
latents_preview = latents
if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
if image_outputs: latents_preview= latents_preview[:, :,:1]
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
callback(i, latents_preview[0], False)
latents_preview = None
@ -846,10 +854,18 @@ class WanAny2V:
videos = self.vae.decode(x0, VAE_tile_size)
if self.image_outputs:
videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0]
if image_outputs:
videos = torch.cat([video[:,:1] for video in videos], dim=1) if len(videos) > 1 else videos[0][:,:1]
else:
videos = videos[0] # return only first video
videos = videos[0] # return only first video
if color_correction_strength > 0:
if vace and False:
# videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), input_frames[0].unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "progressive_blend").squeeze(0)
videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), input_frames[0].unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "reference").squeeze(0)
# videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), videos.unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "reference").squeeze(0)
elif color_reference_frame is not None:
videos = match_and_blend_colors(videos.unsqueeze(0), color_reference_frame.unsqueeze(0), color_correction_strength).squeeze(0)
if return_latent_slice != None:
return { "x" : videos, "latent_slice" : latent_slice }
return videos

View File

@ -15,6 +15,7 @@ import soundfile as sf
import torchvision
import binascii
import os.path as osp
from skimage import color
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
@ -351,3 +352,531 @@ def adaptive_projected_guidance(
diff_parallel, diff_orthogonal = project(diff, pred_cond)
normalized_update = diff_orthogonal + eta * diff_parallel
return normalized_update
def match_and_blend_colors(source_chunk: torch.Tensor, reference_image: torch.Tensor, strength: float) -> torch.Tensor:
"""
Matches the color of a source video chunk to a reference image and blends with the original.
Args:
source_chunk (torch.Tensor): The video chunk to be color-corrected (B, C, T, H, W) in range [-1, 1].
Assumes B=1 (batch size of 1).
reference_image (torch.Tensor): The reference image (B, C, 1, H, W) in range [-1, 1].
Assumes B=1 and T=1 (single reference frame).
strength (float): The strength of the color correction (0.0 to 1.0).
0.0 means no correction, 1.0 means full correction.
Returns:
torch.Tensor: The color-corrected and blended video chunk.
"""
# print(f"[match_and_blend_colors] Input source_chunk shape: {source_chunk.shape}, reference_image shape: {reference_image.shape}, strength: {strength}")
if strength == 0.0:
# print(f"[match_and_blend_colors] Strength is 0, returning original source_chunk.")
return source_chunk
if not 0.0 <= strength <= 1.0:
raise ValueError(f"Strength must be between 0.0 and 1.0, got {strength}")
device = source_chunk.device
dtype = source_chunk.dtype
# Squeeze batch dimension, permute to T, H, W, C for skimage
# Source: (1, C, T, H, W) -> (T, H, W, C)
source_np = source_chunk.squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
# Reference: (1, C, 1, H, W) -> (H, W, C)
ref_np = reference_image.squeeze(0).squeeze(1).permute(1, 2, 0).cpu().numpy() # Squeeze T dimension as well
# Normalize from [-1, 1] to [0, 1] for skimage
source_np_01 = (source_np + 1.0) / 2.0
ref_np_01 = (ref_np + 1.0) / 2.0
# Clip to ensure values are strictly in [0, 1] after potential float precision issues
source_np_01 = np.clip(source_np_01, 0.0, 1.0)
ref_np_01 = np.clip(ref_np_01, 0.0, 1.0)
# Convert reference to Lab
try:
ref_lab = color.rgb2lab(ref_np_01)
except ValueError as e:
# Handle potential errors if image data is not valid for conversion
print(f"Warning: Could not convert reference image to Lab: {e}. Skipping color correction for this chunk.")
return source_chunk
corrected_frames_np_01 = []
for i in range(source_np_01.shape[0]): # Iterate over time (T)
source_frame_rgb_01 = source_np_01[i]
try:
source_lab = color.rgb2lab(source_frame_rgb_01)
except ValueError as e:
print(f"Warning: Could not convert source frame {i} to Lab: {e}. Using original frame.")
corrected_frames_np_01.append(source_frame_rgb_01)
continue
corrected_lab_frame = source_lab.copy()
# Perform color transfer for L, a, b channels
for j in range(3): # L, a, b
mean_src, std_src = source_lab[:, :, j].mean(), source_lab[:, :, j].std()
mean_ref, std_ref = ref_lab[:, :, j].mean(), ref_lab[:, :, j].std()
# Avoid division by zero if std_src is 0
if std_src == 0:
# If source channel has no variation, keep it as is, but shift by reference mean
# This case is debatable, could also just copy source or target mean.
# Shifting by target mean helps if source is flat but target isn't.
corrected_lab_frame[:, :, j] = mean_ref
else:
corrected_lab_frame[:, :, j] = (corrected_lab_frame[:, :, j] - mean_src) * (std_ref / std_src) + mean_ref
try:
fully_corrected_frame_rgb_01 = color.lab2rgb(corrected_lab_frame)
except ValueError as e:
print(f"Warning: Could not convert corrected frame {i} back to RGB: {e}. Using original frame.")
corrected_frames_np_01.append(source_frame_rgb_01)
continue
# Clip again after lab2rgb as it can go slightly out of [0,1]
fully_corrected_frame_rgb_01 = np.clip(fully_corrected_frame_rgb_01, 0.0, 1.0)
# Blend with original source frame (in [0,1] RGB)
blended_frame_rgb_01 = (1 - strength) * source_frame_rgb_01 + strength * fully_corrected_frame_rgb_01
corrected_frames_np_01.append(blended_frame_rgb_01)
corrected_chunk_np_01 = np.stack(corrected_frames_np_01, axis=0)
# Convert back to [-1, 1]
corrected_chunk_np_minus1_1 = (corrected_chunk_np_01 * 2.0) - 1.0
# Permute back to (C, T, H, W), add batch dim, and convert to original torch.Tensor type and device
# (T, H, W, C) -> (C, T, H, W)
corrected_chunk_tensor = torch.from_numpy(corrected_chunk_np_minus1_1).permute(3, 0, 1, 2).unsqueeze(0)
corrected_chunk_tensor = corrected_chunk_tensor.contiguous() # Ensure contiguous memory layout
output_tensor = corrected_chunk_tensor.to(device=device, dtype=dtype)
# print(f"[match_and_blend_colors] Output tensor shape: {output_tensor.shape}")
return output_tensor
from skimage import color
from scipy import ndimage
from scipy.ndimage import binary_erosion, distance_transform_edt
def match_and_blend_colors_with_mask(
source_chunk: torch.Tensor,
reference_video: torch.Tensor,
mask: torch.Tensor,
strength: float,
copy_mode: str = "corrected", # "corrected", "reference", "source", "progressive_blend"
source_border_distance: int = 10,
reference_border_distance: int = 10
) -> torch.Tensor:
"""
Matches the color of a source video chunk to a reference video using mask-based region sampling.
Args:
source_chunk (torch.Tensor): The video chunk to be color-corrected (B, C, T, H, W) in range [-1, 1].
Assumes B=1 (batch size of 1).
reference_video (torch.Tensor): The reference video (B, C, T, H, W) in range [-1, 1].
Must have same temporal dimension as source_chunk.
mask (torch.Tensor): Binary mask (B, 1, T, H, W) or (T, H, W) or (H, W) with values 0 and 1.
Color correction is applied to pixels where mask=1.
strength (float): The strength of the color correction (0.0 to 1.0).
0.0 means no correction, 1.0 means full correction.
copy_mode (str): What to do with mask=0 pixels:
"corrected" (keep original), "reference", "source",
"progressive_blend" (double-sided progressive blending near borders).
source_border_distance (int): Distance in pixels from mask border to sample source video (mask=1 side).
reference_border_distance (int): Distance in pixels from mask border to sample reference video (mask=0 side).
For "progressive_blend" mode, this also defines the blending falloff distance.
Returns:
torch.Tensor: The color-corrected and blended video chunk.
Notes:
- Color statistics are sampled from border regions to determine source and reference tints
- Progressive blending creates smooth double-sided transitions:
* mask=1 side: 60% source + 40% reference at border 100% source deeper in
* mask=0 side: 60% reference + 40% source at border 100% reference deeper in
"""
if strength == 0.0:
return source_chunk
if not 0.0 <= strength <= 1.0:
raise ValueError(f"Strength must be between 0.0 and 1.0, got {strength}")
if copy_mode not in ["corrected", "reference", "source", "progressive_blend"]:
raise ValueError(f"copy_mode must be 'corrected', 'reference', 'source', or 'progressive_blend', got {copy_mode}")
device = source_chunk.device
dtype = source_chunk.dtype
B, C, T, H, W = source_chunk.shape
# Handle different mask dimensions
if mask.dim() == 2: # (H, W)
mask = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(B, 1, T, H, W)
elif mask.dim() == 3: # (T, H, W)
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, T, H, W)
elif mask.dim() == 4: # (B, T, H, W) - missing channel dim
mask = mask.unsqueeze(1)
# mask should now be (B, 1, T, H, W)
# Convert to numpy for processing
source_np = source_chunk.squeeze(0).permute(1, 2, 3, 0).cpu().numpy() # (T, H, W, C)
reference_np = reference_video.squeeze(0).permute(1, 2, 3, 0).cpu().numpy() # (T, H, W, C)
mask_np = mask.squeeze(0).squeeze(0).cpu().numpy() # (T, H, W)
# Normalize from [-1, 1] to [0, 1] for skimage
source_np_01 = (source_np + 1.0) / 2.0
reference_np_01 = (reference_np + 1.0) / 2.0
# Clip to ensure values are in [0, 1]
source_np_01 = np.clip(source_np_01, 0.0, 1.0)
reference_np_01 = np.clip(reference_np_01, 0.0, 1.0)
corrected_frames_np_01 = []
for t in range(T):
source_frame = source_np_01[t] # (H, W, C)
reference_frame = reference_np_01[t] # (H, W, C)
frame_mask = mask_np[t] # (H, W)
# Find mask borders and create distance maps
border_regions = get_border_sampling_regions(frame_mask, source_border_distance, reference_border_distance)
source_sample_region = border_regions['source_region'] # mask=1 side
reference_sample_region = border_regions['reference_region'] # mask=0 side
# Sample pixels for color statistics
try:
source_stats = compute_color_stats(source_frame, source_sample_region)
reference_stats = compute_color_stats(reference_frame, reference_sample_region)
except ValueError as e:
print(f"Warning: Could not compute color statistics for frame {t}: {e}. Using original frame.")
corrected_frames_np_01.append(source_frame)
continue
# Apply color correction to mask=1 area and handle mask=0 area based on copy_mode
corrected_frame = apply_color_correction_with_mask(
source_frame, frame_mask, source_stats, reference_stats, strength
)
# Handle mask=0 pixels based on copy_mode
if copy_mode == "reference":
corrected_frame = apply_copy_with_mask(corrected_frame, reference_frame, frame_mask, "reference")
elif copy_mode == "source":
corrected_frame = apply_copy_with_mask(corrected_frame, source_frame, frame_mask, "source")
elif copy_mode == "progressive_blend":
# Apply progressive blending in mask=1 border area (source side)
corrected_frame = apply_progressive_blend_in_corrected_area(
corrected_frame, reference_frame, frame_mask,
border_regions['source_region'], border_regions['source_distances'],
border_regions['reference_region'], source_border_distance
)
# Copy reference pixels to mask=0 area first
corrected_frame = apply_copy_with_mask(corrected_frame, reference_frame, frame_mask, "reference")
# Then apply progressive blending in mask=0 border area (reference side)
corrected_frame = apply_progressive_blend_in_reference_area(
corrected_frame, source_frame, frame_mask,
border_regions['reference_region'], border_regions['reference_distances'],
reference_border_distance
)
corrected_frames_np_01.append(corrected_frame)
corrected_chunk_np_01 = np.stack(corrected_frames_np_01, axis=0)
# Convert back to [-1, 1] and return to tensor format
corrected_chunk_np_minus1_1 = (corrected_chunk_np_01 * 2.0) - 1.0
corrected_chunk_tensor = torch.from_numpy(corrected_chunk_np_minus1_1).permute(3, 0, 1, 2).unsqueeze(0)
corrected_chunk_tensor = corrected_chunk_tensor.contiguous()
output_tensor = corrected_chunk_tensor.to(device=device, dtype=dtype)
return output_tensor
def get_border_sampling_regions(mask, source_border_distance, reference_border_distance):
"""
Create regions for sampling near mask borders with separate distances for source and reference.
Args:
mask: Binary mask (H, W) with 0s and 1s
source_border_distance: Distance from border to include in source sampling (mask=1 side)
reference_border_distance: Distance from border to include in reference sampling (mask=0 side)
Returns:
Dict with sampling regions and distance maps for blending
"""
# Convert to boolean for safety
mask_bool = mask.astype(bool)
# Distance from mask=0 regions (distance into mask=1 areas from border)
dist_from_mask0 = distance_transform_edt(mask_bool)
# Distance from mask=1 regions (distance into mask=0 areas from border)
dist_from_mask1 = distance_transform_edt(~mask_bool)
# Source region: mask=1 pixels within source_border_distance of mask=0 pixels
source_region = mask_bool & (dist_from_mask0 <= source_border_distance)
# Reference region: mask=0 pixels within reference_border_distance of mask=1 pixels
reference_region = (~mask_bool) & (dist_from_mask1 <= reference_border_distance)
return {
'source_region': source_region,
'reference_region': reference_region,
'source_distances': dist_from_mask0, # Distance into mask=1 from border
'reference_distances': dist_from_mask1 # Distance into mask=0 from border
}
def compute_color_stats(image, sample_region):
"""
Compute color statistics (mean and std) for Lab channels in the sampling region.
Args:
image: RGB image (H, W, C) in range [0, 1]
sample_region: Boolean mask (H, W) indicating pixels to sample
Returns:
Dict with 'mean' and 'std' for Lab components
"""
if not np.any(sample_region):
raise ValueError("No pixels in sampling region")
# Convert to Lab
try:
image_lab = color.rgb2lab(image)
except ValueError as e:
raise ValueError(f"Could not convert image to Lab: {e}")
# Extract pixels in sampling region
sampled_pixels = image_lab[sample_region] # (N, 3) where N is number of sampled pixels
# Compute statistics for each Lab channel
stats = {
'mean': np.mean(sampled_pixels, axis=0), # (3,) for L, a, b
'std': np.std(sampled_pixels, axis=0) # (3,) for L, a, b
}
return stats
def apply_color_correction_with_mask(source_frame, mask, source_stats, reference_stats, strength):
"""
Apply color correction to pixels where mask=1.
Args:
source_frame: RGB image (H, W, C) in range [0, 1]
mask: Binary mask (H, W)
source_stats: Color statistics from source sampling region
reference_stats: Color statistics from reference sampling region
strength: Blending strength
Returns:
Corrected RGB image (H, W, C)
"""
try:
source_lab = color.rgb2lab(source_frame)
except ValueError as e:
print(f"Warning: Could not convert source frame to Lab: {e}. Using original frame.")
return source_frame
corrected_lab = source_lab.copy()
correction_region = (mask == 1) # Apply correction to mask=1 pixels
# Apply color transfer to pixels where mask=1
for c in range(3): # L, a, b channels
mean_src = source_stats['mean'][c]
std_src = source_stats['std'][c]
mean_ref = reference_stats['mean'][c]
std_ref = reference_stats['std'][c]
if std_src == 0:
# Handle case where source channel has no variation
corrected_lab[correction_region, c] = mean_ref
else:
# Standard color transfer formula
corrected_lab[correction_region, c] = (
(corrected_lab[correction_region, c] - mean_src) * (std_ref / std_src) + mean_ref
)
try:
fully_corrected_rgb = color.lab2rgb(corrected_lab)
except ValueError as e:
print(f"Warning: Could not convert corrected frame back to RGB: {e}. Using original frame.")
return source_frame
# Clip to [0, 1]
fully_corrected_rgb = np.clip(fully_corrected_rgb, 0.0, 1.0)
# Blend with original (only in correction region)
result = source_frame.copy()
result[correction_region] = (
(1 - strength) * source_frame[correction_region] +
strength * fully_corrected_rgb[correction_region]
)
return result
def apply_progressive_blend_in_corrected_area(corrected_frame, reference_frame, mask, source_region, source_distances, reference_region, source_border_distance):
"""
Apply progressive blending in the corrected area (mask=1) near the border.
Args:
corrected_frame: RGB image (H, W, C) - the color-corrected source frame
reference_frame: RGB image (H, W, C) - the reference frame
mask: Binary mask (H, W)
source_region: Boolean mask (H, W) indicating the source blending region (mask=1 near border)
source_distances: Distance map (H, W) into mask=1 area from mask=0 border
reference_region: Boolean mask (H, W) indicating the reference sampling region (mask=0 near border)
source_border_distance: Maximum distance for source blending
Returns:
Blended RGB image (H, W, C)
Notes:
- Each source pixel blends with its closest reference border pixel (for speed)
- At mask border: 60% source + 40% reference
- Deeper into mask=1 area: 100% corrected source
"""
result = corrected_frame.copy()
# Blend in the source region (mask=1 pixels near border)
blend_region = source_region
if np.any(blend_region):
# Find immediate border pixels (mask=0 pixels adjacent to mask=1 pixels)
# This is much faster than using the entire reference region
from scipy.ndimage import binary_dilation
# Dilate mask=1 by 1 pixel, then find intersection with mask=0
mask_1_dilated = binary_dilation(mask == 1, structure=np.ones((3, 3)))
border_pixels = (mask == 0) & mask_1_dilated
if np.any(border_pixels):
# Find closest border pixel for each source pixel
source_coords = np.column_stack(np.where(blend_region)) # (N, 2) - y, x coordinates
border_coords = np.column_stack(np.where(border_pixels)) # (M, 2) - much smaller set!
# For each source pixel, find closest border pixel
from scipy.spatial.distance import cdist
distances_matrix = cdist(source_coords, border_coords, metric='euclidean')
closest_border_indices = np.argmin(distances_matrix, axis=1)
# Normalize source distances for blending weights
min_distance_in_region = np.min(source_distances[blend_region])
max_distance_in_region = np.max(source_distances[blend_region])
if max_distance_in_region > min_distance_in_region:
# Calculate blend weights: 0.4 at border (60% source + 40% reference), 0.0 at max distance (100% source)
source_dist_values = source_distances[blend_region]
normalized_distances = (source_dist_values - min_distance_in_region) / (max_distance_in_region - min_distance_in_region)
blend_weights = 0.4 * (1.0 - normalized_distances) # Start with 40% reference influence at border
# Apply blending with closest border pixels
for i, (source_y, source_x) in enumerate(source_coords):
closest_border_idx = closest_border_indices[i]
border_y, border_x = border_coords[closest_border_idx]
weight = blend_weights[i]
# Blend with closest border pixel
result[source_y, source_x] = (
(1.0 - weight) * corrected_frame[source_y, source_x] +
weight * reference_frame[border_y, border_x]
)
return result
def apply_progressive_blend_in_reference_area(reference_frame, source_frame, mask, reference_region, reference_distances, reference_border_distance):
"""
Apply progressive blending in the reference area (mask=0) near the border.
Args:
reference_frame: RGB image (H, W, C) - the reference frame with copied reference pixels
source_frame: RGB image (H, W, C) - the original source frame
mask: Binary mask (H, W)
reference_region: Boolean mask (H, W) indicating the reference blending region (mask=0 near border)
reference_distances: Distance map (H, W) into mask=0 area from mask=1 border
reference_border_distance: Maximum distance for reference blending
Returns:
Blended RGB image (H, W, C)
Notes:
- Each reference pixel blends with its closest source border pixel (for speed)
- At mask border: 60% reference + 40% source
- Deeper into mask=0 area: 100% reference
"""
result = reference_frame.copy()
# Blend in the reference region (mask=0 pixels near border)
blend_region = reference_region
if np.any(blend_region):
# Find immediate border pixels (mask=1 pixels adjacent to mask=0 pixels)
from scipy.ndimage import binary_dilation
# Dilate mask=0 by 1 pixel, then find intersection with mask=1
mask_0_dilated = binary_dilation(mask == 0, structure=np.ones((3, 3)))
source_border_pixels = (mask == 1) & mask_0_dilated
if np.any(source_border_pixels):
# Find closest source border pixel for each reference pixel
reference_coords = np.column_stack(np.where(blend_region)) # (N, 2) - y, x coordinates
source_border_coords = np.column_stack(np.where(source_border_pixels)) # (M, 2)
# For each reference pixel, find closest source border pixel
from scipy.spatial.distance import cdist
distances_matrix = cdist(reference_coords, source_border_coords, metric='euclidean')
closest_source_indices = np.argmin(distances_matrix, axis=1)
# Normalize reference distances for blending weights
min_distance_in_region = np.min(reference_distances[blend_region])
max_distance_in_region = np.max(reference_distances[blend_region])
if max_distance_in_region > min_distance_in_region:
# Calculate blend weights: 0.4 at border (60% reference + 40% source), 0.0 at max distance (100% reference)
reference_dist_values = reference_distances[blend_region]
normalized_distances = (reference_dist_values - min_distance_in_region) / (max_distance_in_region - min_distance_in_region)
blend_weights = 0.4 * (1.0 - normalized_distances) # Start with 40% source influence at border
# Apply blending with closest source border pixels
for i, (ref_y, ref_x) in enumerate(reference_coords):
closest_source_idx = closest_source_indices[i]
source_y, source_x = source_border_coords[closest_source_idx]
weight = blend_weights[i]
# Blend: weight=0.4 means 60% reference + 40% source at border
result[ref_y, ref_x] = (
(1.0 - weight) * reference_frame[ref_y, ref_x] +
weight * source_frame[source_y, source_x]
)
return result
def apply_copy_with_mask(source_frame, reference_frame, mask, copy_source):
"""
Copy pixels to mask=0 regions based on copy_source parameter.
Args:
source_frame: RGB image (H, W, C)
reference_frame: RGB image (H, W, C)
mask: Binary mask (H, W)
copy_source: "reference" or "source"
Returns:
Combined RGB image (H, W, C)
"""
result = source_frame.copy()
mask_0_region = (mask == 0)
if copy_source == "reference":
result[mask_0_region] = reference_frame[mask_0_region]
# If "source", we keep the original source pixels (no change needed)
return result

View File

@ -103,46 +103,67 @@ def generate_notification_beep(volume=50, sample_rate=44100):
wave = wave / max_amplitude * 0.85 # More conservative normalization
return wave
_mixer_lock = threading.Lock()
def play_audio_with_pygame(audio_data, sample_rate=44100):
"""Play audio using pygame backend"""
"""
Play audio with clean stereo output - sounds like single notification from both speakers
"""
try:
import pygame
# Initialize pygame mixer only if not already initialized
if not pygame.mixer.get_init():
pygame.mixer.pre_init(frequency=sample_rate, size=-16, channels=2, buffer=1024)
pygame.mixer.init()
else:
# Reinitialize with new settings if needed
current_freq, current_size, current_channels = pygame.mixer.get_init()
if current_freq != sample_rate or current_channels != 2:
with _mixer_lock:
if len(audio_data) == 0:
return False
# Clean mixer initialization - quit any existing mixer first
if pygame.mixer.get_init() is not None:
pygame.mixer.quit()
pygame.mixer.pre_init(frequency=sample_rate, size=-16, channels=2, buffer=1024)
pygame.mixer.init()
audio_int16 = (audio_data * 32767).astype(np.int16)
# Convert mono to stereo
if len(audio_int16.shape) == 1:
stereo_data = np.column_stack((audio_int16, audio_int16))
else:
stereo_data = audio_int16
sound = pygame.sndarray.make_sound(stereo_data)
sound.play()
pygame.time.wait(int(len(audio_data) / sample_rate * 1000) + 100)
# Don't quit mixer - this can interfere with Gradio server
# pygame.mixer.quit()
return True
time.sleep(0.2) # Longer pause to ensure clean shutdown
# Initialize fresh mixer
pygame.mixer.pre_init(
frequency=sample_rate,
size=-16,
channels=2,
buffer=512 # Smaller buffer to reduce latency/doubling
)
pygame.mixer.init()
# Verify clean initialization
mixer_info = pygame.mixer.get_init()
if mixer_info is None or mixer_info[2] != 2:
return False
# Prepare audio - ensure clean conversion
audio_int16 = (audio_data * 32767).astype(np.int16)
if len(audio_int16.shape) > 1:
audio_int16 = audio_int16.flatten()
# Create clean stereo with identical channels
stereo_data = np.zeros((len(audio_int16), 2), dtype=np.int16)
stereo_data[:, 0] = audio_int16 # Left channel
stereo_data[:, 1] = audio_int16 # Right channel
# Create sound and play once
sound = pygame.sndarray.make_sound(stereo_data)
# Ensure only one playback
pygame.mixer.stop() # Stop any previous sounds
sound.play()
# Wait for completion
duration_ms = int(len(audio_data) / sample_rate * 1000) + 50
pygame.time.wait(duration_ms)
return True
except ImportError:
return False
except Exception as e:
print(f"Pygame error: {e}")
print(f"Pygame clean error: {e}")
return False
def play_audio_with_sounddevice(audio_data, sample_rate=44100):
"""Play audio using sounddevice backend"""
try:

256
wan/utils/stats.py Normal file
View File

@ -0,0 +1,256 @@
import gradio as gr
import signal
import sys
import time
import threading
import atexit
from contextlib import contextmanager
from collections import deque
import psutil
import pynvml
# Initialize NVIDIA Management Library (NVML) for GPU monitoring
try:
pynvml.nvmlInit()
nvml_initialized = True
except pynvml.NVMLError:
print("Warning: Could not initialize NVML. GPU stats will not be available.")
nvml_initialized = False
class SystemStatsApp:
def __init__(self):
self.running = False
self.active_generators = []
self.setup_signal_handlers()
def setup_signal_handlers(self):
# Handle different shutdown signals
signal.signal(signal.SIGINT, self.shutdown_handler)
signal.signal(signal.SIGTERM, self.shutdown_handler)
if hasattr(signal, 'SIGBREAK'): # Windows
signal.signal(signal.SIGBREAK, self.shutdown_handler)
# Also register atexit handler as backup
atexit.register(self.cleanup)
def shutdown_handler(self, signum, frame):
# print(f"\nReceived signal {signum}. Shutting down gracefully...")
self.cleanup()
sys.exit(0)
def cleanup(self):
if not self.running:
print("Cleaning up streaming connections...")
self.running = False
# Give a moment for generators to stop
time.sleep(1)
def get_system_stats(self, first = False, last_disk_io = psutil.disk_io_counters() ):
# Set a reasonable maximum speed for the bar graph display.
# 100 MB/s will represent a 100% full bar.
MAX_SSD_SPEED_MB_S = 100.0
# Get CPU and RAM stats
if first :
cpu_percent = psutil.cpu_percent(interval=.01)
else:
cpu_percent = psutil.cpu_percent(interval=1) # This provides our 1-second delay
memory_info = psutil.virtual_memory()
ram_percent = memory_info.percent
ram_used_gb = memory_info.used / (1024**3)
ram_total_gb = memory_info.total / (1024**3)
# Get new disk IO counters and calculate the read/write speed in MB/s
current_disk_io = psutil.disk_io_counters()
read_mb_s = (current_disk_io.read_bytes - last_disk_io.read_bytes) / (1024**2)
write_mb_s = (current_disk_io.write_bytes - last_disk_io.write_bytes) / (1024**2)
total_disk_speed = read_mb_s + write_mb_s
# Update the last counters for the next loop
last_disk_io = current_disk_io
# Calculate the bar height as a percentage of our defined max speed
ssd_bar_height = min(100.0, (total_disk_speed / MAX_SSD_SPEED_MB_S) * 100)
# Get GPU stats if the library was initialized successfully
if nvml_initialized:
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # Assuming GPU 0
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
gpu_percent = util.gpu
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
vram_percent = (mem_info.used / mem_info.total) * 100
vram_used_gb = mem_info.used / (1024**3)
vram_total_gb = mem_info.total / (1024**3)
except pynvml.NVMLError:
# Handle cases where GPU might be asleep or driver issues
gpu_percent, vram_percent, vram_used_gb, vram_total_gb = 0, 0, 0, 0
else:
# Set default values if NVML failed to load
gpu_percent, vram_percent, vram_used_gb, vram_total_gb = 0, 0, 0, 0
stats_html = f"""
<style>
.stats-container {{
display: flex;
justify-content: space-between;
align-items: flex-start;
padding: 0px 5px;
height: 60px;
width: 100%;
box-sizing: border-box;
}}
.stats-block {{
width: calc(18% - 5px);
min-width: 100px;
text-align: center;
font-family: sans-serif;
}}
.stats-bar-background {{
width: 90%;
height: 30px;
background-color: #e9ecef;
border: 1px solid #dee2e6;
border-radius: 8px;
overflow: hidden;
position: relative;
margin: 0 auto;
}}
.stats-bar-fill {{
position: absolute;
bottom: 0;
left: 0;
height: 100%;
background-color: #0d6efd;
}}
.stats-title {{
margin-top: 5px;
font-size: 11px;
font-weight: bold;
}}
.stats-detail {{
font-size: 10px;
margin-top: -2px;
}}
</style>
<div class="stats-container">
<!-- CPU Stat Block -->
<div class="stats-block">
<div class="stats-bar-background">
<div class="stats-bar-fill" style="width: {cpu_percent}%;"></div>
</div>
<div class="stats-title">CPU: {cpu_percent:.1f}%</div>
</div>
<!-- RAM Stat Block -->
<div class="stats-block">
<div class="stats-bar-background">
<div class="stats-bar-fill" style="width: {ram_percent}%;"></div>
</div>
<div class="stats-title">RAM {ram_percent:.1f}%</div>
<div class="stats-detail">{ram_used_gb:.1f} / {ram_total_gb:.1f} GB</div>
</div>
<!-- SSD Activity Stat Block -->
<div class="stats-block">
<div class="stats-bar-background">
<div class="stats-bar-fill" style="width: {ssd_bar_height}%;"></div>
</div>
<div class="stats-title">SSD R/W</div>
<div class="stats-detail">{read_mb_s:.1f} / {write_mb_s:.1f} MB/s</div>
</div>
<!-- GPU Stat Block -->
<div class="stats-block">
<div class="stats-bar-background">
<div class="stats-bar-fill" style="width: {gpu_percent}%;"></div>
</div>
<div class="stats-title">GPU: {gpu_percent:.1f}%</div>
</div>
<!-- VRAM Stat Block -->
<div class="stats-block">
<div class="stats-bar-background">
<div class="stats-bar-fill" style="width: {vram_percent}%;"></div>
</div>
<div class="stats-title">VRAM {vram_percent:.1f}%</div>
<div class="stats-detail">{vram_used_gb:.1f} / {vram_total_gb:.1f} GB</div>
</div>
</div>
"""
return stats_html, last_disk_io
def streaming_html(self, state):
if "stats_running" in state:
return
state["stats_running"] = True
self.running = True
last_disk_io = psutil.disk_io_counters()
i = 0
import time
try:
while self.running:
i+= 1
# if i % 2 == 0:
# print(f"time:{time.time()}")
html_content, last_disk_io = self.get_system_stats(False, last_disk_io)
yield html_content
# time.sleep(1)
except GeneratorExit:
# print("Generator stopped gracefully")
return
except Exception as e:
print(f"Streaming error: {e}")
# finally:
# # Send final message indicating clean shutdown
final_html = """
<DIV>
<img src="x" onerror="
setInterval(()=>{
console.log('trying...');
setTimeout(() => {
try{
const btn = document.getElementById('restart_stats');
if(btn) {
console.log('found button, clicking');
btn.click();
} else {
console.log('button not found');
}
}catch(e){console.log('error: ' + e.message)}
}, 100);
}, 8000);" style="display:none;">
<button onclick="document.getElementById('restart_stats').click()"
style="background: #007bff; color: white; padding: 15px 30px;
border: none; border-radius: 5px; font-size: 16px; cursor: pointer;">
🔄 Connection to Server Lost. Attempting Auto reconnect. Click Here to for Manual Connection
</button>
</DIV>
"""
try:
yield final_html
except:
pass
def get_gradio_element(self):
self.system_stats_display = gr.HTML(self.get_system_stats(True)[0])
self.restart_btn = gr.Button("restart stats",elem_id="restart_stats", visible= False) # False)
return self.system_stats_display
def setup_events(self, main, state):
gr.on([main.load, self.restart_btn.click],
fn=self.streaming_html,
inputs = state,
outputs=self.system_stats_display,
show_progress=False
)

275
wgp.py
View File

@ -51,7 +51,7 @@ AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.5.1"
WanGP_version = "7.12"
WanGP_version = "7.2"
settings_version = 2.22
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@ -100,16 +100,19 @@ def download_ffmpeg():
os.rename(f, os.path.basename(f))
os.remove(zip_name)
def format_time(seconds):
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
minutes = seconds / 60
return f"{minutes:.1f}m"
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
if hours > 0:
return f"{hours}h {minutes:02d}m {secs:02d}s"
elif seconds >= 60:
return f"{minutes}m {secs:02d}s"
else:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
return f"{hours}h {minutes}m"
return f"{seconds:.1f}s"
def pil_to_base64_uri(pil_image, format="png", quality=75):
if pil_image is None:
return None
@ -274,6 +277,7 @@ def process_prompt_and_add_tasks(state, model_choice):
num_inference_steps= inputs["num_inference_steps"]
skip_steps_cache_type= inputs["skip_steps_cache_type"]
MMAudio_setting = inputs["MMAudio_setting"]
image_mode = inputs["image_mode"]
if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20:
@ -438,7 +442,7 @@ def process_prompt_and_add_tasks(state, model_choice):
image_end = None
if test_any_sliding_window(model_type):
if test_any_sliding_window(model_type) and image_mode == 0:
if video_length > sliding_window_size:
full_video_length = video_length if video_source is None else video_length + sliding_window_overlap
extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation"
@ -628,9 +632,12 @@ def move_up(queue, selected_indices):
idx = idx[0]
idx = int(idx)
with lock:
if idx > 0:
idx += 1
idx += 1
if idx > 1:
queue[idx], queue[idx-1] = queue[idx-1], queue[idx]
elif idx == 1:
queue[:] = queue[0:1] + queue[2:] + queue[1:2]
return update_queue_data(queue)
def move_down(queue, selected_indices):
@ -644,6 +651,9 @@ def move_down(queue, selected_indices):
idx += 1
if idx < len(queue)-1:
queue[idx], queue[idx+1] = queue[idx+1], queue[idx]
elif idx == len(queue)-1:
queue[:] = queue[0:1] + queue[-1:] + queue[1:-1]
return update_queue_data(queue)
def remove_task(queue, selected_indices):
@ -1056,8 +1066,10 @@ def show_countdown_info_from_state(current_value: int):
gr.Info(f"Quitting in {current_value}...")
return current_value - 1
return current_value
quitting_app = False
def autosave_queue():
global quitting_app
quitting_app = True
global global_queue_ref
if not global_queue_ref:
print("Autosave: Queue is empty, nothing to save.")
@ -1937,7 +1949,7 @@ def fix_settings(model_type, ui_defaults):
video_prompt_type = video_prompt_type.replace("I", "")
remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 0)
remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 1)
if video_settings_version < 2.22:
if "I" in video_prompt_type:
if remove_background_images_ref == 2:
@ -2039,7 +2051,7 @@ def get_default_settings(model_type):
ui_defaults.update({
"guidance_scale": 7.5,
"flow_shift": 5,
"remove_background_images_ref": 0,
"remove_background_images_ref": 1,
"video_prompt_type": "I",
# "resolution": "1280x720"
})
@ -2359,19 +2371,20 @@ def download_models(model_filename, model_type):
def download_file(url,filename):
if url.startswith("https://huggingface.co/") and "/resolve/main/" in url:
base_dir = os.path.dirname(filename)
url = url[len("https://huggingface.co/"):]
url_parts = url.split("/resolve/main/")
repoId = url_parts[0]
onefile = os.path.basename(url_parts[-1])
sourceFolder = os.path.dirname(url_parts[-1])
if len(sourceFolder) == 0:
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/")
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/" if len(base_dir)==0 else base_dir)
else:
target_path = "ckpts/temp/" + sourceFolder
if not os.path.exists(target_path):
os.makedirs(target_path)
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/temp/", subfolder=sourceFolder)
shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/")
shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/" if len(base_dir)==0 else base_dir)
shutil.rmtree("ckpts/temp")
else:
urlretrieve(url,filename, create_progress_hook(filename))
@ -2397,9 +2410,7 @@ def download_models(model_filename, model_type):
model_filename = None
preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True)
model_loras = get_model_recursive_prop(model_type, "loras", return_list= True)
for url in preload_URLs + model_loras:
for url in preload_URLs:
filename = "ckpts/" + url.split("/")[-1]
if not os.path.isfile(filename ):
if not url.startswith("http"):
@ -2409,6 +2420,19 @@ def download_models(model_filename, model_type):
except Exception as e:
if os.path.isfile(filename): os.remove(filename)
raise Exception(f"Preload URL '{url}' is invalid: {str(e)}'")
model_loras = get_model_recursive_prop(model_type, "loras", return_list= True)
for url in model_loras:
filename = os.path.join(get_lora_dir(model_type), url.split("/")[-1])
if not os.path.isfile(filename ):
if not url.startswith("http"):
raise Exception(f"Lora '{filename}' was not found in the Loras Folder and no URL was provided to download it. Please add an URL in the model definition file.")
try:
download_file(url, filename)
except Exception as e:
if os.path.isfile(filename): os.remove(filename)
raise Exception(f"Lora URL '{url}' is invalid: {str(e)}'")
if model_family == "wan":
text_encoder_filename = get_wan_text_encoder_filename(text_encoder_quantization)
model_files = {
@ -2777,7 +2801,7 @@ def generate_header(model_type, compile, attention_mode):
get_model_name(model_type, description_container)
model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or ""
description = description_container[0]
header = "<DIV style='height:40px'>" + description + "</DIV>"
header = f"<DIV style=height:{60 if server_config.get('display_stats', 0) == 1 else 40}px>{description}</DIV>"
header += "<DIV style='align:right;width:100%'><FONT SIZE=3>Attention mode <B>" + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
if attention_mode not in attention_modes_installed:
@ -2824,12 +2848,13 @@ def apply_changes( state,
notification_sound_enabled_choice = 1,
notification_sound_volume_choice = 50,
max_frames_multiplier_choice = 1,
display_stats_choice = 0,
last_resolution_choice = None,
):
if args.lock_config:
return
if gen_in_progress:
return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>", gr.update(), gr.update()
return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>",*[gr.update()]*5
global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
server_config = {
"attention_mode" : attention_choice,
@ -2856,6 +2881,7 @@ def apply_changes( state,
"notification_sound_enabled" : notification_sound_enabled_choice,
"notification_sound_volume" : notification_sound_volume_choice,
"max_frames_multiplier" : max_frames_multiplier_choice,
"display_stats" : display_stats_choice,
"last_model_type" : state["model_type"],
"last_advanced_choice": state["advanced"],
"last_resolution_choice": last_resolution_choice,
@ -2894,7 +2920,7 @@ def apply_changes( state,
transformer_types = server_config["transformer_types"]
model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy)
state["model_filename"] = model_filename
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier"] for change in changes ):
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats"] for change in changes ):
model_choice = gr.Dropdown()
else:
reload_needed = True
@ -2926,6 +2952,7 @@ def get_gen_info(state):
def build_callback(state, pipe, send_cmd, status, num_inference_steps):
gen = get_gen_info(state)
gen["num_inference_steps"] = num_inference_steps
start_time = time.time()
def callback(step_idx, latent, force_refresh, read_state = False, override_num_inference_steps = -1, pass_no = -1):
refresh_id = gen.get("refresh", -1)
if force_refresh or step_idx >= 0:
@ -2966,6 +2993,9 @@ def build_callback(state, pipe, send_cmd, status, num_inference_steps):
gen["progress_phase"] = (phase, step_idx)
status_msg = merge_status_context(status, phase)
elapsed_time = time.time() - start_time
status_msg = merge_status_context(status, f"{phase} | {format_time(elapsed_time)}")
if step_idx >= 0:
progress_args = [(step_idx , num_inference_steps) , status_msg , num_inference_steps]
else:
@ -3020,7 +3050,7 @@ def refresh_gallery(state): #, msg
base_model_type = get_base_model_type(model_type)
model_def = get_model_def(model_type)
is_image = model_def.get("image_outputs", False)
onemorewindow_visible = test_any_sliding_window(base_model_type) and not is_image
onemorewindow_visible = test_any_sliding_window(base_model_type) and params.get("image_mode",0) == 0
enhanced = False
if prompt.startswith("!enhanced!\n"):
enhanced = True
@ -3256,7 +3286,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
if len(video_other_prompts) >0 :
values += [video_other_prompts]
labels += ["Other Prompts"]
if len(video_outpainting) >0 :
if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"):
values += [video_outpainting]
labels += ["Outpainting"]
values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps]
@ -3911,7 +3941,8 @@ def edit_video(
def get_transformer_loras(model_type):
model_def = get_model_def(model_type)
transformer_loras_filenames = get_model_recursive_prop(model_type, "loras", return_list=True)
transformer_loras_filenames = [ os.path.join("ckpts", os.path.basename(filename)) for filename in transformer_loras_filenames]
lora_dir = get_lora_dir(model_type)
transformer_loras_filenames = [ os.path.join(lora_dir, os.path.basename(filename)) for filename in transformer_loras_filenames]
transformer_loras_multipliers = get_model_recursive_prop(model_type, "loras_multipliers", return_list=True) + [1.] * len(transformer_loras_filenames)
transformer_loras_multipliers = transformer_loras_multipliers[:len(transformer_loras_filenames)]
return transformer_loras_filenames, transformer_loras_multipliers
@ -3966,6 +3997,7 @@ def generate_video(
speakers_locations,
sliding_window_size,
sliding_window_overlap,
sliding_window_color_correction_strength,
sliding_window_overlap_noise,
sliding_window_discard_last_frames,
remove_background_images_ref,
@ -3988,6 +4020,7 @@ def generate_video(
cfg_star_switch,
cfg_zero_step,
prompt_enhancer,
min_frames_if_references,
state,
model_type,
model_filename,
@ -4017,7 +4050,8 @@ def generate_video(
model_def = get_model_def(model_type)
is_image = image_mode == 1
if is_image:
video_length = 1
# min_frames_if_references = server_config.get("min_frames_if_references", 5)
video_length = min_frames_if_references if "I" in video_prompt_type else 1
else:
batch_size = 1
temp_filenames_list = []
@ -4205,7 +4239,7 @@ def generate_video(
for i, pos in enumerate(frames_positions_list):
frames_to_inject[pos] = image_refs[i]
if video_guide == None and video_source == None and not "L" in image_prompt_type and (nb_frames_positions > 0 or "K" in video_prompt_type) :
from wan.utils.utils import resize_lanczos, get_outpainting_full_area_dimensions
from wan.utils.utils import get_outpainting_full_area_dimensions
w, h = image_refs[0].size
if outpainting_dims != None:
h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims)
@ -4602,6 +4636,7 @@ def generate_video(
input_masks = src_mask,
input_video= pre_video_guide if diffusion_forcing or ltxv or hunyuan_custom_edit else source_video,
denoising_strength=denoising_strength,
prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames,
target_camera= target_camera,
frame_num= (current_video_length // latent_size)* latent_size + 1,
batch_size = batch_size,
@ -4639,6 +4674,7 @@ def generate_video(
overlapped_latents = overlapped_latents,
return_latent_slice= return_latent_slice,
overlap_noise = sliding_window_overlap_noise,
color_correction_strength = sliding_window_color_correction_strength,
conditioning_latents_size = conditioning_latents_size,
keep_frames_parsed = keep_frames_parsed,
model_filename = model_filename,
@ -4648,6 +4684,7 @@ def generate_video(
NAG_tau = NAG_tau,
NAG_alpha = NAG_alpha,
speakers_bboxes =speakers_bboxes,
image_mode = image_mode,
offloadobj = offloadobj,
)
except Exception as e:
@ -5227,9 +5264,11 @@ def process_tasks(state):
gen["prompt"] = ""
end_time = time.time()
if abort:
status = f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
# status = f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
status = f"Video generation was aborted. Total Generation Time: {format_time(end_time-start_time)}"
else:
status = f"Total Generation Time: {end_time-start_time:.1f}s"
# status = f"Total Generation Time: {end_time-start_time:.1f}s"
status = f"Total Generation Time: {format_time(end_time-start_time)}"
# Play notification sound when video generation completed successfully
try:
if server_config.get("notification_sound_enabled", 1):
@ -5273,8 +5312,13 @@ def merge_status_context(status="", context=""):
elif len(context) == 0:
return status
else:
return status + " - " + context
# Check if context already contains the time
if "|" in context:
parts = context.split("|")
return f"{status} - {parts[0].strip()} | {parts[1].strip()}"
else:
return f"{status} - {context}"
def clear_status(state):
gen = get_gen_info(state)
gen["extra_windows"] = 0
@ -5314,7 +5358,7 @@ def one_more_sample(state):
total_generation = gen.get("total_generation", 0) + extra_orders
gen["progress_status"] = get_latest_status(state)
gen["refresh"] = get_new_refresh_id()
gr.Info(f"An extra sample generation is planned for a total of {total_generation} videos for this prompt")
gr.Info(f"An extra sample generation is planned for a total of {total_generation} samples for this prompt")
return state
@ -5769,13 +5813,13 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
if base_model_type in ["t2v"]: unsaved_params = unsaved_params[2:]
pop += unsaved_params
if not vace:
pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"]
pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2", "min_frames_if_references"]
if not (diffusion_forcing or ltxv or vace):
pop += ["keep_frames_video_source"]
if not test_any_sliding_window( base_model_type):
pop += ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]
pop += ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames", "sliding_window_color_correction_strength"]
if not base_model_type in ["fantasy", "multitalk", "vace_multitalk_14B"]:
pop += ["audio_guidance_scale", "speakers_locations"]
@ -6058,8 +6102,21 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw
return configs, tags != None
def record_image_mode_tab(state, evt:gr.SelectData):
state["image_mode_tab"] = 0 if evt.index ==0 else 1
def switch_image_mode(state):
image_mode = state.get("image_mode_tab", 0)
model_type =state["model_type"]
ui_defaults = get_model_settings(state, model_type)
ui_defaults["image_mode"] = image_mode
return str(time.time())
def load_settings_from_file(state, file_path):
gen = get_gen_info(state)
if file_path==None:
return gr.update(), gr.update(), None
@ -6135,6 +6192,7 @@ def save_inputs(
speakers_locations,
sliding_window_size,
sliding_window_overlap,
sliding_window_color_correction_strength,
sliding_window_overlap_noise,
sliding_window_discard_last_frames,
remove_background_images_ref,
@ -6157,6 +6215,7 @@ def save_inputs(
cfg_star_switch,
cfg_zero_step,
prompt_enhancer,
min_frames_if_references,
mode,
state,
):
@ -6181,7 +6240,7 @@ def save_inputs(
def download_loras():
from huggingface_hub import snapshot_download
yield gr.Row(visible=True), "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>", *[gr.Column(visible=False)] * 2
yield gr.Row(visible=True), "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>" #, *[gr.Column(visible=False)] * 2
lora_dir = get_lora_dir("i2v")
log_path = os.path.join(lora_dir, "log.txt")
if not os.path.isfile(log_path):
@ -6198,7 +6257,7 @@ def download_loras():
os.remove(tmp_path)
except:
pass
yield gr.Row(visible=True), "<B><FONT SIZE=3>Loras have been completely downloaded</B></FONT>", *[gr.Column(visible=True)] * 2
yield gr.Row(visible=True), "<B><FONT SIZE=3>Loras have been completely downloaded</B></FONT>" #, *[gr.Column(visible=True)] * 2
from datetime import datetime
dt = datetime.today().strftime('%Y-%m-%d')
@ -6358,16 +6417,16 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_
vace= test_vace_module(state["model_type"])
return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask):
def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask, image_mode):
video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask)
visible= "A" in video_prompt_type
model_type = state["model_type"]
model_def = get_model_def(model_type)
image_outputs = model_def.get("image_outputs", False)
image_outputs = image_mode == 1
return video_prompt_type, gr.update(visible= visible and not image_outputs), gr.update(visible= visible and image_outputs), gr.update(visible= visible )
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide):
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode):
video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMGUV")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
@ -6375,10 +6434,9 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt
base_model_type = get_base_model_type(model_type)
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
model_def = get_model_def(model_type)
image_outputs = model_def.get("image_outputs", False)
image_outputs = image_mode == 1
vace= test_vace_module(model_type)
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = visible and not image_outputs), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = visible and not image_outputs), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
@ -6694,17 +6752,22 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
any_start_image = False
any_end_image = False
any_reference_image = False
v2i_switch_supported = (vace or t2v) and not image_outputs
image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 )
if not v2i_switch_supported and not image_outputs:
image_mode_value = 0
else:
image_outputs = image_mode_value == 1
image_mode = gr.Number(value =image_mode_value, visible = False)
# with gr.Tabs(visible = vace or t2v):
# with gr.Tab("Text 2 Video"):
# pass
# with gr.Tab("Text 2 Image"):
# pass
with gr.Tabs(visible = v2i_switch_supported, selected= "t2i" if image_mode_value == 1 else "t2v" ) as image_mode_tabs:
with gr.Tab("Text to Video", id = "t2v", elem_classes="compact_tab"):
pass
with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"):
pass
# image_mode = gr.Number(value =ui_defaults.get("image_mode",0), visible = False)
image_mode = gr.Number(value =1 if image_outputs else 0, visible = False)
with gr.Column(visible= test_class_i2v(model_type) or diffusion_forcing or ltxv or recammaster or vace) as image_prompt_column:
with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace) as image_prompt_column:
if vace:
image_prompt_type_value= ui_defaults.get("image_prompt_type","")
image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value
@ -6770,7 +6833,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
)
keep_frames_video_source = gr.Text(visible=False)
else:
if test_class_i2v(model_type):
if test_class_i2v(model_type) or hunyuan_i2v:
# image_prompt_type_value= ui_defaults.get("image_prompt_type","SE" if flf2v else "S" )
image_prompt_type_value= ui_defaults.get("image_prompt_type","S" )
image_prompt_type_choices = [("Use only a Start Image", "S")]
@ -6959,7 +7022,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
choices=[
("Keep Backgrounds behind People / Objects", 0),
("Remove Backgrounds behind People / Objects", 1),
# ("Keep it for first Image (landscape) and remove it for other Images (objects / people)", 2),
],
value=ui_defaults.get("remove_background_images_ref",1),
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar and not flux
@ -7271,6 +7333,23 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
)
with gr.Row():
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = not (hunyuan_i2v or hunyuan_t2v or hunyuan_video_avatar or hunyuan_i2v or hunyuan_video_custom ))
with gr.Column(visible = vace and image_outputs) as min_frames_if_references_col:
gr.Markdown("<B>If using Reference Images, generating a single Frame alone may not be sufficient to preserve Identity")
min_frames_if_references = gr.Dropdown(
choices=[
("Disabled, generate only one Frame", 1),
("Generate a 5 Frames long Video but keep only the First Frame (x1.5 slower)",5),
("Generate a 9 Frames long Video but keep only the First Frame (x2.0 slower)",9),
("Generate a 13 Frames long Video but keep only the First Frame (x2.5 slower)",13),
("Generate a 17 Frames long Video but keep only the First Frame (x3.0 slower)",17),
],
value=ui_defaults.get("min_frames_if_references",5),
visible=True,
scale = 1,
label="Generate more frames to preserve Reference Image Identity"
)
with gr.Tab("Sliding Window", visible= sliding_window_enabled) as sliding_window_tab:
with gr.Column():
@ -7279,21 +7358,25 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if diffusion_forcing:
sliding_window_size = gr.Slider(37, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=20, label=" (recommended to keep it at 97)")
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_color_correction_strength = gr.Slider(0, 1, visible=False, value =0)
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = True)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
elif ltxv:
sliding_window_size = gr.Slider(41, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=8, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(9, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_color_correction_strength = gr.Slider(0, 1, visible=False, value =0)
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=8, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
elif hunyuan_video_custom_edit:
sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_color_correction_strength = gr.Slider(0, 1, visible=False, value =0)
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
else: # Vace, Multitalk
sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)")
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
@ -7382,18 +7465,18 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
default_visibility = {} if update_form else {"visible" : False}
video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info())
with gr.Row(**default_visibility) as video_buttons_row:
video_info_extract_settings_btn = gr.Button("Extract Settings", size ="sm")
video_info_to_control_video_btn = gr.Button("To Control Video", size ="sm", visible = any_control_video )
video_info_to_video_source_btn = gr.Button("To Video Source", size ="sm", visible = any_video_source)
video_info_eject_video_btn = gr.Button("Eject Video", size ="sm")
video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source)
video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm")
with gr.Row(**default_visibility) as image_buttons_row:
video_info_extract_image_settings_btn = gr.Button("Extract Settings", size ="sm")
video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", visible = any_start_image )
video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", visible = any_end_image)
video_info_to_image_guide_btn = gr.Button("To Control Image", size ="sm", visible = any_control_image )
video_info_to_image_mask_btn = gr.Button("To Mask Image", size ="sm", visible = any_image_mask)
video_info_to_reference_image_btn = gr.Button("To Reference Image", size ="sm", visible = any_reference_image)
video_info_eject_image_btn = gr.Button("Eject Image", size ="sm")
video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image )
video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image)
video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask)
video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image)
video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm")
with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab:
with gr.Group(elem_classes= "postprocess"):
with gr.Column():
@ -7462,7 +7545,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn,
NAG_col, speakers_locations_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row] # presets_column,
NAG_col, speakers_locations_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, min_frames_if_references_col] # presets_column,
if update_form:
locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
@ -7481,8 +7564,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] )
# video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt])
video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
video_guide_outpainting_bottom.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_bottom,gr.State(1)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
@ -7588,6 +7671,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
js=trigger_settings_download_js
)
image_mode_tabs.select(fn=record_image_mode_tab, inputs=[state], outputs= None
).then(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
outputs= [prompt]
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
).then(fn=switch_image_mode, inputs =[state] , outputs= [refresh_form_trigger], trigger_mode="multiple")
settings_file.upload(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
@ -7779,7 +7870,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
outputs=[modal_container]
)
return ( state, loras_choices, lset_name, state, resolution,
return ( state, loras_choices, lset_name, resolution,
video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col
)
@ -7874,6 +7965,15 @@ def generate_configuration_tab(state, blocks, header, model_choice, resolution,
label="Keep Previously Generated Videos when starting a new Generation Batch"
)
display_stats_choice = gr.Dropdown(
choices=[
("Disabled", 0),
("Enabled", 1),
],
value=server_config.get("display_stats", 0),
label="Display in real time available RAM / VRAM and other stats (needs a restart)"
)
max_frames_multiplier_choice = gr.Dropdown(
choices=[
("Default", 1),
@ -8075,6 +8175,7 @@ def generate_configuration_tab(state, blocks, header, model_choice, resolution,
notification_sound_enabled_choice,
notification_sound_volume_choice,
max_frames_multiplier_choice,
display_stats_choice,
resolution,
],
outputs= [msg , header, model_choice, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col]
@ -8085,20 +8186,24 @@ def generate_about_tab():
gr.Markdown("Original Wan 2.1 Model by <B>Alibaba</B> (<A HREF='https://github.com/Wan-Video/Wan2.1'>GitHub</A>)")
gr.Markdown("Many thanks to:")
gr.Markdown("- <B>Alibaba Wan team for the best open source video generator")
gr.Markdown("- <B>Alibaba Vace and Fun Teams for their incredible control net models")
gr.Markdown("- <B>Alibaba Vace, Multitalk and Fun Teams for their incredible control net models")
gr.Markdown("- <B>Tencent for the impressive Hunyuan Video models")
gr.Markdown("- <B>Lightricks for the super fast LTX Video models")
gr.Markdown("- <B>Blackforest Labs for the innovative Flux image generators")
gr.Markdown("- <B>Lightricks for their super fast LTX Video models")
gr.Markdown("<BR>Huge acknowlegments to these great open source projects used in WanGP:")
gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
gr.Markdown("- <B>DepthAnything</B> & <B>Midas</B>: Depth extractors (https://github.com/DepthAnything/Depth-Anything-V2) and (https://github.com/isl-org/MiDaS")
gr.Markdown("- <B>Matanyone</B> and <B>SAM2</B>: Mask Generation (https://github.com/pq-yang/MatAnyone) and (https://github.com/facebookresearch/sam2)")
gr.Markdown("- <B>Pyannote</B>: speaker diarization (https://github.com/pyannote/pyannote-audio)")
gr.Markdown("<BR>Special thanks to the following people for their support:")
gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
gr.Markdown("- <B>Tophness</B> : created (former) multi tabs and queuing frameworks")
gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection")
gr.Markdown("- <B>Reevoy24</B> : for his repackaging / completing the documentation")
gr.Markdown("<BR>Huge acknowlegments to these great open source projects used in WanGP:")
gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
gr.Markdown("- <B>Midas</B>: Depth extractor (https://github.com/isl-org/MiDaS")
gr.Markdown("- <B>Matanyone</B> and <B>SAM2</B>: Mask Generation (https://github.com/pq-yang/MatAnyone) and (https://github.com/facebookresearch/sam2)")
gr.Markdown("- <B>Redtash1</B> : for designing the protype of the RAM /VRAM stats viewer")
def generate_info_tab():
@ -8541,7 +8646,9 @@ def create_ui():
z-index: 9999;
transition: visibility 0s linear 1s, opacity 0.3s linear 1s; /* Delay both properties */
}
div.compact_tab , span.compact_tab
{ padding: 0px !important;
}
.hover-image .tooltip2 {
visibility: hidden;
opacity: 0;
@ -8591,6 +8698,12 @@ def create_ui():
console.log('sendColIndex function attached to window');
}
"""
if server_config.get("display_stats", 0) == 1:
from wan.utils.stats import SystemStatsApp
stats_app = SystemStatsApp()
else:
stats_app = None
with gr.Blocks(css=css, js=js, theme=theme, title= "WanGP") as main:
gr.Markdown(f"<div align=center><H1>Wan<SUP>GP</SUP> v{WanGP_version} <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
global model_list
@ -8609,8 +8722,11 @@ def create_ui():
gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
with gr.Row():
header = gr.Markdown(generate_header(transformer_type, compile, attention_mode), visible= True)
if stats_app is not None:
stats_element = stats_app.get_gradio_element()
with gr.Row():
( state, loras_choices, lset_name, state, resolution,
( state, loras_choices, lset_name, resolution,
video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col
) = generate_video_tab(model_choice=model_choice, header=header, main = main)
with gr.Tab("Guides", id="info") as info_tab:
@ -8624,7 +8740,8 @@ def create_ui():
generate_configuration_tab(state, main, header, model_choice, resolution, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col)
with gr.Tab("About"):
generate_about_tab()
if stats_app is not None:
stats_app.setup_events(main, state)
main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= main_tabs, trigger_mode="multiple")
return main