mirror of
https://github.com/Wan-Video/Wan2.1.git
synced 2025-12-15 11:43:21 +00:00
ode to Vace
This commit is contained in:
parent
d2843303a2
commit
90275dfc78
14
README.md
14
README.md
@ -20,10 +20,22 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
|
||||
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
|
||||
|
||||
## 🔥 Latest Updates
|
||||
### July 26 2025: WanGP v7.2 : Ode to Vace
|
||||
I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk.
|
||||
|
||||
Here are some new Vace improvements:
|
||||
- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model def in the finetune folder to change the Cocktail composition.
|
||||
- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little be slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*).
|
||||
- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged.
|
||||
- Color fixing when using Sliding Windows. A new postprocessing *Colot Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code.
|
||||
|
||||
Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats.
|
||||
|
||||
|
||||
### July 21 2025: WanGP v7.12
|
||||
- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
|
||||
|
||||
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment
|
||||
- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App)
|
||||
|
||||
- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them.
|
||||
|
||||
|
||||
@ -1,16 +0,0 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Flux 1 Dev 12B",
|
||||
"architecture": "flux",
|
||||
"description": "FLUX.1 Dev is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.",
|
||||
"URLs": [
|
||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev_bf16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev_quanto_bf16_int8.safetensors"
|
||||
],
|
||||
"image_outputs": true,
|
||||
"flux-model": "flux-dev"
|
||||
},
|
||||
"prompt": "draw a hat",
|
||||
"resolution": "1280x720",
|
||||
"batch_size": 1
|
||||
}
|
||||
@ -1,13 +0,0 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Wan2.1 text2image 14B",
|
||||
"architecture": "t2v",
|
||||
"description": "The original Wan Text 2 Video model configured to generate an image instead of a video.",
|
||||
"image_outputs": true,
|
||||
"URLs": "t2v"
|
||||
},
|
||||
"batch_size": 1,
|
||||
"resolution": "1280x720"
|
||||
}
|
||||
|
||||
|
||||
21
defaults/vace_14B_cocktail.json
Normal file
21
defaults/vace_14B_cocktail.json
Normal file
@ -0,0 +1,21 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Vace Cocktail 14B",
|
||||
"architecture": "vace_14B",
|
||||
"modules": [
|
||||
"vace_14B"
|
||||
],
|
||||
"description": "This model has been created on the fly using the Wan text 2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Copy the model def in the finetune folder to change the Cocktail composition.",
|
||||
"URLs": "t2v",
|
||||
"loras": [
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors",
|
||||
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors"
|
||||
],
|
||||
"loras_multipliers": [1, 0.5, 0.5, 0.5]
|
||||
},
|
||||
"num_inference_steps": 10,
|
||||
"guidance_scale": 1,
|
||||
"flow_shift": 2
|
||||
}
|
||||
@ -1,16 +0,0 @@
|
||||
{
|
||||
"model": {
|
||||
"name": "Vace FusioniX image2image 14B",
|
||||
"architecture": "vace_14B",
|
||||
"modules": [
|
||||
"vace_14B"
|
||||
],
|
||||
"image_outputs": true,
|
||||
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
|
||||
"URLs": "t2v_fusionix"
|
||||
},
|
||||
"resolution": "1280x720",
|
||||
"guidance_scale": 1,
|
||||
"num_inference_steps": 10,
|
||||
"batch_size": 1
|
||||
}
|
||||
@ -94,17 +94,20 @@ class Flux(nn.Module):
|
||||
if first_key.startswith("lora_unet_"):
|
||||
new_sd = {}
|
||||
print("Converting Lora Safetensors format to Lora Diffusers format")
|
||||
repl_list = ["linear1", "linear2", "modulation_lin"]
|
||||
repl_list = ["linear1", "linear2", "modulation", "img_attn", "txt_attn", "img_mlp", "txt_mlp", "img_mod", "txt_mod"]
|
||||
src_list = ["_" + k + "." for k in repl_list]
|
||||
tgt_list = ["." + k.replace("_", ".") + "." for k in repl_list]
|
||||
src_list2 = ["_" + k + "_" for k in repl_list]
|
||||
tgt_list = ["." + k + "." for k in repl_list]
|
||||
|
||||
for k,v in sd.items():
|
||||
k = k.replace("lora_unet_blocks_","diffusion_model.blocks.")
|
||||
k = k.replace("lora_unet__blocks_","diffusion_model.blocks.")
|
||||
k = k.replace("lora_unet_single_blocks_","diffusion_model.single_blocks.")
|
||||
k = k.replace("lora_unet_double_blocks_","diffusion_model.double_blocks.")
|
||||
|
||||
for s,t in zip(src_list, tgt_list):
|
||||
for s,s2, t in zip(src_list, src_list2, tgt_list):
|
||||
k = k.replace(s,t)
|
||||
k = k.replace(s2,t)
|
||||
|
||||
k = k.replace("lora_up","lora_B")
|
||||
k = k.replace("lora_down","lora_A")
|
||||
|
||||
@ -30,6 +30,7 @@ def get_noise(
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
|
||||
|
||||
@ -45,5 +45,6 @@ misaki
|
||||
soundfile
|
||||
ffmpeg-python
|
||||
pyannote.audio
|
||||
pynvml
|
||||
# num2words
|
||||
# spacy
|
||||
@ -30,7 +30,7 @@ from wan.modules.posemb_layers import get_rotary_pos_embed
|
||||
from .utils.vace_preprocessor import VaceVideoProcessor
|
||||
from wan.utils.basic_flowmatch import FlowMatchScheduler
|
||||
from wan.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions
|
||||
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance
|
||||
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask
|
||||
from mmgp import safetensors2
|
||||
|
||||
def optimized_scale(positive_flat, negative_flat):
|
||||
@ -78,7 +78,6 @@ class WanAny2V:
|
||||
self.num_train_timesteps = config.num_train_timesteps
|
||||
self.param_dtype = config.param_dtype
|
||||
self.model_def = model_def
|
||||
self.image_outputs = model_def.get("image_outputs", False)
|
||||
self.text_encoder = T5EncoderModel(
|
||||
text_len=config.text_len,
|
||||
dtype=config.t5_dtype,
|
||||
@ -382,6 +381,9 @@ class WanAny2V:
|
||||
offloadobj = None,
|
||||
apg_switch = False,
|
||||
speakers_bboxes = None,
|
||||
color_correction_strength = 1,
|
||||
prefix_frames_count = 0,
|
||||
image_mode = 0,
|
||||
**bbargs
|
||||
):
|
||||
|
||||
@ -418,9 +420,9 @@ class WanAny2V:
|
||||
|
||||
seed_g = torch.Generator(device=self.device)
|
||||
seed_g.manual_seed(seed)
|
||||
|
||||
image_outputs = image_mode == 1
|
||||
kwargs = {'pipeline': self, 'callback': callback}
|
||||
|
||||
color_reference_frame = None
|
||||
if self._interrupt:
|
||||
return None
|
||||
|
||||
@ -468,6 +470,7 @@ class WanAny2V:
|
||||
enc = torch.concat( [input_frames, torch.zeros( (3, frame_num-preframes_count, height, width),
|
||||
device=self.device, dtype= self.VAE_dtype)],
|
||||
dim = 1).to(self.device)
|
||||
color_reference_frame = input_frames[:, -1:].clone()
|
||||
input_frames = None
|
||||
else:
|
||||
preframes_count = 1
|
||||
@ -498,6 +501,7 @@ class WanAny2V:
|
||||
img_interpolated = resize_lanczos(image_start, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
||||
image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
|
||||
image_start = image_start.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
|
||||
color_reference_frame = image_start.clone()
|
||||
if image_end!= None:
|
||||
img_interpolated2 = resize_lanczos(image_end, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
||||
image_end = resize_lanczos(image_end, clip_image_size, clip_image_size)
|
||||
@ -566,12 +570,13 @@ class WanAny2V:
|
||||
injection_denoising_step = 0
|
||||
inject_from_start = False
|
||||
if input_frames != None and denoising_strength < 1 :
|
||||
color_reference_frame = input_frames[:, -1:].clone()
|
||||
if overlapped_latents != None:
|
||||
overlapped_latents_frames_num = overlapped_latents.shape[1]
|
||||
overlapped_latents_frames_num = overlapped_latents.shape[2]
|
||||
overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
|
||||
else:
|
||||
overlapped_latents_frames_num = overlapped_frames_num = 0
|
||||
if len(keep_frames_parsed) == 0 or self.image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
|
||||
if len(keep_frames_parsed) == 0 or image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
|
||||
injection_denoising_step = int(sampling_steps * (1. - denoising_strength) )
|
||||
latent_keep_frames = []
|
||||
if source_latents.shape[1] < lat_frames or len(keep_frames_parsed) > 0:
|
||||
@ -606,6 +611,7 @@ class WanAny2V:
|
||||
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
|
||||
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
||||
if self.background_mask != None:
|
||||
color_reference_frame = input_ref_images[0][0].clone()
|
||||
zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
|
||||
mbg = self.vace_encode_masks(self.background_mask, None)
|
||||
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
|
||||
@ -621,6 +627,8 @@ class WanAny2V:
|
||||
if overlapped_latents != None :
|
||||
overlapped_latents_size = overlapped_latents.shape[2]
|
||||
extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
|
||||
if prefix_frames_count > 0:
|
||||
color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone()
|
||||
|
||||
target_shape = list(z0[0].shape)
|
||||
target_shape[0] = int(target_shape[0] / 2)
|
||||
@ -691,7 +699,6 @@ class WanAny2V:
|
||||
apg_norm_threshold = 55
|
||||
text_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||
audio_momentumbuffer = MomentumBuffer(apg_momentum)
|
||||
# self.image_outputs = False
|
||||
# denoising
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
offload.set_step_no_for_lora(self.model, i)
|
||||
@ -777,7 +784,7 @@ class WanAny2V:
|
||||
noise_pred_cond, noise_pred_noaudio, noise_pred_uncond = ret_values
|
||||
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_noaudio - noise_pred_uncond) + audio_cfg_scale * (noise_pred_cond - noise_pred_noaudio)
|
||||
noise_pred_noaudio = None
|
||||
elif multitalk:
|
||||
elif multitalk and audio_proj != None:
|
||||
noise_pred_cond, noise_pred_drop_text, noise_pred_uncond = ret_values
|
||||
if apg_switch != 0:
|
||||
noise_pred = noise_pred_cond + (guide_scale - 1) * adaptive_projected_guidance(noise_pred_cond - noise_pred_drop_text,
|
||||
@ -830,6 +837,7 @@ class WanAny2V:
|
||||
latents_preview = latents
|
||||
if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
|
||||
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
|
||||
if image_outputs: latents_preview= latents_preview[:, :,:1]
|
||||
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
|
||||
callback(i, latents_preview[0], False)
|
||||
latents_preview = None
|
||||
@ -846,10 +854,18 @@ class WanAny2V:
|
||||
|
||||
videos = self.vae.decode(x0, VAE_tile_size)
|
||||
|
||||
if self.image_outputs:
|
||||
videos = torch.cat(videos, dim=1) if len(videos) > 1 else videos[0]
|
||||
if image_outputs:
|
||||
videos = torch.cat([video[:,:1] for video in videos], dim=1) if len(videos) > 1 else videos[0][:,:1]
|
||||
else:
|
||||
videos = videos[0] # return only first video
|
||||
videos = videos[0] # return only first video
|
||||
if color_correction_strength > 0:
|
||||
if vace and False:
|
||||
# videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), input_frames[0].unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "progressive_blend").squeeze(0)
|
||||
videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), input_frames[0].unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "reference").squeeze(0)
|
||||
# videos = match_and_blend_colors_with_mask(videos.unsqueeze(0), videos.unsqueeze(0), input_masks[0][:1].unsqueeze(0), color_correction_strength,copy_mode= "reference").squeeze(0)
|
||||
elif color_reference_frame is not None:
|
||||
videos = match_and_blend_colors(videos.unsqueeze(0), color_reference_frame.unsqueeze(0), color_correction_strength).squeeze(0)
|
||||
|
||||
if return_latent_slice != None:
|
||||
return { "x" : videos, "latent_slice" : latent_slice }
|
||||
return videos
|
||||
|
||||
@ -15,6 +15,7 @@ import soundfile as sf
|
||||
import torchvision
|
||||
import binascii
|
||||
import os.path as osp
|
||||
from skimage import color
|
||||
|
||||
|
||||
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
|
||||
@ -351,3 +352,531 @@ def adaptive_projected_guidance(
|
||||
diff_parallel, diff_orthogonal = project(diff, pred_cond)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
return normalized_update
|
||||
|
||||
def match_and_blend_colors(source_chunk: torch.Tensor, reference_image: torch.Tensor, strength: float) -> torch.Tensor:
|
||||
"""
|
||||
Matches the color of a source video chunk to a reference image and blends with the original.
|
||||
|
||||
Args:
|
||||
source_chunk (torch.Tensor): The video chunk to be color-corrected (B, C, T, H, W) in range [-1, 1].
|
||||
Assumes B=1 (batch size of 1).
|
||||
reference_image (torch.Tensor): The reference image (B, C, 1, H, W) in range [-1, 1].
|
||||
Assumes B=1 and T=1 (single reference frame).
|
||||
strength (float): The strength of the color correction (0.0 to 1.0).
|
||||
0.0 means no correction, 1.0 means full correction.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The color-corrected and blended video chunk.
|
||||
"""
|
||||
# print(f"[match_and_blend_colors] Input source_chunk shape: {source_chunk.shape}, reference_image shape: {reference_image.shape}, strength: {strength}")
|
||||
|
||||
if strength == 0.0:
|
||||
# print(f"[match_and_blend_colors] Strength is 0, returning original source_chunk.")
|
||||
return source_chunk
|
||||
|
||||
if not 0.0 <= strength <= 1.0:
|
||||
raise ValueError(f"Strength must be between 0.0 and 1.0, got {strength}")
|
||||
|
||||
device = source_chunk.device
|
||||
dtype = source_chunk.dtype
|
||||
|
||||
# Squeeze batch dimension, permute to T, H, W, C for skimage
|
||||
# Source: (1, C, T, H, W) -> (T, H, W, C)
|
||||
source_np = source_chunk.squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
|
||||
# Reference: (1, C, 1, H, W) -> (H, W, C)
|
||||
ref_np = reference_image.squeeze(0).squeeze(1).permute(1, 2, 0).cpu().numpy() # Squeeze T dimension as well
|
||||
|
||||
# Normalize from [-1, 1] to [0, 1] for skimage
|
||||
source_np_01 = (source_np + 1.0) / 2.0
|
||||
ref_np_01 = (ref_np + 1.0) / 2.0
|
||||
|
||||
# Clip to ensure values are strictly in [0, 1] after potential float precision issues
|
||||
source_np_01 = np.clip(source_np_01, 0.0, 1.0)
|
||||
ref_np_01 = np.clip(ref_np_01, 0.0, 1.0)
|
||||
|
||||
# Convert reference to Lab
|
||||
try:
|
||||
ref_lab = color.rgb2lab(ref_np_01)
|
||||
except ValueError as e:
|
||||
# Handle potential errors if image data is not valid for conversion
|
||||
print(f"Warning: Could not convert reference image to Lab: {e}. Skipping color correction for this chunk.")
|
||||
return source_chunk
|
||||
|
||||
|
||||
corrected_frames_np_01 = []
|
||||
for i in range(source_np_01.shape[0]): # Iterate over time (T)
|
||||
source_frame_rgb_01 = source_np_01[i]
|
||||
|
||||
try:
|
||||
source_lab = color.rgb2lab(source_frame_rgb_01)
|
||||
except ValueError as e:
|
||||
print(f"Warning: Could not convert source frame {i} to Lab: {e}. Using original frame.")
|
||||
corrected_frames_np_01.append(source_frame_rgb_01)
|
||||
continue
|
||||
|
||||
corrected_lab_frame = source_lab.copy()
|
||||
|
||||
# Perform color transfer for L, a, b channels
|
||||
for j in range(3): # L, a, b
|
||||
mean_src, std_src = source_lab[:, :, j].mean(), source_lab[:, :, j].std()
|
||||
mean_ref, std_ref = ref_lab[:, :, j].mean(), ref_lab[:, :, j].std()
|
||||
|
||||
# Avoid division by zero if std_src is 0
|
||||
if std_src == 0:
|
||||
# If source channel has no variation, keep it as is, but shift by reference mean
|
||||
# This case is debatable, could also just copy source or target mean.
|
||||
# Shifting by target mean helps if source is flat but target isn't.
|
||||
corrected_lab_frame[:, :, j] = mean_ref
|
||||
else:
|
||||
corrected_lab_frame[:, :, j] = (corrected_lab_frame[:, :, j] - mean_src) * (std_ref / std_src) + mean_ref
|
||||
|
||||
try:
|
||||
fully_corrected_frame_rgb_01 = color.lab2rgb(corrected_lab_frame)
|
||||
except ValueError as e:
|
||||
print(f"Warning: Could not convert corrected frame {i} back to RGB: {e}. Using original frame.")
|
||||
corrected_frames_np_01.append(source_frame_rgb_01)
|
||||
continue
|
||||
|
||||
# Clip again after lab2rgb as it can go slightly out of [0,1]
|
||||
fully_corrected_frame_rgb_01 = np.clip(fully_corrected_frame_rgb_01, 0.0, 1.0)
|
||||
|
||||
# Blend with original source frame (in [0,1] RGB)
|
||||
blended_frame_rgb_01 = (1 - strength) * source_frame_rgb_01 + strength * fully_corrected_frame_rgb_01
|
||||
corrected_frames_np_01.append(blended_frame_rgb_01)
|
||||
|
||||
corrected_chunk_np_01 = np.stack(corrected_frames_np_01, axis=0)
|
||||
|
||||
# Convert back to [-1, 1]
|
||||
corrected_chunk_np_minus1_1 = (corrected_chunk_np_01 * 2.0) - 1.0
|
||||
|
||||
# Permute back to (C, T, H, W), add batch dim, and convert to original torch.Tensor type and device
|
||||
# (T, H, W, C) -> (C, T, H, W)
|
||||
corrected_chunk_tensor = torch.from_numpy(corrected_chunk_np_minus1_1).permute(3, 0, 1, 2).unsqueeze(0)
|
||||
corrected_chunk_tensor = corrected_chunk_tensor.contiguous() # Ensure contiguous memory layout
|
||||
output_tensor = corrected_chunk_tensor.to(device=device, dtype=dtype)
|
||||
# print(f"[match_and_blend_colors] Output tensor shape: {output_tensor.shape}")
|
||||
return output_tensor
|
||||
|
||||
|
||||
from skimage import color
|
||||
from scipy import ndimage
|
||||
from scipy.ndimage import binary_erosion, distance_transform_edt
|
||||
|
||||
|
||||
def match_and_blend_colors_with_mask(
|
||||
source_chunk: torch.Tensor,
|
||||
reference_video: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
strength: float,
|
||||
copy_mode: str = "corrected", # "corrected", "reference", "source", "progressive_blend"
|
||||
source_border_distance: int = 10,
|
||||
reference_border_distance: int = 10
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Matches the color of a source video chunk to a reference video using mask-based region sampling.
|
||||
|
||||
Args:
|
||||
source_chunk (torch.Tensor): The video chunk to be color-corrected (B, C, T, H, W) in range [-1, 1].
|
||||
Assumes B=1 (batch size of 1).
|
||||
reference_video (torch.Tensor): The reference video (B, C, T, H, W) in range [-1, 1].
|
||||
Must have same temporal dimension as source_chunk.
|
||||
mask (torch.Tensor): Binary mask (B, 1, T, H, W) or (T, H, W) or (H, W) with values 0 and 1.
|
||||
Color correction is applied to pixels where mask=1.
|
||||
strength (float): The strength of the color correction (0.0 to 1.0).
|
||||
0.0 means no correction, 1.0 means full correction.
|
||||
copy_mode (str): What to do with mask=0 pixels:
|
||||
"corrected" (keep original), "reference", "source",
|
||||
"progressive_blend" (double-sided progressive blending near borders).
|
||||
source_border_distance (int): Distance in pixels from mask border to sample source video (mask=1 side).
|
||||
reference_border_distance (int): Distance in pixels from mask border to sample reference video (mask=0 side).
|
||||
For "progressive_blend" mode, this also defines the blending falloff distance.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The color-corrected and blended video chunk.
|
||||
|
||||
Notes:
|
||||
- Color statistics are sampled from border regions to determine source and reference tints
|
||||
- Progressive blending creates smooth double-sided transitions:
|
||||
* mask=1 side: 60% source + 40% reference at border → 100% source deeper in
|
||||
* mask=0 side: 60% reference + 40% source at border → 100% reference deeper in
|
||||
"""
|
||||
|
||||
if strength == 0.0:
|
||||
return source_chunk
|
||||
|
||||
if not 0.0 <= strength <= 1.0:
|
||||
raise ValueError(f"Strength must be between 0.0 and 1.0, got {strength}")
|
||||
|
||||
if copy_mode not in ["corrected", "reference", "source", "progressive_blend"]:
|
||||
raise ValueError(f"copy_mode must be 'corrected', 'reference', 'source', or 'progressive_blend', got {copy_mode}")
|
||||
|
||||
device = source_chunk.device
|
||||
dtype = source_chunk.dtype
|
||||
B, C, T, H, W = source_chunk.shape
|
||||
|
||||
# Handle different mask dimensions
|
||||
if mask.dim() == 2: # (H, W)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(B, 1, T, H, W)
|
||||
elif mask.dim() == 3: # (T, H, W)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, T, H, W)
|
||||
elif mask.dim() == 4: # (B, T, H, W) - missing channel dim
|
||||
mask = mask.unsqueeze(1)
|
||||
# mask should now be (B, 1, T, H, W)
|
||||
|
||||
# Convert to numpy for processing
|
||||
source_np = source_chunk.squeeze(0).permute(1, 2, 3, 0).cpu().numpy() # (T, H, W, C)
|
||||
reference_np = reference_video.squeeze(0).permute(1, 2, 3, 0).cpu().numpy() # (T, H, W, C)
|
||||
mask_np = mask.squeeze(0).squeeze(0).cpu().numpy() # (T, H, W)
|
||||
|
||||
# Normalize from [-1, 1] to [0, 1] for skimage
|
||||
source_np_01 = (source_np + 1.0) / 2.0
|
||||
reference_np_01 = (reference_np + 1.0) / 2.0
|
||||
|
||||
# Clip to ensure values are in [0, 1]
|
||||
source_np_01 = np.clip(source_np_01, 0.0, 1.0)
|
||||
reference_np_01 = np.clip(reference_np_01, 0.0, 1.0)
|
||||
|
||||
corrected_frames_np_01 = []
|
||||
|
||||
for t in range(T):
|
||||
source_frame = source_np_01[t] # (H, W, C)
|
||||
reference_frame = reference_np_01[t] # (H, W, C)
|
||||
frame_mask = mask_np[t] # (H, W)
|
||||
|
||||
# Find mask borders and create distance maps
|
||||
border_regions = get_border_sampling_regions(frame_mask, source_border_distance, reference_border_distance)
|
||||
source_sample_region = border_regions['source_region'] # mask=1 side
|
||||
reference_sample_region = border_regions['reference_region'] # mask=0 side
|
||||
|
||||
# Sample pixels for color statistics
|
||||
try:
|
||||
source_stats = compute_color_stats(source_frame, source_sample_region)
|
||||
reference_stats = compute_color_stats(reference_frame, reference_sample_region)
|
||||
except ValueError as e:
|
||||
print(f"Warning: Could not compute color statistics for frame {t}: {e}. Using original frame.")
|
||||
corrected_frames_np_01.append(source_frame)
|
||||
continue
|
||||
|
||||
# Apply color correction to mask=1 area and handle mask=0 area based on copy_mode
|
||||
corrected_frame = apply_color_correction_with_mask(
|
||||
source_frame, frame_mask, source_stats, reference_stats, strength
|
||||
)
|
||||
|
||||
# Handle mask=0 pixels based on copy_mode
|
||||
if copy_mode == "reference":
|
||||
corrected_frame = apply_copy_with_mask(corrected_frame, reference_frame, frame_mask, "reference")
|
||||
elif copy_mode == "source":
|
||||
corrected_frame = apply_copy_with_mask(corrected_frame, source_frame, frame_mask, "source")
|
||||
elif copy_mode == "progressive_blend":
|
||||
# Apply progressive blending in mask=1 border area (source side)
|
||||
corrected_frame = apply_progressive_blend_in_corrected_area(
|
||||
corrected_frame, reference_frame, frame_mask,
|
||||
border_regions['source_region'], border_regions['source_distances'],
|
||||
border_regions['reference_region'], source_border_distance
|
||||
)
|
||||
# Copy reference pixels to mask=0 area first
|
||||
corrected_frame = apply_copy_with_mask(corrected_frame, reference_frame, frame_mask, "reference")
|
||||
# Then apply progressive blending in mask=0 border area (reference side)
|
||||
corrected_frame = apply_progressive_blend_in_reference_area(
|
||||
corrected_frame, source_frame, frame_mask,
|
||||
border_regions['reference_region'], border_regions['reference_distances'],
|
||||
reference_border_distance
|
||||
)
|
||||
|
||||
corrected_frames_np_01.append(corrected_frame)
|
||||
|
||||
corrected_chunk_np_01 = np.stack(corrected_frames_np_01, axis=0)
|
||||
|
||||
# Convert back to [-1, 1] and return to tensor format
|
||||
corrected_chunk_np_minus1_1 = (corrected_chunk_np_01 * 2.0) - 1.0
|
||||
corrected_chunk_tensor = torch.from_numpy(corrected_chunk_np_minus1_1).permute(3, 0, 1, 2).unsqueeze(0)
|
||||
corrected_chunk_tensor = corrected_chunk_tensor.contiguous()
|
||||
output_tensor = corrected_chunk_tensor.to(device=device, dtype=dtype)
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def get_border_sampling_regions(mask, source_border_distance, reference_border_distance):
|
||||
"""
|
||||
Create regions for sampling near mask borders with separate distances for source and reference.
|
||||
|
||||
Args:
|
||||
mask: Binary mask (H, W) with 0s and 1s
|
||||
source_border_distance: Distance from border to include in source sampling (mask=1 side)
|
||||
reference_border_distance: Distance from border to include in reference sampling (mask=0 side)
|
||||
|
||||
Returns:
|
||||
Dict with sampling regions and distance maps for blending
|
||||
"""
|
||||
# Convert to boolean for safety
|
||||
mask_bool = mask.astype(bool)
|
||||
|
||||
# Distance from mask=0 regions (distance into mask=1 areas from border)
|
||||
dist_from_mask0 = distance_transform_edt(mask_bool)
|
||||
|
||||
# Distance from mask=1 regions (distance into mask=0 areas from border)
|
||||
dist_from_mask1 = distance_transform_edt(~mask_bool)
|
||||
|
||||
# Source region: mask=1 pixels within source_border_distance of mask=0 pixels
|
||||
source_region = mask_bool & (dist_from_mask0 <= source_border_distance)
|
||||
|
||||
# Reference region: mask=0 pixels within reference_border_distance of mask=1 pixels
|
||||
reference_region = (~mask_bool) & (dist_from_mask1 <= reference_border_distance)
|
||||
|
||||
return {
|
||||
'source_region': source_region,
|
||||
'reference_region': reference_region,
|
||||
'source_distances': dist_from_mask0, # Distance into mask=1 from border
|
||||
'reference_distances': dist_from_mask1 # Distance into mask=0 from border
|
||||
}
|
||||
|
||||
|
||||
def compute_color_stats(image, sample_region):
|
||||
"""
|
||||
Compute color statistics (mean and std) for Lab channels in the sampling region.
|
||||
|
||||
Args:
|
||||
image: RGB image (H, W, C) in range [0, 1]
|
||||
sample_region: Boolean mask (H, W) indicating pixels to sample
|
||||
|
||||
Returns:
|
||||
Dict with 'mean' and 'std' for Lab components
|
||||
"""
|
||||
if not np.any(sample_region):
|
||||
raise ValueError("No pixels in sampling region")
|
||||
|
||||
# Convert to Lab
|
||||
try:
|
||||
image_lab = color.rgb2lab(image)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Could not convert image to Lab: {e}")
|
||||
|
||||
# Extract pixels in sampling region
|
||||
sampled_pixels = image_lab[sample_region] # (N, 3) where N is number of sampled pixels
|
||||
|
||||
# Compute statistics for each Lab channel
|
||||
stats = {
|
||||
'mean': np.mean(sampled_pixels, axis=0), # (3,) for L, a, b
|
||||
'std': np.std(sampled_pixels, axis=0) # (3,) for L, a, b
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def apply_color_correction_with_mask(source_frame, mask, source_stats, reference_stats, strength):
|
||||
"""
|
||||
Apply color correction to pixels where mask=1.
|
||||
|
||||
Args:
|
||||
source_frame: RGB image (H, W, C) in range [0, 1]
|
||||
mask: Binary mask (H, W)
|
||||
source_stats: Color statistics from source sampling region
|
||||
reference_stats: Color statistics from reference sampling region
|
||||
strength: Blending strength
|
||||
|
||||
Returns:
|
||||
Corrected RGB image (H, W, C)
|
||||
"""
|
||||
try:
|
||||
source_lab = color.rgb2lab(source_frame)
|
||||
except ValueError as e:
|
||||
print(f"Warning: Could not convert source frame to Lab: {e}. Using original frame.")
|
||||
return source_frame
|
||||
|
||||
corrected_lab = source_lab.copy()
|
||||
correction_region = (mask == 1) # Apply correction to mask=1 pixels
|
||||
|
||||
# Apply color transfer to pixels where mask=1
|
||||
for c in range(3): # L, a, b channels
|
||||
mean_src = source_stats['mean'][c]
|
||||
std_src = source_stats['std'][c]
|
||||
mean_ref = reference_stats['mean'][c]
|
||||
std_ref = reference_stats['std'][c]
|
||||
|
||||
if std_src == 0:
|
||||
# Handle case where source channel has no variation
|
||||
corrected_lab[correction_region, c] = mean_ref
|
||||
else:
|
||||
# Standard color transfer formula
|
||||
corrected_lab[correction_region, c] = (
|
||||
(corrected_lab[correction_region, c] - mean_src) * (std_ref / std_src) + mean_ref
|
||||
)
|
||||
|
||||
try:
|
||||
fully_corrected_rgb = color.lab2rgb(corrected_lab)
|
||||
except ValueError as e:
|
||||
print(f"Warning: Could not convert corrected frame back to RGB: {e}. Using original frame.")
|
||||
return source_frame
|
||||
|
||||
# Clip to [0, 1]
|
||||
fully_corrected_rgb = np.clip(fully_corrected_rgb, 0.0, 1.0)
|
||||
|
||||
# Blend with original (only in correction region)
|
||||
result = source_frame.copy()
|
||||
result[correction_region] = (
|
||||
(1 - strength) * source_frame[correction_region] +
|
||||
strength * fully_corrected_rgb[correction_region]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def apply_progressive_blend_in_corrected_area(corrected_frame, reference_frame, mask, source_region, source_distances, reference_region, source_border_distance):
|
||||
"""
|
||||
Apply progressive blending in the corrected area (mask=1) near the border.
|
||||
|
||||
Args:
|
||||
corrected_frame: RGB image (H, W, C) - the color-corrected source frame
|
||||
reference_frame: RGB image (H, W, C) - the reference frame
|
||||
mask: Binary mask (H, W)
|
||||
source_region: Boolean mask (H, W) indicating the source blending region (mask=1 near border)
|
||||
source_distances: Distance map (H, W) into mask=1 area from mask=0 border
|
||||
reference_region: Boolean mask (H, W) indicating the reference sampling region (mask=0 near border)
|
||||
source_border_distance: Maximum distance for source blending
|
||||
|
||||
Returns:
|
||||
Blended RGB image (H, W, C)
|
||||
|
||||
Notes:
|
||||
- Each source pixel blends with its closest reference border pixel (for speed)
|
||||
- At mask border: 60% source + 40% reference
|
||||
- Deeper into mask=1 area: 100% corrected source
|
||||
"""
|
||||
result = corrected_frame.copy()
|
||||
|
||||
# Blend in the source region (mask=1 pixels near border)
|
||||
blend_region = source_region
|
||||
|
||||
if np.any(blend_region):
|
||||
# Find immediate border pixels (mask=0 pixels adjacent to mask=1 pixels)
|
||||
# This is much faster than using the entire reference region
|
||||
from scipy.ndimage import binary_dilation
|
||||
|
||||
# Dilate mask=1 by 1 pixel, then find intersection with mask=0
|
||||
mask_1_dilated = binary_dilation(mask == 1, structure=np.ones((3, 3)))
|
||||
border_pixels = (mask == 0) & mask_1_dilated
|
||||
|
||||
if np.any(border_pixels):
|
||||
# Find closest border pixel for each source pixel
|
||||
source_coords = np.column_stack(np.where(blend_region)) # (N, 2) - y, x coordinates
|
||||
border_coords = np.column_stack(np.where(border_pixels)) # (M, 2) - much smaller set!
|
||||
|
||||
# For each source pixel, find closest border pixel
|
||||
from scipy.spatial.distance import cdist
|
||||
distances_matrix = cdist(source_coords, border_coords, metric='euclidean')
|
||||
closest_border_indices = np.argmin(distances_matrix, axis=1)
|
||||
|
||||
# Normalize source distances for blending weights
|
||||
min_distance_in_region = np.min(source_distances[blend_region])
|
||||
max_distance_in_region = np.max(source_distances[blend_region])
|
||||
|
||||
if max_distance_in_region > min_distance_in_region:
|
||||
# Calculate blend weights: 0.4 at border (60% source + 40% reference), 0.0 at max distance (100% source)
|
||||
source_dist_values = source_distances[blend_region]
|
||||
normalized_distances = (source_dist_values - min_distance_in_region) / (max_distance_in_region - min_distance_in_region)
|
||||
blend_weights = 0.4 * (1.0 - normalized_distances) # Start with 40% reference influence at border
|
||||
|
||||
# Apply blending with closest border pixels
|
||||
for i, (source_y, source_x) in enumerate(source_coords):
|
||||
closest_border_idx = closest_border_indices[i]
|
||||
border_y, border_x = border_coords[closest_border_idx]
|
||||
|
||||
weight = blend_weights[i]
|
||||
# Blend with closest border pixel
|
||||
result[source_y, source_x] = (
|
||||
(1.0 - weight) * corrected_frame[source_y, source_x] +
|
||||
weight * reference_frame[border_y, border_x]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def apply_progressive_blend_in_reference_area(reference_frame, source_frame, mask, reference_region, reference_distances, reference_border_distance):
|
||||
"""
|
||||
Apply progressive blending in the reference area (mask=0) near the border.
|
||||
|
||||
Args:
|
||||
reference_frame: RGB image (H, W, C) - the reference frame with copied reference pixels
|
||||
source_frame: RGB image (H, W, C) - the original source frame
|
||||
mask: Binary mask (H, W)
|
||||
reference_region: Boolean mask (H, W) indicating the reference blending region (mask=0 near border)
|
||||
reference_distances: Distance map (H, W) into mask=0 area from mask=1 border
|
||||
reference_border_distance: Maximum distance for reference blending
|
||||
|
||||
Returns:
|
||||
Blended RGB image (H, W, C)
|
||||
|
||||
Notes:
|
||||
- Each reference pixel blends with its closest source border pixel (for speed)
|
||||
- At mask border: 60% reference + 40% source
|
||||
- Deeper into mask=0 area: 100% reference
|
||||
"""
|
||||
result = reference_frame.copy()
|
||||
|
||||
# Blend in the reference region (mask=0 pixels near border)
|
||||
blend_region = reference_region
|
||||
|
||||
if np.any(blend_region):
|
||||
# Find immediate border pixels (mask=1 pixels adjacent to mask=0 pixels)
|
||||
from scipy.ndimage import binary_dilation
|
||||
|
||||
# Dilate mask=0 by 1 pixel, then find intersection with mask=1
|
||||
mask_0_dilated = binary_dilation(mask == 0, structure=np.ones((3, 3)))
|
||||
source_border_pixels = (mask == 1) & mask_0_dilated
|
||||
|
||||
if np.any(source_border_pixels):
|
||||
# Find closest source border pixel for each reference pixel
|
||||
reference_coords = np.column_stack(np.where(blend_region)) # (N, 2) - y, x coordinates
|
||||
source_border_coords = np.column_stack(np.where(source_border_pixels)) # (M, 2)
|
||||
|
||||
# For each reference pixel, find closest source border pixel
|
||||
from scipy.spatial.distance import cdist
|
||||
distances_matrix = cdist(reference_coords, source_border_coords, metric='euclidean')
|
||||
closest_source_indices = np.argmin(distances_matrix, axis=1)
|
||||
|
||||
# Normalize reference distances for blending weights
|
||||
min_distance_in_region = np.min(reference_distances[blend_region])
|
||||
max_distance_in_region = np.max(reference_distances[blend_region])
|
||||
|
||||
if max_distance_in_region > min_distance_in_region:
|
||||
# Calculate blend weights: 0.4 at border (60% reference + 40% source), 0.0 at max distance (100% reference)
|
||||
reference_dist_values = reference_distances[blend_region]
|
||||
normalized_distances = (reference_dist_values - min_distance_in_region) / (max_distance_in_region - min_distance_in_region)
|
||||
blend_weights = 0.4 * (1.0 - normalized_distances) # Start with 40% source influence at border
|
||||
|
||||
# Apply blending with closest source border pixels
|
||||
for i, (ref_y, ref_x) in enumerate(reference_coords):
|
||||
closest_source_idx = closest_source_indices[i]
|
||||
source_y, source_x = source_border_coords[closest_source_idx]
|
||||
|
||||
weight = blend_weights[i]
|
||||
# Blend: weight=0.4 means 60% reference + 40% source at border
|
||||
result[ref_y, ref_x] = (
|
||||
(1.0 - weight) * reference_frame[ref_y, ref_x] +
|
||||
weight * source_frame[source_y, source_x]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def apply_copy_with_mask(source_frame, reference_frame, mask, copy_source):
|
||||
"""
|
||||
Copy pixels to mask=0 regions based on copy_source parameter.
|
||||
|
||||
Args:
|
||||
source_frame: RGB image (H, W, C)
|
||||
reference_frame: RGB image (H, W, C)
|
||||
mask: Binary mask (H, W)
|
||||
copy_source: "reference" or "source"
|
||||
|
||||
Returns:
|
||||
Combined RGB image (H, W, C)
|
||||
"""
|
||||
result = source_frame.copy()
|
||||
mask_0_region = (mask == 0)
|
||||
|
||||
if copy_source == "reference":
|
||||
result[mask_0_region] = reference_frame[mask_0_region]
|
||||
# If "source", we keep the original source pixels (no change needed)
|
||||
|
||||
return result
|
||||
@ -103,46 +103,67 @@ def generate_notification_beep(volume=50, sample_rate=44100):
|
||||
wave = wave / max_amplitude * 0.85 # More conservative normalization
|
||||
|
||||
return wave
|
||||
|
||||
_mixer_lock = threading.Lock()
|
||||
|
||||
def play_audio_with_pygame(audio_data, sample_rate=44100):
|
||||
"""Play audio using pygame backend"""
|
||||
"""
|
||||
Play audio with clean stereo output - sounds like single notification from both speakers
|
||||
"""
|
||||
try:
|
||||
import pygame
|
||||
# Initialize pygame mixer only if not already initialized
|
||||
if not pygame.mixer.get_init():
|
||||
pygame.mixer.pre_init(frequency=sample_rate, size=-16, channels=2, buffer=1024)
|
||||
pygame.mixer.init()
|
||||
else:
|
||||
# Reinitialize with new settings if needed
|
||||
current_freq, current_size, current_channels = pygame.mixer.get_init()
|
||||
if current_freq != sample_rate or current_channels != 2:
|
||||
|
||||
with _mixer_lock:
|
||||
if len(audio_data) == 0:
|
||||
return False
|
||||
|
||||
# Clean mixer initialization - quit any existing mixer first
|
||||
if pygame.mixer.get_init() is not None:
|
||||
pygame.mixer.quit()
|
||||
pygame.mixer.pre_init(frequency=sample_rate, size=-16, channels=2, buffer=1024)
|
||||
pygame.mixer.init()
|
||||
|
||||
audio_int16 = (audio_data * 32767).astype(np.int16)
|
||||
|
||||
# Convert mono to stereo
|
||||
if len(audio_int16.shape) == 1:
|
||||
stereo_data = np.column_stack((audio_int16, audio_int16))
|
||||
else:
|
||||
stereo_data = audio_int16
|
||||
|
||||
sound = pygame.sndarray.make_sound(stereo_data)
|
||||
sound.play()
|
||||
pygame.time.wait(int(len(audio_data) / sample_rate * 1000) + 100)
|
||||
# Don't quit mixer - this can interfere with Gradio server
|
||||
# pygame.mixer.quit()
|
||||
return True
|
||||
time.sleep(0.2) # Longer pause to ensure clean shutdown
|
||||
|
||||
# Initialize fresh mixer
|
||||
pygame.mixer.pre_init(
|
||||
frequency=sample_rate,
|
||||
size=-16,
|
||||
channels=2,
|
||||
buffer=512 # Smaller buffer to reduce latency/doubling
|
||||
)
|
||||
pygame.mixer.init()
|
||||
|
||||
# Verify clean initialization
|
||||
mixer_info = pygame.mixer.get_init()
|
||||
if mixer_info is None or mixer_info[2] != 2:
|
||||
return False
|
||||
|
||||
# Prepare audio - ensure clean conversion
|
||||
audio_int16 = (audio_data * 32767).astype(np.int16)
|
||||
if len(audio_int16.shape) > 1:
|
||||
audio_int16 = audio_int16.flatten()
|
||||
|
||||
# Create clean stereo with identical channels
|
||||
stereo_data = np.zeros((len(audio_int16), 2), dtype=np.int16)
|
||||
stereo_data[:, 0] = audio_int16 # Left channel
|
||||
stereo_data[:, 1] = audio_int16 # Right channel
|
||||
|
||||
# Create sound and play once
|
||||
sound = pygame.sndarray.make_sound(stereo_data)
|
||||
|
||||
# Ensure only one playback
|
||||
pygame.mixer.stop() # Stop any previous sounds
|
||||
sound.play()
|
||||
|
||||
# Wait for completion
|
||||
duration_ms = int(len(audio_data) / sample_rate * 1000) + 50
|
||||
pygame.time.wait(duration_ms)
|
||||
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Pygame error: {e}")
|
||||
print(f"Pygame clean error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def play_audio_with_sounddevice(audio_data, sample_rate=44100):
|
||||
"""Play audio using sounddevice backend"""
|
||||
try:
|
||||
|
||||
256
wan/utils/stats.py
Normal file
256
wan/utils/stats.py
Normal file
@ -0,0 +1,256 @@
|
||||
import gradio as gr
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import atexit
|
||||
from contextlib import contextmanager
|
||||
from collections import deque
|
||||
import psutil
|
||||
import pynvml
|
||||
|
||||
# Initialize NVIDIA Management Library (NVML) for GPU monitoring
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
nvml_initialized = True
|
||||
except pynvml.NVMLError:
|
||||
print("Warning: Could not initialize NVML. GPU stats will not be available.")
|
||||
nvml_initialized = False
|
||||
|
||||
class SystemStatsApp:
|
||||
def __init__(self):
|
||||
self.running = False
|
||||
self.active_generators = []
|
||||
self.setup_signal_handlers()
|
||||
|
||||
def setup_signal_handlers(self):
|
||||
# Handle different shutdown signals
|
||||
signal.signal(signal.SIGINT, self.shutdown_handler)
|
||||
signal.signal(signal.SIGTERM, self.shutdown_handler)
|
||||
if hasattr(signal, 'SIGBREAK'): # Windows
|
||||
signal.signal(signal.SIGBREAK, self.shutdown_handler)
|
||||
|
||||
# Also register atexit handler as backup
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def shutdown_handler(self, signum, frame):
|
||||
# print(f"\nReceived signal {signum}. Shutting down gracefully...")
|
||||
self.cleanup()
|
||||
sys.exit(0)
|
||||
|
||||
def cleanup(self):
|
||||
if not self.running:
|
||||
print("Cleaning up streaming connections...")
|
||||
self.running = False
|
||||
# Give a moment for generators to stop
|
||||
time.sleep(1)
|
||||
|
||||
def get_system_stats(self, first = False, last_disk_io = psutil.disk_io_counters() ):
|
||||
|
||||
# Set a reasonable maximum speed for the bar graph display.
|
||||
# 100 MB/s will represent a 100% full bar.
|
||||
MAX_SSD_SPEED_MB_S = 100.0
|
||||
# Get CPU and RAM stats
|
||||
if first :
|
||||
cpu_percent = psutil.cpu_percent(interval=.01)
|
||||
else:
|
||||
cpu_percent = psutil.cpu_percent(interval=1) # This provides our 1-second delay
|
||||
memory_info = psutil.virtual_memory()
|
||||
ram_percent = memory_info.percent
|
||||
ram_used_gb = memory_info.used / (1024**3)
|
||||
ram_total_gb = memory_info.total / (1024**3)
|
||||
|
||||
# Get new disk IO counters and calculate the read/write speed in MB/s
|
||||
current_disk_io = psutil.disk_io_counters()
|
||||
read_mb_s = (current_disk_io.read_bytes - last_disk_io.read_bytes) / (1024**2)
|
||||
write_mb_s = (current_disk_io.write_bytes - last_disk_io.write_bytes) / (1024**2)
|
||||
total_disk_speed = read_mb_s + write_mb_s
|
||||
|
||||
# Update the last counters for the next loop
|
||||
last_disk_io = current_disk_io
|
||||
|
||||
# Calculate the bar height as a percentage of our defined max speed
|
||||
ssd_bar_height = min(100.0, (total_disk_speed / MAX_SSD_SPEED_MB_S) * 100)
|
||||
|
||||
# Get GPU stats if the library was initialized successfully
|
||||
if nvml_initialized:
|
||||
try:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # Assuming GPU 0
|
||||
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
||||
gpu_percent = util.gpu
|
||||
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
vram_percent = (mem_info.used / mem_info.total) * 100
|
||||
vram_used_gb = mem_info.used / (1024**3)
|
||||
vram_total_gb = mem_info.total / (1024**3)
|
||||
except pynvml.NVMLError:
|
||||
# Handle cases where GPU might be asleep or driver issues
|
||||
gpu_percent, vram_percent, vram_used_gb, vram_total_gb = 0, 0, 0, 0
|
||||
else:
|
||||
# Set default values if NVML failed to load
|
||||
gpu_percent, vram_percent, vram_used_gb, vram_total_gb = 0, 0, 0, 0
|
||||
|
||||
stats_html = f"""
|
||||
<style>
|
||||
.stats-container {{
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: flex-start;
|
||||
padding: 0px 5px;
|
||||
height: 60px;
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
}}
|
||||
|
||||
.stats-block {{
|
||||
width: calc(18% - 5px);
|
||||
min-width: 100px;
|
||||
text-align: center;
|
||||
font-family: sans-serif;
|
||||
}}
|
||||
|
||||
.stats-bar-background {{
|
||||
width: 90%;
|
||||
height: 30px;
|
||||
background-color: #e9ecef;
|
||||
border: 1px solid #dee2e6;
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
margin: 0 auto;
|
||||
}}
|
||||
|
||||
.stats-bar-fill {{
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
height: 100%;
|
||||
background-color: #0d6efd;
|
||||
}}
|
||||
|
||||
.stats-title {{
|
||||
margin-top: 5px;
|
||||
font-size: 11px;
|
||||
font-weight: bold;
|
||||
}}
|
||||
|
||||
.stats-detail {{
|
||||
font-size: 10px;
|
||||
margin-top: -2px;
|
||||
}}
|
||||
</style>
|
||||
|
||||
<div class="stats-container">
|
||||
<!-- CPU Stat Block -->
|
||||
<div class="stats-block">
|
||||
<div class="stats-bar-background">
|
||||
<div class="stats-bar-fill" style="width: {cpu_percent}%;"></div>
|
||||
</div>
|
||||
<div class="stats-title">CPU: {cpu_percent:.1f}%</div>
|
||||
</div>
|
||||
|
||||
<!-- RAM Stat Block -->
|
||||
<div class="stats-block">
|
||||
<div class="stats-bar-background">
|
||||
<div class="stats-bar-fill" style="width: {ram_percent}%;"></div>
|
||||
</div>
|
||||
<div class="stats-title">RAM {ram_percent:.1f}%</div>
|
||||
<div class="stats-detail">{ram_used_gb:.1f} / {ram_total_gb:.1f} GB</div>
|
||||
</div>
|
||||
|
||||
<!-- SSD Activity Stat Block -->
|
||||
<div class="stats-block">
|
||||
<div class="stats-bar-background">
|
||||
<div class="stats-bar-fill" style="width: {ssd_bar_height}%;"></div>
|
||||
</div>
|
||||
<div class="stats-title">SSD R/W</div>
|
||||
<div class="stats-detail">{read_mb_s:.1f} / {write_mb_s:.1f} MB/s</div>
|
||||
</div>
|
||||
|
||||
<!-- GPU Stat Block -->
|
||||
<div class="stats-block">
|
||||
<div class="stats-bar-background">
|
||||
<div class="stats-bar-fill" style="width: {gpu_percent}%;"></div>
|
||||
</div>
|
||||
<div class="stats-title">GPU: {gpu_percent:.1f}%</div>
|
||||
</div>
|
||||
|
||||
<!-- VRAM Stat Block -->
|
||||
<div class="stats-block">
|
||||
<div class="stats-bar-background">
|
||||
<div class="stats-bar-fill" style="width: {vram_percent}%;"></div>
|
||||
</div>
|
||||
<div class="stats-title">VRAM {vram_percent:.1f}%</div>
|
||||
<div class="stats-detail">{vram_used_gb:.1f} / {vram_total_gb:.1f} GB</div>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
return stats_html, last_disk_io
|
||||
|
||||
def streaming_html(self, state):
|
||||
if "stats_running" in state:
|
||||
return
|
||||
state["stats_running"] = True
|
||||
|
||||
self.running = True
|
||||
last_disk_io = psutil.disk_io_counters()
|
||||
i = 0
|
||||
import time
|
||||
try:
|
||||
while self.running:
|
||||
i+= 1
|
||||
# if i % 2 == 0:
|
||||
# print(f"time:{time.time()}")
|
||||
html_content, last_disk_io = self.get_system_stats(False, last_disk_io)
|
||||
yield html_content
|
||||
# time.sleep(1)
|
||||
|
||||
except GeneratorExit:
|
||||
# print("Generator stopped gracefully")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Streaming error: {e}")
|
||||
# finally:
|
||||
# # Send final message indicating clean shutdown
|
||||
final_html = """
|
||||
<DIV>
|
||||
<img src="x" onerror="
|
||||
setInterval(()=>{
|
||||
console.log('trying...');
|
||||
setTimeout(() => {
|
||||
try{
|
||||
const btn = document.getElementById('restart_stats');
|
||||
if(btn) {
|
||||
console.log('found button, clicking');
|
||||
btn.click();
|
||||
} else {
|
||||
console.log('button not found');
|
||||
}
|
||||
}catch(e){console.log('error: ' + e.message)}
|
||||
}, 100);
|
||||
}, 8000);" style="display:none;">
|
||||
|
||||
<button onclick="document.getElementById('restart_stats').click()"
|
||||
style="background: #007bff; color: white; padding: 15px 30px;
|
||||
border: none; border-radius: 5px; font-size: 16px; cursor: pointer;">
|
||||
🔄 Connection to Server Lost. Attempting Auto reconnect. Click Here to for Manual Connection
|
||||
</button>
|
||||
</DIV>
|
||||
"""
|
||||
try:
|
||||
yield final_html
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def get_gradio_element(self):
|
||||
self.system_stats_display = gr.HTML(self.get_system_stats(True)[0])
|
||||
self.restart_btn = gr.Button("restart stats",elem_id="restart_stats", visible= False) # False)
|
||||
return self.system_stats_display
|
||||
|
||||
def setup_events(self, main, state):
|
||||
gr.on([main.load, self.restart_btn.click],
|
||||
fn=self.streaming_html,
|
||||
inputs = state,
|
||||
outputs=self.system_stats_display,
|
||||
show_progress=False
|
||||
)
|
||||
275
wgp.py
275
wgp.py
@ -51,7 +51,7 @@ AUTOSAVE_FILENAME = "queue.zip"
|
||||
PROMPT_VARS_MAX = 10
|
||||
|
||||
target_mmgp_version = "3.5.1"
|
||||
WanGP_version = "7.12"
|
||||
WanGP_version = "7.2"
|
||||
settings_version = 2.22
|
||||
max_source_video_frames = 3000
|
||||
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
|
||||
@ -100,16 +100,19 @@ def download_ffmpeg():
|
||||
os.rename(f, os.path.basename(f))
|
||||
os.remove(zip_name)
|
||||
|
||||
|
||||
def format_time(seconds):
|
||||
if seconds < 60:
|
||||
return f"{seconds:.1f}s"
|
||||
elif seconds < 3600:
|
||||
minutes = seconds / 60
|
||||
return f"{minutes:.1f}m"
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
secs = int(seconds % 60)
|
||||
|
||||
if hours > 0:
|
||||
return f"{hours}h {minutes:02d}m {secs:02d}s"
|
||||
elif seconds >= 60:
|
||||
return f"{minutes}m {secs:02d}s"
|
||||
else:
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
return f"{hours}h {minutes}m"
|
||||
return f"{seconds:.1f}s"
|
||||
|
||||
def pil_to_base64_uri(pil_image, format="png", quality=75):
|
||||
if pil_image is None:
|
||||
return None
|
||||
@ -274,6 +277,7 @@ def process_prompt_and_add_tasks(state, model_choice):
|
||||
num_inference_steps= inputs["num_inference_steps"]
|
||||
skip_steps_cache_type= inputs["skip_steps_cache_type"]
|
||||
MMAudio_setting = inputs["MMAudio_setting"]
|
||||
image_mode = inputs["image_mode"]
|
||||
|
||||
|
||||
if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20:
|
||||
@ -438,7 +442,7 @@ def process_prompt_and_add_tasks(state, model_choice):
|
||||
image_end = None
|
||||
|
||||
|
||||
if test_any_sliding_window(model_type):
|
||||
if test_any_sliding_window(model_type) and image_mode == 0:
|
||||
if video_length > sliding_window_size:
|
||||
full_video_length = video_length if video_source is None else video_length + sliding_window_overlap
|
||||
extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation"
|
||||
@ -628,9 +632,12 @@ def move_up(queue, selected_indices):
|
||||
idx = idx[0]
|
||||
idx = int(idx)
|
||||
with lock:
|
||||
if idx > 0:
|
||||
idx += 1
|
||||
idx += 1
|
||||
if idx > 1:
|
||||
queue[idx], queue[idx-1] = queue[idx-1], queue[idx]
|
||||
elif idx == 1:
|
||||
queue[:] = queue[0:1] + queue[2:] + queue[1:2]
|
||||
|
||||
return update_queue_data(queue)
|
||||
|
||||
def move_down(queue, selected_indices):
|
||||
@ -644,6 +651,9 @@ def move_down(queue, selected_indices):
|
||||
idx += 1
|
||||
if idx < len(queue)-1:
|
||||
queue[idx], queue[idx+1] = queue[idx+1], queue[idx]
|
||||
elif idx == len(queue)-1:
|
||||
queue[:] = queue[0:1] + queue[-1:] + queue[1:-1]
|
||||
|
||||
return update_queue_data(queue)
|
||||
|
||||
def remove_task(queue, selected_indices):
|
||||
@ -1056,8 +1066,10 @@ def show_countdown_info_from_state(current_value: int):
|
||||
gr.Info(f"Quitting in {current_value}...")
|
||||
return current_value - 1
|
||||
return current_value
|
||||
|
||||
quitting_app = False
|
||||
def autosave_queue():
|
||||
global quitting_app
|
||||
quitting_app = True
|
||||
global global_queue_ref
|
||||
if not global_queue_ref:
|
||||
print("Autosave: Queue is empty, nothing to save.")
|
||||
@ -1937,7 +1949,7 @@ def fix_settings(model_type, ui_defaults):
|
||||
video_prompt_type = video_prompt_type.replace("I", "")
|
||||
|
||||
|
||||
remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 0)
|
||||
remove_background_images_ref = ui_defaults.get("remove_background_images_ref", 1)
|
||||
if video_settings_version < 2.22:
|
||||
if "I" in video_prompt_type:
|
||||
if remove_background_images_ref == 2:
|
||||
@ -2039,7 +2051,7 @@ def get_default_settings(model_type):
|
||||
ui_defaults.update({
|
||||
"guidance_scale": 7.5,
|
||||
"flow_shift": 5,
|
||||
"remove_background_images_ref": 0,
|
||||
"remove_background_images_ref": 1,
|
||||
"video_prompt_type": "I",
|
||||
# "resolution": "1280x720"
|
||||
})
|
||||
@ -2359,19 +2371,20 @@ def download_models(model_filename, model_type):
|
||||
|
||||
def download_file(url,filename):
|
||||
if url.startswith("https://huggingface.co/") and "/resolve/main/" in url:
|
||||
base_dir = os.path.dirname(filename)
|
||||
url = url[len("https://huggingface.co/"):]
|
||||
url_parts = url.split("/resolve/main/")
|
||||
repoId = url_parts[0]
|
||||
onefile = os.path.basename(url_parts[-1])
|
||||
sourceFolder = os.path.dirname(url_parts[-1])
|
||||
if len(sourceFolder) == 0:
|
||||
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/")
|
||||
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/" if len(base_dir)==0 else base_dir)
|
||||
else:
|
||||
target_path = "ckpts/temp/" + sourceFolder
|
||||
if not os.path.exists(target_path):
|
||||
os.makedirs(target_path)
|
||||
hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/temp/", subfolder=sourceFolder)
|
||||
shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/")
|
||||
shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/" if len(base_dir)==0 else base_dir)
|
||||
shutil.rmtree("ckpts/temp")
|
||||
else:
|
||||
urlretrieve(url,filename, create_progress_hook(filename))
|
||||
@ -2397,9 +2410,7 @@ def download_models(model_filename, model_type):
|
||||
model_filename = None
|
||||
|
||||
preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True)
|
||||
model_loras = get_model_recursive_prop(model_type, "loras", return_list= True)
|
||||
|
||||
for url in preload_URLs + model_loras:
|
||||
for url in preload_URLs:
|
||||
filename = "ckpts/" + url.split("/")[-1]
|
||||
if not os.path.isfile(filename ):
|
||||
if not url.startswith("http"):
|
||||
@ -2409,6 +2420,19 @@ def download_models(model_filename, model_type):
|
||||
except Exception as e:
|
||||
if os.path.isfile(filename): os.remove(filename)
|
||||
raise Exception(f"Preload URL '{url}' is invalid: {str(e)}'")
|
||||
|
||||
model_loras = get_model_recursive_prop(model_type, "loras", return_list= True)
|
||||
for url in model_loras:
|
||||
filename = os.path.join(get_lora_dir(model_type), url.split("/")[-1])
|
||||
if not os.path.isfile(filename ):
|
||||
if not url.startswith("http"):
|
||||
raise Exception(f"Lora '{filename}' was not found in the Loras Folder and no URL was provided to download it. Please add an URL in the model definition file.")
|
||||
try:
|
||||
download_file(url, filename)
|
||||
except Exception as e:
|
||||
if os.path.isfile(filename): os.remove(filename)
|
||||
raise Exception(f"Lora URL '{url}' is invalid: {str(e)}'")
|
||||
|
||||
if model_family == "wan":
|
||||
text_encoder_filename = get_wan_text_encoder_filename(text_encoder_quantization)
|
||||
model_files = {
|
||||
@ -2777,7 +2801,7 @@ def generate_header(model_type, compile, attention_mode):
|
||||
get_model_name(model_type, description_container)
|
||||
model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or ""
|
||||
description = description_container[0]
|
||||
header = "<DIV style='height:40px'>" + description + "</DIV>"
|
||||
header = f"<DIV style=height:{60 if server_config.get('display_stats', 0) == 1 else 40}px>{description}</DIV>"
|
||||
|
||||
header += "<DIV style='align:right;width:100%'><FONT SIZE=3>Attention mode <B>" + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
|
||||
if attention_mode not in attention_modes_installed:
|
||||
@ -2824,12 +2848,13 @@ def apply_changes( state,
|
||||
notification_sound_enabled_choice = 1,
|
||||
notification_sound_volume_choice = 50,
|
||||
max_frames_multiplier_choice = 1,
|
||||
display_stats_choice = 0,
|
||||
last_resolution_choice = None,
|
||||
):
|
||||
if args.lock_config:
|
||||
return
|
||||
if gen_in_progress:
|
||||
return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>", gr.update(), gr.update()
|
||||
return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>",*[gr.update()]*5
|
||||
global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
|
||||
server_config = {
|
||||
"attention_mode" : attention_choice,
|
||||
@ -2856,6 +2881,7 @@ def apply_changes( state,
|
||||
"notification_sound_enabled" : notification_sound_enabled_choice,
|
||||
"notification_sound_volume" : notification_sound_volume_choice,
|
||||
"max_frames_multiplier" : max_frames_multiplier_choice,
|
||||
"display_stats" : display_stats_choice,
|
||||
"last_model_type" : state["model_type"],
|
||||
"last_advanced_choice": state["advanced"],
|
||||
"last_resolution_choice": last_resolution_choice,
|
||||
@ -2894,7 +2920,7 @@ def apply_changes( state,
|
||||
transformer_types = server_config["transformer_types"]
|
||||
model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy)
|
||||
state["model_filename"] = model_filename
|
||||
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier"] for change in changes ):
|
||||
if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats"] for change in changes ):
|
||||
model_choice = gr.Dropdown()
|
||||
else:
|
||||
reload_needed = True
|
||||
@ -2926,6 +2952,7 @@ def get_gen_info(state):
|
||||
def build_callback(state, pipe, send_cmd, status, num_inference_steps):
|
||||
gen = get_gen_info(state)
|
||||
gen["num_inference_steps"] = num_inference_steps
|
||||
start_time = time.time()
|
||||
def callback(step_idx, latent, force_refresh, read_state = False, override_num_inference_steps = -1, pass_no = -1):
|
||||
refresh_id = gen.get("refresh", -1)
|
||||
if force_refresh or step_idx >= 0:
|
||||
@ -2966,6 +2993,9 @@ def build_callback(state, pipe, send_cmd, status, num_inference_steps):
|
||||
|
||||
gen["progress_phase"] = (phase, step_idx)
|
||||
status_msg = merge_status_context(status, phase)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
status_msg = merge_status_context(status, f"{phase} | {format_time(elapsed_time)}")
|
||||
if step_idx >= 0:
|
||||
progress_args = [(step_idx , num_inference_steps) , status_msg , num_inference_steps]
|
||||
else:
|
||||
@ -3020,7 +3050,7 @@ def refresh_gallery(state): #, msg
|
||||
base_model_type = get_base_model_type(model_type)
|
||||
model_def = get_model_def(model_type)
|
||||
is_image = model_def.get("image_outputs", False)
|
||||
onemorewindow_visible = test_any_sliding_window(base_model_type) and not is_image
|
||||
onemorewindow_visible = test_any_sliding_window(base_model_type) and params.get("image_mode",0) == 0
|
||||
enhanced = False
|
||||
if prompt.startswith("!enhanced!\n"):
|
||||
enhanced = True
|
||||
@ -3256,7 +3286,7 @@ def select_video(state, input_file_list, event_data: gr.EventData):
|
||||
if len(video_other_prompts) >0 :
|
||||
values += [video_other_prompts]
|
||||
labels += ["Other Prompts"]
|
||||
if len(video_outpainting) >0 :
|
||||
if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"):
|
||||
values += [video_outpainting]
|
||||
labels += ["Outpainting"]
|
||||
values += [video_resolution, video_length_summary, video_seed, video_guidance_scale, video_flow_shift, video_num_inference_steps]
|
||||
@ -3911,7 +3941,8 @@ def edit_video(
|
||||
def get_transformer_loras(model_type):
|
||||
model_def = get_model_def(model_type)
|
||||
transformer_loras_filenames = get_model_recursive_prop(model_type, "loras", return_list=True)
|
||||
transformer_loras_filenames = [ os.path.join("ckpts", os.path.basename(filename)) for filename in transformer_loras_filenames]
|
||||
lora_dir = get_lora_dir(model_type)
|
||||
transformer_loras_filenames = [ os.path.join(lora_dir, os.path.basename(filename)) for filename in transformer_loras_filenames]
|
||||
transformer_loras_multipliers = get_model_recursive_prop(model_type, "loras_multipliers", return_list=True) + [1.] * len(transformer_loras_filenames)
|
||||
transformer_loras_multipliers = transformer_loras_multipliers[:len(transformer_loras_filenames)]
|
||||
return transformer_loras_filenames, transformer_loras_multipliers
|
||||
@ -3966,6 +3997,7 @@ def generate_video(
|
||||
speakers_locations,
|
||||
sliding_window_size,
|
||||
sliding_window_overlap,
|
||||
sliding_window_color_correction_strength,
|
||||
sliding_window_overlap_noise,
|
||||
sliding_window_discard_last_frames,
|
||||
remove_background_images_ref,
|
||||
@ -3988,6 +4020,7 @@ def generate_video(
|
||||
cfg_star_switch,
|
||||
cfg_zero_step,
|
||||
prompt_enhancer,
|
||||
min_frames_if_references,
|
||||
state,
|
||||
model_type,
|
||||
model_filename,
|
||||
@ -4017,7 +4050,8 @@ def generate_video(
|
||||
model_def = get_model_def(model_type)
|
||||
is_image = image_mode == 1
|
||||
if is_image:
|
||||
video_length = 1
|
||||
# min_frames_if_references = server_config.get("min_frames_if_references", 5)
|
||||
video_length = min_frames_if_references if "I" in video_prompt_type else 1
|
||||
else:
|
||||
batch_size = 1
|
||||
temp_filenames_list = []
|
||||
@ -4205,7 +4239,7 @@ def generate_video(
|
||||
for i, pos in enumerate(frames_positions_list):
|
||||
frames_to_inject[pos] = image_refs[i]
|
||||
if video_guide == None and video_source == None and not "L" in image_prompt_type and (nb_frames_positions > 0 or "K" in video_prompt_type) :
|
||||
from wan.utils.utils import resize_lanczos, get_outpainting_full_area_dimensions
|
||||
from wan.utils.utils import get_outpainting_full_area_dimensions
|
||||
w, h = image_refs[0].size
|
||||
if outpainting_dims != None:
|
||||
h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims)
|
||||
@ -4602,6 +4636,7 @@ def generate_video(
|
||||
input_masks = src_mask,
|
||||
input_video= pre_video_guide if diffusion_forcing or ltxv or hunyuan_custom_edit else source_video,
|
||||
denoising_strength=denoising_strength,
|
||||
prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames,
|
||||
target_camera= target_camera,
|
||||
frame_num= (current_video_length // latent_size)* latent_size + 1,
|
||||
batch_size = batch_size,
|
||||
@ -4639,6 +4674,7 @@ def generate_video(
|
||||
overlapped_latents = overlapped_latents,
|
||||
return_latent_slice= return_latent_slice,
|
||||
overlap_noise = sliding_window_overlap_noise,
|
||||
color_correction_strength = sliding_window_color_correction_strength,
|
||||
conditioning_latents_size = conditioning_latents_size,
|
||||
keep_frames_parsed = keep_frames_parsed,
|
||||
model_filename = model_filename,
|
||||
@ -4648,6 +4684,7 @@ def generate_video(
|
||||
NAG_tau = NAG_tau,
|
||||
NAG_alpha = NAG_alpha,
|
||||
speakers_bboxes =speakers_bboxes,
|
||||
image_mode = image_mode,
|
||||
offloadobj = offloadobj,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -5227,9 +5264,11 @@ def process_tasks(state):
|
||||
gen["prompt"] = ""
|
||||
end_time = time.time()
|
||||
if abort:
|
||||
status = f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
|
||||
# status = f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
|
||||
status = f"Video generation was aborted. Total Generation Time: {format_time(end_time-start_time)}"
|
||||
else:
|
||||
status = f"Total Generation Time: {end_time-start_time:.1f}s"
|
||||
# status = f"Total Generation Time: {end_time-start_time:.1f}s"
|
||||
status = f"Total Generation Time: {format_time(end_time-start_time)}"
|
||||
# Play notification sound when video generation completed successfully
|
||||
try:
|
||||
if server_config.get("notification_sound_enabled", 1):
|
||||
@ -5273,8 +5312,13 @@ def merge_status_context(status="", context=""):
|
||||
elif len(context) == 0:
|
||||
return status
|
||||
else:
|
||||
return status + " - " + context
|
||||
|
||||
# Check if context already contains the time
|
||||
if "|" in context:
|
||||
parts = context.split("|")
|
||||
return f"{status} - {parts[0].strip()} | {parts[1].strip()}"
|
||||
else:
|
||||
return f"{status} - {context}"
|
||||
|
||||
def clear_status(state):
|
||||
gen = get_gen_info(state)
|
||||
gen["extra_windows"] = 0
|
||||
@ -5314,7 +5358,7 @@ def one_more_sample(state):
|
||||
total_generation = gen.get("total_generation", 0) + extra_orders
|
||||
gen["progress_status"] = get_latest_status(state)
|
||||
gen["refresh"] = get_new_refresh_id()
|
||||
gr.Info(f"An extra sample generation is planned for a total of {total_generation} videos for this prompt")
|
||||
gr.Info(f"An extra sample generation is planned for a total of {total_generation} samples for this prompt")
|
||||
|
||||
return state
|
||||
|
||||
@ -5769,13 +5813,13 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
|
||||
if base_model_type in ["t2v"]: unsaved_params = unsaved_params[2:]
|
||||
pop += unsaved_params
|
||||
if not vace:
|
||||
pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"]
|
||||
pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2", "min_frames_if_references"]
|
||||
|
||||
if not (diffusion_forcing or ltxv or vace):
|
||||
pop += ["keep_frames_video_source"]
|
||||
|
||||
if not test_any_sliding_window( base_model_type):
|
||||
pop += ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames"]
|
||||
pop += ["sliding_window_size", "sliding_window_overlap", "sliding_window_overlap_noise", "sliding_window_discard_last_frames", "sliding_window_color_correction_strength"]
|
||||
|
||||
if not base_model_type in ["fantasy", "multitalk", "vace_multitalk_14B"]:
|
||||
pop += ["audio_guidance_scale", "speakers_locations"]
|
||||
@ -6058,8 +6102,21 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw
|
||||
|
||||
return configs, tags != None
|
||||
|
||||
def record_image_mode_tab(state, evt:gr.SelectData):
|
||||
state["image_mode_tab"] = 0 if evt.index ==0 else 1
|
||||
|
||||
def switch_image_mode(state):
|
||||
image_mode = state.get("image_mode_tab", 0)
|
||||
model_type =state["model_type"]
|
||||
ui_defaults = get_model_settings(state, model_type)
|
||||
|
||||
ui_defaults["image_mode"] = image_mode
|
||||
|
||||
return str(time.time())
|
||||
|
||||
def load_settings_from_file(state, file_path):
|
||||
gen = get_gen_info(state)
|
||||
|
||||
if file_path==None:
|
||||
return gr.update(), gr.update(), None
|
||||
|
||||
@ -6135,6 +6192,7 @@ def save_inputs(
|
||||
speakers_locations,
|
||||
sliding_window_size,
|
||||
sliding_window_overlap,
|
||||
sliding_window_color_correction_strength,
|
||||
sliding_window_overlap_noise,
|
||||
sliding_window_discard_last_frames,
|
||||
remove_background_images_ref,
|
||||
@ -6157,6 +6215,7 @@ def save_inputs(
|
||||
cfg_star_switch,
|
||||
cfg_zero_step,
|
||||
prompt_enhancer,
|
||||
min_frames_if_references,
|
||||
mode,
|
||||
state,
|
||||
):
|
||||
@ -6181,7 +6240,7 @@ def save_inputs(
|
||||
|
||||
def download_loras():
|
||||
from huggingface_hub import snapshot_download
|
||||
yield gr.Row(visible=True), "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>", *[gr.Column(visible=False)] * 2
|
||||
yield gr.Row(visible=True), "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>" #, *[gr.Column(visible=False)] * 2
|
||||
lora_dir = get_lora_dir("i2v")
|
||||
log_path = os.path.join(lora_dir, "log.txt")
|
||||
if not os.path.isfile(log_path):
|
||||
@ -6198,7 +6257,7 @@ def download_loras():
|
||||
os.remove(tmp_path)
|
||||
except:
|
||||
pass
|
||||
yield gr.Row(visible=True), "<B><FONT SIZE=3>Loras have been completely downloaded</B></FONT>", *[gr.Column(visible=True)] * 2
|
||||
yield gr.Row(visible=True), "<B><FONT SIZE=3>Loras have been completely downloaded</B></FONT>" #, *[gr.Column(visible=True)] * 2
|
||||
|
||||
from datetime import datetime
|
||||
dt = datetime.today().strftime('%Y-%m-%d')
|
||||
@ -6358,16 +6417,16 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_
|
||||
vace= test_vace_module(state["model_type"])
|
||||
return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
|
||||
|
||||
def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask):
|
||||
def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask, image_mode):
|
||||
video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA")
|
||||
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask)
|
||||
visible= "A" in video_prompt_type
|
||||
model_type = state["model_type"]
|
||||
model_def = get_model_def(model_type)
|
||||
image_outputs = model_def.get("image_outputs", False)
|
||||
image_outputs = image_mode == 1
|
||||
return video_prompt_type, gr.update(visible= visible and not image_outputs), gr.update(visible= visible and image_outputs), gr.update(visible= visible )
|
||||
|
||||
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide):
|
||||
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode):
|
||||
video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMGUV")
|
||||
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
|
||||
visible = "V" in video_prompt_type
|
||||
@ -6375,10 +6434,9 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt
|
||||
base_model_type = get_base_model_type(model_type)
|
||||
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
|
||||
model_def = get_model_def(model_type)
|
||||
image_outputs = model_def.get("image_outputs", False)
|
||||
|
||||
image_outputs = image_mode == 1
|
||||
vace= test_vace_module(model_type)
|
||||
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = visible and not image_outputs), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
|
||||
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = visible and not image_outputs), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
|
||||
|
||||
# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
|
||||
# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
|
||||
@ -6694,17 +6752,22 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
any_start_image = False
|
||||
any_end_image = False
|
||||
any_reference_image = False
|
||||
v2i_switch_supported = (vace or t2v) and not image_outputs
|
||||
image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 )
|
||||
if not v2i_switch_supported and not image_outputs:
|
||||
image_mode_value = 0
|
||||
else:
|
||||
image_outputs = image_mode_value == 1
|
||||
image_mode = gr.Number(value =image_mode_value, visible = False)
|
||||
|
||||
# with gr.Tabs(visible = vace or t2v):
|
||||
# with gr.Tab("Text 2 Video"):
|
||||
# pass
|
||||
# with gr.Tab("Text 2 Image"):
|
||||
# pass
|
||||
with gr.Tabs(visible = v2i_switch_supported, selected= "t2i" if image_mode_value == 1 else "t2v" ) as image_mode_tabs:
|
||||
with gr.Tab("Text to Video", id = "t2v", elem_classes="compact_tab"):
|
||||
pass
|
||||
with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"):
|
||||
pass
|
||||
|
||||
# image_mode = gr.Number(value =ui_defaults.get("image_mode",0), visible = False)
|
||||
image_mode = gr.Number(value =1 if image_outputs else 0, visible = False)
|
||||
|
||||
with gr.Column(visible= test_class_i2v(model_type) or diffusion_forcing or ltxv or recammaster or vace) as image_prompt_column:
|
||||
with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace) as image_prompt_column:
|
||||
if vace:
|
||||
image_prompt_type_value= ui_defaults.get("image_prompt_type","")
|
||||
image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value
|
||||
@ -6770,7 +6833,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
)
|
||||
keep_frames_video_source = gr.Text(visible=False)
|
||||
else:
|
||||
if test_class_i2v(model_type):
|
||||
if test_class_i2v(model_type) or hunyuan_i2v:
|
||||
# image_prompt_type_value= ui_defaults.get("image_prompt_type","SE" if flf2v else "S" )
|
||||
image_prompt_type_value= ui_defaults.get("image_prompt_type","S" )
|
||||
image_prompt_type_choices = [("Use only a Start Image", "S")]
|
||||
@ -6959,7 +7022,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
choices=[
|
||||
("Keep Backgrounds behind People / Objects", 0),
|
||||
("Remove Backgrounds behind People / Objects", 1),
|
||||
# ("Keep it for first Image (landscape) and remove it for other Images (objects / people)", 2),
|
||||
],
|
||||
value=ui_defaults.get("remove_background_images_ref",1),
|
||||
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar and not flux
|
||||
@ -7271,6 +7333,23 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
)
|
||||
with gr.Row():
|
||||
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = not (hunyuan_i2v or hunyuan_t2v or hunyuan_video_avatar or hunyuan_i2v or hunyuan_video_custom ))
|
||||
|
||||
with gr.Column(visible = vace and image_outputs) as min_frames_if_references_col:
|
||||
gr.Markdown("<B>If using Reference Images, generating a single Frame alone may not be sufficient to preserve Identity")
|
||||
min_frames_if_references = gr.Dropdown(
|
||||
choices=[
|
||||
("Disabled, generate only one Frame", 1),
|
||||
("Generate a 5 Frames long Video but keep only the First Frame (x1.5 slower)",5),
|
||||
("Generate a 9 Frames long Video but keep only the First Frame (x2.0 slower)",9),
|
||||
("Generate a 13 Frames long Video but keep only the First Frame (x2.5 slower)",13),
|
||||
("Generate a 17 Frames long Video but keep only the First Frame (x3.0 slower)",17),
|
||||
],
|
||||
value=ui_defaults.get("min_frames_if_references",5),
|
||||
visible=True,
|
||||
scale = 1,
|
||||
label="Generate more frames to preserve Reference Image Identity"
|
||||
)
|
||||
|
||||
with gr.Tab("Sliding Window", visible= sliding_window_enabled) as sliding_window_tab:
|
||||
|
||||
with gr.Column():
|
||||
@ -7279,21 +7358,25 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
if diffusion_forcing:
|
||||
sliding_window_size = gr.Slider(37, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=20, label=" (recommended to keep it at 97)")
|
||||
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
||||
sliding_window_color_correction_strength = gr.Slider(0, 1, visible=False, value =0)
|
||||
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = True)
|
||||
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
|
||||
elif ltxv:
|
||||
sliding_window_size = gr.Slider(41, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=8, label="Sliding Window Size")
|
||||
sliding_window_overlap = gr.Slider(9, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
||||
sliding_window_color_correction_strength = gr.Slider(0, 1, visible=False, value =0)
|
||||
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
|
||||
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=8, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
|
||||
elif hunyuan_video_custom_edit:
|
||||
sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size")
|
||||
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
||||
sliding_window_color_correction_strength = gr.Slider(0, 1, visible=False, value =0)
|
||||
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
|
||||
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
|
||||
else: # Vace, Multitalk
|
||||
sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size")
|
||||
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
||||
sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)")
|
||||
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace)
|
||||
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
|
||||
|
||||
@ -7382,18 +7465,18 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
default_visibility = {} if update_form else {"visible" : False}
|
||||
video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info())
|
||||
with gr.Row(**default_visibility) as video_buttons_row:
|
||||
video_info_extract_settings_btn = gr.Button("Extract Settings", size ="sm")
|
||||
video_info_to_control_video_btn = gr.Button("To Control Video", size ="sm", visible = any_control_video )
|
||||
video_info_to_video_source_btn = gr.Button("To Video Source", size ="sm", visible = any_video_source)
|
||||
video_info_eject_video_btn = gr.Button("Eject Video", size ="sm")
|
||||
video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
|
||||
video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
|
||||
video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source)
|
||||
video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm")
|
||||
with gr.Row(**default_visibility) as image_buttons_row:
|
||||
video_info_extract_image_settings_btn = gr.Button("Extract Settings", size ="sm")
|
||||
video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", visible = any_start_image )
|
||||
video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", visible = any_end_image)
|
||||
video_info_to_image_guide_btn = gr.Button("To Control Image", size ="sm", visible = any_control_image )
|
||||
video_info_to_image_mask_btn = gr.Button("To Mask Image", size ="sm", visible = any_image_mask)
|
||||
video_info_to_reference_image_btn = gr.Button("To Reference Image", size ="sm", visible = any_reference_image)
|
||||
video_info_eject_image_btn = gr.Button("Eject Image", size ="sm")
|
||||
video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
|
||||
video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image )
|
||||
video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image)
|
||||
video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
|
||||
video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask)
|
||||
video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image)
|
||||
video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm")
|
||||
with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab:
|
||||
with gr.Group(elem_classes= "postprocess"):
|
||||
with gr.Column():
|
||||
@ -7462,7 +7545,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
|
||||
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
|
||||
video_buttons_row, image_buttons_row, video_postprocessing_tab, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn,
|
||||
NAG_col, speakers_locations_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row] # presets_column,
|
||||
NAG_col, speakers_locations_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, min_frames_if_references_col] # presets_column,
|
||||
if update_form:
|
||||
locals_dict = locals()
|
||||
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
|
||||
@ -7481,8 +7564,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] )
|
||||
# video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
|
||||
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions, video_guide_outpainting_col])
|
||||
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
|
||||
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
|
||||
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
|
||||
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
|
||||
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt])
|
||||
video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
|
||||
video_guide_outpainting_bottom.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_bottom,gr.State(1)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
|
||||
@ -7588,6 +7671,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
js=trigger_settings_download_js
|
||||
)
|
||||
|
||||
image_mode_tabs.select(fn=record_image_mode_tab, inputs=[state], outputs= None
|
||||
).then(fn=validate_wizard_prompt,
|
||||
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
|
||||
outputs= [prompt]
|
||||
).then(fn=save_inputs,
|
||||
inputs =[target_state] + gen_inputs,
|
||||
outputs= None
|
||||
).then(fn=switch_image_mode, inputs =[state] , outputs= [refresh_form_trigger], trigger_mode="multiple")
|
||||
|
||||
settings_file.upload(fn=validate_wizard_prompt,
|
||||
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
|
||||
@ -7779,7 +7870,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
||||
outputs=[modal_container]
|
||||
)
|
||||
|
||||
return ( state, loras_choices, lset_name, state, resolution,
|
||||
return ( state, loras_choices, lset_name, resolution,
|
||||
video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col
|
||||
)
|
||||
|
||||
@ -7874,6 +7965,15 @@ def generate_configuration_tab(state, blocks, header, model_choice, resolution,
|
||||
label="Keep Previously Generated Videos when starting a new Generation Batch"
|
||||
)
|
||||
|
||||
display_stats_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Disabled", 0),
|
||||
("Enabled", 1),
|
||||
],
|
||||
value=server_config.get("display_stats", 0),
|
||||
label="Display in real time available RAM / VRAM and other stats (needs a restart)"
|
||||
)
|
||||
|
||||
max_frames_multiplier_choice = gr.Dropdown(
|
||||
choices=[
|
||||
("Default", 1),
|
||||
@ -8075,6 +8175,7 @@ def generate_configuration_tab(state, blocks, header, model_choice, resolution,
|
||||
notification_sound_enabled_choice,
|
||||
notification_sound_volume_choice,
|
||||
max_frames_multiplier_choice,
|
||||
display_stats_choice,
|
||||
resolution,
|
||||
],
|
||||
outputs= [msg , header, model_choice, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col]
|
||||
@ -8085,20 +8186,24 @@ def generate_about_tab():
|
||||
gr.Markdown("Original Wan 2.1 Model by <B>Alibaba</B> (<A HREF='https://github.com/Wan-Video/Wan2.1'>GitHub</A>)")
|
||||
gr.Markdown("Many thanks to:")
|
||||
gr.Markdown("- <B>Alibaba Wan team for the best open source video generator")
|
||||
gr.Markdown("- <B>Alibaba Vace and Fun Teams for their incredible control net models")
|
||||
gr.Markdown("- <B>Alibaba Vace, Multitalk and Fun Teams for their incredible control net models")
|
||||
gr.Markdown("- <B>Tencent for the impressive Hunyuan Video models")
|
||||
gr.Markdown("- <B>Lightricks for the super fast LTX Video models")
|
||||
gr.Markdown("- <B>Blackforest Labs for the innovative Flux image generators")
|
||||
gr.Markdown("- <B>Lightricks for their super fast LTX Video models")
|
||||
gr.Markdown("<BR>Huge acknowlegments to these great open source projects used in WanGP:")
|
||||
gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
|
||||
gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
|
||||
gr.Markdown("- <B>DepthAnything</B> & <B>Midas</B>: Depth extractors (https://github.com/DepthAnything/Depth-Anything-V2) and (https://github.com/isl-org/MiDaS")
|
||||
gr.Markdown("- <B>Matanyone</B> and <B>SAM2</B>: Mask Generation (https://github.com/pq-yang/MatAnyone) and (https://github.com/facebookresearch/sam2)")
|
||||
gr.Markdown("- <B>Pyannote</B>: speaker diarization (https://github.com/pyannote/pyannote-audio)")
|
||||
|
||||
gr.Markdown("<BR>Special thanks to the following people for their support:")
|
||||
gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
|
||||
gr.Markdown("- <B>Tophness</B> : created (former) multi tabs and queuing frameworks")
|
||||
gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
|
||||
gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection")
|
||||
gr.Markdown("- <B>Reevoy24</B> : for his repackaging / completing the documentation")
|
||||
gr.Markdown("<BR>Huge acknowlegments to these great open source projects used in WanGP:")
|
||||
gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
|
||||
gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
|
||||
gr.Markdown("- <B>Midas</B>: Depth extractor (https://github.com/isl-org/MiDaS")
|
||||
gr.Markdown("- <B>Matanyone</B> and <B>SAM2</B>: Mask Generation (https://github.com/pq-yang/MatAnyone) and (https://github.com/facebookresearch/sam2)")
|
||||
|
||||
gr.Markdown("- <B>Redtash1</B> : for designing the protype of the RAM /VRAM stats viewer")
|
||||
|
||||
def generate_info_tab():
|
||||
|
||||
@ -8541,7 +8646,9 @@ def create_ui():
|
||||
z-index: 9999;
|
||||
transition: visibility 0s linear 1s, opacity 0.3s linear 1s; /* Delay both properties */
|
||||
}
|
||||
|
||||
div.compact_tab , span.compact_tab
|
||||
{ padding: 0px !important;
|
||||
}
|
||||
.hover-image .tooltip2 {
|
||||
visibility: hidden;
|
||||
opacity: 0;
|
||||
@ -8591,6 +8698,12 @@ def create_ui():
|
||||
console.log('sendColIndex function attached to window');
|
||||
}
|
||||
"""
|
||||
if server_config.get("display_stats", 0) == 1:
|
||||
from wan.utils.stats import SystemStatsApp
|
||||
stats_app = SystemStatsApp()
|
||||
else:
|
||||
stats_app = None
|
||||
|
||||
with gr.Blocks(css=css, js=js, theme=theme, title= "WanGP") as main:
|
||||
gr.Markdown(f"<div align=center><H1>Wan<SUP>GP</SUP> v{WanGP_version} <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
||||
global model_list
|
||||
@ -8609,8 +8722,11 @@ def create_ui():
|
||||
gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
|
||||
with gr.Row():
|
||||
header = gr.Markdown(generate_header(transformer_type, compile, attention_mode), visible= True)
|
||||
if stats_app is not None:
|
||||
stats_element = stats_app.get_gradio_element()
|
||||
|
||||
with gr.Row():
|
||||
( state, loras_choices, lset_name, state, resolution,
|
||||
( state, loras_choices, lset_name, resolution,
|
||||
video_guide, image_guide, video_mask, image_mask, image_refs, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col
|
||||
) = generate_video_tab(model_choice=model_choice, header=header, main = main)
|
||||
with gr.Tab("Guides", id="info") as info_tab:
|
||||
@ -8624,7 +8740,8 @@ def create_ui():
|
||||
generate_configuration_tab(state, main, header, model_choice, resolution, prompt_enhancer_row, mmaudio_tab, PP_MMAudio_col)
|
||||
with gr.Tab("About"):
|
||||
generate_about_tab()
|
||||
|
||||
if stats_app is not None:
|
||||
stats_app.setup_events(main, state)
|
||||
main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= main_tabs, trigger_mode="multiple")
|
||||
return main
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user