From febeb957674e1bcc409aa53c910f96edfc46cf3d Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 17 Jun 2025 23:45:47 +0200 Subject: [PATCH] Vace powercharged --- README.md | 12 + docs/FINETUNES.md | 4 +- finetunes/vace_14B_fusionix.json | 28 +- hyvideo/hunyuan.py | 6 +- ltx_video/ltxv.py | 2 - preprocessing/dwpose/util.py | 2 +- preprocessing/matanyone/app.py | 34 +- requirements.txt | 2 +- wan/diffusion_forcing.py | 5 +- wan/image2video.py | 5 +- wan/modules/motion_patch.py | 150 +++ wan/text2video.py | 48 +- wan/trajectory_editor/app.py | 209 +++ wan/trajectory_editor/templates/index.html | 571 +++++++++ wan/utils/motion.py | 74 ++ wan/utils/utils.py | 26 +- wan/utils/vace_preprocessor.py | 22 +- wgp.py | 1350 ++++++++++++-------- 18 files changed, 1945 insertions(+), 605 deletions(-) create mode 100644 wan/modules/motion_patch.py create mode 100644 wan/trajectory_editor/app.py create mode 100644 wan/trajectory_editor/templates/index.html create mode 100644 wan/utils/motion.py diff --git a/README.md b/README.md index 61bd8b0..91cc3ed 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,18 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## πŸ”₯ Latest Updates +### June 17 2025: WanGP v6.1, Vace Powercharged +Lots of improvements for Vace the Mother of all Models: +- masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask +- on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ... +- view these modified masks directly inside WanGP during the video generation to check they are really as expected +- multiple frames injections: multiples frames can be injected at any location of the video +- expand past videos in on click: just select one generated video to expand it + +Of course all these new stuff work on all Vace finetunes (including Vace Fusionix). + +Thanks also to Reevoy24 for adding a Notfication sound at the end of a generation and for fixing the background color of the current generation summary. + ### June 12 2025: WanGP v6.0 πŸ‘‹ *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them. diff --git a/docs/FINETUNES.md b/docs/FINETUNES.md index d8ed957..28823f4 100644 --- a/docs/FINETUNES.md +++ b/docs/FINETUNES.md @@ -75,10 +75,10 @@ If you launch the app with the *--save-quantized* switch, WanGP will create a qu 2) Launch WanGP *python wgp.py --save-quantized* 3) In the configuration menu *Transformer Data Type* property choose either *BF16* of *FP16* 4) Launch a video generation (settings used do not matter). As soon as the model is loaded, a new quantized model will be created in the **ckpts** subfolder if it doesn't already exist. -5) To test that this works properly set the local path in the "URLs" key of the finetune definition file. For instance *URLs = ["ckpts/finetune_quanto_fp16_int8.safetensors"]* +5) WanGP will update automatically the finetune definition file with the local path of the newly created quantized file (the list "URLs" will have an extra value such as *"ckpts/finetune_quanto_fp16_int8.safetensors"* 6) Remove *--save-quantized*, restart WanGP and select *Scaled Int8 Quantization* in the *Transformer Model Quantization* property 7) Launch a new generation and verify in the terminal window that the right quantized model is loaded -8) In order to share the finetune definition file you will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties) +8) In order to share the finetune definition file you will need to store the fine model weights in the cloud. You can upload them for instance on *Huggingface*. You can now replace in the finetune definition file the local path by a URL (on Huggingface to get the URL of the model file click *Copy download link* when accessing the model properties) You need to create a quantized model specifically for *bf16* or *fp16* as they can not converted on the fly. However there is no need for a non quantized model as they can be converted on the fly while being loaded. diff --git a/finetunes/vace_14B_fusionix.json b/finetunes/vace_14B_fusionix.json index 3d1ee49..ba35e96 100644 --- a/finetunes/vace_14B_fusionix.json +++ b/finetunes/vace_14B_fusionix.json @@ -1,17 +1,19 @@ { - "model": - { - "name": "Vace FusioniX 14B", - "architecture" : "vace_14B", - "modules" : ["vace_14B"], - "description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", - "URLs": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors" - ], - "auto_quantize": true - }, + "model": { + "name": "Vace FusioniX 14B", + "architecture": "vace_14B", + "modules": [ + "vace_14B" + ], + "description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_bf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors", + "ckpts/Wan14BT2VFusioniX_quanto_fp16_int8.safetensors" + ], + "auto_quantize": true + }, "negative_prompt": "", "prompt": "", "resolution": "832x480", diff --git a/hyvideo/hunyuan.py b/hyvideo/hunyuan.py index ef35894..83d94ea 100644 --- a/hyvideo/hunyuan.py +++ b/hyvideo/hunyuan.py @@ -315,7 +315,7 @@ class Inference(object): @classmethod - def from_pretrained(cls, model_filepath, base_model_type, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , quantizeTransformer = False, save_quantized = False, **kwargs): + def from_pretrained(cls, model_filepath, model_type, base_model_type, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , quantizeTransformer = False, save_quantized = False, **kwargs): device = "cuda" @@ -392,8 +392,8 @@ class Inference(object): # offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors") # offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True) if save_quantized: - from wan.utils.utils import save_quantized_model - save_quantized_model(model, filepath, dtype, None) + from wgp import save_quantized_model + save_quantized_model(model, model_type, filepath, dtype, None) model.mixed_precision = mixed_precision_transformer diff --git a/ltx_video/ltxv.py b/ltx_video/ltxv.py index 62bb190..7b1c6fe 100644 --- a/ltx_video/ltxv.py +++ b/ltx_video/ltxv.py @@ -287,13 +287,11 @@ class LTXV: height, width = input_video.shape[-2:] else: if image_start != None: - image_start = image_start[0] frame_width, frame_height = image_start.size height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32) conditioning_media_paths.append(image_start) conditioning_start_frames.append(0) if image_end != None: - image_end = image_end[0] conditioning_media_paths.append(image_end) conditioning_start_frames.append(frame_num-1) diff --git a/preprocessing/dwpose/util.py b/preprocessing/dwpose/util.py index 48e2ba6..f7c62b0 100644 --- a/preprocessing/dwpose/util.py +++ b/preprocessing/dwpose/util.py @@ -4,7 +4,7 @@ import math import numpy as np import matplotlib import cv2 -matplotlib.use('TkAgg') +# matplotlib.use('TkAgg') eps = 0.01 diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py index c367101..49bb649 100644 --- a/preprocessing/matanyone/app.py +++ b/preprocessing/matanyone/app.py @@ -354,21 +354,29 @@ def video_matting(video_state, end_slider, matting_type, interactive_state, mask foreground, alpha = matanyone(matanyone_processor, following_frames, template_mask*255, r_erode=erode_kernel_size, r_dilate=dilate_kernel_size) output_frames = [] foreground_mat = matting_type == "Foreground" - for frame_origin, frame_alpha in zip(following_frames, alpha): - if foreground_mat: - frame_alpha[frame_alpha > 127] = 255 - frame_alpha[frame_alpha <= 127] = 0 - else: + if not foreground_mat: + new_alpha = [] + for frame_alpha in alpha: frame_temp = frame_alpha.copy() frame_alpha[frame_temp > 127] = 0 frame_alpha[frame_temp <= 127] = 255 + new_alpha.append(frame_alpha) + alpha = new_alpha + # for frame_origin, frame_alpha in zip(following_frames, alpha): + # if foreground_mat: + # frame_alpha[frame_alpha > 127] = 255 + # frame_alpha[frame_alpha <= 127] = 0 + # else: + # frame_temp = frame_alpha.copy() + # frame_alpha[frame_temp > 127] = 0 + # frame_alpha[frame_temp <= 127] = 255 - output_frame = np.bitwise_and(frame_origin, 255-frame_alpha) - frame_grey = frame_alpha.copy() - frame_grey[frame_alpha == 255] = 127 - output_frame += frame_grey - output_frames.append(output_frame) - foreground = output_frames + # output_frame = np.bitwise_and(frame_origin, 255-frame_alpha) + # frame_grey = frame_alpha.copy() + # frame_grey[frame_alpha == 255] = 127 + # output_frame += frame_grey + # output_frames.append(output_frame) + foreground = following_frames if not os.path.exists("mask_outputs"): os.makedirs("mask_outputs") @@ -465,6 +473,7 @@ def load_unload_models(selected): global model global matanyone_model if selected: + # print("Matanyone Tab Selected") if model_loaded: model.samcontroler.sam_controler.model.to(arg_device) matanyone_model.to(arg_device) @@ -494,6 +503,7 @@ def load_unload_models(selected): matanyone_processor = InferenceCore(matanyone_model, cfg=matanyone_model.cfg) model_loaded = True else: + # print("Matanyone Tab UnSelected") import gc model.samcontroler.sam_controler.model.to("cpu") matanyone_model.to("cpu") @@ -520,7 +530,7 @@ def export_image(image_refs, image_output): def export_to_current_video_engine(model_type, foreground_video_output, alpha_video_output): gr.Info("Masked Video Input and Full Mask transferred to Current Video Engine For Inpainting") # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output - if "custom_edit" in model_type: + if "custom_edit" in model_type and False: return gr.update(), alpha_video_output else: return foreground_video_output, alpha_video_output diff --git a/requirements.txt b/requirements.txt index b4868d1..1d82bc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,7 +29,7 @@ timm segment-anything omegaconf hydra-core -librosa +librosa==0.11.0 loguru sentencepiece av diff --git a/wan/diffusion_forcing.py b/wan/diffusion_forcing.py index b1869a8..3c6f701 100644 --- a/wan/diffusion_forcing.py +++ b/wan/diffusion_forcing.py @@ -29,6 +29,7 @@ class DTT2V: checkpoint_dir, rank=0, model_filename = None, + model_type = None, base_model_type = None, save_quantized = False, text_encoder_filename = None, @@ -77,8 +78,8 @@ class DTT2V: self.model.eval().requires_grad_(False) if save_quantized: - from wan.utils.utils import save_quantized_model - save_quantized_model(self.model, model_filename[0], dtype, base_config_file) + from wgp import save_quantized_model + save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) self.scheduler = FlowUniPCMultistepScheduler() diff --git a/wan/image2video.py b/wan/image2video.py index dd11c77..51656e4 100644 --- a/wan/image2video.py +++ b/wan/image2video.py @@ -49,6 +49,7 @@ class WanI2V: config, checkpoint_dir, model_filename = None, + model_type = None, base_model_type= None, text_encoder_filename= None, quantizeTransformer = False, @@ -115,8 +116,8 @@ class WanI2V: # offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors") self.model.eval().requires_grad_(False) if save_quantized: - from wan.utils.utils import save_quantized_model - save_quantized_model(self.model, model_filename[0], dtype, base_config_file) + from wgp import save_quantized_model + save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) self.sample_neg_prompt = config.sample_neg_prompt diff --git a/wan/modules/motion_patch.py b/wan/modules/motion_patch.py new file mode 100644 index 0000000..8c970ad --- /dev/null +++ b/wan/modules/motion_patch.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union +import torch + + +# Refer to https://github.com/Angtian/VoGE/blob/main/VoGE/Utils.py +def ind_sel(target: torch.Tensor, ind: torch.Tensor, dim: int = 1): + """ + :param target: [... (can be k or 1), n > M, ...] + :param ind: [... (k), M] + :param dim: dim to apply index on + :return: sel_target [... (k), M, ...] + """ + assert ( + len(ind.shape) > dim + ), "Index must have the target dim, but get dim: %d, ind shape: %s" % (dim, str(ind.shape)) + + target = target.expand( + *tuple( + [ind.shape[k] if target.shape[k] == 1 else -1 for k in range(dim)] + + [ + -1, + ] + * (len(target.shape) - dim) + ) + ) + + ind_pad = ind + + if len(target.shape) > dim + 1: + for _ in range(len(target.shape) - (dim + 1)): + ind_pad = ind_pad.unsqueeze(-1) + ind_pad = ind_pad.expand(*(-1,) * (dim + 1), *target.shape[(dim + 1) : :]) + + return torch.gather(target, dim=dim, index=ind_pad) + + +def merge_final(vert_attr: torch.Tensor, weight: torch.Tensor, vert_assign: torch.Tensor): + """ + + :param vert_attr: [n, d] or [b, n, d] color or feature of each vertex + :param weight: [b(optional), w, h, M] weight of selected vertices + :param vert_assign: [b(optional), w, h, M] selective index + :return: + """ + target_dim = len(vert_assign.shape) - 1 + if len(vert_attr.shape) == 2: + assert vert_attr.shape[0] > vert_assign.max() + # [n, d] ind: [b(optional), w, h, M]-> [b(optional), w, h, M, d] + sel_attr = ind_sel( + vert_attr[(None,) * target_dim], vert_assign.type(torch.long), dim=target_dim + ) + else: + assert vert_attr.shape[1] > vert_assign.max() + sel_attr = ind_sel( + vert_attr[(slice(None),) + (None,)*(target_dim-1)], vert_assign.type(torch.long), dim=target_dim + ) + + # [b(optional), w, h, M] + final_attr = torch.sum(sel_attr * weight.unsqueeze(-1), dim=-2) + return final_attr + + +def patch_motion( + tracks: torch.FloatTensor, # (B, T, N, 4) + vid: torch.FloatTensor, # (C, T, H, W) + temperature: float = 220.0, + training: bool = True, + tail_dropout: float = 0.2, + vae_divide: tuple = (4, 16), + topk: int = 2, +): + with torch.no_grad(): + _, T, H, W = vid.shape + N = tracks.shape[2] + _, tracks, visible = torch.split( + tracks, [1, 2, 1], dim=-1 + ) # (B, T, N, 2) | (B, T, N, 1) + tracks_n = tracks / torch.tensor([W / min(H, W), H / min(H, W)], device=tracks.device) + tracks_n = tracks_n.clamp(-1, 1) + visible = visible.clamp(0, 1) + + if tail_dropout > 0 and training: + TT = visible.shape[1] + rrange = torch.arange(TT, device=visible.device, dtype=visible.dtype)[ + None, :, None, None + ] + rand_nn = torch.rand_like(visible[:, :1]) + rand_rr = torch.rand_like(visible[:, :1]) * (TT - 1) + visible = visible * ( + (rand_nn > tail_dropout).type_as(visible) + + (rrange < rand_rr).type_as(visible) + ).clamp(0, 1) + + xx = torch.linspace(-W / min(H, W), W / min(H, W), W) + yy = torch.linspace(-H / min(H, W), H / min(H, W), H) + + grid = torch.stack(torch.meshgrid(yy, xx, indexing="ij")[::-1], dim=-1).to( + tracks.device + ) + + tracks_pad = tracks[:, 1:] + visible_pad = visible[:, 1:] + + visible_align = visible_pad.view(T - 1, 4, *visible_pad.shape[2:]).sum(1) + tracks_align = (tracks_pad * visible_pad).view(T - 1, 4, *tracks_pad.shape[2:]).sum( + 1 + ) / (visible_align + 1e-5) + dist_ = ( + (tracks_align[:, None, None] - grid[None, :, :, None]).pow(2).sum(-1) + ) # T, H, W, N + weight = torch.exp(-dist_ * temperature) * visible_align.clamp(0, 1).view( + T - 1, 1, 1, N + ) + vert_weight, vert_index = torch.topk( + weight, k=min(topk, weight.shape[-1]), dim=-1 + ) + + grid_mode = "bilinear" + point_feature = torch.nn.functional.grid_sample( + vid[vae_divide[0]:].permute(1, 0, 2, 3)[:1], + tracks_n[:, :1].type(vid.dtype), + mode=grid_mode, + padding_mode="zeros", + align_corners=None, + ) + point_feature = point_feature.squeeze(0).squeeze(1).permute(1, 0) # N, C=16 + + out_feature = merge_final(point_feature, vert_weight, vert_index).permute(3, 0, 1, 2) # T - 1, H, W, C => C, T - 1, H, W + out_weight = vert_weight.sum(-1) # T - 1, H, W + + # out feature -> already soft weighted + mix_feature = out_feature + vid[vae_divide[0]:, 1:] * (1 - out_weight.clamp(0, 1)) + + out_feature_full = torch.cat([vid[vae_divide[0]:, :1], mix_feature], dim=1) # C, T, H, W + out_mask_full = torch.cat([torch.ones_like(out_weight[:1]), out_weight], dim=0) # T, H, W + return torch.cat([out_mask_full[None].expand(vae_divide[0], -1, -1, -1), out_feature_full], dim=0) diff --git a/wan/text2video.py b/wan/text2video.py index 7bda44e..cdfcae3 100644 --- a/wan/text2video.py +++ b/wan/text2video.py @@ -50,6 +50,7 @@ class WanT2V: checkpoint_dir, rank=0, model_filename = None, + model_type = None, base_model_type = None, text_encoder_filename = None, quantizeTransformer = False, @@ -99,9 +100,9 @@ class WanT2V: # offload.save_model(self.model, "wan2.1_text2video_14B_mbf16.safetensors", config_file_path=base_config_file) # offload.save_model(self.model, "wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) self.model.eval().requires_grad_(False) - if save_quantized: - from wan.utils.utils import save_quantized_model - save_quantized_model(self.model, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file) + if save_quantized: + from wgp import save_quantized_model + save_quantized_model(self.model, model_type, model_filename[1 if base_model_type=="fantasy" else 0], dtype, base_config_file) self.sample_neg_prompt = config.sample_neg_prompt @@ -186,7 +187,25 @@ class WanT2V: def vace_latent(self, z, m): return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = True, pre_src_video = None): + def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device): + ref_width, ref_height = ref_img.size + if (ref_height, ref_width) == image_size: + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + else: + canvas_height, canvas_width = image_size + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + white_canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] + ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + white_canvas[:, :, top:top + new_height, left:left + new_width] = ref_img + ref_img = white_canvas + return ref_img.to(device) + + def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = []): image_sizes = [] trim_video = len(keep_frames) canvas_height, canvas_width = image_size @@ -234,25 +253,18 @@ class WanT2V: src_video[i][:, k:k+1] = 0 src_mask[i][:, k:k+1] = 1 + for k, frame in enumerate(inject_frames): + if frame != None: + src_video[i][:, k:k+1] = self.fit_image_into_canvas(frame, image_size, 0, device) + src_mask[i][:, k:k+1] = 0 + + for i, ref_images in enumerate(src_ref_images): if ref_images is not None: image_size = image_sizes[i] for j, ref_img in enumerate(ref_images): if ref_img is not None: - ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) - if ref_img.shape[-2:] != image_size: - canvas_height, canvas_width = image_size - ref_height, ref_width = ref_img.shape[-2:] - white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] - scale = min(canvas_height / ref_height, canvas_width / ref_width) - new_height = int(ref_height * scale) - new_width = int(ref_width * scale) - resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) - top = (canvas_height - new_height) // 2 - left = (canvas_width - new_width) // 2 - white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image - ref_img = white_canvas - src_ref_images[i][j] = ref_img.to(device) + src_ref_images[i][j] = self.fit_image_into_canvas(ref_img, image_size, 1, device) return src_video, src_mask, src_ref_images def decode_latent(self, zs, ref_images=None, tile_size= 0 ): diff --git a/wan/trajectory_editor/app.py b/wan/trajectory_editor/app.py new file mode 100644 index 0000000..5a3962f --- /dev/null +++ b/wan/trajectory_editor/app.py @@ -0,0 +1,209 @@ +# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import matplotlib.pyplot as plt +from flask import Flask, request, jsonify, render_template +import os +import io +import numpy as np +import torch +import yaml +import matplotlib +import argparse +matplotlib.use('Agg') + +app = Flask(__name__, static_folder='static', template_folder='templates') + + +# β€”β€”β€” Arguments β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” +parser = argparse.ArgumentParser() +parser.add_argument('--save_dir', type=str, default='videos_example') +args = parser.parse_args() + + +# β€”β€”β€” Configuration β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” +BASE_DIR = args.save_dir +STATIC_BASE = os.path.join('static', BASE_DIR) +IMAGES_DIR = os.path.join(STATIC_BASE, 'images') +OVERLAY_DIR = os.path.join(STATIC_BASE, 'images_tracks') +TRACKS_DIR = os.path.join(BASE_DIR, 'tracks') +YAML_PATH = os.path.join(BASE_DIR, 'test.yaml') +IMAGES_DIR_OUT = os.path.join(BASE_DIR, 'images') + +FIXED_LENGTH = 121 +COLOR_CYCLE = ['r', 'g', 'b', 'c', 'm', 'y', 'k'] +QUANT_MULTI = 8 + +for d in (IMAGES_DIR, TRACKS_DIR, OVERLAY_DIR, IMAGES_DIR_OUT): + os.makedirs(d, exist_ok=True) + +# β€”β€”β€” Helpers β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” + + +def array_to_npz_bytes(arr, path, compressed=True, quant_multi=QUANT_MULTI): + # pack into uint16 as before + arr_q = (quant_multi * arr).astype(np.float32) + bio = io.BytesIO() + if compressed: + np.savez_compressed(bio, array=arr_q) + else: + np.savez(bio, array=arr_q) + torch.save(bio.getvalue(), path) + + +def load_existing_tracks(path): + raw = torch.load(path) + bio = io.BytesIO(raw) + with np.load(bio) as npz: + return npz['array'] + +# β€”β€”β€” Routes β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€” + + +@app.route('/') +def index(): + return render_template('index.html') + + +@app.route('/upload_image', methods=['POST']) +def upload_image(): + f = request.files['image'] + from PIL import Image + img = Image.open(f.stream) + orig_w, orig_h = img.size + + idx = len(os.listdir(IMAGES_DIR)) + 1 + ext = f.filename.rsplit('.', 1)[-1] + fname = f"{idx:02d}.{ext}" + img.save(os.path.join(IMAGES_DIR, fname)) + img.save(os.path.join(IMAGES_DIR_OUT, fname)) + + return jsonify({ + 'image_url': f"{STATIC_BASE}/images/{fname}", + 'image_id': idx, + 'ext': ext, + 'orig_width': orig_w, + 'orig_height': orig_h + }) + + +@app.route('/store_tracks', methods=['POST']) +def store_tracks(): + data = request.get_json() + image_id = data['image_id'] + ext = data['ext'] + free_tracks = data.get('tracks', []) + circ_trajs = data.get('circle_trajectories', []) + + # Debug lengths + for i, tr in enumerate(free_tracks, 1): + print(f"Freehand Track {i}: {len(tr)} points") + for i, tr in enumerate(circ_trajs, 1): + print(f"Circle/Static Traj {i}: {len(tr)} points") + + def pad_pts(tr): + """Convert list of {x,y} to (FIXED_LENGTH,1,3) array, padding/truncating.""" + pts = np.array([[p['x'], p['y'], 1] for p in tr], dtype=np.float32) + n = pts.shape[0] + if n < FIXED_LENGTH: + pad = np.zeros((FIXED_LENGTH - n, 3), dtype=np.float32) + pts = np.vstack((pts, pad)) + else: + pts = pts[:FIXED_LENGTH] + return pts.reshape(FIXED_LENGTH, 1, 3) + + arrs = [] + + # 1) Freehand tracks + for i, tr in enumerate(free_tracks): + pts = pad_pts(tr) + arrs.append(pts,) + + # 2) Circle + Static combined + for i, tr in enumerate(circ_trajs): + pts = pad_pts(tr) + + arrs.append(pts) + print(arrs) + # Nothing to save? + if not arrs: + overlay_file = f"{image_id:02d}.png" + return jsonify({ + 'status': 'ok', + 'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}" + }) + + new_tracks = np.stack(arrs, axis=0) # (T_new, FIXED_LENGTH,1,4) + + # Load existing .pth and pad old channels to 4 if needed + track_path = os.path.join(TRACKS_DIR, f"{image_id:02d}.pth") + if os.path.exists(track_path): + # shape (T_old, FIXED_LENGTH,1,3) or (...,4) + old = load_existing_tracks(track_path) + if old.ndim == 4 and old.shape[-1] == 3: + pad = np.zeros( + (old.shape[0], old.shape[1], old.shape[2], 1), dtype=np.float32) + old = np.concatenate((old, pad), axis=-1) + all_tracks = np.concatenate([old, new_tracks], axis=0) + else: + all_tracks = new_tracks + + # Save updated track file + array_to_npz_bytes(all_tracks, track_path, compressed=True) + + # Build overlay PNG + img_path = os.path.join(IMAGES_DIR, f"{image_id:02d}.{ext}") + img = plt.imread(img_path) + fig, ax = plt.subplots(figsize=(12, 8)) + ax.imshow(img) + for t in all_tracks: + coords = t[:, 0, :] # (FIXED_LENGTH,4) + ax.plot(coords[:, 0][coords[:, 2] > 0.5], coords[:, 1] + [coords[:, 2] > 0.5], marker='o', color=COLOR_CYCLE[0]) + ax.axis('off') + overlay_file = f"{image_id:02d}.png" + fig.savefig(os.path.join(OVERLAY_DIR, overlay_file), + bbox_inches='tight', pad_inches=0) + plt.close(fig) + + # Update YAML (unchanged) + entry = { + "image": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/images/{image_id:02d}.{ext}"), + "text": None, + "track": os.path.join(f"tools/trajectory_editor/{BASE_DIR}/tracks/{image_id:02d}.pth") + } + if os.path.exists(YAML_PATH): + with open(YAML_PATH) as yf: + docs = yaml.safe_load(yf) or [] + else: + docs = [] + + for e in docs: + if e.get("image", "").endswith(f"{image_id:02d}.{ext}"): + e.update(entry) + break + else: + docs.append(entry) + + with open(YAML_PATH, 'w') as yf: + yaml.dump(docs, yf, default_flow_style=False) + + return jsonify({ + 'status': 'ok', + 'overlay_url': f"{STATIC_BASE}/images_tracks/{overlay_file}" + }) + + +if __name__ == '__main__': + app.run(debug=True) diff --git a/wan/trajectory_editor/templates/index.html b/wan/trajectory_editor/templates/index.html new file mode 100644 index 0000000..2ac8c78 --- /dev/null +++ b/wan/trajectory_editor/templates/index.html @@ -0,0 +1,571 @@ + + + + + + + Track Point Editor + + + +

Track Point Editor

+ + +
+ + +
+ + + + + +
+
+ + + +
+
+ + + +
+
+ + +
+
+ +
+
+ + + +
+ +
+ + + + +
+
+ + + + + diff --git a/wan/utils/motion.py b/wan/utils/motion.py new file mode 100644 index 0000000..d9f36f6 --- /dev/null +++ b/wan/utils/motion.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024-2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import os, io +from typing import Dict, List, Optional, Tuple, Union +import numpy as np +import torch + + +def get_tracks_inference(tracks, height, width, quant_multi: Optional[int] = 8, **kwargs): + if isinstance(tracks, str): + tracks = torch.load(tracks) + + tracks_np = unzip_to_array(tracks) + + tracks = process_tracks( + tracks_np, (width, height), quant_multi=quant_multi, **kwargs + ) + + return tracks + + +def unzip_to_array( + data: bytes, key: Union[str, List[str]] = "array" +) -> Union[np.ndarray, Dict[str, np.ndarray]]: + bytes_io = io.BytesIO(data) + + if isinstance(key, str): + # Load the NPZ data from the BytesIO object + with np.load(bytes_io) as data: + return data[key] + else: + get = {} + with np.load(bytes_io) as data: + for k in key: + get[k] = data[k] + return get + + +def process_tracks(tracks_np: np.ndarray, frame_size: Tuple[int, int], quant_multi: int = 8, **kwargs): + # tracks: shape [t, h, w, 3] => samples align with 24 fps, model trained with 16 fps. + # frame_size: tuple (W, H) + + tracks = torch.from_numpy(tracks_np).float() / quant_multi + if tracks.shape[1] == 121: + tracks = torch.permute(tracks, (1, 0, 2, 3)) + tracks, visibles = tracks[..., :2], tracks[..., 2:3] + short_edge = min(*frame_size) + + tracks = tracks - torch.tensor([*frame_size]).type_as(tracks) / 2 + tracks = tracks / short_edge * 2 + + visibles = visibles * 2 - 1 + + trange = torch.linspace(-1, 1, tracks.shape[0]).view(-1, 1, 1, 1).expand(*visibles.shape) + + out_ = torch.cat([trange, tracks, visibles], dim=-1).view(121, -1, 4) + out_0 = out_[:1] + out_l = out_[1:] # 121 => 120 | 1 + out_l = torch.repeat_interleave(out_l, 2, dim=0)[1::3] # 120 => 240 => 80 + return torch.cat([out_0, out_l], dim=0) diff --git a/wan/utils/utils.py b/wan/utils/utils.py index 134b788..052b6b4 100644 --- a/wan/utils/utils.py +++ b/wan/utils/utils.py @@ -78,8 +78,15 @@ def remove_background(img, session=None): img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0) +def save_image(tensor_image, name): + import numpy as np + tensor_image = tensor_image.clone() + tensor_image= tensor_image.add_(1).mul(127.5).squeeze(1).permute(1,2,0) + Image.fromarray(tensor_image.cpu().numpy().astype(np.uint8)).save(name) def calculate_new_dimensions(canvas_height, canvas_width, height, width, fit_into_canvas, block_size = 16): + if fit_into_canvas == None: + return height, width if fit_into_canvas: scale1 = min(canvas_height / height, canvas_width / width) scale2 = min(canvas_width / height, canvas_height / width) @@ -337,22 +344,3 @@ def create_progress_hook(filename): return progress_hook(block_num, block_size, total_size, filename) return hook -def save_quantized_model(model, model_filename, dtype, config_file): - if "quanto" in model_filename: - return - from mmgp import offload - if dtype == torch.bfloat16: - model_filename = model_filename.replace("fp16", "bf16").replace("FP16", "bf16") - elif dtype == torch.float16: - model_filename = model_filename.replace("bf16", "fp16").replace("BF16", "bf16") - - for rep in ["mfp16", "fp16", "mbf16", "bf16"]: - if "_" + rep in model_filename: - model_filename = model_filename.replace("_" + rep, "_quanto_" + rep + "_int8") - break - if not "quanto" in model_filename: - pos = model_filename.rfind(".") - model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:] - - if not os.path.isfile(model_filename): - offload.save_model(model, model_filename, do_quantize= True, config_file_path=config_file) diff --git a/wan/utils/vace_preprocessor.py b/wan/utils/vace_preprocessor.py index 0722a17..7fdb8c9 100644 --- a/wan/utils/vace_preprocessor.py +++ b/wan/utils/vace_preprocessor.py @@ -196,7 +196,7 @@ class VaceVideoProcessor(object): return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps - def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0, canvas_height = 0, canvas_width = 0, fit_into_canvas= True): + def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0, start_frame= 0, canvas_height = 0, canvas_width = 0, fit_into_canvas= None): if self.keep_last: return self._get_frameid_bbox_adjust_last(fps, video_frames_count, canvas_height, canvas_width, h, w, fit_into_canvas, crop_box, rng, max_frames= max_frames, start_frame= start_frame) else: @@ -208,23 +208,23 @@ class VaceVideoProcessor(object): def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs): return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) - def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, canvas_height = 0, canvas_width = 0, fit_into_canvas = False, **kwargs): + def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, start_frame = 0, canvas_height = 0, canvas_width = 0, fit_into_canvas = None, **kwargs): rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) # read video import decord decord.bridge.set_bridge('torch') readers = [] - src_video = None + src_videos = [] for data_k in data_key_batch: if torch.is_tensor(data_k): - src_video = data_k + src_videos.append(data_k) else: reader = decord.VideoReader(data_k) readers.append(reader) - if src_video != None: + if len(src_videos) >0: fps = 16 - length = src_video.shape[0] + start_frame + length = src_videos[0].shape[0] + start_frame if len(readers) > 0: min_readers = min([len(r) for r in readers]) length = min(length, min_readers ) @@ -234,17 +234,17 @@ class VaceVideoProcessor(object): # frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)] # frame_timestamps = np.array(frame_timestamps, dtype=np.float32) max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames - if src_video != None: - src_video = src_video[:max_frames] - h, w = src_video.shape[1:3] + if len(src_videos) >0: + src_videos = [ src_video[:max_frames] for src_video in src_videos] + h, w = src_videos[0].shape[1:3] else: h, w = readers[0].next().shape[:2] frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, canvas_height = canvas_height, canvas_width = canvas_width, fit_into_canvas = fit_into_canvas, max_frames=max_frames, start_frame = start_frame ) # preprocess video videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers] - if src_video != None: - videos = [src_video] + videos + if len(src_videos) >0: + videos = src_videos + videos videos = [self._video_preprocess(video, oh, ow) for video in videos] return *videos, frame_ids, (oh, ow), fps # return videos if len(videos) > 1 else videos[0] diff --git a/wgp.py b/wgp.py index b720e3a..a4e1efa 100644 --- a/wgp.py +++ b/wgp.py @@ -46,6 +46,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.4.9" WanGP_version = "6.0" +settings_version = 2 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None from importlib.metadata import version @@ -128,6 +129,10 @@ def is_integer(n): else: return float(n).is_integer() +def compute_sliding_window_no(current_video_length, sliding_window_size, discard_last_frames, reuse_frames): + left_after_first_window = current_video_length - sliding_window_size + discard_last_frames + return 1 + math.ceil(left_after_first_window / (sliding_window_size - discard_last_frames - reuse_frames)) + def process_prompt_and_add_tasks(state, model_choice): @@ -165,7 +170,7 @@ def process_prompt_and_add_tasks(state, model_choice): model_filename = get_model_filename(get_base_model_type(model_type)) prompts = prompt.replace("\r", "").split("\n") prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] - if len(prompts) ==0: + if len(prompts) == 0: gr.Info("Prompt cannot be empty.") gen = get_gen_info(state) queue = gen.get("queue", []) @@ -174,128 +179,132 @@ def process_prompt_and_add_tasks(state, model_choice): resolution = inputs["resolution"] width, height = resolution.split("x") width, height = int(width), int(height) - # if test_class_i2v(model_type): - # if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480: - # gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P") - # return - # resolution = str(width) + "*" + str(height) - # if resolution not in ['720*1280', '1280*720', '480*832', '832*480']: - # gr.Info(f"Resolution {resolution} not supported by image 2 video") - # return + image_start = inputs["image_start"] + image_end = inputs["image_end"] + image_refs = inputs["image_refs"] + audio_guide = inputs["audio_guide"] + image_prompt_type = inputs["image_prompt_type"] + if image_prompt_type == None: image_prompt_type = "" + video_prompt_type = inputs["video_prompt_type"] + if video_prompt_type == None: video_prompt_type = "" + video_guide = inputs["video_guide"] + video_mask = inputs["video_mask"] + video_source = inputs["video_source"] + frames_positions = inputs["frames_positions"] + keep_frames_video_source = inputs["keep_frames_video_source"] + keep_frames_video_guide= inputs["keep_frames_video_guide"] + sliding_window_size = inputs["sliding_window_size"] + sliding_window_overlap = inputs["sliding_window_overlap"] + sliding_window_discard_last_frames = inputs["sliding_window_discard_last_frames"] + video_length = inputs["video_length"] + + + if "F" in video_prompt_type: + if len(frames_positions.strip()) > 0: + positions = frames_positions.split(" ") + for pos_str in positions: + if not is_integer(pos_str): + gr.Info(f"Invalid Frame Position '{pos_str}'") + return + pos = int(pos_str) + if pos <1 or pos > 1000: + gr.Info(f"Invalid Frame Position Value'{pos_str}'") + return + else: + frames_positions = None - if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ): - gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P") - return - - if "diffusion_forcing" in model_filename or "ltxv" in model_filename or "Vace" in model_filename or "hunyuan_video_custom_edit" in model_filename: - video_length = inputs["video_length"] - sliding_window_size = inputs["sliding_window_size"] - if video_length > sliding_window_size: - gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , multiple Windows will be generated") - - if "hunyuan_video_custom_edit" in model_filename: - keep_frames_video_guide= inputs["keep_frames_video_guide"] - if len(keep_frames_video_guide) > 0: - gr.Info("Filtering Frames with this model is not supported") + if len(filter_letters(image_prompt_type, "VLG")) > 0 and len(keep_frames_video_source) > 0: + if not is_integer(keep_frames_video_source) or int(keep_frames_video_source) == 0: + gr.Info("The number of frames to keep must be a non null integer") return + else: + keep_frames_video_source = "" - if "phantom" in model_filename or "hunyuan_video_custom" in model_filename or "hunyuan_video_avatar" in model_filename: - image_refs = inputs["image_refs"] - audio_guide = inputs["audio_guide"] - if image_refs == None : - gr.Info("You must provide an Image Reference") + if "V" in image_prompt_type: + if video_source == None: + gr.Info("You must provide a Source Video file to continue") return - if len(image_refs) > 1 and ("hunyuan_video_custom" in model_filename or "hunyuan_video_avatar" in model_filename): - gr.Info("Only one Image Reference (a person) is supported for the moment by Hunyuan Custom / Avatar") - return - if audio_guide == None and "hunyuan_video_avatar" in model_filename: - gr.Info("You must provide an audio file") + elif "G" in image_prompt_type: + gen = get_gen_info(state) + file_list = gen.get("file_list",[]) + choice = gen.get("selected",-1) + if choice >=0 and len(file_list)>0: + video_source = file_list[choice] + else: + gr.Info("Please Select a generated Video as a Video to continue") return + else: + video_source = None + if "I" in video_prompt_type: + if image_refs == None or len(image_refs) == 0: + gr.Info("You must provide at least one Refererence Image") + return if any(isinstance(image[0], str) for image in image_refs) : - gr.Info("Reference Image should be an Image") + gr.Info("A Reference Image should be an Image") return - if isinstance(image_refs, list): - image_refs = [ convert_image(tup[0]) for tup in image_refs ] - + image_refs = [ convert_image(tup[0]) for tup in image_refs ] + else: + image_refs = None - if len(prompts) > 0: - prompts = ["\n".join(prompts)] - - for single_prompt in prompts: - extra_inputs = { - "prompt" : single_prompt, - "image_refs": image_refs, - } - inputs.update(extra_inputs) - add_video_task(**inputs) - - elif "diffusion_forcing" in model_filename or "ltxv" in model_filename: - image_start = inputs["image_start"] - image_end = inputs["image_end"] - video_source = inputs["video_source"] - keep_frames_video_source = inputs["keep_frames_video_source"] - image_prompt_type = inputs["image_prompt_type"] - - if len(keep_frames_video_source) > 0: - if not is_integer(keep_frames_video_source): - gr.Info("The number of frames to keep must be an integer") + if "V" in video_prompt_type: + if video_guide == None: + gr.Info("You must provide a Control Video") + return + if "A" in video_prompt_type: + if video_mask == None: + gr.Info("You must provide a Video Mask") return + else: + video_mask = None - if "V" in image_prompt_type: - if video_source == None or len(video_source) == 0: - gr.Info("You must provide a Video to continue") - return - image_start = None + keep_frames_video_guide= inputs["keep_frames_video_guide"] + _, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length) + if len(error) > 0: + gr.Info(f"Invalid Keep Frames property: {error}") + return + else: + video_guide = None + video_mask = None + keep_frames_video_guide = "" + - if "S" in image_prompt_type: - if image_start == None : - gr.Info("You must provide a Start Image") - return - if len(image_start) > 1: - gr.Info("Only one Start Image is supported for the moment") - return - if isinstance(image_start[0][0], str) : - gr.Info("Start Image should be an Image") - return + if "S" in image_prompt_type: + if image_start == None or isinstance(image_start, list) and len(image_start) == 0: + gr.Info("You must provide a Start Image") + if not isinstance(image_start, list): + image_start = [image_start] + if not all( not isinstance(img[0], str) for img in image_start) : + gr.Info("Start Image should be an Image") + return + image_start = [ convert_image(tup[0]) for tup in image_start ] + else: + image_start = None - image_start = [ convert_image(tup[0]) for tup in image_start ] - video_source = None + if "E" in image_prompt_type: + if image_end == None or isinstance(image_end, list) and len(image_end) == 0: + gr.Info("You must provide an End Image") + return + if not isinstance(image_end, list): + image_end = [image_end] + if not all( not isinstance(img[0], str) for img in image_end) : + gr.Info("End Image should be an Image") + return + if len(image_start) != len(image_end): + gr.Info("The number of Start and End Images should be the same ") + return + image_end = [ convert_image(tup[0]) for tup in image_end ] + else: + image_end = None - if "E" in image_prompt_type: - if image_end == None : - gr.Info("You must provide an End Image") - return - if len(image_end) > 1: - gr.Info("Only one End Image is supported for the moment") - return - if isinstance(image_end[0][0], str) : - gr.Info("End Image should be an Image") - return - image_end = [ convert_image(tup[0]) for tup in image_end ] - video_source = None + if test_any_sliding_window(model_type): + if video_length > sliding_window_size: + no_windows = compute_sliding_window_no(video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap) + gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , {no_windows} Windows will be generated") - if "T" in image_prompt_type: - image_start = None - image_end = None - video_source = None - - if len(prompts) > 0: - prompts = ["\n".join(prompts)] - - for single_prompt in prompts: - extra_inputs = { - "prompt" : single_prompt, - "image_start" : image_start, - "image_end" : image_end, - "video_source" : video_source, - } - inputs.update(extra_inputs) - add_video_task(**inputs) - elif "recam" in model_filename: - video_source = inputs["video_source"] + if "recam" in model_filename: if video_source == None: gr.Info("You must provide a Source Video") return @@ -303,162 +312,111 @@ def process_prompt_and_add_tasks(state, model_choice): if len(frames)<81: gr.Info("Recammaster source video should be at least 81 frames once the resampling at 16 fps has been done") return - for single_prompt in prompts: - extra_inputs = { - "prompt" : single_prompt, - "video_source" : video_source, - } - inputs.update(extra_inputs) - add_video_task(**inputs) - elif "Vace" in model_filename: - video_prompt_type = inputs["video_prompt_type"] - image_refs = inputs["image_refs"] - video_guide = inputs["video_guide"] - video_mask = inputs["video_mask"] - - # if "1.3B" in model_filename : - # resolution_reformated = str(height) + "*" + str(width) - # if not resolution_reformated in VACE_SIZE_CONFIGS: - # res = (" and ").join(VACE_SIZE_CONFIGS.keys()) - # gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.") - # return - if "I" in video_prompt_type: - if image_refs == None: - gr.Info("You must provide at least one Refererence Image") - return - else: - image_refs = None - if "V" in video_prompt_type: - if video_guide == None: - gr.Info("You must provide a Control Video") - return - else: - video_guide = None - if "M" in video_prompt_type: - if video_mask == None: - gr.Info("You must provide a Video Mask ") - return - else: - video_mask = None - if "O" in video_prompt_type : - keep_frames_video_guide= inputs["keep_frames_video_guide"] - video_length = inputs["video_length"] - if len(keep_frames_video_guide) > 0: - gr.Info("Keeping Frames with Extending Video is not yet supported") - return - # gr.Info(f"Warning : you have asked to reuse all the frames of the control Video in the Alternate Video Ending it. Please make sure the number of frames of the control Video is lower than the total number of frames to generate otherwise it won't make a difference.") - # elif keep_frames >= video_length: - # gr.Info(f"The number of frames in the control Video to reuse ({keep_frames_video_guide}) in Alternate Video Ending can not be bigger than the total number of frames ({video_length}) to generate.") - # return - elif "V" in video_prompt_type: - keep_frames_video_guide= inputs["keep_frames_video_guide"] - video_length = inputs["video_length"] - _, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length) - if len(error) > 0: - gr.Info(f"Invalid Keep Frames property: {error}") - return - - if isinstance(image_refs, list): - image_refs = [ convert_image(tup[0]) for tup in image_refs ] - - if len(prompts) > 0: - prompts = ["\n".join(prompts)] - - for single_prompt in prompts: - extra_inputs = { - "prompt" : single_prompt, - "image_refs": image_refs, - "video_guide" : video_guide, - "video_mask" : video_mask , - } - inputs.update(extra_inputs) - add_video_task(**inputs) - elif test_class_i2v(model_type) : - image_prompt_type = inputs["image_prompt_type"] - - image_start = inputs["image_start"] - image_end = inputs["image_end"] - if image_start == None or isinstance(image_start, list) and len(image_start) == 0: + if "phantom" in model_filename or "hunyuan_video_custom" in model_filename or "hunyuan_video_avatar" in model_filename: + if image_refs == None : + gr.Info("You must provide an Image Reference") return - if not "E" in image_prompt_type: - image_end = None - if not isinstance(image_start, list): - image_start = [image_start] - if not all( not isinstance(img[0], str) for img in image_start) : - gr.Info("Start Image should be an Image") + if len(image_refs) > 1 and ("hunyuan_video_custom" in model_filename or "hunyuan_video_avatar" in model_filename): + gr.Info("Only one Image Reference (a person) is supported for the moment by Hunyuan Custom / Avatar") return - image_start = [ convert_image(tup[0]) for tup in image_start ] - if image_end != None: - if not isinstance(image_end , list): - image_end = [image_end] - if not all( not isinstance(img[0], str) for img in image_end) : - gr.Info("End Image should be an Image") - return - image_end = [ convert_image(tup[0]) for tup in image_end ] - if len(image_start) != len(image_end): - gr.Info("The number of start and end images should be the same ") - return - - if inputs["multi_images_gen_type"] == 0: - new_prompts = [] - new_image_start = [] - new_image_end = [] - for i in range(len(prompts) * len(image_start) ): - new_prompts.append( prompts[ i % len(prompts)] ) - new_image_start.append(image_start[i // len(prompts)] ) - if image_end != None: - new_image_end.append(image_end[i // len(prompts)] ) - prompts = new_prompts - image_start = new_image_start - if image_end != None: - image_end = new_image_end - else: - if len(prompts) >= len(image_start): - if len(prompts) % len(image_start) != 0: - raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - rep = len(prompts) // len(image_start) + if "hunyuan_custom_custom_edit" in model_filename: + if video_guide == None: + gr.Info("You must provide a Control Video") + return + + if len(keep_frames_video_guide) > 0: + gr.Info("Filtering Frames with this model is not supported") + return + + if "hunyuan_video_avatar" in model_filename and audio_guide == None: + gr.Info("You must provide an audio file") + return + + if inputs["multi_prompts_gen_type"] != 0: + if image_start != None and len(image_start) > 1: + gr.Info("Only one Start Image must be provided if multiple prompts are used for different windows") + return + + if image_end != None and len(image_end) > 1: + gr.Info("Only one End Image must be provided if multiple prompts are used for different windows") + return + + override_inputs = { + "image_start": image_start[0] if image_start !=None and len(image_start) > 0 else None, + "image_end": image_end[0] if image_end !=None and len(image_end) > 0 else None, + "image_refs": image_refs, + "audio_guide": audio_guide, + "video_guide": video_guide, + "video_mask": video_mask, + "video_source": video_source, + "frames_positions": frames_positions, + "keep_frames_video_source": keep_frames_video_source, + "keep_frames_video_guide": keep_frames_video_guide, + "image_prompt_type": image_prompt_type, + "video_prompt_type": video_prompt_type, + } + + if inputs["multi_prompts_gen_type"] == 0: + if image_start != None and len(image_start) > 0: + if inputs["multi_images_gen_type"] == 0: + new_prompts = [] new_image_start = [] new_image_end = [] - for i, _ in enumerate(prompts): - new_image_start.append(image_start[i//rep] ) + for i in range(len(prompts) * len(image_start) ): + new_prompts.append( prompts[ i % len(prompts)] ) + new_image_start.append(image_start[i // len(prompts)] ) if image_end != None: - new_image_end.append(image_end[i//rep] ) + new_image_end.append(image_end[i // len(prompts)] ) + prompts = new_prompts image_start = new_image_start if image_end != None: image_end = new_image_end - else: - if len(image_start) % len(prompts) !=0: - raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - rep = len(image_start) // len(prompts) - new_prompts = [] - for i, _ in enumerate(image_start): - new_prompts.append( prompts[ i//rep] ) - prompts = new_prompts + else: + if len(prompts) >= len(image_start): + if len(prompts) % len(image_start) != 0: + gr.Info("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") + return + rep = len(prompts) // len(image_start) + new_image_start = [] + new_image_end = [] + for i, _ in enumerate(prompts): + new_image_start.append(image_start[i//rep] ) + if image_end != None: + new_image_end.append(image_end[i//rep] ) + image_start = new_image_start + if image_end != None: + image_end = new_image_end + else: + if len(image_start) % len(prompts) !=0: + gr.Info("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") + return + rep = len(image_start) // len(prompts) + new_prompts = [] + for i, _ in enumerate(image_start): + new_prompts.append( prompts[ i//rep] ) + prompts = new_prompts + if image_end == None or len(image_end) == 0: + image_end = [None] * len(prompts) - - if image_start == None: - image_start = [None] * len(prompts) - if image_end == None: - image_end = [None] * len(prompts) - - for single_prompt, start, end in zip(prompts, image_start, image_end) : - extra_inputs = { - "prompt" : single_prompt, - "image_start": start, - "image_end" : end, - "video_source": None, - } - inputs.update(extra_inputs) - add_video_task(**inputs) + for single_prompt, start, end in zip(prompts, image_start, image_end) : + override_inputs.update({ + "prompt" : single_prompt, + "image_start": start, + "image_end" : end, + }) + inputs.update(override_inputs) + add_video_task(**inputs) + else: + for single_prompt in prompts : + override_inputs["prompt"] = single_prompt + inputs.update(override_inputs) + add_video_task(**inputs) else: - for single_prompt in prompts : - extra_inputs = { - "prompt" : single_prompt, - } - inputs.update(extra_inputs) - add_video_task(**inputs) + override_inputs["prompt"] = "\n".join(prompts) + inputs.update(override_inputs) + add_video_task(**inputs) gen = get_gen_info(state) gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0) @@ -468,21 +426,31 @@ def process_prompt_and_add_tasks(state, model_choice): def get_preview_images(inputs): inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask", "video_source"] + labels = ["Start Image", "End Image", "Video Guide", "Image Reference","Video Mask", "Video Source"] start_image_data = None + start_image_labels = [] end_image_data = None - for name in inputs_to_query: + end_image_labels = [] + for label, name in zip(labels,inputs_to_query): image= inputs.get(name, None) if image != None: - image= [image] if not isinstance(image, list) else image + image= [image] if not isinstance(image, list) else image.copy() if start_image_data == None: start_image_data = image + start_image_labels += [label] * len(image) else: - end_image_data = image - break + if end_image_data == None: + end_image_data = image + else: + end_image_data += image + end_image_labels += [label] * len(image) + if start_image_data != None and len(start_image_data) > 1 and end_image_data == None: end_image_data = start_image_data [1:] + end_image_labels = start_image_labels [1:] start_image_data = start_image_data [:1] - return start_image_data, end_image_data + start_image_labels = start_image_labels [:1] + return start_image_data, end_image_data, start_image_labels, end_image_labels def add_video_task(**inputs): global task_id @@ -492,7 +460,7 @@ def add_video_task(**inputs): task_id += 1 current_task_id = task_id - start_image_data, end_image_data = get_preview_images(inputs) + start_image_data, end_image_data, start_image_labels, end_image_labels = get_preview_images(inputs) queue.append({ "id": current_task_id, @@ -501,6 +469,8 @@ def add_video_task(**inputs): "length": inputs["video_length"], "steps": inputs["num_inference_steps"], "prompt": inputs["prompt"], + "start_image_labels": start_image_labels, + "end_image_labels": end_image_labels, "start_image_data": start_image_data, "end_image_data": end_image_data, "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None, @@ -509,9 +479,11 @@ def add_video_task(**inputs): return update_queue_data(queue) def update_task_thumbnails(task, inputs): - start_image_data, end_image_data = get_preview_images(inputs) + start_image_data, end_image_data, start_labels, end_labels = get_preview_images(inputs) task.update({ + "start_image_labels": start_labels, + "end_image_labels": end_labels, "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None, "end_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in end_image_data] if end_image_data != None else None }) @@ -650,6 +622,8 @@ def save_queue_action(state): params_copy.pop('state', None) + params_copy.pop('start_image_labels', None) + params_copy.pop('end_image_labels', None) params_copy.pop('start_image_data_base64', None) params_copy.pop('end_image_data_base64', None) params_copy.pop('start_image_data', None) @@ -813,7 +787,7 @@ def load_queue_action(filepath, state, evt:gr.EventData): print(f"[load_queue_action] Error copying video {video_filename_in_zip} to cache: {vid_e}") params.pop(key, None) - primary_preview_pil_list, secondary_preview_pil_list = get_preview_images(params) + primary_preview_pil_list, secondary_preview_pil_list, primary_preview_pil_labels, secondary_preview_pil_labels = get_preview_images(params) start_b64 = [pil_to_base64_uri(primary_preview_pil_list[0], format="jpeg", quality=70)] if isinstance(primary_preview_pil_list, list) and primary_preview_pil_list else None end_b64 = [pil_to_base64_uri(secondary_preview_pil_list[0], format="jpeg", quality=70)] if isinstance(secondary_preview_pil_list, list) and secondary_preview_pil_list else None @@ -828,6 +802,8 @@ def load_queue_action(filepath, state, evt:gr.EventData): "length": params.get('video_length'), "steps": params.get('num_inference_steps'), "prompt": params.get('prompt'), + "start_image_labels": primary_preview_pil_labels, + "end_image_labels": secondary_preview_pil_labels, "start_image_data": top_level_start_image, "end_image_data": top_level_end_image, "start_image_data_base64": start_b64, @@ -1083,17 +1059,19 @@ def get_queue_table(queue): prompt_cell = f'{truncated_prompt}' start_img_uri =item.get('start_image_data_base64') start_img_uri = start_img_uri[0] if start_img_uri !=None else None + start_img_labels =item.get('start_image_labels') end_img_uri = item.get('end_image_data_base64') end_img_uri = end_img_uri[0] if end_img_uri !=None else None + end_img_labels =item.get('end_image_labels') thumbnail_size = "50px" num_steps = item.get('steps') length = item.get('length') start_img_md = "" end_img_md = "" if start_img_uri: - start_img_md = f'Start' + start_img_md = f'
{start_img_labels[0]}{start_img_labels[0]}
' if end_img_uri: - end_img_md = f'End' + end_img_md = f'
{end_img_labels[0]}{end_img_labels[0]}
' data.append([item.get('repeats', "1"), @@ -1138,9 +1116,9 @@ def _parse_args(): description="Generate a video from a text prompt or image using Gradio") parser.add_argument( - "--quantize-transformer", + "--save-masks", action="store_true", - help="On the fly 'transformer' quantization" + help="save proprocessed masks for debugging or editing" ) parser.add_argument( @@ -1488,7 +1466,6 @@ lock_ui_compile = False force_profile_no = int(args.profile) verbose_level = int(args.verbose) -quantizeTransformer = args.quantize_transformer check_loras = args.check_loras ==1 advanced = args.advanced @@ -1548,7 +1525,7 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion "sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_int8.safetensors", "wan2.1_image2video_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_720p_14B_bf16.safetensors", "wan2.1_text2video_14B_bf16.safetensors", "wan2.1_text2video_14B_quanto_int8.safetensors", -"wan2.1_Vace_1.3B_mbf16.safetensors", "wan2.1_Vace_14B_mbf16.safetensors", "wan2.1_Vace_14B_quanto_mbf16_int8.safetensors" +"wan2.1_Vace_14B_mbf16.safetensors", "wan2.1_Vace_14B_quanto_mbf16_int8.safetensors" ]: if Path(os.path.join("ckpts" , path)).is_file(): print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.") @@ -1556,16 +1533,17 @@ for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion finetunes = {} -wan_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_mbf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_mbf16_int8.safetensors", +wan_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_mbf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_text2video_14B_quanto_mfp16_int8.safetensors", "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", - "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", - "ckpts/wan2.1_Vace_1.3B_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", - "ckpts/wan2.1_moviigen1.1_14B_mbf16.safetensors", "ckpts/wan2.1_moviigen1.1_14B_quanto_mbf16_int8.safetensors", - "ckpts/wan2_1_phantom_1.3B_mbf16.safetensors", "ckpts/wan2.1_phantom_14B_mbf16.safetensors", "ckpts/wan2.1_phantom_14B_quanto_mbf16_int8.safetensors", + "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_fp16_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_mbf16.safetensors", + "ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_mfp16_int8.safetensors", + "ckpts/wan2.1_Vace_1.3B_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_mbf16.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Vace_14B_module_quanto_mfp16_int8.safetensors", + "ckpts/wan2.1_moviigen1.1_14B_mbf16.safetensors", "ckpts/wan2.1_moviigen1.1_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_moviigen1.1_14B_quanto_mfp16_int8.safetensors", + "ckpts/wan2_1_phantom_1.3B_mbf16.safetensors", "ckpts/wan2.1_phantom_14B_mbf16.safetensors", "ckpts/wan2.1_phantom_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_phantom_14B_quanto_mfp16_int8.safetensors", ] -wan_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_mbf16.safetensors", - "ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", - "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", +wan_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_mbf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_mfp16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_mbf16.safetensors", + "ckpts/wan2.1_image2video_720p_14B_quanto_mbf16_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_mfp16_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", + "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_fp16_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_fp16_int8.safetensors", "ckpts/wan2.1_fantasy_speaking_14B_bf16.safetensors"] ltxv_choices= ["ckpts/ltxv_0.9.7_13B_dev_bf16.safetensors", "ckpts/ltxv_0.9.7_13B_dev_quanto_bf16_int8.safetensors", "ckpts/ltxv_0.9.7_13B_distilled_lora128_bf16.safetensors"] @@ -1588,7 +1566,9 @@ def get_dependent_models(model_type, quantization, dtype_policy ): return [], [] return [get_model_filename(dependent_model_type, quantization, dtype_policy)], [dependent_model_type] -model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "flf2v_720p", "vace_1.3B","vace_14B","moviigen", "phantom_1.3B", "phantom_14B", "fantasy", "fun_inp_1.3B", "fun_inp", "recam_1.3B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"] +model_types = [ "t2v_1.3B", "t2v", "i2v", "i2v_720p", "flf2v_720p", "vace_1.3B","vace_14B","moviigen", "phantom_1.3B", "phantom_14B", "fantasy", + "fun_inp_1.3B", "fun_inp", "recam_1.3B", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled", + "hunyuan", "hunyuan_i2v", "hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_avatar"] model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B", "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "vace_14B" : "Vace_14B","recam_1.3B": "recammaster_1.3B", "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B", @@ -1625,6 +1605,10 @@ def get_model_family(model_type): return "wan" +def test_any_sliding_window(model_type): + model_type = get_base_model_type(model_type) + return model_type in ["vace_1.3B","vace_14B","sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "ltxv_13B", "ltxv_13B_distilled"] + def test_class_i2v(model_type): model_type = get_base_model_type(model_type) return model_type in ["i2v", "i2v_720p", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "hunyuan_i2v" ] @@ -1746,13 +1730,6 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = ""): else: raw_filename = choices[0] - if dtype == torch.float16 and not any(x in raw_filename for x in ["fp16", "FP16"]) and model_family == "wan" and finetune_def == None: - if "quanto_int8" in raw_filename: - raw_filename = raw_filename.replace("quanto_int8", "quanto_fp16_int8") - elif "quanto_bf16_int8" in raw_filename: - raw_filename = raw_filename.replace("quanto_bf16_int8", "quanto_fp16_int8") - elif "quanto_mbf16_int8" in raw_filename: - raw_filename= raw_filename.replace("quanto_mbf16_int8", "quanto_mfp16_int8") return raw_filename def get_transformer_dtype(model_family, transformer_dtype_policy): @@ -1845,12 +1822,13 @@ def get_default_settings(model_type): "guidance_scale": 7.5, "flow_shift": 13, "resolution": "1280x720", + "video_prompt_type": "I", }) elif model_type in ("hunyuan_custom_edit"): ui_defaults.update({ "guidance_scale": 7.5, "flow_shift": 13, - "video_prompt_type": "MV", + "video_prompt_type": "MVAI", "sliding_window_size": 129, }) elif model_type in ("hunyuan_avatar"): @@ -1970,6 +1948,42 @@ if args.compile: #args.fastest or #attention_mode="xformers" # compile = "transformer" +def save_quantized_model(model, model_type, model_filename, dtype, config_file): + if "quanto" in model_filename: + return + from mmgp import offload + if dtype == torch.bfloat16: + model_filename = model_filename.replace("fp16", "bf16").replace("FP16", "bf16") + elif dtype == torch.float16: + model_filename = model_filename.replace("bf16", "fp16").replace("BF16", "bf16") + + for rep in ["mfp16", "fp16", "mbf16", "bf16"]: + if "_" + rep in model_filename: + model_filename = model_filename.replace("_" + rep, "_quanto_" + rep + "_int8") + break + if not "quanto" in model_filename: + pos = model_filename.rfind(".") + model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:] + + if not os.path.isfile(model_filename): + offload.save_model(model, model_filename, do_quantize= True, config_file_path=config_file) + print(f"New quantized file '{model_filename}' had been created for finetune Id '{model_type}'.") + finetune_def = get_model_finetune_def(model_type) + if finetune_def != None: + URLs= finetune_def["URLs"] + if not model_filename in URLs: + URLs.append(model_filename) + finetune_def = finetune_def.copy() + if "settings" in finetune_def: + saved_def = typing.OrderedDict() + saved_def["model"] = finetune_def + saved_def.update(finetune_def["settings"]) + del finetune_def["settings"] + finetune_file = os.path.join("finetunes" , model_type + ".json") + with open(finetune_file, "w", encoding="utf-8") as writer: + writer.write(json.dumps(saved_def, indent=4)) + print(f"The '{finetune_file}' definition file has been automatically updated with the local path to the new quantized model.") + def get_loras_preprocessor(transformer, model_type): preprocessor = getattr(transformer, "preprocess_loras", None) if preprocessor == None: @@ -1981,14 +1995,14 @@ def get_loras_preprocessor(transformer, model_type): return preprocessor_wrapper -def get_model_manager(model_family): - if model_family == "wan": - return None - elif model_family == "ltxv": - from ltxv import model_def - return model_def - else: - raise Exception("model family not supported") +# def get_model_manager(model_family): +# if model_family == "wan": +# return None +# elif model_family == "ltxv": +# from ltxv import model_def +# return model_def +# else: +# raise Exception("model family not supported") def get_wan_text_encoder_filename(text_encoder_quantization): text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors" @@ -2077,17 +2091,19 @@ def download_models(model_filename, model_type): model_family = get_model_family(model_type) finetune_def = get_model_finetune_def(model_type) if finetune_def != None: - if not os.path.isfile(model_filename ): + if not os.path.isfile(model_filename ): + use_url = model_filename for url in finetune_def["URLs"]: - if model_filename in url: + if os.path.basename(model_filename) in url: + use_url = url break if not url.startswith("http"): raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the finetune definition file.") try: - download_file(url, model_filename) + download_file(use_url, model_filename) except Exception as e: if os.path.isfile(model_filename): os.remove(model_filename) - raise Exception(f"URL '{url}' is invalid for Model '{model_filename}' : {str(e)}'") + raise Exception(f"URL '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") for url in finetune_def.get("preload_URLs", []): filename = "ckpts/" + url.split("/")[-1] if not os.path.isfile(filename ): @@ -2126,9 +2142,9 @@ def download_models(model_filename, model_type): ] } - else: - model_manager = get_model_manager(model_family) - model_def = model_manager.get_files_def(model_filename, text_encoder_quantization) + # else: + # model_manager = get_model_manager(model_family) + # model_def = model_manager.get_files_def(model_filename, text_encoder_quantization) process_files_def(**model_def) @@ -2221,7 +2237,7 @@ def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, spl return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset -def load_wan_model(model_filename, base_model_type, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): +def load_wan_model(model_filename, model_type, base_model_type, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): if test_class_i2v(base_model_type): cfg = WAN_CONFIGS['i2v-14B'] model_factory = wan.WanI2V @@ -2237,6 +2253,7 @@ def load_wan_model(model_filename, base_model_type, quantizeTransformer = False, config=cfg, checkpoint_dir="ckpts", model_filename=model_filename, + model_type = model_type, base_model_type=base_model_type, text_encoder_filename= get_wan_text_encoder_filename(text_encoder_quantization), quantizeTransformer = quantizeTransformer, @@ -2251,7 +2268,7 @@ def load_wan_model(model_filename, base_model_type, quantizeTransformer = False, pipe["text_encoder_2"] = wan_model.clip.model return wan_model, pipe -def load_ltxv_model(model_filename, base_model_type, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): +def load_ltxv_model(model_filename, model_type, base_model_type, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): from ltx_video.ltxv import LTXV ltxv_model = LTXV( @@ -2268,11 +2285,12 @@ def load_ltxv_model(model_filename, base_model_type, quantizeTransformer = False return ltxv_model, pipe -def load_hunyuan_model(model_filename, base_model_type = None, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): +def load_hunyuan_model(model_filename, model_type = None, base_model_type = None, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): from hyvideo.hunyuan import HunyuanVideoSampler hunyuan_model = HunyuanVideoSampler.from_pretrained( model_filepath = model_filename, + model_type = model_type, base_model_type = base_model_type, text_encoder_filepath = get_hunyuan_text_encoder_filename(text_encoder_quantization), dtype = dtype, @@ -2355,11 +2373,11 @@ def load_models(model_type): print(f"Loading Module '{filename}' ...") if model_family == "wan" : - wan_model, pipe = load_wan_model(model_file_list, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + wan_model, pipe = load_wan_model(model_file_list, model_type, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) elif model_family == "ltxv": - wan_model, pipe = load_ltxv_model(model_file_list, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + wan_model, pipe = load_ltxv_model(model_file_list, model_type, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) elif model_family == "hunyuan": - wan_model, pipe = load_hunyuan_model(model_file_list, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + wan_model, pipe = load_hunyuan_model(model_file_list, model_type, base_model_type, quantizeTransformer = quantizeTransformer, dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) else: raise Exception(f"Model '{new_transformer_filename}' not supported.") wan_model._model_file_name = new_transformer_filename @@ -2664,17 +2682,20 @@ def refresh_gallery(state): #, msg if enhanced: prompt = "Enhanced:
" + prompt list_uri = [] + list_labels = [] start_img_uri = task.get('start_image_data_base64') if start_img_uri != None: list_uri += start_img_uri + list_labels += task.get('start_image_labels') end_img_uri = task.get('end_image_data_base64') if end_img_uri != None: list_uri += end_img_uri + list_labels += task.get('end_image_labels') thumbnail_size = "100px" thumbnails = "" - for img_uri in list_uri: - thumbnails += f'Start' + for i, (img_label, img_uri) in enumerate(zip(list_labels,list_uri)): + thumbnails += f'
{img_label}{img_label}
' # Get current theme from server config current_theme = server_config.get("UI_theme", "default") @@ -2689,8 +2710,8 @@ def refresh_gallery(state): #, msg border-radius: 6px; box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); """ - - html = f"" + thumbnails + "
" + prompt + "
" + + html = f"" + thumbnails + "
" + prompt + "
" html_output = gr.HTML(html, visible= True) return gr.Gallery(selected_index=choice, value = file_list), html_output, gr.Button(visible=False), gr.Button(visible=True), gr.Row(visible=True), update_queue_data(queue), gr.Button(interactive= abort_interactive), gr.Button(visible= onemorewindow_visible) @@ -2745,114 +2766,17 @@ def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='t import decord decord.bridge.set_bridge(bridge) reader = decord.VideoReader(video_in) - fps = round(reader.get_avg_fps()) + if max_frames < 0: + max_frames = max(len(reader)/ fps * target_fps + max_frames, 0) + frame_nos = resample(fps, len(reader), max_target_frames_count= max_frames, target_fps=target_fps, start_target_frame= start_frame) frames_list = reader.get_batch(frame_nos) # print(f"frame nos: {frame_nos}") return frames_list -def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, pose_enhance = True, to_bbox = False): - if not input_video_path or not input_mask_path: - return None, None - - from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator - cfg_dict = { - "DETECTION_MODEL": "ckpts/pose/yolox_l.onnx", - "POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx", - "RESIZE_SIZE": 1024 - } - dwpose = PoseBodyFaceVideoAnnotator(cfg_dict) - - video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps) - mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps) - - if len(video) == 0 or len(mask_video) == 0: - return None, None - - frame_height, frame_width, _ = video[0].shape - - if fit_canvas : - scale1 = min(height / frame_height, width / frame_width) - scale2 = min(height / frame_width, width / frame_height) - scale = max(scale1, scale2) - else: - scale = ((height * width ) / (frame_height * frame_width))**(1/2) - - height = (int(frame_height * scale) // block_size) * block_size - width = (int(frame_width * scale) // block_size) * block_size - - num_frames = min(len(video), len(mask_video)) - - masked_frames = [] - masks = [] - for frame_idx in range(num_frames): - frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy() - mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) #.asnumpy() - frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) - mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS) - frame = np.array(frame) - mask = np.array(mask) - - if len(mask.shape) == 3 and mask.shape[2] == 3: - mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) - - if expand_scale != 0: - kernel_size = abs(expand_scale) - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) - op_expand = cv2.dilate if expand_scale > 0 else cv2.erode - mask = op_expand(mask, kernel, iterations=3) - - _, mask = cv2.threshold(mask, 127.5, 255, cv2.THRESH_BINARY) - if to_bbox and np.sum(mask == 255) > 0: - x0, y0, x1, y1 = mask_to_xyxy_box(mask) - mask = mask * 0 - mask[y0:y1, x0:x1] = 255 - - inverse_mask = mask == 0 - if pose_enhance: - pose_img = dwpose.forward([frame])[0] - masked_frame = np.where(inverse_mask[..., None], frame, pose_img) - else: - masked_frame = frame * (inverse_mask[..., None].astype(frame.dtype)) - - mask = torch.from_numpy(mask) # to be commented if save one video enabled - masked_frame = torch.from_numpy(masked_frame) # to be commented if save one video debug enabled - masks.append(mask) - masked_frames.append(masked_frame) - - - # from preprocessing.dwpose.pose import save_one_video - # save_one_video("masked_frames.mp4", masked_frames, fps=target_fps, quality=8, macro_block_size=None) - # save_one_video("masks.mp4", masks, fps=target_fps, quality=8, macro_block_size=None) - - return torch.stack(masked_frames), torch.stack(masks) - -def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size = 16): - - frames_list = get_resampled_video(video_in, start_frame, max_frames, target_fps) - - if len(frames_list) == 0: - return None - frame_height, frame_width, _ = frames_list[0].shape - - if fit_canvas : - scale1 = min(height / frame_height, width / frame_width) - scale2 = min(height / frame_width, width / frame_height) - scale = max(scale1, scale2) - else: - scale = ((height * width ) / (frame_height * frame_width))**(1/2) - - new_height = (int(frame_height * scale) // block_size) * block_size - new_width = (int(frame_width * scale) // block_size) * block_size - - processed_frames_list = [] - for frame in frames_list: - frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8)) - frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) - processed_frames_list.append(frame) - +def get_preprocessor(process_type): if process_type=="pose": from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator cfg_dict = { @@ -2873,7 +2797,164 @@ def preprocess_video(process_type, height, width, video_in, max_frames, start_fr anno_ins = GrayVideoAnnotator(cfg_dict) else: anno_ins = None + return anno_ins + +def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpainting", to_bbox = False, RGB_Mask = False, negate_mask = False, outpaint_outside_mask = False, inpaint_color = 127): + + def mask_to_xyxy_box(mask): + rows, cols = np.where(mask == 255) + xmin = min(cols) + xmax = max(cols) + 1 + ymin = min(rows) + ymax = max(rows) + 1 + xmin = max(xmin, 0) + ymin = max(ymin, 0) + xmax = min(xmax, mask.shape[1]) + ymax = min(ymax, mask.shape[0]) + box = [xmin, ymin, xmax, ymax] + box = [int(x) for x in box] + return box + if not input_video_path: + return None, None + any_mask = input_mask_path != None + pose_special = "pose" in process_type + if process_type == "pose_depth": + preproc = get_preprocessor("pose") + preproc2 = get_preprocessor("depth") + else: + preproc = get_preprocessor(process_type) + preproc2 = None + video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps) + if any_mask: + mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps) + + if len(video) == 0 or any_mask and len(mask_video) == 0: + return None, None + + + if fit_canvas != None: + frame_height, frame_width, _ = video[0].shape + + if fit_canvas : + scale1 = min(height / frame_height, width / frame_width) + scale2 = min(height / frame_width, width / frame_height) + scale = max(scale1, scale2) + else: + scale = ((height * width ) / (frame_height * frame_width))**(1/2) + + height = (int(frame_height * scale) // block_size) * block_size + width = (int(frame_width * scale) // block_size) * block_size + + if any_mask: + num_frames = min(len(video), len(mask_video)) + else: + num_frames = len(video) + masked_frames = [] + masks = [] + for frame_idx in range(num_frames): + frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy() + frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) + frame = np.array(frame) + if any_mask: + mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) #.asnumpy() + mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS) + mask = np.array(mask) + + if len(mask.shape) == 3 and mask.shape[2] == 3: + mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) + original_mask = mask.copy() + if expand_scale != 0: + kernel_size = abs(expand_scale) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + op_expand = cv2.dilate if expand_scale > 0 else cv2.erode + mask = op_expand(mask, kernel, iterations=3) + + _, mask = cv2.threshold(mask, 127.5, 255, cv2.THRESH_BINARY) + if to_bbox and np.sum(mask == 255) > 0: + x0, y0, x1, y1 = mask_to_xyxy_box(mask) + mask = mask * 0 + mask[y0:y1, x0:x1] = 255 + if negate_mask: + mask = 255 - mask + if pose_special: + original_mask = 255 - original_mask + + if preproc != None: + if pose_special and any_mask: + target_frame = np.where(original_mask[..., None], frame, 0) + else: + target_frame = frame + # Image.fromarray(target_frame).save("preprocframe.png") + processed_img = preproc.forward([target_frame])[0] + if pose_special and outpaint_outside_mask: + processed_img = np.where(processed_img ==0, inpaint_color, processed_img) + + if preproc2 != None: + processed_img2 = preproc2.forward([frame])[0] + processed_img = (processed_img.astype(np.uint16) + processed_img2.astype(np.uint16))/2 + processed_img = processed_img.astype(np.uint8) + if any_mask: + inverse_mask = mask == 0 + masked_frame = np.where(inverse_mask[..., None], inpaint_color if outpaint_outside_mask else frame, processed_img) + else: + masked_frame = processed_img + + else: + if any_mask and not outpaint_outside_mask: + masked_frame = np.where(mask[..., None], inpaint_color, frame) + else: + masked_frame = np.full_like(frame, inpaint_color) + + if any_mask : + if outpaint_outside_mask: + mask = np.full_like(mask, 255) + mask = torch.from_numpy(mask) + if RGB_Mask: + mask = mask.unsqueeze(-1).repeat(1,1,3) + masks.append(mask) + masked_frame = torch.from_numpy(masked_frame) + masked_frames.append(masked_frame) + if args.save_masks: + from preprocessing.dwpose.pose import save_one_video + saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] + save_one_video("masked_frames.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) + if any_mask: + saved_masks = [mask.cpu().numpy() for mask in masks ] + save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) + + return torch.stack(masked_frames), torch.stack(masks) if any_mask else None + +def preprocess_video(process_type, height, width, video_in, max_frames, start_frame=0, fit_canvas = None, target_fps = 16, block_size = 16): + + frames_list = get_resampled_video(video_in, start_frame, max_frames, target_fps) + + if len(frames_list) == 0: + return None + + if fit_canvas == None: + new_height = height + new_width = width + else: + frame_height, frame_width, _ = frames_list[0].shape + if fit_canvas : + scale1 = min(height / frame_height, width / frame_width) + scale2 = min(height / frame_width, width / frame_height) + scale = max(scale1, scale2) + else: + scale = ((height * width ) / (frame_height * frame_width))**(1/2) + + new_height = (int(frame_height * scale) // block_size) * block_size + new_width = (int(frame_width * scale) // block_size) * block_size + + processed_frames_list = [] + for frame in frames_list: + frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8)) + frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) + processed_frames_list.append(frame) + + anno_ins = get_preprocessor(process_type) + if anno_ins == None: np_frames = [np.array(frame) for frame in processed_frames_list] else: @@ -2948,6 +3029,7 @@ def generate_video( flow_shift, embedded_guidance_scale, repeat_generation, + multi_prompts_gen_type, multi_images_gen_type, tea_cache_setting, tea_cache_start_step_perc, @@ -2961,9 +3043,11 @@ def generate_video( keep_frames_video_source, video_prompt_type, image_refs, + frames_positions, video_guide, keep_frames_video_guide, video_mask, + mask_expand, audio_guide, sliding_window_size, sliding_window_overlap, @@ -3026,6 +3110,7 @@ def generate_video( width, height = resolution.split("x") width, height = int(width), int(height) resolution_reformated = str(height) + "*" + str(width) + default_image_size = (height, width) if slg_switch == 0: slg_layers = None @@ -3121,16 +3206,30 @@ def generate_video( latent_size = 8 if ltxv else 4 original_image_refs = image_refs + frames_to_inject = [] if image_refs != None and len(image_refs) > 0 and (hunyuan_custom or phantom or hunyuan_avatar or vace): - if hunyuan_avatar: remove_background_images_ref = 0 - if remove_background_images_ref > 0: - send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) - os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") - from wan.utils.utils import resize_and_remove_background - image_refs = resize_and_remove_background(image_refs, width, height, remove_background_images_ref, fit_into_canvas= not (vace or hunyuan_avatar) ) # no fit for vace ref images as it is done later - update_task_thumbnails(task, locals()) - send_cmd("output") + frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions !=None and len(frames_positions)> 0 else [] + frames_positions_list = frames_positions_list[:len(image_refs)] + nb_frames_positions = len(frames_positions_list) + if nb_frames_positions > 0: + frames_to_inject = [None] * (max(frames_positions_list) + 1) + for i, pos in enumerate(frames_positions_list): + frames_to_inject[pos] = image_refs[i] + if video_guide == None and video_source == None and not "L" in image_prompt_type: + from wan.utils.utils import resize_lanczos, calculate_new_dimensions + w, h = image_refs[0].size + default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) + fit_canvas = None + if len(image_refs) > nb_frames_positions: + if hunyuan_avatar: remove_background_images_ref = 0 + if remove_background_images_ref > 0: + send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) + os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") + from wan.utils.utils import resize_and_remove_background + image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref, fit_into_canvas= not (vace or hunyuan_avatar) ) # no fit for vace ref images as it is done later + update_task_thumbnails(task, locals()) + send_cmd("output") joint_pass = boost ==1 #and profile != 1 and profile != 3 # TeaCache if args.teacache > 0: @@ -3195,14 +3294,17 @@ def generate_video( initial_total_windows = 0 if (diffusion_forcing or ltxv) and source_video != None: current_video_length += sliding_window_overlap - sliding_window = (vace or diffusion_forcing or ltxv or hunyuan_custom_edit) and current_video_length > sliding_window_size - reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) if sliding_window else 0 + if vace or diffusion_forcing or ltxv or hunyuan_custom_edit: + sliding_window = current_video_length > sliding_window_size + reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) + else: + sliding_window = False + reuse_frames = 0 discard_last_frames = sliding_window_discard_last_frames default_max_frames_to_generate = current_video_length if sliding_window: - left_after_first_window = current_video_length - sliding_window_size + discard_last_frames - initial_total_windows= 1 + math.ceil(left_after_first_window / (sliding_window_size - discard_last_frames - reuse_frames)) + initial_total_windows= compute_sliding_window_no(current_video_length, sliding_window_size, discard_last_frames, reuse_frames) current_video_length = sliding_window_size else: initial_total_windows = 1 @@ -3228,6 +3330,8 @@ def generate_video( window_no = 0 extra_windows = 0 guide_start_frame = 0 + image_size = default_image_size # default frame dimensions for budget until it is change due to a resize + sample_fit_canvas = fit_canvas current_video_length = first_window_video_length gen["extra_windows"] = 0 gen["total_windows"] = 1 @@ -3289,52 +3393,66 @@ def generate_video( return_latent_slice = None if reuse_frames > 0: return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) - + refresh_preview = {} if hunyuan_custom or hunyuan_avatar: src_ref_images = image_refs elif phantom: src_ref_images = image_refs.copy() if image_refs != None else None - elif diffusion_forcing or ltxv or vace and "O" in video_prompt_type: - if vace: - video_source = video_guide - video_guide = None + elif diffusion_forcing or ltxv or vace and len(image_prompt_type) > 0: + if vace and window_no == 1 and "L" in image_prompt_type: + if len(file_list)>0: + video_source = file_list[-1] + else: + mp4_files = glob.glob(os.path.join(save_path, "*.mp4")) + video_source = max(mp4_files, key=os.path.getmtime) if mp4_files else None + from wan.utils.utils import get_video_frame + if video_source != None: refresh_preview["video_source"] = get_video_frame(video_source, 0) if video_source != None and len(video_source) > 0 and window_no == 1: keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source) - keep_frames_video_source = (keep_frames_video_source // latent_size ) * latent_size + 1 - prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= fit_canvas, target_fps = fps, block_size = 32 if ltxv else 16) + prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = 32 if ltxv else 16) prefix_video = prefix_video .permute(3, 0, 1, 2) prefix_video = prefix_video .float().div_(127.5).sub_(1.) # c, f, h, w pre_video_guide = prefix_video[:, -reuse_frames:] prefix_video_frames_count = pre_video_guide.shape[1] - if vace: - height, width = pre_video_guide.shape[-2:] + if vace and sample_fit_canvas != None: + image_size = pre_video_guide.shape[-2:] + guide_start_frame = prefix_video.shape[1] + sample_fit_canvas = None if vace: - image_refs_copy = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications + image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications video_guide_copy = video_guide video_mask_copy = video_mask - if any(process in video_prompt_type for process in ("P", "D", "G")) : + if "V" in video_prompt_type: preprocess_type = None - if "P" in video_prompt_type : + if "P" in video_prompt_type and "D" in video_prompt_type : + progress_args = [0, get_latest_status(state,"Extracting Open Pose and Depth Information")] + preprocess_type = "pose_depth" + elif "P" in video_prompt_type : progress_args = [0, get_latest_status(state,"Extracting Open Pose Information")] preprocess_type = "pose" elif "D" in video_prompt_type : progress_args = [0, get_latest_status(state,"Extracting Depth Information")] preprocess_type = "depth" - elif "G" in video_prompt_type : + elif "C" in video_prompt_type : progress_args = [0, get_latest_status(state,"Extracting Gray Level Information")] preprocess_type = "gray" - - if preprocess_type != None : - send_cmd("progress", progress_args) - video_guide_copy = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = fit_canvas, target_fps = fps) + else: + progress_args = [0, get_latest_status(state,"Creating Inpainting Mask")] + preprocess_type = "inpainting" + send_cmd("progress", progress_args) + video_guide_copy, video_mask_copy = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= current_video_length if guide_start_frame == 0 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, outpaint_outside_mask = "X" in video_prompt_type) + if video_guide_copy != None: + if sample_fit_canvas != None: + image_size = video_guide_copy.shape[-3: -1] + sample_fit_canvas = None + refresh_preview["video_guide"] = Image.fromarray(video_guide_copy[0].cpu().numpy()) + if video_mask_copy != None: + refresh_preview["video_mask"] = Image.fromarray(video_mask_copy[0].cpu().numpy()) keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, max_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") keep_frames_parsed = keep_frames_parsed[guide_start_frame: guide_start_frame + current_video_length] - - if window_no == 1: - image_size = (height, width) # default frame dimensions until it is set by video_src (if there is any) - + frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_start_frame + current_video_length] src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_copy], [video_mask_copy ], @@ -3344,8 +3462,12 @@ def generate_video( keep_frames=keep_frames_parsed, start_frame = guide_start_frame, pre_src_video = [pre_video_guide], - fit_into_canvas = fit_canvas + fit_into_canvas = sample_fit_canvas, + inject_frames= frames_to_inject_parsed, ) + if sample_fit_canvas != None: + image_size = src_video[0].shape[-2:] + sample_fit_canvas = None elif hunyuan_custom_edit: if "P" in video_prompt_type: progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")] @@ -3353,7 +3475,16 @@ def generate_video( progress_args = [0, get_latest_status(state,"Extracting Video and Mask")] send_cmd("progress", progress_args) - src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = fit_canvas, target_fps = fps, pose_enhance = "P" in video_prompt_type) + src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) + refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) + if src_mask != None: + refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy()) + if len(refresh_preview) > 0: + new_inputs= locals() + new_inputs.update(refresh_preview) + update_task_thumbnails(task, new_inputs) + send_cmd("output") + if window_no == 1: conditioning_latents_size = ( (prefix_video_frames_count-1) // latent_size) + 1 if prefix_video_frames_count > 0 else 0 else: @@ -4368,12 +4499,13 @@ def prepare_inputs_dict(target, inputs ): model_filename = state["model_filename"] model_type = state["model_type"] inputs["type"] = f"WanGP v{WanGP_version} by DeepBeepMeep - " + get_model_name(model_type) + inputs["settings_version"] = settings_version if target == "settings": return inputs model_filename = get_model_filename(get_base_model_type(model_type)) - if not test_class_i2v(model_type): + if not (test_class_i2v(model_type) or "diffusion_forcing" in model_filename or "ltxv" in model_filename or "recammaster" in model_filename or "Vace" in model_filename): inputs.pop("image_prompt_type") if not server_config.get("enhancer_enabled", 0) == 1: @@ -4383,10 +4515,13 @@ def prepare_inputs_dict(target, inputs ): inputs.pop("model_mode") if not "Vace" in model_filename and not "phantom" in model_filename and not "hunyuan_video_custom" in model_filename: - unsaved_params = ["keep_frames_video_guide", "video_prompt_type", "remove_background_images_ref"] + unsaved_params = ["keep_frames_video_guide", "video_prompt_type", "remove_background_images_ref", "mask_expand"] for k in unsaved_params: inputs.pop(k) + if not "Vace" in model_filename: + inputs.pop("frames_positions") + if not ("diffusion_forcing" in model_filename or "ltxv" in model_filename): unsaved_params = ["keep_frames_video_source"] for k in unsaved_params: @@ -4514,6 +4649,7 @@ def save_inputs( flow_shift, embedded_guidance_scale, repeat_generation, + multi_prompts_gen_type, multi_images_gen_type, tea_cache_setting, tea_cache_start_step_perc, @@ -4527,9 +4663,11 @@ def save_inputs( keep_frames_video_source, video_prompt_type, image_refs, + frames_positions, video_guide, keep_frames_video_guide, video_mask, + mask_expand, audio_guide, sliding_window_size, sliding_window_overlap, @@ -4595,7 +4733,8 @@ def download_loras(): return def refresh_image_prompt_type(state, image_prompt_type): - return gr.update(visible = "S" in image_prompt_type ), gr.update(visible = "E" in image_prompt_type ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = "V" in image_prompt_type ) + any_video_source = len(filter_letters(image_prompt_type, "VLG"))>0 + return gr.update(visible = "S" in image_prompt_type ), gr.update(visible = "E" in image_prompt_type ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source) def refresh_video_prompt_type(state, video_prompt_type): return gr.Gallery(visible = "I" in video_prompt_type), gr.Video(visible= "V" in video_prompt_type),gr.Video(visible= "M" in video_prompt_type ), gr.Text(visible= "V" in video_prompt_type) , gr.Checkbox(visible= "I" in video_prompt_type) @@ -4631,15 +4770,18 @@ def handle_celll_selection(state, evt: gr.SelectData): with lock: row_index += 1 if row_index < len(queue): - image_data_to_show = queue[row_index].get('start_image_data') + image_data_to_show = queue[row_index].get('start_image_data_base64') + names = queue[row_index].get('start_image_labels') elif col_index == end_img_col_idx: with lock: row_index += 1 if row_index < len(queue): - image_data_to_show = queue[row_index].get('end_image_data') + image_data_to_show = queue[row_index].get('end_image_data_base64') + names = queue[row_index].get('end_image_labels') if image_data_to_show: - return gr.update(), gr.update(value=image_data_to_show[0]), gr.update(visible=True) + value = get_modal_image( image_data_to_show[0], names[0]) + return gr.update(), gr.update(value=value), gr.update(visible=True) else: return gr.update(), gr.update(), gr.update(visible=False) @@ -4713,23 +4855,30 @@ def del_in_sequence(source_str, letters): def refresh_video_prompt_type_image_refs(video_prompt_type, video_prompt_type_image_refs): - # video_prompt_type = add_to_sequence(video_prompt_type, "I") if video_prompt_type_image_refs else del_in_sequence(video_prompt_type, "I") - video_prompt_type_image_refs = "I" in video_prompt_type_image_refs - video_prompt_type = add_to_sequence(video_prompt_type, "I") if video_prompt_type_image_refs else del_in_sequence(video_prompt_type, "I") - return video_prompt_type, gr.update(visible = video_prompt_type_image_refs),gr.update(visible = video_prompt_type_image_refs) + video_prompt_type = del_in_sequence(video_prompt_type, "FI") + video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs) + visible = "I" in video_prompt_type + return video_prompt_type, gr.update(visible = visible),gr.update(visible = visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs) def refresh_video_prompt_type_video_guide(video_prompt_type, video_prompt_type_video_guide): - video_prompt_type = del_in_sequence(video_prompt_type, "ODPCMV") + video_prompt_type = del_in_sequence(video_prompt_type, "DPCMV") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type - return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible= "M" in video_prompt_type ) + return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible= visible), gr.update(visible= visible and "A" in video_prompt_type ), gr.update(visible= visible and "A" in video_prompt_type ) + +def refresh_video_prompt_type_video_mask(video_prompt_type, video_prompt_type_video_mask): + video_prompt_type = del_in_sequence(video_prompt_type, "XNA") + video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask) + visible= "A" in video_prompt_type + return video_prompt_type, gr.update(visible= visible), gr.update(visible= visible ) def refresh_video_prompt_video_guide_trigger(video_prompt_type, video_prompt_type_video_guide): video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0] - video_prompt_type = del_in_sequence(video_prompt_type, "ODPCMV") + video_prompt_type = del_in_sequence(video_prompt_type, "DPCMV") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) + visible = "V" in video_prompt_type + return video_prompt_type, video_prompt_type_video_guide, gr.update(visible= visible ), gr.update(visible= visible ), gr.update(visible= visible and "A" in video_prompt_type ), gr.update(visible= visible and "A" in video_prompt_type ) - return video_prompt_type, video_prompt_type_video_guide, gr.update(visible= "V" in video_prompt_type ), gr.update(visible= "M" in video_prompt_type) , gr.update(visible= "V" in video_prompt_type ) def refresh_preview(state): gen = get_gen_info(state) @@ -4744,8 +4893,38 @@ def init_process_queue_if_any(state): else: return gr.Button(visible=True), gr.Button(visible=False), gr.Column(visible=False) +def get_modal_image(image_base64, label): + return "
" + label + "
" +def get_prompt_labels(multi_prompts_gen_type): + new_line_text = "each new line of prompt will be used for a window" if multi_prompts_gen_type != 0 else "each new line of prompt will generate a new video" + return "Prompts (" + new_line_text + ", # lines = comments, ! lines = macros)", "Prompts (" + new_line_text + ", # lines = comments)" +def refresh_prompt_labels(multi_prompts_gen_type): + prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type) + return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label) + +def show_preview_column_modal(state, column_no): + column_no = int(column_no) + if column_no == -1: + return gr.update(), gr.update(), gr.update() + gen = get_gen_info(state) + queue = gen.get("queue", []) + task = queue[0] + list_uri = [] + names = [] + start_img_uri = task.get('start_image_data_base64') + if start_img_uri != None: + list_uri += start_img_uri + names += task.get('start_image_labels') + end_img_uri = task.get('end_image_data_base64') + if end_img_uri != None: + list_uri += end_img_uri + names += task.get('end_image_labels') + + value = get_modal_image( list_uri[column_no],names[column_no] ) + + return -1, gr.update(value=value), gr.update(visible=True) def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None, main = None): global inputs_names #, advanced @@ -4802,16 +4981,18 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non try: idx = lora_filenames.index(lora_file) activated_indices.append(str(idx)) - except ValueError: + except ValueError: print(f"Warning: Lora file {lora_file} from config not found in loras directory") launch_loras = activated_indices with gr.Row(): with gr.Column(): - with gr.Column(visible=False, elem_id="image-modal-container") as modal_container: - with gr.Row(elem_id="image-modal-close-button-row"): - close_modal_button = gr.Button("❌", size="sm") - modal_image_display = gr.Image(label="Full Resolution Image", interactive=False, show_label=False) + with gr.Column(visible=False, elem_id="image-modal-container") as modal_container: + with gr.Row(elem_id="image-modal-close-button-row"): # + close_modal_button = gr.Button("❌", size="sm", scale=1) + # modal_image_display = gr.Image(label="Full Resolution Image", interactive=False, show_label=False) + modal_image_display = gr.HTML(label="Full Resolution Image") + preview_column_no = gr.Text(visible=False, value=-1, elem_id="preview_column_no") with gr.Row(visible= True): #len(loras)>0) as presets_column: lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced_ui), "")] with gr.Column(scale=6): @@ -4852,11 +5033,23 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non hunyuan_video_custom_audio = hunyuan_video_custom and "audio" in model_filename hunyuan_video_custom_edit = hunyuan_video_custom and "edit" in model_filename hunyuan_video_avatar = "hunyuan_video_avatar" in model_filename - sliding_window_enabled = vace or diffusion_forcing or ltxv or hunyuan_video_custom_edit - new_line_text = "each new line of prompt will be used for a window" if sliding_window_enabled else "each new line of prompt will generate a new video" + sliding_window_enabled = test_any_sliding_window(model_type) + multi_prompts_gen_type_value = ui_defaults.get("multi_prompts_gen_type_value",0) + prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type_value) - with gr.Column(visible= test_class_i2v(model_type) or diffusion_forcing or ltxv or recammaster) as image_prompt_column: - if diffusion_forcing or ltxv: + with gr.Column(visible= test_class_i2v(model_type) or diffusion_forcing or ltxv or recammaster or vace) as image_prompt_column: + if vace: + image_prompt_type_value= ui_defaults.get("image_prompt_type","") + image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value + image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L"),("Continue Selected Video", "G")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= True , scale= 3) + + image_start = gr.Gallery(visible = False) + image_end = gr.Gallery(visible = False) + video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None)) + model_mode = gr.Dropdown(visible = False) + keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) + + elif diffusion_forcing or ltxv: image_prompt_type_value= ui_defaults.get("image_prompt_type","S") # image_prompt_type = gr.Radio( [("Start Video with Image", "S"),("Start and End Video with Images", "SE"), ("Continue Video", "V"),("Text Prompt Only", "T")], value =image_prompt_type_value, label="Location", show_label= False, visible= True, scale= 3) image_prompt_type = gr.Radio( [("Start Video with Image", "S"),("Continue Video", "V"),("Text Prompt Only", "T")], value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3) @@ -4887,9 +5080,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" ) elif recammaster: - image_prompt_type = gr.Radio(visible= False) - image_start = gr.Gallery(visible = False) - image_end = gr.Gallery(visible=False) + image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V") + image_start = gr.Gallery(value = None, visible = False) + image_end = gr.Gallery(value = None, visible= False) video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),) model_mode = gr.Dropdown( choices=[ @@ -4910,18 +5103,22 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ) keep_frames_video_source = gr.Text(visible=False) else: - image_prompt_type_value= ui_defaults.get("image_prompt_type","S") - image_prompt_type = gr.Radio( [("Use only a Start Image", "S"),("Use both a Start and an End Image", "SE")], value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) + if test_class_i2v(model_type): + image_prompt_type_value= ui_defaults.get("image_prompt_type","S") + image_prompt_type = gr.Radio( [("Use only a Start Image", "S"),("Use both a Start and an End Image", "SE")], value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) - image_start = gr.Gallery( - label="Images as starting points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) + image_start = gr.Gallery( + label="Images as starting points for new videos", type ="pil", #file_types= "image", + columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) - image_end = gr.Gallery( - label="Images as ending points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) - - video_source = gr.Video(visible=False) + image_end = gr.Gallery( + label="Images as ending points for new videos", type ="pil", #file_types= "image", + columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) + else: + image_prompt_type = gr.Radio(choices=[("", "")], value="") + image_start = gr.Gallery(value=None) + image_end = gr.Gallery(value=None) + video_source = gr.Video(value=None, visible=False) model_mode = gr.Dropdown(value=None, visible=False) keep_frames_video_source = gr.Text(visible=False) @@ -4933,23 +5130,23 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_prompt_type_video_guide = gr.Dropdown( choices=[ ("None", ""), - ("Transfer Human Motion from the Control Video", "PV"), - ("Transfer Depth from the Control Video", "DV"), - ("Recolorize the Control Video", "CV"), - ("Extend Video", "OV"), - ("Video contains Open Pose, Depth, Black & White, Inpainting ", "V"), - ("Control Video and Mask video for Inpainting ", "MV"), + ("Transfer Human Motion", "PV"), + ("Transfer Depth", "DV"), + ("Transfer Human Motion & Depth", "DPV"), + ("Recolorize Control Video", "CV"), + ("Inpainting", "MV"), + ("Vace multi formats", "V"), ], - value=filter_letters(video_prompt_type_value, "ODPCMV"), - label="Video to Video", scale = 3, visible= True + value=filter_letters(video_prompt_type_value, "DPCMV"), + label="Control Video Process", scale = 2, visible= True ) elif hunyuan_video_custom_edit: video_prompt_type_video_guide = gr.Dropdown( choices=[ - ("Inpaint Control Video in area defined by Mask", "MV"), - ("Inpaint and Transfer Human Motion from the Control Video in area defined by Mask", "PMV"), + ("Inpaint Control Video", "MV"), + ("Transfer Human Motion", "PMV"), ], - value=filter_letters(video_prompt_type_value, "ODPCMV"), + value=filter_letters(video_prompt_type_value, "DPCMV"), label="Video to Video", scale = 3, visible= True ) else: @@ -4957,39 +5154,72 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_prompt_video_guide_trigger = gr.Text(visible=False, value="") - video_prompt_type_image_refs = gr.Dropdown( - choices=[ - ("None", ""), - ("Inject custom Faces / Objects", "I"), - ], - value="I" if "I" in video_prompt_type_value else "", - label="Reference Images", scale = 2 - ) - - # video_prompt_type_image_refs = gr.Checkbox(value="I" in video_prompt_type_value , label= "Use References Images (Faces, Objects) to customize New Video", scale =1 ) + if hunyuan_video_custom_edit: + video_prompt_type_video_mask = gr.Dropdown( + choices=[ + ("Masked Area", "A"), + ("Non Masked Area", "NA"), + ], + value= filter_letters(video_prompt_type_value, "NA"), + visible= "V" in video_prompt_type_value, + label="Area Processed", scale = 2 + ) + else: + video_prompt_type_video_mask = gr.Dropdown( + choices=[ + ("Whole Frame", ""), + ("Masked Area", "A"), + ("Non Masked Area", "NA"), + ("Masked Area, rest Outpainted", "XA"), + ("Non Masked Area, rest Inpainted", "XNA"), + ], + value= filter_letters(video_prompt_type_value, "XNA"), + visible= "V" in video_prompt_type_value and not hunyuan_video_custom, + label="Area Processed", scale = 2 + ) + if vace: + video_prompt_type_image_refs = gr.Dropdown( + choices=[ + ("None", ""), + ("Inject Landscape / People / Objects", "I"), + ("Inject Frames & People / Objects", "FI"), + ], + value=filter_letters(video_prompt_type_value, "FI"), + visible = True, + label="Reference Images", scale = 2 + ) + else: + video_prompt_type_image_refs = gr.Dropdown( + choices=[ ("Start / Ref Image", "I")], + value="I", + visible = False, + label="Start / Reference Images", scale = 2 + ) + video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),) keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= "V" in video_prompt_type_value, scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last - image_refs = gr.Gallery( label ="Reference Images", + image_refs = gr.Gallery( label ="Start Image" if hunyuan_video_avatar else "Reference Images", type ="pil", show_label= True, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, value= ui_defaults.get("image_refs", None), ) - # with gr.Row(): + frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Ref. Images)" ) remove_background_images_ref = gr.Dropdown( choices=[ ("Keep Backgrounds of All Images (landscape)", 0), - ("Remove Backgrounds of All Images (objects / faces)", 1), - ("Keep it for first Image (landscape) and remove it for other Images (objects / faces)", 2), + ("Remove Backgrounds of All Images (objects / people)", 1), + ("Keep it for first Image (landscape) and remove it for other Images (objects / people)", 2), ], value=ui_defaults.get("remove_background_images_ref",1), - label="Remove Background of Images References", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar + label="Remove Background of Images References (non positioned frames only)", scale = 3, visible= "I" in video_prompt_type_value and not hunyuan_video_avatar ) - # remove_background_images_ref = gr.Checkbox(value=ui_defaults.get("remove_background_images_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 ) + video_mask = gr.Video(label= "Video Mask Area (for Inpainting or Outpainting, white = Control Area, black = Unchanged)", visible= "A" in video_prompt_type_value, value= ui_defaults.get("video_mask", None)) + + mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "A" in video_prompt_type_value) - video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None)) audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= fantasy or hunyuan_video_avatar or hunyuan_video_custom_audio ) advanced_prompt = advanced_ui @@ -5001,7 +5231,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non default_wizard_prompt, variables, values, errors = extract_wizard_prompt(launch_prompt) advanced_prompt = len(errors) > 0 with gr.Column(visible= advanced_prompt) as prompt_column_advanced: - prompt = gr.Textbox( visible= advanced_prompt, label="Prompts (" + new_line_text + ", # lines = comments, ! lines = macros)", value=launch_prompt, lines=3) + prompt = gr.Textbox( visible= advanced_prompt, label=prompt_label, value=launch_prompt, lines=3) with gr.Column(visible=not advanced_prompt and len(variables) > 0) as prompt_column_wizard_vars: gr.Markdown("Please fill the following input fields to adapt automatically the Prompt:") @@ -5018,7 +5248,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non for _ in range( PROMPT_VARS_MAX - len(prompt_vars)): prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False)) with gr.Column(not advanced_prompt) as prompt_column_wizard: - wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (" + new_line_text + ", # lines = comments)", value=default_wizard_prompt, lines=3) + wizard_prompt = gr.Textbox(visible = not advanced_prompt, label=wizard_prompt_label, value=default_wizard_prompt, lines=3) wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False) wizard_variables_var = gr.Text(wizard_variables, visible = False) with gr.Row(visible= server_config.get("enhancer_enabled", 0) == 1 ) as prompt_enhancer_row: @@ -5237,7 +5467,17 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = True) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 8), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) - + multi_prompts_gen_type = gr.Dropdown( + choices=[ + ("Will create new generated Video", 0), + ("Will be used for a new Sliding Window of the same Video generation", 1), + ], + value=ui_defaults.get("multi_prompts_gen_type",0), + visible=True, + scale = 1, + label="Text Prompts separated by a Carriage Return" + ) + with gr.Tab("Miscellaneous", visible= not (recammaster or ltxv or diffusion_forcing)) as misc_tab: gr.Markdown("With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model") RIFLEx_setting = gr.Dropdown( @@ -5311,7 +5551,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row, speed_tab, quality_tab, sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, - video_prompt_type_video_guide, video_prompt_type_image_refs] # show_advanced presets_column, + video_prompt_type_video_guide, video_prompt_type_video_mask, video_prompt_type_image_refs] # show_advanced presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs @@ -5321,9 +5561,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non target_settings = gr.Text(value = "settings", interactive= False, visible= False) image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) - video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, video_mask, keep_frames_video_guide]) - video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref ]) - video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames_video_guide, video_mask]) + video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, video_prompt_type_video_mask, video_mask, mask_expand]) + video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, frames_positions ]) + video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, keep_frames_video_guide, video_prompt_type_video_mask, video_mask, mask_expand]) + video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [video_prompt_type, video_prompt_type_video_mask], outputs = [video_prompt_type, video_mask, mask_expand]) + multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=multi_prompts_gen_type, outputs=[prompt, wizard_prompt]) show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) @@ -5375,7 +5617,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn, onemorewindow_btn]) - + preview_column_no.input(show_preview_column_modal, inputs=[state, preview_column_no], outputs=[preview_column_no, modal_image_display, modal_container]) abort_btn.click(abort_generation, [state], [ abort_btn] ) #.then(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info, queue_df] ) onemoresample_btn.click(fn=one_more_sample,inputs=[state], outputs= [state]) onemorewindow_btn.click(fn=one_more_window,inputs=[state], outputs= [state]) @@ -5459,6 +5701,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non ).then(fn=activate_status, inputs= [state], outputs= [status_trigger], + ).then( + fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), + inputs=[state], + outputs=[queue_accordion] ).then(fn=process_tasks, inputs= [state], outputs= [preview_trigger, output_trigger], @@ -5913,6 +6159,7 @@ def select_tab(tab_state, evt:gr.SelectData): old_tab_no = tab_state.get("tab_no",0) new_tab_no = evt.index + # print(f"old tab {old_tab_no}, new tab {new_tab_no}") if old_tab_no == tab_video_mask_creator: vmc_event_handler(False) elif new_tab_no == tab_video_mask_creator: @@ -6262,6 +6509,54 @@ def create_ui(): pointer-events: none; } + .hover-image { + cursor: pointer; + position: relative; + display: inline-block; /* Important for positioning */ + } + + .hover-image .tooltip { + visibility: hidden; + opacity: 0; + position: absolute; + top: 100%; + left: 50%; + transform: translateX(-50%); + background-color: rgba(0, 0, 0, 0.8); + color: white; + padding: 4px 6px; + border-radius: 2px; + font-size: 14px; + white-space: nowrap; + pointer-events: none; + z-index: 9999; + transition: visibility 0s linear 1s, opacity 0.3s linear 1s; /* Delay both properties */ + } + + .hover-image .tooltip2 { + visibility: hidden; + opacity: 0; + position: absolute; + top: 50%; /* Center vertically with the image */ + left: 0; /* Position to the left of the image */ + transform: translateY(-50%); /* Center vertically */ + margin-left: -10px; /* Small gap to the left of image */ + background-color: rgba(0, 0, 0, 0.8); + color: white; + padding: 8px 12px; + border-radius: 4px; + font-size: 14px; + white-space: nowrap; + pointer-events: none; + z-index: 9999; + transition: visibility 0s linear 1s, opacity 0.3s linear 1s; + } + + .hover-image:hover .tooltip, .hover-image:hover .tooltip2 { + visibility: visible; + opacity: 1; + transition: visibility 0s linear 1s, opacity 0.3s linear 1s; /* 1s delay before showing */ + } """ UI_theme = server_config.get("UI_theme", "default") UI_theme = args.theme if len(args.theme) > 0 else UI_theme @@ -6270,14 +6565,31 @@ def create_ui(): else: theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md") - with gr.Blocks(css=css, theme=theme, title= "WanGP") as main: + js = """ + function() { + // Attach function to window object to make it globally accessible + window.sendColIndex = function(index) { + const input= document.querySelector('#preview_column_no textarea'); + if (input) { + input.value = index; + input.dispatchEvent(new Event("input", { bubbles: true })); + input.focus(); + input.blur(); + console.log('Events dispatched for column:', index); + } + }; + + console.log('sendColIndex function attached to window'); + } + """ + with gr.Blocks(css=css, js=js, theme=theme, title= "WanGP") as main: gr.Markdown(f"

WanGP v{WanGP_version} by DeepBeepMeep ") # (Updates)

") global model_list tab_state = gr.State({ "tab_no":0 }) with gr.Tabs(selected="video_gen", ) as main_tabs: - with gr.Tab("Video Generator", id="video_gen"): + with gr.Tab("Video Generator", id="video_gen") as video_generator_tab: with gr.Row(): if args.lock_model: gr.Markdown("

" + get_model_name(transformer_type) + "

") @@ -6292,7 +6604,7 @@ def create_ui(): ( state, loras_choices, lset_name, state, video_guide, video_mask, image_refs, video_prompt_type_video_trigger, prompt_enhancer_row ) = generate_video_tab(model_choice=model_choice, header=header, main = main) - with gr.Tab("Guides", id="info"): + with gr.Tab("Guides", id="info") as info_tab: generate_info_tab() with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: from preprocessing.matanyone import app as matanyone_app @@ -6302,12 +6614,12 @@ def create_ui(): if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(lset_name, loras_choices, state) - with gr.Tab("Configuration", id="configuration"): + with gr.Tab("Configuration", id="configuration") as configuration_tab: generate_configuration_tab(state, main, header, model_choice, prompt_enhancer_row) with gr.Tab("About"): generate_about_tab() - main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= main_tabs) + main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= main_tabs, trigger_mode="multiple") return main if __name__ == "__main__":